In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
import json
import glob
from typing import List, Tuple
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np


# -----------------------------
# Data utilities
# -----------------------------
class CLAPEDataset(Dataset):
    def __init__(self, dist_dir: str, plm_dir: str, pae_dir: str,
                 label_csv: str, file_list: List[str] = None):
        self.dist_dir = dist_dir
        self.plm_dir = plm_dir
        self.pae_dir = pae_dir

        # --- Load and index labels ---
        df = pd.read_csv(label_csv, sep="\t" if label_csv.endswith(".tsv") else ",")
        self.labels = self._group_labels(df)

        # collect basenames (same as before)
        dist_files = glob.glob(os.path.join(dist_dir, "*.csv"))
        basenames = [os.path.basename(p).split("_distances")[0].replace("AF-", "") for p in dist_files]
        self.basenames = file_list or basenames

    def _group_labels(self, df: pd.DataFrame) -> dict:
        """
        Converts binding_sites_uniprot.csv into {uniprot_id: [(start,end),...]} dictionary.
        Skips any rows with missing or invalid start/end values.
        """
        grouped = {}
        for uid, sub in df.groupby("uniprot_id"):
            ranges = []
            for _, row in sub.iterrows():
                # Skip rows with missing or non-numeric positions
                if pd.isna(row["start"]) or pd.isna(row["end"]):
                    continue
                try:
                    s, e = int(row["start"]), int(row["end"])
                except Exception:
                    continue
                if e < s:
                    s, e = e, s  # just in case of inverted ranges
                ranges.append((s, e))
            if ranges:
                grouped[uid] = ranges
        return grouped


    def _make_label_vector(self, uniprot_id: str, n: int) -> np.ndarray:
        """
        Create a binary vector of length n, with 1s where residues are binding sites.
        """
        label_vec = np.zeros(n, dtype=np.float32)
        if uniprot_id not in self.labels:
            return label_vec
        for (s, e) in self.labels[uniprot_id]:
            s = max(s - 1, 0)  # convert 1-based → 0-based indexing
            e = min(e - 1, n - 1)
            label_vec[s:e+1] = 1
        return label_vec

    def __len__(self):
        return len(self.basenames)

    def _read_distance(self, path: str) -> np.ndarray:
        # genfromtxt handles missing or non-numeric entries gracefully
        arr = np.genfromtxt(path, delimiter=",", filling_values=np.nan)
        # Replace NaN or inf with 0
        arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
        return arr
    def _read_plm(self, path: str) -> np.ndarray:
        return np.loadtxt(path)

    def _read_pae(self, path: str) -> np.ndarray:
        """
        Reads Alphafold PAE JSON files that may be:
        - a dict with key "predicted_aligned_error", or
        - a list of dicts each containing that key.
        Returns a numeric NxN numpy array.
        """
        with open(path, "r") as f:
            txt = f.read()

        try:
            obj = json.loads(txt)
        except Exception:
            # fallback if slightly malformed JSON (e.g., single quotes)
            obj = eval(txt)

        # handle both cases
        if isinstance(obj, list):
            # take the first dict if it's a list of dicts
            if len(obj) > 0 and isinstance(obj[0], dict) and "predicted_aligned_error" in obj[0]:
                pae = obj[0]["predicted_aligned_error"]
            else:
                raise ValueError(f"Unexpected PAE list structure in {path}")
        elif isinstance(obj, dict):
            pae = obj.get("predicted_aligned_error", None)
            if pae is None:
                raise ValueError(f"No 'predicted_aligned_error' key found in dict for {path}")
        else:
            raise ValueError(f"Unexpected PAE object type: {type(obj)} in {path}")

        arr = np.array(pae, dtype=float)
        arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
        return arr

    def __getitem__(self, idx: int):
        base = self.basenames[idx]
        uniprot_id = base.split("-")[0]
        # --- Distance map ---
        dist_path = glob.glob(os.path.join(self.dist_dir, f"AF-{base}*distances.csv")) \
                    or glob.glob(os.path.join(self.dist_dir, f"{base}*distances.csv"))
        if not dist_path:
            raise FileNotFoundError(f"No distance map found for base {base} in {self.dist_dir}")
        dist_path = dist_path[0]

        # --- PLM matrix ---
        plm_path = (
            glob.glob(os.path.join(self.plm_dir, f"{base}*_M_transformed.txt"))
            or glob.glob(os.path.join(self.plm_dir, f"{base.replace('AF-', '')}*_M_transformed.txt"))
            or glob.glob(os.path.join(self.plm_dir, f"{uniprot_id}-F1-model_v4_M_transformed.txt"))  # <-- fixed
        )

        if not plm_path:
            raise FileNotFoundError(f"No PLM file found for base {base} (uniprot_id={uniprot_id}) in {self.plm_dir}")
        plm_path = plm_path[0]
        # --- PAE matrix ---
        pae_path = glob.glob(os.path.join(self.pae_dir, f"{base.split('-')[0]}*_pae.txt")) \
                  or glob.glob(os.path.join(self.pae_dir, f"{base.replace('AF-', '')}*_pae.txt")) \
                  or glob.glob(os.path.join(self.pae_dir, f"*{uniprot_id}*_pae.txt"))
        if not pae_path:
            raise FileNotFoundError(f"No PAE file found for base {base} in {self.pae_dir}")
        pae_path = pae_path[0]

        # Load data
        dist = self._read_distance(dist_path)
        plm = self._read_plm(plm_path)
        pae = self._read_pae(pae_path)

        # checks and normalization
        n = dist.shape[0]
        # Ensure all matrices are square and same size
        for name, arr in zip(["dist", "plm", "pae"], [dist, plm, pae]):
            if arr.shape[0] != arr.shape[1]:
                raise ValueError(f"{name} matrix not square for {base}")
            if arr.shape[0] != n:
                # Handle off-by-one or small mismatches
                diff = n - arr.shape[0]
                if abs(diff) <= 2:
                    if diff > 0:
                        # pad with zeros
                        arr = np.pad(arr, ((0, diff), (0, diff)), constant_values=0.0)
                    else:
                        # trim extra rows/cols
                        arr = arr[:n, :n]
                else:
                    raise ValueError(
                        f"{name} matrix shape mismatch for {base}: expected {n}x{n}, got {arr.shape}"
                    )
            # update after potential padding/trimming
            if name == "plm":
                plm = arr
            elif name == "pae":
                pae = arr
            elif name == "dist":
                dist = arr


        dist = (dist - dist.mean()) / (dist.std() + 1e-6)
        plm = (plm - plm.mean()) / (plm.std() + 1e-6)
        pae = (pae - pae.mean()) / (pae.std() + 1e-6)

        n = dist.shape[0]
        label_vec = self._make_label_vector(uniprot_id, n)
        global_label = float(label_vec.sum() > 0)
        max_len = 1200
        # --- Pad matrices to fixed max_len ---
        # --- Pad label vector to same max_len ---
        if len(label_vec) < max_len:
            pad = max_len - len(label_vec)
            label_vec = np.pad(label_vec, (0, pad), constant_values=0)
        elif len(label_vec) > max_len:
            label_vec = label_vec[:max_len]

          # pick a value that fits your largest protein
        def pad_to(t, size=max_len):
            n = t.shape[-1]
            if n < size:
                pad = size - n
                return F.pad(t, (0, pad, 0, pad), value=0.0)
            elif n > size:
                return t[:, :size, :size]
            else:
                return t
        label_vec = torch.tensor(label_vec, dtype=torch.float32)


        dist = pad_to(torch.tensor(dist, dtype=torch.float32).unsqueeze(0))
        plm  = pad_to(torch.tensor(plm,  dtype=torch.float32).unsqueeze(0))
        pae  = pad_to(torch.tensor(pae,  dtype=torch.float32).unsqueeze(0))


        return (
            dist,                      # [1, max_len, max_len]
            plm,                       # [1, max_len, max_len]
            pae,                       # [1, max_len, max_len]
            label_vec,  # [max_len] residue-level label
            torch.tensor(global_label, dtype=torch.float32),  # scalar global label
            base
        )



# -----------------------------
# Model building blocks
# -----------------------------

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, padding=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class CrossAttention2D(nn.Module):
    """
    Cross attention between two 2D maps. We flatten spatial dims to sequence and use MultiheadAttention.
    Input shapes: (B, C, H, W) and (B, C, H, W) with same H=W=n.
    """

    def __init__(self, channels: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.channels = channels
        self.mha = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.norm_q = nn.LayerNorm(channels)
        self.norm_kv = nn.LayerNorm(channels)
        self.proj = nn.Sequential(nn.Linear(channels, channels), nn.ReLU())

    def forward(self, x_q, x_kv):
        # x: (B, C, n, n) -> (B, n*n, C)
        B, C, H, W = x_q.shape
        seq_len = H * W
        q = x_q.view(B, C, seq_len).permute(0, 2, 1)
        kv = x_kv.view(B, C, seq_len).permute(0, 2, 1)

        q = self.norm_q(q)
        kv = self.norm_kv(kv)

        attn_out, _ = self.mha(q, kv, kv)
        out = self.proj(attn_out)
        out = out.permute(0, 2, 1).view(B, C, H, W)
        return out


class AxialAttention(nn.Module):
    """
    Applies attention along rows then along columns (axial attention).
    Input shape: (B, C, n, n)
    """

    def __init__(self, channels: int, num_heads: int = 8):
        super().__init__()
        self.row_attn = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True)
        self.col_attn = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True)
        self.ln1 = nn.LayerNorm(channels)
        self.ln2 = nn.LayerNorm(channels)

    def forward(self, x):
        B, C, H, W = x.shape
        # row-wise: treat each row as sequence of length W, with H separate sequences -> combine batch and H
        x_r = x.permute(0, 2, 3, 1).contiguous()  # (B, H, W, C)
        x_r = x_r.view(B * H, W, C)
        x_r_ln = self.ln1(x_r)
        out_r, _ = self.row_attn(x_r_ln, x_r_ln, x_r_ln)
        out_r = out_r.view(B, H, W, C).permute(0, 3, 1, 2)

        # column-wise: treat each column as sequence of length H
        x_c = x.permute(0, 3, 2, 1).contiguous()  # (B, W, H, C)
        x_c = x_c.view(B * W, H, C)
        x_c_ln = self.ln2(x_c)
        out_c, _ = self.col_attn(x_c_ln, x_c_ln, x_c_ln)
        out_c = out_c.view(B, W, H, C).permute(0, 3, 2, 1)

        # combine
        return out_r + out_c


# -----------------------------
# Full architecture
# -----------------------------
class BindingAffinityNet(nn.Module):
    def __init__(self, base_channels: int = 32, num_heads: int = 8, K: int = 3):
        """
        K is the hyperparameter requested by the user. It will be stored in the model
        so downstream code can use it for predicting number of binding regions later.
        """
        super().__init__()
        self.K = K

        # initial conv encoders for distance and plm-derived matrix
        self.conv_dist = ConvBlock(1, base_channels)
        self.conv_plm = ConvBlock(1, base_channels)
        self.conv_pae = ConvBlock(1, base_channels)

        # three repeated blocks: convs then cross-attention both ways
        self.repeats = nn.ModuleList()
        for _ in range(3):
            block = nn.ModuleDict({
                'conv_d': ConvBlock(base_channels, base_channels),
                'conv_p': ConvBlock(base_channels, base_channels),
                'conv_pa': ConvBlock(base_channels, base_channels),
                'cross_dp': CrossAttention2D(base_channels, num_heads),  # dist attends to plm
                'cross_pd': CrossAttention2D(base_channels, num_heads),  # plm attends to dist
            })
            self.repeats.append(block)

        # axial attention for final representations
        self.axial_dist = AxialAttention(base_channels, num_heads)
        self.axial_plm = AxialAttention(base_channels, num_heads)

        # projection and classification head
        self.proj = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 2, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_channels * 2),
        )

        # row-wise pooling (we'll pool across columns to get per-residue vectors of length n)
        self.head_mask = nn.Sequential(
            nn.Conv1d(base_channels * 2, base_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(base_channels, 1, kernel_size=1),
        )

        # second output: a global binary label (0/1) for protein-level property
        self.head_global = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(base_channels * 2, base_channels),
            nn.ReLU(),
            nn.Linear(base_channels, 1),
        )

    def forward(self, dist, plm, pae):
        # dist/plm/pae: (B, 1, n, n)
        d = self.conv_dist(dist)
        p = self.conv_plm(plm)
        pa = self.conv_pae(pae)

        # three repeats
        for block in self.repeats:
            d = block['conv_d'](d)
            p = block['conv_p'](p)
            pa = block['conv_pa'](pa)

            # cross attention between dist and plm (both directions)
            d_att = block['cross_dp'](d, p)
            p_att = block['cross_pd'](p, d)

            # residual style
            d = d + d_att
            p = p + p_att
            # optionally mix pae as auxiliary skip
            pa = pa + 0.5 * (d_att + p_att)

        # axial attention on dist and plm
        d_ax = self.axial_dist(d)
        p_ax = self.axial_plm(p)

        # concat and project
        concat = torch.cat([d_ax, p_ax], dim=1)
        x = self.proj(concat)

        # row-wise pooling: for each residue i (row i) we pool across columns
        # x shape (B, C, n, n) -> permute to (B, C, n, n) then treat columns as sequence for conv1d
        B, C, n, _ = x.shape
        row_vectors = x.mean(dim=3)  # average across columns -> (B, C, n)

        mask_logits = self.head_mask(row_vectors)  # expects (B, C, n) -> output (B, 1, n)
        mask_logits = mask_logits.squeeze(1)  # (B, n)

        global_logit = self.head_global(x)  # (B, 1)

        return mask_logits, global_logit.squeeze(1)


# -----------------------------
# Losses
# -----------------------------
class FocalLoss(nn.Module):
    # Binary focal loss for logits
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        # logits: (B, n) or (B,)  targets: same shape 0/1
        prob = torch.sigmoid(logits)
        p_t = prob * targets + (1 - prob) * (1 - targets)
        alpha_factor = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        focal_weight = alpha_factor * ((1 - p_t) ** self.gamma)
        bce = F.binary_cross_entropy_with_logits(logits, targets.float(), reduction='none')
        loss = focal_weight * bce
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss


class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, logits, targets):
        prob = torch.sigmoid(logits)
        num = 2 * (prob * targets).sum(dim=-1)
        den = prob.sum(dim=-1) + targets.sum(dim=-1)
        loss = 1 - (num + self.eps) / (den + self.eps)
        return loss.mean()


# combined loss wrapper
class CombinedLoss(nn.Module):
    def __init__(self, weight_mask=1.0, weight_global=1.0, alpha=0.25, gamma=2.0):
        super().__init__()
        self.focal = FocalLoss(alpha=alpha, gamma=gamma)
        self.dice = DiceLoss()
        self.bce = nn.BCEWithLogitsLoss()
        self.w_mask = weight_mask
        self.w_global = weight_global

    def forward(self, mask_logits, mask_targets, global_logits, global_targets):
        # mask_targets: (B, n)  global_targets: (B,)  both 0/1
        focal_loss = self.focal(mask_logits, mask_targets)
        dice_loss = self.dice(mask_logits, mask_targets)
        mask_loss = 0.7 * focal_loss + 0.3 * dice_loss  # semantic + focal blended

        global_loss = self.bce(global_logits, global_targets.float())

        return self.w_mask * mask_loss + self.w_global * global_loss, {
            'mask_focal': focal_loss.item(),
            'mask_dice': dice_loss.item(),
            'global_bce': global_loss.item(),
        }


# -----------------------------
# Example training step (skeleton)
# -----------------------------

def training_step(model: nn.Module, batch, optimizer, loss_fn, device='cpu'):
    model.train()
    dist, plm, pae, mask_targets, global_targets, names = batch

    dist = dist.to(device)
    plm = plm.to(device)
    pae = pae.to(device)
    mask_targets = mask_targets.to(device)
    global_targets = global_targets.to(device)

    optimizer.zero_grad()
    mask_logits, global_logits = model(dist, plm, pae)
    loss, stats = loss_fn(mask_logits, mask_targets, global_logits, global_targets)
    loss.backward()
    optimizer.step()

    return loss.item(), stats



# -----------------------------
# How to instantiate and run
# -----------------------------


if __name__ == '__main__':
    base_path = '/content/drive/MyDrive/CLAPE-RESULTS'
    ds = CLAPEDataset(
        dist_dir=os.path.join(base_path, 'af-distancemaps'),
        plm_dir=os.path.join(base_path, 'transformed_matrices'),
        pae_dir=os.path.join(base_path, 'PAE-MATRICES'),
        label_csv=os.path.join(base_path, 'binding_sites_uniprot.csv')
    )
    dl = DataLoader(ds, batch_size=2, shuffle=True, num_workers=0)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = BindingAffinityNet(base_channels=32, num_heads=8, K=5).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = CombinedLoss()

    for i, batch in enumerate(dl):
        loss, stats = training_step(model, batch, opt, loss_fn, device)
        print(f"iter {i} loss={loss:.4f}", stats)
        if i >= 5:
            break

TEST

In [3]:
base_path = '/content/drive/MyDrive/CLAPE-RESULTS'

ds = CLAPEDataset(
    dist_dir=os.path.join(base_path, 'af-distancemaps'),
    plm_dir=os.path.join(base_path, 'transformed_matrices'),
    pae_dir=os.path.join(base_path, 'PAE-MATRICES'),
    label_csv=os.path.join(base_path, 'binding_sites_uniprot.csv'),
    file_list=['P16234-F1-model_v4']

sample = ds[0]
for x in sample:
    if isinstance(x, torch.Tensor):
        print(x.shape)
    else:
        print(x)


torch.Size([1, 1200, 1200])
torch.Size([1, 1200, 1200])
torch.Size([1, 1200, 1200])
torch.Size([1200])
torch.Size([])
P16234-F1-model_v4


In [4]:
dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0)
batch = next(iter(dl))
for b in batch:
    if isinstance(b, torch.Tensor):
        print(b.shape)
    else:
        print(b)


torch.Size([1, 1, 1200, 1200])
torch.Size([1, 1, 1200, 1200])
torch.Size([1, 1, 1200, 1200])
torch.Size([1, 1200])
torch.Size([1])
('P16234-F1-model_v4',)


NameError: name 'dist' is not defined

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BindingAffinityNet(base_channels=8, num_heads=2, K=3).to(device)
dist, plm, pae, mask, global_label, name = [b.to(device) if isinstance(b, torch.Tensor) else b for b in batch]
print(dist.shape, plm.shape, pae.shape)
print(torch.isnan(dist).any(), torch.isinf(dist).any())
print(torch.isnan(plm).any(), torch.isinf(plm).any())
print(torch.isnan(pae).any(), torch.isinf(pae).any())
print(dist.device,next(model.parameters()).device)

dist_small = dist[:, :, :20, :20]
plm_small = plm[:, :, :20, :20]
pae_small = pae[:, :, :20, :20]

with torch.no_grad():
    mask_logits, global_logits = model(dist_small, plm_small, pae_small)



torch.Size([1, 1, 1200, 1200]) torch.Size([1, 1, 1200, 1200]) torch.Size([1, 1, 1200, 1200])
tensor(False) tensor(False)
tensor(False) tensor(False)
tensor(False) tensor(False)
cpu cpu


In [13]:
import importlib
import glob
importlib.reload(glob)


<module 'glob' from '/usr/lib/python3.12/glob.py'>

In [17]:
# -----------------------------
# Minimal test training loop
# -----------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Instantiate model and loss
model = BindingAffinityNet(base_channels=8, num_heads=2, K=3).to(device)
loss_fn = CombinedLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Use your small batch (dist_small etc.)
dist_small = dist[:, :, :20, :20]
plm_small = plm[:, :, :20, :20]
pae_small = pae[:, :, :20, :20]
n_small = dist_small.shape[-1]
mask_small = mask[:, :n_small]        # ensure same size
global_small = global_label

dist_small = dist_small.to(device)
plm_small = plm_small.to(device)
pae_small = pae_small.to(device)
mask_small = mask_small.to(device)
global_small = global_small.to(device)


# Run a few steps
for step in range(5):
    model.train()
    optimizer.zero_grad()
    mask_logits, global_logits = model(dist_small, plm_small, pae_small)
    loss, stats = loss_fn(mask_logits, mask_small, global_logits, global_small)
    loss.backward()
    optimizer.step()

    print(f"Step {step} | Loss: {loss.item():.4f} | "
          f"mask_focal: {stats['mask_focal']:.4f}, "
          f"mask_dice: {stats['mask_dice']:.4f}, "
          f"global_bce: {stats['global_bce']:.4f}")


Step 0 | Loss: 1.0479 | mask_focal: 0.0908, mask_dice: 1.0000, global_bce: 0.6843
Step 1 | Loss: 1.0416 | mask_focal: 0.0844, mask_dice: 1.0000, global_bce: 0.6825
Step 2 | Loss: 1.0376 | mask_focal: 0.0814, mask_dice: 1.0000, global_bce: 0.6806
Step 3 | Loss: 1.0345 | mask_focal: 0.0796, mask_dice: 1.0000, global_bce: 0.6788
Step 4 | Loss: 1.0326 | mask_focal: 0.0796, mask_dice: 1.0000, global_bce: 0.6769
