In [1]:
from sklearn.model_selection import KFold
import sys
import pandas as pd
import ast
import pickle
import random
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
import scanpy as sc
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import networkx as nx
from torch_geometric.utils import from_networkx
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, precision_recall_curve
from dotenv import find_dotenv, load_dotenv

load_dotenv(find_dotenv())

sys.path.append('../src/null-effect-net')
import utils
import models
import dataset
import train_utils


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_utils.set_seed(42)

id_map_df = pd.read_csv('../data/id_mappings/gene_ref.tsv', sep='\t')

with open('../data/embeddings.pkl', 'rb') as f:
    node_features_df = pickle.load(f)

node_features_df['Concat Embedding'] = node_features_df['PINNACLE Embedding'] + node_features_df['SubCell Embedding'] + node_features_df['ESM Embedding']

train_df = pd.read_csv('../data/perturbation_screens/e_distance/train.csv', index_col=0)
test_df = pd.read_csv('../data/perturbation_screens/e_distance/test.csv', index_col=0)

active_nodes_df = pd.read_csv('../data/expression_reference/expression_reference.csv', index_col=0)

G = nx.read_edgelist('../data/networks/global_ppi_edgelist.txt')

ensembl_to_node = dict(zip(id_map_df['Ensembl gene ID'], id_map_df['Approved symbol']))
node_to_ensembl = dict(zip(id_map_df['Approved symbol'], id_map_df['Ensembl gene ID']))

G = nx.relabel_nodes(G, node_to_ensembl)

# Get set of nodes that have features
valid_nodes = set(node_features_df['Ensembl ID'])

# Remove nodes from G that are not in valid_nodes
G.remove_nodes_from([n for n in list(G.nodes) if n not in valid_nodes])

node_to_idx = {node: idx for idx, node in enumerate(G.nodes())}

train_df = train_df[train_df['Target'].isin(G.nodes())]
test_df = test_df[test_df['Target'].isin(G.nodes())]

# Convert to edge_index format
data = from_networkx(G)

In [3]:
num_epochs=20

In [4]:
node_features_df['Concatenated Embedding'] = node_features_df['PINNACLE Embedding'] + node_features_df['SubCell Embedding'] + node_features_df['ESM Embedding']

## GNN Attention

In [None]:
node_features_df.drop(['Concat Embedding', 'Concatenated Embedding'], axis=1, inplace=True)
node_features_df

In [6]:
def pad_vector(v, max_len):
    padded = np.zeros(max_len, dtype=np.float32)
    padded[:len(v)] = v
    return padded

def make_padded_set(row, max_len):
    return np.stack([
        pad_vector(row['ESM Embedding'], max_len),
        pad_vector(row['SubCell Embedding'], max_len),
        pad_vector(row['PINNACLE Embedding'], max_len)
    ])  # shape: (3, max_len)

all_lengths = [
    len(vec)
    for _, row in node_features_df.iterrows()
    for vec in [row['ESM Embedding'], row['SubCell Embedding'], row['PINNACLE Embedding']]
]
max_len = max(all_lengths)

node_features_df['Set Embedding'] = node_features_df.apply(
    lambda row: make_padded_set(row, max_len),
    axis=1
)

data = from_networkx(G)

set_tensor = torch.stack([
    torch.tensor(node_features_df.set_index('Ensembl ID').loc[idx]['Set Embedding'])
    for idx in G.nodes()
])  # shape: (num_nodes, 3, max_len)

# Mask where entries are non-zero (i.e., not padded)
mask_tensor = (set_tensor != 0).any(dim=-1).to(torch.float32)  # shape: (num_nodes, 3)

data.set_features = set_tensor.to(torch.float32)
data.set_mask = mask_tensor

In [7]:
# Build datasets
train_dataset = dataset.GNNDataset(
    train_df,
    active_nodes_df,
    node_features_df,
    node_to_idx,
    device=device
)
val_dataset = dataset.GNNDataset(
    test_df,
    active_nodes_df,
    node_features_df,
    node_to_idx,
    device=device
)

# Build dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=dataset.collate_function_gnn)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=dataset.collate_function_gnn)

# Initialize model and optimizer fresh for each fold
input_dim = len(node_features_df['Set Embedding'][0][0])

In [None]:
model = models.GNNAttentionClassifier(
    input_dim,       
    pool_hidden_dim=512,
    pool_out_dim=128,
    gcn_hidden_dim=128,
    gcn_out_dim=32, 
    neg_weight=3.0,
    only_active=False,
    return_attn=False
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)

best_val_auc = 0.0

for epoch in range(1, num_epochs + 1):
    print(f"\n----- Epoch {epoch}/{num_epochs} -----")
    train_metrics = train_utils.train_one_epoch_gnn(model, data, train_loader, optimizer, device)
    val_metrics, y_true_val, y_pred_val, y_prob_val = train_utils.evaluate_gnn(model, data, val_loader, device)

    if val_metrics['auc'] > best_val_auc:
        best_val_auc = val_metrics['auc']
        best_precision = val_metrics['precision']
        best_recall = val_metrics['recall']
        best_f1 = val_metrics['f1']
        train_utils.save_model(model, '../models', experiment=f'attention_analysis')
        print(f"New best model saved with AUC: {best_val_auc:.4f}")

In [None]:
model = models.GNNAttentionClassifier(
    input_dim,       
    pool_hidden_dim=256,
    pool_out_dim=128,
    gcn_hidden_dim=256,
    gcn_out_dim=128, 
    neg_weight=3.0,
    only_active=False
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)

best_val_auc = 0.0

for epoch in range(1, num_epochs + 1):
    print(f"\n----- Epoch {epoch}/{num_epochs} -----")
    train_metrics = train_utils.train_one_epoch_gnn(model, data, train_loader, optimizer, device)
    val_metrics, y_true_val, y_pred_val, y_prob_val = train_utils.evaluate_gnn(model, data, val_loader, device)

    if val_metrics['auc'] > best_val_auc:
        best_val_auc = val_metrics['auc']
        best_precision = val_metrics['precision']
        best_recall = val_metrics['recall']
        best_f1 = val_metrics['f1']
        train_utils.save_model(model, '../models', experiment=f'attention_analysis_second_hyperparameters')
        print(f"New best model saved with AUC: {best_val_auc:.4f}")

In [None]:
metrics, y_true, y_pred, y_prob, attn_data = train_utils.evaluate_gnn_with_attention(model, data, val_loader, device)

In [50]:
attn_weights = attn_data["attn_weights"]       # shape: (N, set_len)
query_nodes = attn_data["query_nodes"]         # shape: (N,)
set_masks = attn_data["set_masks"]             # shape: (N, set_len)

In [51]:
attn_weights = attn_weights[0:9305]

In [52]:
set_masks = set_masks[0:9305]

In [None]:
input_dim = len(node_features_df['Set Embedding'][0][0])
model = train_utils.load_model(
    models.GNNAttentionClassifier, 
    input_dim=input_dim, 
    pool_hidden_dim=256,
    pool_out_dim=128,
    gcn_hidden_dim=256,
    gcn_out_dim=128, 
    neg_weight=3.0,
    only_active=False, 
    path='../models/attention_analysis_second_hyperparameters/GNNAttentionClassifier/checkpoint_05_03-22_23.pt')
model

In [None]:
df = pd.DataFrame(attn_weights.numpy(), columns=["ESM", "SubCell", "PINNACLE"])

# Optional: Normalize each row to sum to 1 (if not already softmaxed)
#df = df.div(df.sum(axis=1), axis=0)

# Plot clustermap
sns.clustermap(df, metric="euclidean", method="ward", cmap="viridis", standard_scale=1)

plt.title("Node Attention Clustermap")
plt.savefig('../figures/attention_analysis/node_attention_clustermap.svg', dpi=400)
plt.show()

In [None]:
df

In [None]:
idx_to_node = {idx: node for idx, node in zip(node_to_idx.values(), node_to_idx.keys())}
df['Ensembl ID'] = df.index.to_series().map(idx_to_node)
df['Symbol'] = df['Ensembl ID'].map(ensembl_to_node)
df

In [77]:
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.cluster import DBSCAN

In [None]:
scaler = StandardScaler()
df_scaled = scaler.fit_transform(df[['ESM', 'SubCell', 'PINNACLE']])

# Apply K-means clustering
kmeans = KMeans(n_clusters=5, random_state=42)  # Specify the number of clusters
df['KMeans_Cluster'] = kmeans.fit_predict(df_scaled)

df

In [None]:
df[df['KMeans_Cluster'] == 4]

In [None]:
for sym in list(df[df['KMeans_Cluster'] == 4]['Symbol']):
    print(sym)