In [10]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error
from scipy.stats   import pearsonr, spearmanr

batch_size = 8
learning_rate = 3e-5
n_embd_1 = 768
n_embd_2=960
emb_dim = n_embd_1 + n_embd_2
num_classes = 1
n_head = 8
head_size = n_embd_1 // n_head
n_layer = 4
dropout = 0.2
epochs = 5
device = 'cuda' if torch.cuda.is_available() else 'cpu'


def evaluate(y_true, y_pred):
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    pearson_corr, _ = pearsonr(y_true, y_pred)
    spearman_corr, _ = spearmanr(y_true, y_pred)

    return mae, rmse, pearson_corr, spearman_corr


class MolProtDataset(Dataset):
    def __init__(self, dataframe, mol_emb, prot_emb):
        self.df = dataframe
        self.mol_emb = mol_emb
        self.prot_emb = prot_emb

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        mol_id = int(row['MoleculeIdx'])
        prot_id = int(row['ProteinIdx'])
        target = float(row['Ki'])
        mol = torch.tensor(self.mol_emb[mol_id], dtype=torch.float32)
        prot = torch.tensor(self.prot_emb[prot_id], dtype=torch.float32)
        return mol, prot, torch.tensor(target, dtype=torch.float32)


def pad_collate(batch):
    mols, prots, targets = zip(*batch)

    mol_lengths = [m.shape[0] for m in mols]
    mol_max_len = max(mol_lengths)
    mols_padded = [torch.cat([m, m.new_zeros(mol_max_len - m.shape[0], m.shape[1])], dim=0) for m in mols]
    mols_batch = torch.stack(mols_padded)

    prot_lengths = [p.shape[0] for p in prots]
    prot_max_len = max(prot_lengths)
    prots_padded = [torch.cat([p, p.new_zeros(prot_max_len - p.shape[0], p.shape[1])], dim=0) for p in prots]
    prots_batch = torch.stack(prots_padded)

    targets_batch = torch.stack(targets)

    return mols_batch, prots_batch, targets_batch

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd_2, head_size, bias=False)
        self.query = nn.Linear(n_embd_1, head_size, bias=False)
        self.value = nn.Linear(n_embd_2, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, mol, prot):
        q = self.query(mol)
        k = self.key(prot)
        v = self.value(prot)

        wei = q @ k.transpose(-2, -1) / (k.size(-1) ** 0.5)
        wei = torch.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        return wei @ v


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads * head_size, n_embd_1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, mol, prot):
        out = torch.cat([h(mol, prot) for h in self.heads], dim=-1)
        return self.dropout(self.proj(out))

class MLP(nn.Module):
      def __init__(self, emb_dim, num_classes, dropout=0.2):

        super().__init__()
        self.desc_skip_connection = True
        self.fc1 = nn.Linear(emb_dim, emb_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.relu1 = nn.GELU()
        self.fc2 = nn.Linear(emb_dim, emb_dim)
        self.dropout2 = nn.Dropout(dropout)
        self.relu2 = nn.GELU()
        self.final = nn.Linear(emb_dim, num_classes)

      def forward(self, inter_emb):
          x_out = self.fc1(inter_emb)
          x_out = self.dropout1(x_out)
          x_out = self.relu1(x_out)

          x_out = x_out + inter_emb
          z = self.fc2(x_out)
          z = self.dropout2(z)
          z = self.relu2(z)
          z = self.final(z + x_out)
          return z

class Model_2(nn.Module):
    def __init__(self):
        super().__init__()
        self.ca = MultiHeadAttention(n_head, head_size)
        self.mlp=MLP(emb_dim, num_classes)
        self.ln_f = nn.LayerNorm(n_embd_1)
        self.reg_head=nn.LayerNorm(n_embd_1,1)

    def forward(self, mol, prot, target=None):
        inter = self.ca(mol, prot) + mol     # residual connection
        inter = self.ln_f(inter)
        mol_pooled = inter.mean(dim=1)
        prot_pooled = prot.mean(dim=1)
        x = torch.cat([mol_pooled, prot_pooled], dim=1)
        x = self.mlp(x)
        prediction = x.squeeze(-1)
        loss = None
        if target is not None:
            loss = torch.sqrt(nn.functional.mse_loss(prediction, target))
        return loss, prediction

# ==================== Load embeddings ====================
molecule_embeddings = np.load('/content/molecules_all_embeddings.npy',allow_pickle=True)
protein_embeddings = np.load('/content/proteins_all_embeddings.npy',allow_pickle=True)


fold_loaders = []
for fold in range(1, 6):
    train_df = pd.read_csv(f'/content/train_{fold}.csv')
    test_df  = pd.read_csv(f'/content/test_{fold}.csv')

    train_ds = MolProtDataset(train_df, molecule_embeddings, protein_embeddings)
    test_ds  = MolProtDataset(test_df,  molecule_embeddings, protein_embeddings)

    train_loader = DataLoader(train_ds, batch_size=batch_size,
                              shuffle=True, collate_fn=pad_collate)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size,
                              shuffle=False, collate_fn=pad_collate)

    fold_loaders.append((train_loader, test_loader))

#Instantiate one model & optimizer per fold
fold_models     = [Model_2().to(device) for _ in range(5)]
fold_optimizers = [torch.optim.AdamW(m.parameters(), lr=learning_rate)
                   for m in fold_models]

for epoch in range(1, epochs+1):
    # per-epoch containers
    fold_train_losses, fold_val_losses = [], []
    fold_train_maes,   fold_val_maes   = [], []
    fold_train_ps,     fold_val_ps     = [], []
    fold_train_ss,     fold_val_ss     = [], []

    #Inner fold loop
    for (train_loader, test_loader), model, opt in zip(fold_loaders,
                                                      fold_models,
                                                      fold_optimizers):
        # ---- TRAIN  ----
        model.train()
        train_loss = 0.0
        for mol, prot, target in train_loader:
            mol, prot, target = mol.to(device), prot.to(device), target.to(device)
            opt.zero_grad()
            loss, _ = model(mol, prot, target)
            loss.backward()
            opt.step()
            train_loss += loss.item() * mol.size(0)
        train_loss /= len(train_loader.dataset)

        # ---- TRAIN METRICS ----
        model.eval()
        y_t_train, y_p_train = [], []
        with torch.no_grad():
            for mol, prot, target in train_loader:
                mol, prot, target = mol.to(device), prot.to(device), target.to(device)
                _, preds = model(mol, prot, target)
                y_t_train.extend(target.cpu().numpy())
                y_p_train.extend(preds.cpu().numpy())
        t_mae, t_rmse, t_p, t_s = evaluate(y_t_train, y_p_train)

        # ---- VALID ----
        y_t_val, y_p_val = [], []
        with torch.no_grad():
            for mol, prot, target in test_loader:
                mol, prot, target = mol.to(device), prot.to(device), target.to(device)
                _, preds = model(mol, prot, target)
                y_t_val.extend(target.cpu().numpy())
                y_p_val.extend(preds.cpu().numpy())
        v_mae, v_rmse, v_p, v_s = evaluate(y_t_val, y_p_val)

        # ---- COLLECT for this fold ----
        fold_train_losses.append(t_rmse)
        fold_train_maes.append(t_mae)
        fold_train_ps.append(t_p)
        fold_train_ss.append(t_s)

        fold_val_losses.append(v_rmse)
        fold_val_maes.append(v_mae)
        fold_val_ps.append(v_p)
        fold_val_ss.append(v_s)

    # ---- AVERAGE 5 FOLDS ----
    print(f"\nEpoch {epoch}/{epochs}")
    print(f" TRAIN  — RMSE: {np.mean(fold_train_losses):.4f} ( {np.std(fold_train_losses):.4f}), "
          f"MAE: {np.mean(fold_train_maes):.4f} ( {np.std(fold_train_maes):.4f})")
    print(f"          Pearson: {np.mean(fold_train_ps):.4f} ( {np.std(fold_train_ps):.4f}), "
          f"Spearman: {np.mean(fold_train_ss):.4f} ( {np.std(fold_train_ss):.4f})")
    print(f" VALID  — RMSE: {np.mean(fold_val_losses):.4f} ( {np.std(fold_val_losses):.4f}), "
          f"MAE: {np.mean(fold_val_maes):.4f} ( {np.std(fold_val_maes):.4f})")
    print(f"          Pearson: {np.mean(fold_val_ps):.4f} ( {np.std(fold_val_ps):.4f}), "
          f"Spearman: {np.mean(fold_val_ss):.4f} ( {np.std(fold_val_ss):.4f})\n")



Epoch 1/5
 TRAIN  — RMSE: 0.9947 ( 0.0591), MAE: 0.8436 ( 0.0575)
          Pearson: 0.5146 ( 0.0261), Spearman: 0.4825 ( 0.0256)
 VALID  — RMSE: 1.0088 ( 0.0635), MAE: 0.8453 ( 0.0593)
          Pearson: 0.3995 ( 0.0255), Spearman: 0.3772 ( 0.0373)


Epoch 2/5
 TRAIN  — RMSE: 1.0019 ( 0.0907), MAE: 0.8564 ( 0.0877)
          Pearson: 0.5845 ( 0.0108), Spearman: 0.5464 ( 0.0041)
 VALID  — RMSE: 1.0745 ( 0.0964), MAE: 0.9048 ( 0.0917)
          Pearson: 0.4054 ( 0.0536), Spearman: 0.3747 ( 0.0688)


Epoch 3/5
 TRAIN  — RMSE: 0.9662 ( 0.0371), MAE: 0.8247 ( 0.0336)
          Pearson: 0.6134 ( 0.0146), Spearman: 0.5769 ( 0.0142)
 VALID  — RMSE: 1.0327 ( 0.0539), MAE: 0.8661 ( 0.0566)
          Pearson: 0.4108 ( 0.0609), Spearman: 0.3801 ( 0.0676)


Epoch 4/5
 TRAIN  — RMSE: 0.9927 ( 0.1351), MAE: 0.8428 ( 0.1202)
          Pearson: 0.5991 ( 0.0677), Spearman: 0.5678 ( 0.0606)
 VALID  — RMSE: 1.0723 ( 0.1321), MAE: 0.8947 ( 0.1196)
          Pearson: 0.3879 ( 0.0882), Spearman: 0.3536 ( 0