In [None]:
!pip install rdkit torch torch_geometric seaborn

In [None]:
import os
import csv
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
import time
import traceback
import matplotlib.pyplot as plt
from rdkit import Chem
import math
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import GATConv
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
from torch.optim.lr_scheduler import MultiplicativeLR
import seaborn as sns
import warnings
from sklearn.manifold import TSNE
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
# Set seed for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed()

In [None]:
# Load and validate SMILES
smiles_df = pd.read_csv("/content/drive/MyDrive/VRSEC/III Year/3 - 2/Mini/Teju/Databases/DDI/SMILES_Dataset.csv")
smiles_df = smiles_df.dropna(subset=["SMILES"]).reset_index(drop=True)

smiles_df['IsValid'] = smiles_df['SMILES'].apply(lambda s: Chem.MolFromSmiles(s) is not None)
valid_smiles = smiles_df[smiles_df['IsValid']].drop(columns=['IsValid'])
valid_smiles.to_csv('ValidSmiles.csv', index=False)

# Load interaction datasets
biosnap = pd.read_csv("/content/drive/MyDrive/VRSEC/III Year/3 - 2/Mini/Teju/Databases/DDI/ChCh-Miner_durgbank-chem-chem.tsv", sep='\t', header=None, names=['src', 'dst'])
biosnap = biosnap.dropna().reset_index(drop=True)

drugbank = pd.read_csv("/content/drive/MyDrive/VRSEC/III Year/3 - 2/Mini/Teju/Databases/DDI/DrugBankDDI.csv")
drugbank.rename(columns={"Drug1": "src", "Drug2": "dst"}, inplace=True)
drugbank = drugbank.dropna().reset_index(drop=True)

allowed = set(valid_smiles["DrugBank ID"])

In [None]:
# Efficient filtering
biosnap = biosnap[biosnap['src'].isin(allowed) & biosnap['dst'].isin(allowed)].reset_index(drop=True)
drugbank = drugbank[drugbank['src'].isin(allowed) & drugbank['dst'].isin(allowed)].reset_index(drop=True)

# Combine graphs
combined_graph = pd.concat([biosnap, drugbank], ignore_index=True).drop_duplicates().reset_index(drop=True)
combined_graph.to_csv("CombinedGraph.csv", index=False)

In [None]:
# Graph Analysis
G = nx.from_pandas_edgelist(combined_graph, source='src', target='dst')

print("Graph Summary")
print(f"Nodes: {G.number_of_nodes()} | Edges: {G.number_of_edges()}")
print(f"Isolated Nodes: {any(deg == 0 for _, deg in G.degree())}")
print(f"Self-loops: {any(G.has_edge(n, n) for n in G.nodes())}")

In [None]:
# Convert to PyTorch Geometric format
def prepare_pyg_data(features, graph):

    unique_ids = np.unique(graph.values)
    id_map = {name: i for i, name in enumerate(unique_ids)}

    src = graph['src'].map(id_map).values
    dst = graph['dst'].map(id_map).values
    edge_index = torch.tensor(np.vstack((np.concatenate([src, dst]), np.concatenate([dst, src]))), dtype=torch.long)

    # Sanitize features: convert to DataFrame, drop string columns, keep only numeric
    if not isinstance(features, pd.DataFrame):
        features = pd.DataFrame(features)

    # Drop common identifier/non-numeric columns
    features = features.drop(columns=[col for col in features.columns if col in ['DrugBank ID', 'SMILES','Unnamed: 0', 'Description', 'DrugID']], errors='ignore')
    features = features.select_dtypes(include=[np.number])  # Keep only numeric

    # 🔧 Final clean up and conversion
    features = np.nan_to_num(features.values.astype(np.float32))
    x = torch.tensor(features, dtype=torch.float32)

    return Data(x=x, edge_index=edge_index)

In [None]:
# GNN Model
class GATNet(nn.Module):
    def __init__(self, in_channels, hidden):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden, heads=2)
        self.bn1 = nn.BatchNorm1d(hidden * 2)
        self.conv2 = GATConv(hidden * 2, hidden)
        self.bn2 = nn.BatchNorm1d(hidden)

    def encode(self, x, edge_index):
        x = F.relu(self.bn1(self.conv1(x, edge_index)))
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.relu(self.bn2(self.conv2(x, edge_index)))
        return x

    def decode(self, z, edge_index):
        return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

In [None]:
# Training & Evaluation
def train(model, optimizer, scheduler, criterion, data, edge_label_index, edge_label):
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, edge_label_index).view(-1)
    loss = criterion(out, edge_label)
    loss.backward()
    optimizer.step()
    scheduler.step()
    return loss.item()

@torch.no_grad()
def evaluate(model, data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).sigmoid()
    return roc_auc_score(data.edge_label.cpu(), out.cpu()), out.cpu().numpy(), data.edge_label.cpu().numpy()

In [None]:
def tsne_visualize_per_model(model, data, model_name, sample_size=1000):
    model.eval()
    with torch.no_grad():
        z = model.encode(data.x, data.edge_index)
        edge_index = data.edge_label_index
        labels = data.edge_label.cpu().numpy()

        # Sample subset
        total_edges = edge_index.size(1)
        sampled_idx = np.random.choice(total_edges, min(sample_size, total_edges), replace=False)
        sampled_edge_index = edge_index[:, sampled_idx]
        sampled_labels = labels[sampled_idx]

        # Mean-pool the two node embeddings per edge
        edge_embeddings = (z[sampled_edge_index[0]] + z[sampled_edge_index[1]]) / 2
        edge_embeddings = edge_embeddings.cpu().numpy()

        # t-SNE
        tsne = TSNE(n_components=2, random_state=42, perplexity=30)
        tsne_result = tsne.fit_transform(edge_embeddings)

        # Plot
        plt.figure(figsize=(8, 6))
        colors = ['green' if l == 1 else 'red' for l in sampled_labels]
        plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c=colors, alpha=0.7)
        plt.title(f"t-SNE of GAT Interaction Embeddings: {model_name}")
        plt.xlabel("TSNE-1")
        plt.ylabel("TSNE-2")
        plt.legend(handles=[
            plt.Line2D([0], [0], marker='o', color='w', label='Positive', markerfacecolor='green', markersize=8),
            plt.Line2D([0], [0], marker='o', color='w', label='Negative', markerfacecolor='red', markersize=8)
        ])
        plt.tight_layout()
        plt.savefig(f"tsne_{model_name}.png")
        plt.show()


In [None]:
#embedding models
Embedding_models = {
    'T5': (pd.read_csv("/content/drive/MyDrive/VRSEC/III Year/3 - 2/Mini/Teju/Databases/ImmI Embeddings/DDI Embeddings/T5/T5_SMILES_Embeddings.csv"), combined_graph),
    'SBERT': (pd.read_csv("/content/drive/MyDrive/VRSEC/III Year/3 - 2/Mini/Teju/Databases/ImmI Embeddings/DDI Embeddings/SBERT_Embeddings.csv"), combined_graph)
}

In [None]:
# Assuming GATNet, prepare_pyg_data, train, evaluate, Embedding_models are already defined
lr_list =  [0.01, 0.001, 0.0001, 0.0002, 0.0003, 0.00001]
results_pr = {}

# Define the CSV file path
csv_file_path = "GATNet_Results.csv"

# Prepare CSV header if the file doesn't exist
if not os.path.exists(csv_file_path):
    with open(csv_file_path, 'w', newline='') as csvfile:
        csv_writer = csv.writer(csvfile)
        csv_writer.writerow(["Model", "Learning Rate", "Best Validation AUC", "Final Test AUC", "PR AUC"])

for model_name, (features, graph) in Embedding_models.items():
    print(f"\n Starting training for: {model_name}")

    for selected_lr in lr_list:
        print(f"\n Training with Learning Rate: {selected_lr}")

        try:
            # --- Step 1: Feature cleanup ---
            features_df = pd.DataFrame(features)
            features_df = features_df.drop(columns=[
                col for col in features_df.columns if col in
                ['DrugBank ID', 'SMILES', 'Unnamed: 0', 'Description', 'DrugID']
            ], errors='ignore')
            features_df = features_df.select_dtypes(include=[np.number])

            if features_df.shape[0] == 0 or features_df.shape[1] == 0:
                print(f" Skipping LR {selected_lr} for model {model_name} → Feature matrix is empty after cleaning.")
                continue

            features_array = np.nan_to_num(features_df.values.astype(np.float32))

            if np.isnan(features_array).any():
                print(f" Skipping LR {selected_lr} for model {model_name} → Feature matrix has NaNs even after conversion.")
                continue

        except Exception as e:
            print(f" Skipping LR {selected_lr} for model {model_name} due to feature error: {e}")
            continue

        try:
            # --- Step 2: Graph preparation ---
            data = prepare_pyg_data(features_array, graph)
            transform = RandomLinkSplit(is_undirected=True, add_negative_train_samples=False)
            train_data, val_data, test_data = map(lambda x: x.to('cuda' if torch.cuda.is_available() else 'cpu'), transform(data))

            device = train_data.x.device
            model = GATNet(data.num_features, 256).to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=selected_lr)
            scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: 0.96)
            criterion = nn.BCEWithLogitsLoss()

            best_val_auc = 0
            final_test_auc = 0
            best_scores, best_labels = None, None

            for epoch in range(1, 25):
                torch.cuda.empty_cache()
                epoch_start = time.time()

                try:
                    # Negative sampling
                    num_neg = min(500, train_data.edge_label_index.size(1))
                    neg_edge = negative_sampling(
                        edge_index=train_data.edge_index,
                        num_nodes=train_data.num_nodes,
                        num_neg_samples=num_neg,
                        method='sparse'
                    )

                    if neg_edge.size(1) == 0:
                        # print(f" Epoch {epoch}: No negative samples, skipping.")
                        continue

                    # Prepare training edges & labels
                    edge_idx = torch.cat([train_data.edge_label_index, neg_edge], dim=1).to(device)
                    labels = torch.cat([
                        train_data.edge_label,
                        torch.zeros(neg_edge.size(1), device=device)
                    ])

                    assert edge_idx.device == device and labels.device == device and train_data.x.device == device

                    # Training
                    loss = train(model, optimizer, scheduler, criterion, train_data, edge_idx, labels)

                    if math.isnan(loss):
                        print(f" NaN loss at epoch {epoch}, skipping...")
                        break

                    # Evaluation
                    val_auc, _, _ = evaluate(model, val_data)
                    test_auc, scores, test_labels = evaluate(model, test_data)

                    # print(f"Epoch {epoch:03d} | val_auc={val_auc:.4f}, best_auc={best_val_auc:.4f}", flush=True)

                    if val_auc > best_val_auc:
                        best_val_auc = val_auc
                        final_test_auc = test_auc
                        best_scores = scores
                        best_labels = test_labels
                        # print(f" Epoch {epoch:03d} [BEST] | Loss: {loss:.4f}, Val AUC: {val_auc:.4f}, Test AUC: {test_auc:.4f}", flush=True)

                except Exception as e:
                    print(f" Error inside epoch {epoch} for {model_name} with LR {selected_lr}: {type(e).__name__}: {e}")
                    # traceback.print_exc() # Uncomment for detailed traceback
                    break

            pr_auc = 0
            if best_scores is not None and len(np.unique(best_labels)) > 1: # Ensure both classes are present for PR curve
                try:
                    tsne_visualize_per_model(model, test_data, model_name)
                    precision, recall, _ = precision_recall_curve(best_labels, best_scores)
                    pr_auc = auc(recall, precision)
                except Exception as e:
                    print(f" Error calculating PR AUC for {model_name} with LR {selected_lr}: {e}")
                    pr_auc = 0 # Set PR AUC to 0 if calculation fails

            else:
                print(f" Not enough classes for PR AUC calculation for {model_name} with LR {selected_lr}, setting PR AUC to 0.")
                pr_auc = 0


            # Store results in the CSV file
            with open(csv_file_path, 'a', newline='') as csvfile:
                csv_writer = csv.writer(csvfile)
                csv_writer.writerow([model_name, selected_lr, best_val_auc, final_test_auc, pr_auc])

            print(f" Results for {model_name} (LR={selected_lr}): Best Val AUC={best_val_auc:.4f}, Final Test AUC={final_test_auc:.4f}, PR AUC={pr_auc:.4f}")

        except Exception as e:
            print(f" Skipping training for {model_name} with LR {selected_lr} due to error: {e}")
            # traceback.print_exc() # Uncomment for detailed traceback
            continue

print("\n All models and learning rates processed. Results saved to:", csv_file_path)


In [None]:
# Read the CSV file
df_results = pd.read_csv("GATNet_Results.csv")

df_results_dropped = df_results.drop(columns=["Best Validation AUC", "PR AUC"])

# Pivot the table to make 'Learning Rate' the new columns
df_pivot = df_results_dropped.pivot(index='Model', columns='Learning Rate', values='Final Test AUC')

# Reset index to turn 'Model' into a column again
AUC = df_pivot.reset_index()

# Rename the column axis name (which is 'Learning Rate' after pivot)
AUC = AUC.rename_axis(None, axis=1)

# Display the final dataframe
AUC.to_csv("GATNet_AUC.csv", index=False)
AUC

In [None]:
# Read the CSV file
df_results = pd.read_csv("GATNet_Results.csv")

# Drop the specified columns
df_results_dropped = df_results.drop(columns=["Best Validation AUC", "Final Test AUC"])

df_results_dropped = df_results_dropped.drop_duplicates(subset=['Model', 'Learning Rate'])

# Pivot the table to make 'Learning Rate' the new columns
df_pivot = df_results_dropped.pivot(index='Model', columns='Learning Rate', values='PR AUC')

# Reset index to turn 'Model' into a column again
PR = df_pivot.reset_index()

# Rename the 'index' column to 'Model'
PR = PR.rename_axis(None, axis=1) # Remove the column axis name 'Learning Rate'

# Display the final dataframe
PR.to_csv("GATNet_PR.csv", index=False)
PR

In [None]:
AUC = pd.read_csv("GATNet_AUC.csv")
AUC = AUC.drop([2])

# Create barplot
plt.figure(figsize=(10, 6))
ax = sns.barplot(x='Model' ,  y='1e-05', data=AUC ,palette="crest")
ax.set_title("Combined Graph - AUROC")
ax.set_ylabel("AUROC")
ax.set_ylim(0.7, 1)
ax.tick_params(axis='x', rotation=90)
plt.tight_layout()
plt.show()

In [None]:
# prompt: read the GATNet_PR and drop the rows of index 2

import pandas as pd
# Read the GATNet_PR CSV file
PR = pd.read_csv("GATNet_PR.csv")

# Drop rows with index 2
PR = PR.drop([2])

# Display the modified dataframe
PR

In [None]:
# PR = pd.read_csv("GATNet_PR.csv")
# Create barplot
ax=sns.barplot(x='Model' , y= '1e-05' , data=PR ,palette="crest")
plt.figure(figsize=(10, 6))
ax.set_title('Combined Graph - AUPR')
ax.set(ylabel='AUPR')
ax.set_ylim(0.7,1)
ax.tick_params(axis='x', rotation=90)
plt.tight_layout()
plt.show()

In [None]:
# Melt the dataframe to long format for plotting
df_pr_melted = PR.melt(id_vars='Model', var_name='Learning Rate', value_name='PR AUC')

# Ensure Learning Rate is treated as a categorical variable for better plotting
df_pr_melted['Learning Rate'] = df_pr_melted['Learning Rate'].astype(str)

# Create the line plot
plt.figure(figsize=(12, 7))
sns.lineplot(data=df_pr_melted, x='Learning Rate', y='PR AUC', hue='Model', marker='o')

# Add title and labels
plt.title("PR AUC across Different Learning Rates for Each Model")
plt.xlabel("Learning Rate")
plt.ylabel("PR AUC")

# Improve layout and show plot
plt.xticks(rotation=45)
plt.legend(title="Model")
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()