In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import networkx as nx
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import KFold
import random

# Load the data
df_raw = pd.read_csv('GSM2230757_human1_umifm_counts_preprocessed.csv')

# Use features starting from the fourth column
df = df_raw.iloc[:, 2:]
data = df

num_Cells, num_Features = df.shape
print(num_Cells, num_Features)

# Class labels preprocessing
labels_raw = df_raw.iloc[:, 1]
label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(labels_raw)
labels = torch.tensor(labels_encoded, dtype=torch.long)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# Aggregation Function
class RW_Aggregator(nn.Module):
    def __init__(self, in_features, out_features):
        super(RW_Aggregator, self).__init__()
        self.fc1 = nn.Linear(in_features, 256)  # 'in_features' should match the embedding size, likely 2071
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, out_features)
        self.fc_homophily = nn.Linear(out_features, out_features)
        self.fc_heterophily = nn.Linear(out_features, out_features)
        self.fc_combined = nn.Linear(out_features, out_features)  # Adjusted to simple summation

    def forward(self, embeddings, walks, edge_types):
        homophilic_embeddings = []
        heterophilic_embeddings = []
        for walk, edge_type in zip(walks, edge_types):
            if edge_type == 'homophily':
                homophilic_embeddings.append(embeddings[walk])
            else:
                heterophilic_embeddings.append(embeddings[walk])

        # Ensuring the shape is correct for linear layers
        if homophilic_embeddings:
            aggregated_homophily = torch.mean(torch.stack(homophilic_embeddings), 0)
        else:
            aggregated_homophily = torch.zeros(embeddings.shape[1], device=device)

        if heterophilic_embeddings:
            aggregated_heterophily = torch.mean(torch.stack(heterophilic_embeddings), 0)
        else:
            aggregated_heterophily = torch.zeros(embeddings.shape[1], device=device)

        # Pass through fully connected layers
        out1 = self.fc1(aggregated_homophily)
        out1 = self.relu(out1)
        out1 = self.dropout(out1)
        out_homophily = self.fc_homophily(self.fc2(out1))

        out2 = self.fc1(aggregated_heterophily)
        out2 = self.relu(out2)
        out2 = self.dropout(out2)
        out_heterophily = self.fc_heterophily(self.fc2(out2))

        # Sum the homophilic and heterophilic results
        combined_output = out_homophily + out_heterophily
        combined_output = torch.relu(self.fc_combined(combined_output))
        return combined_output
# Modified embedding_based_random_walk function to use get_label for dynamic label access
def embedding_based_random_walk(G, start, get_label, node_embeddings, length=40):
    walk = [start]
    edge_types = []

    for _ in range(length - 1):
        current = walk[-1]
        neighbors = list(G.neighbors(current))
        if not neighbors:
            break

        # Calculate cosine similarity between current node's embedding and neighbors' embeddings
        current_embedding = node_embeddings[current].unsqueeze(0)
        neighbor_embeddings = torch.stack([node_embeddings[neighbor] for neighbor in neighbors])
        transition_probs = torch.cosine_similarity(current_embedding, neighbor_embeddings, dim=1).detach().cpu().numpy()

        # Normalize transition probabilities
        total = sum(transition_probs)
        transition_probs = [prob / total for prob in transition_probs]

        # Choose next node based on transition probabilities
        next_node = random.choices(neighbors, transition_probs)[0]

        # Dynamically fetch labels for current and next_node using get_label
        edge_type = 'homophily' if get_label(current) == get_label(next_node) else 'heterophily'
        edge_types.append(edge_type)
        walk.append(next_node)

    return walk, edge_types
kf = KFold(n_splits=10, shuffle=True, random_state=42)
accuracy_scores = []
# Main loop (training and testing with correct label access for both sets)
for train_index, test_index in kf.split(np.arange(len(data))):
    # Split the data into train and test
    train_data, test_data = data.iloc[train_index], data.iloc[test_index]
    train_labels_encoded, test_labels_encoded = labels_encoded[train_index], labels_encoded[test_index]

    # Create cosine similarity matrices and adjacency matrices for the training data
    similarity_matrix_train = cosine_similarity(train_data)
    threshold = 0.4
    adjacency_matrix_train = (similarity_matrix_train > threshold).astype(int)
    np.fill_diagonal(adjacency_matrix_train, 0)

    # Create the training graph
    G_train = nx.Graph()
    num_Train = len(train_data)
    for i in range(num_Train):
        G_train.add_node(i, features=train_data.iloc[i, :])
        for j in range(i + 1, num_Train):
            if adjacency_matrix_train[i][j] == 1:
                G_train.add_edge(i, j)

    # Initialize embeddings for the training nodes
    embeddings = torch.randn(num_Cells, num_Features, device=device, requires_grad=True)

    # Initialize Aggregator and Optimizer
    aggregator = RW_Aggregator(in_features=num_Features, out_features=len(set(train_labels_encoded))).to(device)
    optimizer = optim.Adam([embeddings] + list(aggregator.parameters()), lr=0.01)
    loss_function = nn.CrossEntropyLoss()

    # Training Loop
    for epoch in range(150):
        nodes = random.sample(list(G_train.nodes), k=100)
        optimizer.zero_grad()
        losses = []

        for node in nodes:
            # Use lambda to provide training labels during training
            walk, edge_types = embedding_based_random_walk(G_train, node, lambda node_id: train_labels_encoded[node_id], embeddings)
            aggregated_embedding = aggregator(embeddings, walk, edge_types)
            output = aggregated_embedding.unsqueeze(0)
            loss = loss_function(output, torch.tensor([train_labels_encoded[node]], dtype=torch.long, device=device))
            losses.append(loss)

        total_loss = torch.mean(torch.stack(losses))
        total_loss.backward()
        optimizer.step()

    ### TEST PHASE: ADD UNSEEN NODES TO THE TRAINING GRAPH ###

    # Compute similarity between test nodes and training nodes
    similarity_matrix_test_train = cosine_similarity(test_data, train_data)

    # Add unseen test nodes to the training graph
    num_Test = len(test_data)
    for i in range(num_Test):
        test_node_id = num_Train + i  # New node ID for the test set, ensuring no overlap with training nodes
        G_train.add_node(test_node_id, features=test_data.iloc[i, :])  # Add test node

        # Create edges between the test node and the training nodes based on similarity
        for j in range(num_Train):  # Connect to existing nodes from the training set
            if similarity_matrix_test_train[i][j] > threshold:  # Use the similarity between test and train nodes
                G_train.add_edge(test_node_id, j)

    # Model Evaluation on Test Set (after adding unseen nodes to the training graph)
    total_correct_test = 0
    test_labels = torch.tensor(test_labels_encoded, dtype=torch.long, device=device)
    for i in range(num_Test):
        test_node_id = num_Train + i  # ID of the test node in the extended graph

        # Use the correct label for both training and test nodes
        def get_label(node_id):
            if node_id >= num_Train:
                return test_labels[node_id - num_Train]  # Use test labels for test nodes
            else:
                return train_labels_encoded[node_id]  # Use training labels for training nodes

        # Call get_label function dynamically in the walk
        walk, edge_types = embedding_based_random_walk(G_train, test_node_id, get_label, embeddings, length=40)
        aggregated_embedding = aggregator(embeddings, walk, edge_types)
        output = torch.argmax(aggregated_embedding)
        if output.item() == test_labels[i].item():
            total_correct_test += 1

    accuracy_test = total_correct_test / num_Test
    accuracy_scores.append(accuracy_test)
    print(f"Fold Test Accuracy: {accuracy_test}")

# After all folds, calculate the average test accuracy
average_accuracy = sum(accuracy_scores) / len(accuracy_scores)
print("Average Test Accuracy:", average_accuracy)