In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error
from scipy.stats   import pearsonr, spearmanr


batch_size = 8
learning_rate = 3e-5
n_embd = 768
n_head = 8
n_layer = 4
dropout = 0.2
epochs = 3
device = 'cuda' if torch.cuda.is_available() else 'cpu'


def evaluate(y_true, y_pred):
    """Compute and print MAE, RMSE, Pearson & Spearman correlations."""
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    pearson_corr, _ = pearsonr(y_true, y_pred)
    spearman_corr, _ = spearmanr(y_true, y_pred)

    return mae, rmse, pearson_corr, spearman_corr


class MolProtDataset(Dataset):
    def __init__(self, dataframe, mol_emb, prot_emb):
        self.df = dataframe
        self.mol_emb = mol_emb
        self.prot_emb = prot_emb

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        mol_id = int(row['MoleculeIdx'])
        prot_id = int(row['ProteinIdx'])
        target = float(row['Ki'])
        mol = torch.tensor(self.mol_emb[mol_id], dtype=torch.float32)
        prot = torch.tensor(self.prot_emb[prot_id], dtype=torch.float32)
        return mol, prot, torch.tensor(target, dtype=torch.float32)


def pad_collate(batch):
    mols, prots, targets = zip(*batch)

    mol_lengths = [m.shape[0] for m in mols]
    mol_max_len = max(mol_lengths)
    mols_padded = [torch.cat([m, m.new_zeros(mol_max_len - m.shape[0], m.shape[1])], dim=0) for m in mols]
    mols_batch = torch.stack(mols_padded)

    prot_lengths = [p.shape[0] for p in prots]
    prot_max_len = max(prot_lengths)
    prots_padded = [torch.cat([p, p.new_zeros(prot_max_len - p.shape[0], p.shape[1])], dim=0) for p in prots]
    prots_batch = torch.stack(prots_padded)

    targets_batch = torch.stack(targets)

    return mols_batch, prots_batch, targets_batch

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        k, q, v = self.key(x), self.query(x), self.value(x)
        wei = q @ k.transpose(-2, -1) * (k.shape[-1] ** -0.5)

        padding_mask = (x.abs().sum(dim=-1) > 1e-5).float()
        attn_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2)
        wei = wei.masked_fill(attn_mask == 0, -1e6)

        wei = torch.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        return wei @ v

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads * head_size, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.dropout(self.proj(out))

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class Model_1(nn.Module):
    def __init__(self):
        super().__init__()
        self.protein_proj = nn.Linear(960, n_embd)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, n_embd))
        self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.reg_head = nn.Linear(n_embd, 1)

    def forward(self, mol, prot, target=None):
        prot_emb = self.protein_proj(prot)
        cls_emb = self.cls_token.expand(mol.size(0), -1, -1)
        x = torch.cat([cls_emb, mol, prot_emb], dim=1)
        x = self.blocks(x)
        x = self.ln_f(x)
        cls_out = x[:, 0]
        prediction = self.reg_head(cls_out).squeeze(-1)
        loss = None
        if target is not None:
            loss = torch.sqrt(nn.functional.mse_loss(prediction, target))
        return loss, prediction

# ------------- Load embeddings -------------------
molecule_embeddings = np.load('/content/molecules_all_embeddings.npy',allow_pickle=True)
protein_embeddings = np.load('/content/proteins_all_embeddings.npy',allow_pickle=True)


fold_loaders = []
for fold in range(1, 6):
    train_df = pd.read_csv(f'/content/train_{fold}.csv')
    test_df  = pd.read_csv(f'/content/test_{fold}.csv')

    train_ds = MolProtDataset(train_df, molecule_embeddings, protein_embeddings)
    test_ds  = MolProtDataset(test_df,  molecule_embeddings, protein_embeddings)

    train_loader = DataLoader(train_ds, batch_size=batch_size,
                              shuffle=True, collate_fn=pad_collate)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size,
                              shuffle=False, collate_fn=pad_collate)

    fold_loaders.append((train_loader, test_loader))

fold_models     = [Model_1().to(device) for _ in range(5)]
fold_optimizers = [torch.optim.AdamW(m.parameters(), lr=learning_rate)
                   for m in fold_models]

for epoch in range(1, epochs+1):

    fold_train_losses, fold_val_losses = [], []
    fold_train_maes,   fold_val_maes   = [], []
    fold_train_ps,     fold_val_ps     = [], []
    fold_train_ss,     fold_val_ss     = [], []

    for (train_loader, test_loader), model, opt in zip(fold_loaders,
                                                      fold_models,
                                                      fold_optimizers):
        # ---- TRAIN ----
        model.train()
        train_loss = 0.0
        for mol, prot, target in train_loader:
            mol, prot, target = mol.to(device), prot.to(device), target.to(device)
            opt.zero_grad()
            loss, _ = model(mol, prot, target)
            loss.backward()
            opt.step()
            train_loss += loss.item() * mol.size(0)
        train_loss /= len(train_loader.dataset)

        # ---- TRAIN METRICS ----
        model.eval()
        y_t_train, y_p_train = [], []
        with torch.no_grad():
            for mol, prot, target in train_loader:
                mol, prot, target = mol.to(device), prot.to(device), target.to(device)
                _, preds = model(mol, prot, target)
                y_t_train.extend(target.cpu().numpy())
                y_p_train.extend(preds.cpu().numpy())
        t_mae, t_rmse, t_p, t_s = evaluate(y_t_train, y_p_train)

        # ---- VALID & METRICS ----
        y_t_val, y_p_val = [], []
        with torch.no_grad():
            for mol, prot, target in test_loader:
                mol, prot, target = mol.to(device), prot.to(device), target.to(device)
                _, preds = model(mol, prot, target)
                y_t_val.extend(target.cpu().numpy())
                y_p_val.extend(preds.cpu().numpy())
        v_mae, v_rmse, v_p, v_s = evaluate(y_t_val, y_p_val)

        fold_train_losses.append(t_rmse)
        fold_train_maes.append(t_mae)
        fold_train_ps.append(t_p)
        fold_train_ss.append(t_s)

        fold_val_losses.append(v_rmse)
        fold_val_maes.append(v_mae)
        fold_val_ps.append(v_p)
        fold_val_ss.append(v_s)

    # ---- AVERAGE ACROSS FOLDS ----
    print(f"\nEpoch {epoch}/{epochs}")
    print(f" TRAIN  — RMSE: {np.mean(fold_train_losses):.4f} ( {np.std(fold_train_losses):.4f}), "
          f"MAE: {np.mean(fold_train_maes):.4f} ( {np.std(fold_train_maes):.4f})")
    print(f"          Pearson: {np.mean(fold_train_ps):.4f} ( {np.std(fold_train_ps):.4f}), "
          f"Spearman: {np.mean(fold_train_ss):.4f} ( {np.std(fold_train_ss):.4f})")
    print(f" VALID  — RMSE: {np.mean(fold_val_losses):.4f} ( {np.std(fold_val_losses):.4f}), "
          f"MAE: {np.mean(fold_val_maes):.4f} ( {np.std(fold_val_maes):.4f})")
    print(f"          Pearson: {np.mean(fold_val_ps):.4f} ( {np.std(fold_val_ps):.4f}), "
          f"Spearman: {np.mean(fold_val_ss):.4f} ( {np.std(fold_val_ss):.4f})\n")



Epoch 1/3
 TRAIN  — RMSE: 0.7672 ( 0.0230), MAE: 0.6258 ( 0.0310)
          Pearson: 0.6530 ( 0.0060), Spearman: 0.6039 ( 0.0128)
 VALID  — RMSE: 0.8886 ( 0.0315), MAE: 0.7207 ( 0.0408)
          Pearson: 0.4562 ( 0.0368), Spearman: 0.4147 ( 0.0357)


Epoch 2/3
 TRAIN  — RMSE: 0.7074 ( 0.0352), MAE: 0.5635 ( 0.0413)
          Pearson: 0.7058 ( 0.0064), Spearman: 0.6631 ( 0.0075)
 VALID  — RMSE: 0.8735 ( 0.0225), MAE: 0.6943 ( 0.0315)
          Pearson: 0.4826 ( 0.0328), Spearman: 0.4232 ( 0.0521)


Epoch 3/3
 TRAIN  — RMSE: 0.7071 ( 0.0519), MAE: 0.5717 ( 0.0596)
          Pearson: 0.7414 ( 0.0040), Spearman: 0.7017 ( 0.0044)
 VALID  — RMSE: 0.8971 ( 0.0628), MAE: 0.7153 ( 0.0446)
          Pearson: 0.4957 ( 0.0496), Spearman: 0.4439 ( 0.0638)

