In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# -*- coding: utf-8 -*-
"""
DTA-TRAIN.ipynb

This script implements a deep learning model to predict protein binding sites
based on structural and sequence-based information.

Architecture Overview:
1.  Inputs: Three matrices - AlphaFold distance map, a PLM-derived matrix (A*A^T),
    and an AlphaFold PAE matrix.
2.  Feature Extraction: Each input is processed by an initial 2D convolution block.
3.  Core Interaction Block (x3): The model contains three repeating blocks. Each block consists of:
    - 2D Convolutions for each of the three feature streams.
    - Cross-attention between the distance map and PLM streams.
4.  Attention Layers: After the core blocks, axial attention is applied separately to the
    distance and PLM feature maps.
5.  Prediction Head:
    - The resulting feature maps are concatenated.
    - A projection layer processes the combined features.
    - Row-wise pooling generates per-residue features.
    - Two separate heads produce the final outputs:
        a) Per-residue binding probability (a vector of 0s and 1s).
        b) A global protein-level binding prediction (a single 0 or 1).
6.  Loss Function: A combined loss is used, incorporating Focal Loss and Dice Loss for the
    per-residue prediction and BCE loss for the global prediction.
"""

import os
import json
import glob
from typing import List, Tuple
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import logging

# Set up basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


# -----------------------------
# 1. Data Loading and Preprocessing
# -----------------------------
class CLAPEDataset(Dataset):
    """
    Dataset to load AlphaFold distance maps, PLM-derived matrices, PAE matrices,
    and corresponding binding site labels.
    """
    def __init__(self, dist_dir: str, plm_dir: str, pae_dir: str,
                 label_csv: str, file_list: List[str] = None, max_len: int = 1200):
        self.dist_dir = dist_dir
        self.plm_dir = plm_dir
        self.pae_dir = pae_dir
        self.max_len = max_len

        # --- Load and process binding site labels ---
        try:
            self.label_df = pd.read_csv(label_csv)
            self.labels = self._group_labels(self.label_df)
            logging.info(f"Successfully loaded and processed labels from {label_csv}")
        except FileNotFoundError:
            logging.error(f"Label file not found at {label_csv}")
            raise

        # --- Collect all available protein identifiers (basenames) ---
        if file_list:
            self.basenames = file_list
        else:
            dist_files = glob.glob(os.path.join(dist_dir, "*_distances.csv"))
            # Extracts 'O00141-F1-model_v4' from 'AF-O00141-F1-model_v4_distances.csv'
            self.basenames = [os.path.basename(p).split('_distances.csv')[0].replace("AF-", "") for p in dist_files]

        logging.info(f"Dataset initialized with {len(self.basenames)} samples.")

    def _group_labels(self, df: pd.DataFrame) -> dict:
        """
        Groups binding site ranges by uniprot_id.
        Example output: {'O00141': [(104, 112), (127, 127)], ...}
        """
        grouped = {}
        for uid, sub_df in df.groupby("uniprot_id"):
            ranges = []
            for _, row in sub_df.iterrows():
                if pd.notna(row["start"]) and pd.notna(row["end"]):
                    try:
                        # Convert to int and ensure start <= end
                        s, e = int(row["start"]), int(row["end"])
                        ranges.append((min(s, e), max(s, e)))
                    except ValueError:
                        logging.warning(f"Could not parse start/end for {uid}: {row['start']}, {row['end']}")
            if ranges:
                grouped[uid] = ranges
        return grouped

    def _make_label_vector(self, uniprot_id: str, n: int) -> np.ndarray:
        """
        Creates a binary vector indicating binding sites for a given protein length.
        Residues in binding regions are marked as 1, others as 0.
        """
        label_vec = np.zeros(n, dtype=np.float32)
        if uniprot_id not in self.labels:
            return label_vec

        for (start, end) in self.labels[uniprot_id]:
            # Convert 1-based start/end from CSV to 0-based index
            s_idx, e_idx = max(0, start - 1), min(n - 1, end - 1)
            if s_idx <= e_idx:
                label_vec[s_idx : e_idx + 1] = 1
        return label_vec

    def __len__(self):
        return len(self.basenames)

    def _read_matrix(self, path: str, delimiter=",") -> np.ndarray:
        """Reads a matrix from file, handling potential errors."""
        if not os.path.exists(path):
            logging.error(f"File not found at path: {path}")
            return None

        try:
            if delimiter == 'json':
                with open(path, 'r') as f:
                    data = json.load(f)
                # Handle PAE JSON structure
                if isinstance(data, list) and data:
                    data = data[0]
                arr = np.array(data.get("predicted_aligned_error", []), dtype=np.float32)
            elif delimiter == ',':
                # Use genfromtxt for CSVs as it is robust to missing/empty values
                arr = np.genfromtxt(path, delimiter=delimiter, dtype=np.float32, filling_values=0.0)
            else:
                 arr = np.loadtxt(path, delimiter=delimiter, dtype=np.float32)

            # Replace any NaN/inf values with 0
            return np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            logging.error(f"Failed to read or parse matrix {path}: {e}")
            return None

    def _pad_or_truncate(self, arr: np.ndarray) -> np.ndarray:
        """Pads or truncates a matrix/vector to self.max_len."""
        if arr is None:
            return None

        if arr.ndim == 2:
            n, _ = arr.shape
            if n > self.max_len:
                return arr[:self.max_len, :self.max_len]
            elif n < self.max_len:
                pad_width = ((0, self.max_len - n), (0, self.max_len - n))
                return np.pad(arr, pad_width, mode='constant', constant_values=0)
        elif arr.ndim == 1:
            n = arr.shape[0]
            if n > self.max_len:
                return arr[:self.max_len]
            elif n < self.max_len:
                pad_width = (0, self.max_len - n)
                return np.pad(arr, pad_width, mode='constant', constant_values=0)
        return arr


    def __getitem__(self, idx: int):
        base = self.basenames[idx]
        uniprot_id = base.split("-")[0]

        # --- Construct file paths robustly ---
        dist_path_list = glob.glob(os.path.join(self.dist_dir, f"AF-{base}*_distances.csv"))
        plm_path_list = glob.glob(os.path.join(self.plm_dir, f"{base}*M_transformed.txt"))
        pae_path_list = glob.glob(os.path.join(self.pae_dir, f"{uniprot_id}*_pae.txt"))

        # --- Check if files were found ---
        if not dist_path_list or not plm_path_list or not pae_path_list:
            logging.warning(f"Skipping sample {base} because at least one file was not found.")
            return (torch.zeros(1, self.max_len, self.max_len),
                    torch.zeros(1, self.max_len, self.max_len),
                    torch.zeros(1, self.max_len, self.max_len),
                    torch.zeros(self.max_len),
                    torch.tensor(0.0), base + "_error")

        # --- Load data ---
        dist = self._read_matrix(dist_path_list[0], delimiter=",")
        plm = self._read_matrix(plm_path_list[0], delimiter=" ")
        pae = self._read_matrix(pae_path_list[0], delimiter="json")

        if dist is None or plm is None or pae is None:
            logging.error(f"Skipping sample {base} due to a file reading error.")
            # Return a dummy sample to avoid crashing the DataLoader
            return (torch.zeros(1, self.max_len, self.max_len),
                    torch.zeros(1, self.max_len, self.max_len),
                    torch.zeros(1, self.max_len, self.max_len),
                    torch.zeros(self.max_len),
                    torch.tensor(0.0), base + "_error")


        # --- Unify matrix sizes ---
        n = dist.shape[0]
        if not (plm.shape[0] == n and pae.shape[0] == n):
            logging.warning(f"Matrix size mismatch for {base}. dist:{dist.shape}, plm:{plm.shape}, pae:{pae.shape}. Truncating to smallest dimension.")
            min_n = min(n, plm.shape[0], pae.shape[0])
            dist, plm, pae = dist[:min_n, :min_n], plm[:min_n, :min_n], pae[:min_n, :min_n]
            n = min_n

        # --- Create label vector ---
        label_vec = self._make_label_vector(uniprot_id, n)
        global_label = float(label_vec.sum() > 0)

        # --- Pad all matrices and labels to max_len ---
        dist_padded = self._pad_or_truncate(dist)
        plm_padded = self._pad_or_truncate(plm)
        pae_padded = self._pad_or_truncate(pae)
        label_vec_padded = self._pad_or_truncate(label_vec)

        # --- Convert to Tensors and add channel dimension ---
        dist_tensor = torch.from_numpy(dist_padded).float().unsqueeze(0)
        plm_tensor = torch.from_numpy(plm_padded).float().unsqueeze(0)
        pae_tensor = torch.from_numpy(pae_padded).float().unsqueeze(0)
        label_tensor = torch.from_numpy(label_vec_padded).float()

        return (
            dist_tensor,      # [1, max_len, max_len]
            plm_tensor,       # [1, max_len, max_len]
            pae_tensor,       # [1, max_len, max_len]
            label_tensor,     # [max_len]
            torch.tensor(global_label, dtype=torch.float32),
            base
        )


# -----------------------------
# 2. Model Building Blocks
# -----------------------------
class ConvBlock(nn.Module):
    """A block of two 2D convolutions with batch norm and ReLU."""
    def __init__(self, in_ch, out_ch, kernel_size=3, padding=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class CrossAttention2D(nn.Module):
    """Cross-attention between two 2D feature maps (e.g., dist_map and plm_map)."""
    def __init__(self, channels: int, num_heads: int = 8):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True, dropout=0.1)
        self.norm_q = nn.LayerNorm(channels)
        self.norm_kv = nn.LayerNorm(channels)
    def forward(self, query, key_value):
        B, C, H, W = query.shape
        q = query.flatten(2).permute(0, 2, 1)    # (B, H*W, C)
        kv = key_value.flatten(2).permute(0, 2, 1) # (B, H*W, C)
        q_norm = self.norm_q(q)
        kv_norm = self.norm_kv(kv)
        attn_out, _ = self.mha(q_norm, kv_norm, kv_norm)
        return attn_out.permute(0, 2, 1).view(B, C, H, W)

class AxialAttention(nn.Module):
    """Axial attention applying self-attention first along rows, then columns."""
    def __init__(self, channels: int, num_heads: int = 8):
        super().__init__()
        self.row_attn = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True)
        self.col_attn = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True)
        self.norm = nn.LayerNorm(channels)
    def forward(self, x):
        B, C, H, W = x.shape
        # Row attention
        x_rows = x.permute(0, 2, 3, 1).reshape(B * H, W, C) # Treat each row as a sequence
        x_rows = self.norm(x_rows)
        out_rows, _ = self.row_attn(x_rows, x_rows, x_rows)
        out_rows = out_rows.reshape(B, H, W, C).permute(0, 3, 1, 2)
        # Column attention
        x_cols = x.permute(0, 3, 2, 1).reshape(B * W, H, C) # Treat each column as a sequence
        x_cols = self.norm(x_cols)
        out_cols, _ = self.col_attn(x_cols, x_cols, x_cols)
        out_cols = out_cols.reshape(B, W, H, C).permute(0, 3, 2, 1)
        return out_rows + out_cols


# -----------------------------
# 3. Main Model Architecture
# -----------------------------
class BindingAffinityNet(nn.Module):
    def __init__(self, base_channels: int = 16, num_heads: int = 4, K: int = 5):
        super().__init__()
        self.K = K  # Store hyperparameter K as requested

        # --- Initial Encoders ---
        self.conv_dist = ConvBlock(1, base_channels)
        self.conv_plm = ConvBlock(1, base_channels)
        self.conv_pae = ConvBlock(1, base_channels)

        # --- 3x Repeated Interaction Blocks ---
        self.repeats = nn.ModuleList()
        for _ in range(3):
            block = nn.ModuleDict({
                'conv_d': ConvBlock(base_channels, base_channels),
                'conv_p': ConvBlock(base_channels, base_channels),
                'cross_dp': CrossAttention2D(base_channels, num_heads),
                'cross_pd': CrossAttention2D(base_channels, num_heads),
            })
            self.repeats.append(block)

        # --- Final Attention Layers ---
        self.axial_dist = AxialAttention(base_channels, num_heads)
        self.axial_plm = AxialAttention(base_channels, num_heads)

        # --- Prediction Head ---
        # Combines the two streams
        final_channels = base_channels * 2
        self.proj = nn.Conv2d(final_channels, final_channels, kernel_size=1)

        # Head for per-residue prediction
        self.head_residue = nn.Sequential(
            nn.Conv1d(final_channels, base_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(base_channels, 1, kernel_size=1),
        )
        # Head for global protein prediction
        self.head_global = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(final_channels, base_channels),
            nn.ReLU(),
            nn.Linear(base_channels, 1),
        )

    def forward(self, dist, plm, pae):
        # Initial encoding
        d = self.conv_dist(dist)
        p = self.conv_plm(plm)
        # PAE is used as an auxiliary feature map, added to dist and plm
        pae_feats = self.conv_pae(pae)
        d, p = d + pae_feats, p + pae_feats

        # 3x Interaction Blocks
        for block in self.repeats:
            d_res, p_res = d, p
            d, p = block['conv_d'](d), block['conv_p'](p)
            d_att = block['cross_dp'](d, p) # dist queries plm
            p_att = block['cross_pd'](p, d) # plm queries dist
            d, p = d + d_att, p + p_att
            # Add residuals
            d, p = d + d_res, p + p_res

        # Final Axial Attention
        d_ax = self.axial_dist(d)
        p_ax = self.axial_plm(p)

        # Concatenate and project for prediction
        concat = torch.cat([d_ax, p_ax], dim=1)
        x = F.relu(self.proj(concat))

        # --- Generate Outputs ---
        # 1. Per-residue prediction via row-wise pooling
        row_features = x.mean(dim=3)  # Avg pool across columns -> (B, C, H)
        residue_logits = self.head_residue(row_features).squeeze(1) # (B, H)

        # 2. Global prediction
        global_logit = self.head_global(x).squeeze(1) # (B,)

        return residue_logits, global_logit


# -----------------------------
# 4. Loss Functions
# -----------------------------
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha, self.gamma, self.reduction = alpha, gamma, reduction
    def forward(self, logits, targets):
        bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        focal_loss = alpha_t * (1 - pt)**self.gamma * bce_loss
        return focal_loss.mean() if self.reduction == 'mean' else focal_loss.sum()

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        intersection = (probs * targets).sum(dim=1)
        union = probs.sum(dim=1) + targets.sum(dim=1)
        dice_score = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1. - dice_score.mean()

class CombinedLoss(nn.Module):
    """Combines Focal, Dice, and BCE losses for the two model outputs."""
    def __init__(self, w_focal=0.7, w_dice=0.3, w_global=0.5):
        super().__init__()
        self.focal_loss = FocalLoss()
        self.dice_loss = DiceLoss()
        self.global_loss_fn = nn.BCEWithLogitsLoss()
        self.w_focal, self.w_dice, self.w_global = w_focal, w_dice, w_global

    def forward(self, residue_logits, global_logits, residue_targets, global_targets):
        focal = self.focal_loss(residue_logits, residue_targets)
        dice = self.dice_loss(residue_logits, residue_targets)
        residue_loss = self.w_focal * focal + self.w_dice * dice
        global_loss = self.global_loss_fn(global_logits, global_targets)
        total_loss = residue_loss + self.w_global * global_loss

        stats = {'loss': total_loss.item(), 'focal': focal.item(), 'dice': dice.item(), 'global_bce': global_loss.item()}
        return total_loss, stats


# -----------------------------
# 5. Training and Evaluation
# -----------------------------
def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for i, batch in enumerate(dataloader):
        dist, plm, pae, res_labels, glob_labels, _ = batch
        # Filter out error samples
        if dist.nelement() == 0: continue

        # Move batch to device
        dist, plm, pae = dist.to(device), plm.to(device), pae.to(device)
        res_labels, glob_labels = res_labels.to(device), glob_labels.to(device)

        optimizer.zero_grad()
        residue_logits, global_logits = model(dist, plm, pae)
        loss, stats = loss_fn(residue_logits, global_logits, res_labels, glob_labels)

        if torch.isnan(loss):
            logging.warning(f"NaN loss detected at batch {i}. Skipping update.")
            continue

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if i % 20 == 0:
            logging.info(f"Batch {i}/{len(dataloader)} - {stats}")

    return total_loss / len(dataloader)


def verify_labelling(dataset: CLAPEDataset, uniprot_id: str, protein_length: int):
    """
    A sanity-check function to verify the label creation for a specific protein.
    """
    print("\n" + "="*50)
    print(f"--- Verifying Labels for UniProt ID: {uniprot_id} ---")

    # 1. Find the original ranges from the dataframe
    original_ranges = dataset.labels.get(uniprot_id)
    if not original_ranges:
        print(f"No binding site entries found for {uniprot_id} in the CSV.")
        return

    print(f"Found {len(original_ranges)} binding site range(s) in CSV (1-based): {original_ranges}")

    # 2. Create the label vector
    label_vector = dataset._make_label_vector(uniprot_id, protein_length)

    # 3. Find the indices that are marked as '1'
    labeled_indices = np.where(label_vector == 1)[0]

    # Convert 0-based indices back to 1-based for easy comparison
    labeled_indices_1_based = labeled_indices + 1

    print(f"Protein Length Used: {protein_length}")
    print(f"Generated Label Vector has {len(labeled_indices)} '1's.")
    print(f"Indices labeled as binding sites (1-based):")
    print(labeled_indices_1_based.tolist())
    print("="*50 + "\n")


# -----------------------------
# 6. Main Execution Block
# -----------------------------
if __name__ == '__main__':
    # --- Configuration ---
    # NOTE: Adjust this path to your Google Drive folder structure
    DRIVE_PATH = '/content/drive/MyDrive/CLAPE-RESULTS'

    # Check if paths exist
    if not os.path.exists(DRIVE_PATH):
        raise FileNotFoundError(f"The base directory {DRIVE_PATH} was not found. "
                              "Please mount your Google Drive or update the path.")

    DIST_DIR = os.path.join(DRIVE_PATH, 'af-distancemaps')
    PLM_DIR = os.path.join(DRIVE_PATH, 'transformed_matrices')
    PAE_DIR = os.path.join(DRIVE_PATH, 'PAE-MATRICES')
    LABEL_CSV = os.path.join(DRIVE_PATH, 'binding_sites_uniprot.csv')

    BATCH_SIZE = 2
    MAX_LEN = 1024 # Max sequence length for padding/truncating
    LEARNING_RATE = 1e-4
    EPOCHS = 10

    # --- Data Setup ---
    logging.info("Setting up dataset...")
    dataset = CLAPEDataset(
        dist_dir=DIST_DIR,
        plm_dir=PLM_DIR,
        pae_dir=PAE_DIR,
        label_csv=LABEL_CSV,
        max_len=MAX_LEN
    )
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

    # --- VERIFY LABELLING ---
    # You can change 'O00141' to any UniProt ID from your dataset
    # and provide its approximate length to check the labels.
    verify_labelling(dataset, uniprot_id='O00141', protein_length=360)
    verify_labelling(dataset, uniprot_id='O14920', protein_length=470)


    # --- Model Setup ---
    logging.info("Setting up model...")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = BindingAffinityNet(base_channels=16, num_heads=4, K=5).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_fn = CombinedLoss()

    logging.info(f"Starting training on {device} for {EPOCHS} epochs.")

    # --- Training Loop ---
    for epoch in range(EPOCHS):
        avg_loss = train_one_epoch(model, dataloader, optimizer, loss_fn, device)
        logging.info(f"--- Epoch {epoch+1}/{EPOCHS} | Average Loss: {avg_loss:.4f} ---")

        # TODO: Add a validation loop here
        # TODO: Save model checkpoints

    logging.info("Training finished.")

    # --- Example of making a prediction ---
    logging.info("Running a sample prediction...")
    model.eval()
    with torch.no_grad():
        # Get a single batch for inference
        try:
            sample_batch = next(iter(dataloader))
            dist, plm, pae, res_labels, _, names = sample_batch

            if dist.nelement() > 0:
                dist, plm, pae = dist.to(device), plm.to(device), pae.to(device)

                res_logits, glob_logits = model(dist, plm, pae)

                # Get probabilities
                res_probs = torch.sigmoid(res_logits)
                glob_probs = torch.sigmoid(glob_logits)

                print(f"\nSample Prediction for protein: {names[0]}")
                print(f"Global binding probability: {glob_probs[0].item():.3f}")
                # Find residue indices with high predicted probability
                high_prob_indices = (res_probs[0] > 0.5).nonzero(as_tuple=True)[0]
                print(f"Residues with >50% binding probability: {high_prob_indices.cpu().numpy().tolist()}")
                print(f"Actual binding sites (first few): {(res_labels[0] > 0.5).nonzero(as_tuple=True)[0].numpy().tolist()[:15]}")
        except StopIteration:
            logging.warning("Could not get a sample batch for prediction, the dataloader is empty.")




--- Verifying Labels for UniProt ID: O00141 ---
Found 2 binding site range(s) in CSV (1-based): [(104, 112), (127, 127)]
Protein Length Used: 360
Generated Label Vector has 10 '1's.
Indices labeled as binding sites (1-based):
[104, 105, 106, 107, 108, 109, 110, 111, 112, 127]


--- Verifying Labels for UniProt ID: O14920 ---
Found 2 binding site range(s) in CSV (1-based): [(21, 29), (44, 44)]
Protein Length Used: 470
Generated Label Vector has 10 '1's.
Indices labeled as binding sites (1-based):
[21, 22, 23, 24, 25, 26, 27, 28, 29, 44]



