In [None]:
import torch
import numpy as np
import os
import networkx as nx
from torch_geometric.utils import from_networkx
from torch_geometric.data import HeteroData
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, Linear
from torch_geometric.nn import GATv2Conv
from sklearn.metrics import roc_auc_score
import pickle
import pandas as pd
from sklearn.model_selection import KFold

os.chdir("PATH_TO_PROJECT")

In [None]:
with open('PATH_TO_PROCESSED_DATASET', 'rb') as f:
    data = pickle.load(f)

df_selected = data["df_selected"]
df_feature = data["df_feature"]
tcr_features = data["tcr_features"][:, :-2]
tcr_similarity = data["tcr_similarity"]
epitope_features = data["epitope_features"]
epitope_similarity = data["epitope_similarity"]
associations = data["associations"]

TCR_list = df_feature["TCR_list"]
epitope_list = df_feature["Epitope"]
num_tcr = len(TCR_list)
num_epitope = len(df_feature["Epitope"])

In [None]:
# Initialize a NetworkX graph
G = nx.Graph()

# Add TCR nodes with features
for i in range(num_tcr):
    G.add_node(f'TCR_{i}', type='TCR', x=tcr_features[i])

# Add Epitope nodes with features
for i in range(num_epitope):
    G.add_node(f'Epitope_{i}', type='Epitope', x=epitope_features[i])

N = 10
for i, row in enumerate(tcr_similarity):
    top_indices = np.argsort(row)[-N:]
    for j in top_indices:
        G.add_edge(f'TCR_{i}', f'TCR_{j}', edge_type='TCR-TCR', weight=tcr_similarity[i, j])

N = 10
for i, row in enumerate(epitope_similarity):
    top_indices = np.argsort(row)[-N:]
    for j in top_indices:
        G.add_edge(f'Epitope_{i}', f'Epitope_{j}', edge_type='Epitope-Epitope', weight=epitope_similarity[i, j])

# Add TCR ↔ Epitope edges based on associations
for i in range(int(num_tcr/2)):
    for j in range(num_epitope):
        if associations[i, j] == 1:  # If associated, add an edge
            G.add_edge(f'TCR_{i}', f'Epitope_{j}', edge_type='TCR-Epitope')

print(nx.is_connected(G))

In [None]:
# Initialize HeteroData
data = HeteroData()

data['TCR'].x = torch.tensor(tcr_features, dtype=torch.float)
data['Epitope'].x = torch.tensor(epitope_features, dtype=torch.float)
tcr_edges = np.array([(int(u.split('_')[1]), int(v.split('_')[1]))
                      for u, v, d in G.edges(data=True) if d['edge_type'] == 'TCR-TCR'])
data['TCR', 'connected_to', 'TCR'].edge_index = torch.tensor(tcr_edges.T, dtype=torch.long)

epitope_edges = np.array([(int(u.split('_')[1]), int(v.split('_')[1])) 
                          for u, v, d in G.edges(data=True) if d['edge_type'] == 'Epitope-Epitope'])
data['Epitope', 'connected_to', 'Epitope'].edge_index = torch.tensor(epitope_edges.T, dtype=torch.long)

tcr_epitope_edges = np.array([(int(u.split('_')[1]), int(v.split('_')[1])) 
                              for u, v, d in G.edges(data=True) if d['edge_type'] == 'TCR-Epitope'])
data['TCR', 'associated_with', 'Epitope'].edge_index = torch.tensor(tcr_epitope_edges.T, dtype=torch.long)

In [None]:
class HeteroGAT(torch.nn.Module):
    def __init__(self, in_channels_tcr, in_channels_epitope, hidden_channels, out_channels):
        super().__init__()
        self.tcr_conv1 = GATv2Conv(in_channels_tcr, hidden_channels, heads=1, add_self_loops=False)
        self.tcr_conv2 = GATv2Conv(hidden_channels, out_channels, heads=1, add_self_loops=False)
        self.epitope_conv1 = GATv2Conv(in_channels_epitope, hidden_channels, heads=1, add_self_loops=False)
        self.epitope_conv2 = GATv2Conv(hidden_channels, out_channels, heads=1, add_self_loops=False)

        self.edge_classifier = torch.nn.Linear(2 * out_channels, 1) 


    def forward(self, data, edge_index):
        tcr_x = F.relu(self.tcr_conv1(data['TCR'].x, data['TCR', 'connected_to', 'TCR'].edge_index))
        tcr_x = self.tcr_conv2(tcr_x, data['TCR', 'connected_to', 'TCR'].edge_index)
        
        epitope_x = F.relu(self.epitope_conv1(data['Epitope'].x, data['Epitope', 'connected_to', 'Epitope'].edge_index))
        epitope_x = self.epitope_conv2(epitope_x, data['Epitope', 'connected_to', 'Epitope'].edge_index)

        src_nodes, dst_nodes = edge_index
        edge_features = torch.cat([tcr_x[src_nodes], epitope_x[dst_nodes]], dim=1)
        edge_logits = self.edge_classifier(edge_features).squeeze(-1)
        
        return edge_logits

In [None]:
class ComplexHeteroGAT(torch.nn.Module):
    def __init__(self, in_channels_tcr, in_channels_epitope, hidden_channels, out_channels):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=0.3)
        # Increased complexity
        self.tcr_conv1 = GATv2Conv(in_channels_tcr, hidden_channels, heads=4, add_self_loops=False)
        self.tcr_conv2 = GATv2Conv(hidden_channels * 4, hidden_channels, heads=4, add_self_loops=False)
        self.tcr_conv3 = GATv2Conv(hidden_channels * 4, out_channels, heads=4, add_self_loops=False)

        self.epitope_conv1 = GATv2Conv(in_channels_epitope, hidden_channels, heads=4, add_self_loops=False)
        self.epitope_conv2 = GATv2Conv(hidden_channels * 4, hidden_channels, heads=4, add_self_loops=False)
        self.epitope_conv3 = GATv2Conv(hidden_channels * 4, out_channels, heads=4, add_self_loops=False)
        
        self.edge_classifier = torch.nn.Linear(2 * out_channels * 4, 1)

    def forward(self, data, edge_index):
        # Process TCR nodes
        tcr_x = self.tcr_conv1(data['TCR'].x, data['TCR', 'connected_to', 'TCR'].edge_index)
        tcr_x = F.relu(tcr_x)
        tcr_x = self.dropout(tcr_x)
        
        tcr_x = self.tcr_conv2(tcr_x, data['TCR', 'connected_to', 'TCR'].edge_index)
        tcr_x = F.relu(tcr_x) + tcr_x  # Residual connection
        tcr_x = self.dropout(tcr_x)
        
        tcr_x = self.tcr_conv3(tcr_x, data['TCR', 'connected_to', 'TCR'].edge_index)
        
        # Process Epitope nodes
        epitope_x = self.epitope_conv1(data['Epitope'].x, data['Epitope', 'connected_to', 'Epitope'].edge_index)
        epitope_x = F.relu(epitope_x)
        epitope_x = self.dropout(epitope_x)

        epitope_x = self.epitope_conv2(epitope_x, data['Epitope', 'connected_to', 'Epitope'].edge_index)
        epitope_x = F.relu(epitope_x) + epitope_x  # Residual connection
        epitope_x = self.dropout(epitope_x)

        epitope_x = self.epitope_conv3(epitope_x, data['Epitope', 'connected_to', 'Epitope'].edge_index)

        # Associations
        src_nodes, dst_nodes = edge_index
        edge_features = torch.cat((tcr_x[src_nodes], epitope_x[dst_nodes]), dim=1)
        
        # Predict
        edge_logits = self.edge_classifier(edge_features).squeeze(-1)
        return edge_logits


In [None]:
edge_index = data['TCR', 'associated_with', 'Epitope'].edge_index

def generate_negative_edge_index(df_selected, epitope_list):
    edge_start, edge_end = [], []
    for i in range(int(df_selected.shape[0] / 2), df_selected.shape[0]):
        edge_start.append(i)
        epitope = df_selected["Epitope"].iloc[i]
        edge_end.append(epitope_list.index(epitope))
    return torch.tensor([edge_start, edge_end], dtype=torch.long)

negative_edge_index = generate_negative_edge_index(df_selected, epitope_list)

# Define indices for positive and negative edges
num_pos_edges = edge_index.size(1)
num_neg_edges = negative_edge_index.size(1)

# Perform k-fold cross-validation
k_folds = 10
kf_pos = KFold(n_splits=k_folds, shuffle=True)
kf_neg = KFold(n_splits=k_folds, shuffle=True)

auc_scores = []
all_test_targets = []
all_pred_probs = []

for (train_pos_idx, test_pos_idx), (train_neg_idx, test_neg_idx) in zip(kf_pos.split(range(num_pos_edges)), kf_neg.split(range(num_neg_edges))):
    train_pos_edge_index = edge_index[:, train_pos_idx]
    test_pos_edge_index = edge_index[:, test_pos_idx]

    train_neg_edge_index = negative_edge_index[:, train_neg_idx]
    test_neg_edge_index = negative_edge_index[:, test_neg_idx]

    train_edge_index = torch.cat([train_pos_edge_index, train_neg_edge_index], dim=1)
    test_edge_index = torch.cat([test_pos_edge_index, test_neg_edge_index], dim=1)

    train_target = torch.cat([
        torch.ones(train_pos_edge_index.size(1)), 
        torch.zeros(train_neg_edge_index.size(1))
    ], dim=0)

    test_target = torch.cat([
        torch.ones(test_pos_edge_index.size(1)),
        torch.zeros(test_neg_edge_index.size(1))
    ], dim=0)

    # Initialize model for each fold
    model = ComplexHeteroGAT(tcr_features.shape[1], 1024, 32, 16)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # Train the model
    for epoch in range(100):
        model.train()
        optimizer.zero_grad()

        edge_logits = model(data, train_edge_index)
        loss = F.binary_cross_entropy_with_logits(edge_logits, train_target.to(edge_logits.device))
        loss.backward()
        optimizer.step()

    # Evaluate the model
    model.eval()
    test_output = model(data, test_edge_index)
    test_loss = F.binary_cross_entropy_with_logits(test_output, test_target.to(test_output.device))
    pred_probs = torch.sigmoid(test_output).detach().cpu().numpy()

    # Calculate AUC
    auc_score = roc_auc_score(test_target.cpu().numpy(), pred_probs)
    auc_scores.append(auc_score)

    all_pred_probs.append(pred_probs)
    all_test_targets.append(test_target.cpu().numpy())
    # print(f"Fold AUC: {auc_score:.4f}")

average_auc = sum(auc_scores) / len(auc_scores)
# print(f"Average AUC: {average_auc:.4f}")

In [None]:
concatenated_test_targets = np.concatenate(all_test_targets)
concatenated_pred_probs = np.concatenate(all_pred_probs)

np.save('results/covid/covid_targets.npy', concatenated_test_targets)
np.save('results/covid/covid_pred_probs.npy', concatenated_pred_probs)