# Generating Graph from the SMILES representation

In [None]:
import numpy as np
import pandas as pd

from torch_geometric.data import Data
import torch

from rdkit import Chem
from rdkit.Chem import AllChem


In [None]:
df = pd.read_csv("dataset/GDSC_SMILES_merged.csv")
df = df.dropna(subset=["SMILES"])
drug_smiles = dict(zip(df["DRUG_NAME"], df["SMILES"]))

In [None]:
drug_smiles

In [None]:
def smiles_to_mol(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        Chem.Kekulize(mol)
    return mol

In [None]:
def build_atom_vocab(smiles_dict):
    atom_set = set()
    for smi in smiles_dict.values():
        mol = Chem.MolFromSmiles(smi)
        if mol:
            for atom in mol.GetAtoms():
                atom_set.add(atom.GetSymbol())
    return sorted(atom_set)

In [None]:
atom_vocab = build_atom_vocab(drug_smiles)
atom_vocab

In [None]:
from rdkit import Chem

# Define vocabularies
# ATOM_LIST = ['C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Br', 'I', 'H', 'B', 'Si', 'Na', 'K', 'Li', 'Mg', 'Ca', 'Fe', 'Zn', 'Se', 'Cu']
ATOM_LIST = atom_vocab
DEGREE_LIST = [0, 1, 2, 3, 4, 5]
NUM_H_LIST = [0, 1, 2, 3, 4]
VALENCE_LIST = [0, 1, 2, 3, 4, 5, 6]

def one_hot_encoding(x, allowable_set):
    return [int(x == s) for s in allowable_set]

def atom_features(atom):
    return one_hot_encoding(atom.GetSymbol(), ATOM_LIST) + \
           one_hot_encoding(atom.GetDegree(), DEGREE_LIST) + \
           one_hot_encoding(atom.GetTotalNumHs(), NUM_H_LIST) + \
           one_hot_encoding(atom.GetImplicitValence(), VALENCE_LIST) + \
           [int(atom.GetIsAromatic())]

def mol_to_graph_data_obj(mol):
    atoms = mol.GetAtoms()
    x = torch.tensor([atom_features(atom) for atom in atoms], dtype=torch.float)

    edge_index = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.append((i, j))
        edge_index.append((j, i))  # undirected graph

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    return Data(x=x, edge_index=edge_index)


In [None]:
from torch_geometric.data import InMemoryDataset

class DrugGraphDataset(InMemoryDataset):
    def __init__(self, smiles_dict):
        self.smiles_dict = smiles_dict
        super().__init__('.', transform=None, pre_transform=None)
        self.data, self.slices = self._process()

    def _process(self):
        data_list = []
        for name, smi in self.smiles_dict.items():
            mol = smiles_to_mol(smi)
            if mol:
                data = mol_to_graph_data_obj(mol)
                data.drug_name = name
                data_list.append(data)
        return self.collate(data_list)


In [None]:
dataset = DrugGraphDataset(drug_smiles)
print(dataset[0])

In [None]:
mol2 = smiles_to_mol(drug_smiles["Camptothecin"])
mol2

In [None]:
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import networkx as nx

G = to_networkx(dataset[0], to_undirected=True)
plt.figure(figsize=(6,6))
nx.draw(G, node_size=50)
plt.show()


In [None]:
print(dataset[0].x.shape)
print(dataset[0].edge_index.shape)

In [None]:
print(len(dataset[0].x[0]))
print(len(dataset[0].edge_index[0]))

In [None]:
print(dataset[100])

# Preparing the Dataset

In [None]:
# Load merged GDSC dataset with drug, cell line, IC50, and SMILES
df = pd.read_csv("dataset/GDSC_SMILES_merged.csv", index_col=0)
df = df.dropna(subset=["SMILES"])

# Load GSVA matrix (rows = cell line, cols = pathways OR vice versa)
gsva_df = pd.read_csv("dataset/ccle_gsva_scores.csv", index_col=0)


In [None]:
df

In [None]:
cell_lines_available = set(gsva_df.columns)

In [None]:
# Normalize IC50 if needed
df["LN_IC50"] = df["LN_IC50"].astype(float)
df["IC50_NORMALIZED"] = 1 / (1 + (1/(np.exp(df["LN_IC50"])**0.1)) )

In [None]:
df.info()

In [None]:
df.describe()

In [None]:
df

In [None]:
cell_lines_obs = set(df["CELL_LINE_NAME"].unique())

In [None]:
print("GDSC df cell lines:", df["CELL_LINE_NAME"].unique()[:5])
print("GSVA df index:", gsva_df.columns[:5])

In [None]:
# Normalize GDSC
df["CELL_LINE_NAME"] = df["CELL_LINE_NAME"].str.strip().str.upper()
cell_lines_obs = set(df["CELL_LINE_NAME"].unique())

# Normalize GSVA
gsva_df.columns = gsva_df.columns.str.strip().str.upper()
cell_lines_available = sorted(set(gsva_df.columns.str.split("_").str[0].str.upper()))

In [None]:
common_cell_lines = cell_lines_obs.intersection(cell_lines_available)
print("Now common cell lines:", len(common_cell_lines))

In [None]:
len(cell_lines_obs)

In [None]:
cell_lines_obs

In [None]:
cell_lines_available[:10]

In [None]:
import re

def clean_name(name):
    return re.sub(r'[^A-Za-z0-9]', '', name).upper()

In [None]:
cell_lines_available = {clean_name(name) for name in cell_lines_available}
cell_lines_obs = {clean_name(name) for name in cell_lines_obs}

In [None]:
common_cell_lines = cell_lines_obs.intersection(cell_lines_available)
print("Common cell lines:", len(common_cell_lines))

In [None]:
df["CELL_LINE_NAME_CLEAN"] = df["CELL_LINE_NAME"].apply(clean_name)
df = df[df["CELL_LINE_NAME_CLEAN"].isin(common_cell_lines)]

In [None]:
df

In [None]:
len(cell_lines_available)

In [None]:
gsva_df

In [None]:
gsva_df.columns = gsva_df.columns.str.split("_").str[0].str.upper().to_series().apply(clean_name)
gsva_df.columns.has_duplicates

In [None]:
gsva_df = gsva_df.loc[:, ~gsva_df.columns.duplicated()]
print(gsva_df.columns.has_duplicates)
gsva_df

# Building the PyTorch Dataset

In [None]:
from torch.utils.data import Dataset
from torch_geometric.data import Data

class DrugResponseDataset(Dataset):
    def __init__(self, dataframe, gsva_df, drug_smiles, atom_vocab):
        self.df = dataframe.reset_index(drop=True)
        self.gsva_df = gsva_df
        self.drug_smiles = drug_smiles
        self.atom_vocab = atom_vocab

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        drug = row["DRUG_NAME"]
        cell_line = row["CELL_LINE_NAME_CLEAN"]
        y = row["IC50_NORMALIZED"]

        if drug not in self.drug_smiles:
            raise ValueError(f"Drug '{drug}' not found in SMILES dictionary")

        smi = self.drug_smiles[drug]
        mol = smiles_to_mol(smi)

        if mol is None:
            raise ValueError(f"Invalid SMILES for {drug}: {smi}")

        graph = mol_to_graph_data_obj(mol)
        graph.idx = idx

        if cell_line not in self.gsva_df.columns:
            raise ValueError(f"Cell line '{cell_line}' not in GSVA")

        gsva_vec = torch.tensor(self.gsva_df[cell_line].values, dtype=torch.float)
        return graph, gsva_vec, torch.tensor([y], dtype=torch.float)


In [None]:
from torch_geometric.loader import DataLoader as PyGLoader
from torch.utils.data import DataLoader as TorchLoader
from torch_geometric.data import Batch

# Custom collate to combine graph batches + tensor batches
def collate_fn(batch):
    graphs, gsvectors, targets = zip(*batch)
    return (
        Batch.from_data_list(graphs),
        torch.stack(gsvectors),
        torch.stack(targets)
    )

In [None]:
from sklearn.model_selection import train_test_split

cell_lines = df["CELL_LINE_NAME_CLEAN"].unique()
train_cl, temp_cl = train_test_split(cell_lines, test_size=0.3, random_state=42)
val_cl, test_cl = train_test_split(temp_cl, test_size=2/3, random_state=42)

In [None]:
train_df = df[df["CELL_LINE_NAME_CLEAN"].isin(train_cl)]
val_df   = df[df["CELL_LINE_NAME_CLEAN"].isin(val_cl)]
test_df  = df[df["CELL_LINE_NAME_CLEAN"].isin(test_cl)]

In [None]:
train_dataset = DrugResponseDataset(train_df, gsva_df, drug_smiles, atom_vocab)
val_dataset   = DrugResponseDataset(val_df, gsva_df, drug_smiles, atom_vocab)
test_dataset  = DrugResponseDataset(test_df, gsva_df, drug_smiles, atom_vocab)

In [None]:
train_loader = TorchLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader   = TorchLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader  = TorchLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [None]:
# train_dataset = DrugResponseDataset(df, gsva_df, drug_smiles, atom_vocab)
# train_loader = TorchLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

# Building the Model Architecture

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_max_pool
from torch_geometric.data import Data

from tqdm.notebook import tqdm

from sklearn.metrics import root_mean_squared_error
from scipy.stats import pearsonr

from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

In [None]:
class DrugGINEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, out_dim=128):
        super().__init__()

        # MLPs for GIN layers
        self.mlp1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.gin1 = GINConv(self.mlp1)

        self.mlp2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.gin2 = GINConv(self.mlp2)

        self.fc = nn.Linear(hidden_dim, out_dim)  # Final output: 128-dim

    def forward(self, x, edge_index, batch):
        x = self.gin1(x, edge_index)
        x = F.relu(x)
        x = self.gin2(x, edge_index)
        x = F.relu(x)

        # Global pooling to get graph-level embedding
        x = global_max_pool(x, batch)

        # Final FC layer
        x = self.fc(x)
        return x

In [None]:
class CellLineEncoder(nn.Module):
    def __init__(self, input_dim=658): # gsva_vector.shape = (658,)
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.dropout = nn.Dropout(0.2)
        self.fc3 = nn.Linear(1024, 128)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x


In [None]:
dataset[0]

In [None]:
class GPDRPModel(nn.Module):
    def __init__(self, drug_in_dim=30, gsva_dim=658):
        super().__init__()
        self.drug_encoder = DrugGINEncoder(drug_in_dim)
        self.cell_encoder = CellLineEncoder(gsva_dim)
        self.fc = nn.Sequential(
            nn.Linear(256, 1024),
            nn.ReLU(),
            nn.Linear(1024, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # final IC50 prediction
        )

    def forward(self, data, gsva):
        drug_repr = self.drug_encoder(data.x, data.edge_index, data.batch)
        cell_repr = self.cell_encoder(gsva)
        fused = torch.cat([drug_repr, cell_repr], dim=1)
        return self.fc(fused)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
best_rmse = float("inf")
best_model_path = "models/gpdrp_best_model.pt"

# Training

In [None]:
max_epocs = 50

model = GPDRPModel(drug_in_dim=30, gsva_dim=658).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=max_epocs, eta_min=1e-6)
loss_fn = nn.MSELoss()
train_rmses = []
val_rmses = []

for epoch in range(1, max_epocs+1):
    print(f"\n🌟 Epoch {epoch}")
    model.train()
    total_loss = 0
    preds, truths = [], []

    train_bar = tqdm(train_loader, desc="Training", leave=False)
    for i, (graph_batch, gsva_batch, ic50_batch) in enumerate(train_bar):
        graph_batch = graph_batch.to(device)
        gsva_batch = gsva_batch.to(device)
        ic50_batch = ic50_batch.to(device)

        pred = model(graph_batch, gsva_batch)
        loss = loss_fn(pred, ic50_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(ic50_batch)

        preds.extend(pred.detach().cpu().numpy().flatten())
        truths.extend(ic50_batch.detach().cpu().numpy().flatten())

        if (i + 1) % 50 == 0:
            rmse = root_mean_squared_error(truths, preds)
            pcc = pearsonr(truths, preds)[0]
            train_bar.set_postfix({
                "Batch": i + 1,
                "Loss": loss.item(),
                "RMSE": f"{rmse:.4f}",
                "PCC": f"{pcc:.4f}"
            })

    avg_loss = total_loss / len(train_loader.dataset)
    train_rmse = root_mean_squared_error(truths, preds)
    train_pcc = pearsonr(truths, preds)[0]
    print(f"✅ Epoch {epoch}: MSE = {avg_loss:.4f}, RMSE = {train_rmse:.4f}, PCC = {train_pcc:.4f}")

    # 🔍 Validation Phase
    model.eval()
    val_preds, val_truths = [], []

    val_bar = tqdm(val_loader, desc="Validation", leave=False)
    with torch.no_grad():
        for graph_batch, gsva_batch, ic50_batch in val_bar:
            graph_batch = graph_batch.to(device)
            gsva_batch = gsva_batch.to(device)
            ic50_batch = ic50_batch.to(device)

            pred = model(graph_batch, gsva_batch)
            val_preds.extend(pred.cpu().numpy().flatten())
            val_truths.extend(ic50_batch.cpu().numpy().flatten())

    val_rmse = root_mean_squared_error(val_truths, val_preds)
    val_pcc = pearsonr(val_truths, val_preds)[0]
    
    scheduler.step()
    current_lr = optimizer.param_groups[0]["lr"]
    print(f"📉 Val RMSE = {val_rmse:.4f} | LR = {current_lr:.1e}")


    if val_rmse < best_rmse:
        best_rmse = val_rmse
        torch.save(model.state_dict(), best_model_path)
        print(f"✅ Saved new best model (RMSE: {val_rmse:.4f})")
    
    train_rmses.append(train_rmse)
    val_rmses.append(val_rmse)


In [None]:
import matplotlib.pyplot as plt

plt.plot(train_rmses, label="Train RMSE")
plt.plot(val_rmses, label="Validation RMSE")
plt.xlabel("Epoch")
plt.ylabel("RMSE")
plt.title("Training vs Validation RMSE")
plt.legend()
plt.grid(True)
plt.show()

# Testing and Evaluation

In [None]:
best_model = GPDRPModel(drug_in_dim=30, gsva_dim=658).to(device)
best_model.load_state_dict(torch.load("models/gpdrp_best_model.pt"))
best_model.eval()

In [None]:
test_preds, test_truths = [], []

with torch.no_grad():
    for graph_batch, gsva_batch, ic50_batch in tqdm(test_loader, desc="Testing"):
        graph_batch = graph_batch.to(device)
        gsva_batch = gsva_batch.to(device)
        ic50_batch = ic50_batch.to(device)

        pred = best_model(graph_batch, gsva_batch)
        test_preds.extend(pred.cpu().numpy().flatten())
        test_truths.extend(ic50_batch.cpu().numpy().flatten())


In [None]:
test_rmse = root_mean_squared_error(test_truths, test_preds)
test_pcc = pearsonr(test_truths, test_preds)[0]

print(f"🧪 Test RMSE: {test_rmse:.4f}")
print(f"🧪 Test PCC:  {test_pcc:.4f}")


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(6,6))
plt.scatter(test_truths, test_preds, alpha=0.3)
plt.xlabel("True IC50")
plt.ylabel("Predicted IC50")
plt.title(f"Predicted vs True IC50 (PCC={test_pcc:.2f}, RMSE={test_rmse:.2f})")
plt.plot([0, 1], [0, 1], '--', color='gray')  # line y=x
plt.grid(True)
plt.tight_layout()
plt.show()