<a href="https://colab.research.google.com/github/aidanmck650/Carbapenemase-ProteinMPNN-Investigation/blob/main/CNN_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install captum

In [None]:
import h5py
import numpy as np
import dask.array as da
from sklearn.model_selection import GroupShuffleSplit
from sklearn.model_selection import StratifiedGroupKFold, StratifiedShuffleSplit
from torch.utils.data import TensorDataset, DataLoader, Subset
import torch
import torch.nn as nn
import torch.optim as optim


In [None]:
import numpy as np, h5py, dask.array as da

data_path = '/content/drive/MyDrive/organised_data_new.h5'
ambler_min, ambler_max = 29, 287  # inclusive

# temp holders per-sample
samples_X   = []   # list of arrays (n_frames, 20, L_i)
samples_A   = []   # list of 1D arrays of Ambler bases kept in that sample (len L_i)
samples_y   = []
samples_grp = []

with h5py.File(data_path, 'r') as f:
    enzyme_names = list(f.keys())

    for enzyme_id, enzyme in enumerate(f.keys()):
        label = int(f[enzyme].attrs['carbapenemase'])

        for run in f[enzyme].keys():
            grp = f[enzyme][run]
            probs_ds = grp['probs']                       # (n_frames, 20, n_res)
            n_frames, _, n_res = probs_ds.shape
            probs = da.from_array(probs_ds, chunks=(1000, 20, n_res))

            amb = grp['ambler_idx'][:]                    # (n_res,)

            ins = grp['ambler_ins_rank'][:] if 'ambler_ins_rank' in grp else None

            # window filter
            in_window = (amb >= ambler_min) & (amb <= ambler_max)

            if ins is not None:
                # canonical only
                keep_mask = in_window & (ins == 0)
                amb_cols  = amb[keep_mask].astype(int)
                X_crop    = probs[:, :, keep_mask].compute()
            else:
                # no insertion ranks available
                idxs = np.where(in_window)[0]
                # stable “first occurrence per base” in sequence order
                seen = set()
                keep_idx = []
                for i in idxs:
                    if int(amb[i]) not in seen:
                        seen.add(int(amb[i]))
                        keep_idx.append(i)
                keep_idx = np.array(keep_idx, dtype=int)
                amb_cols = amb[keep_idx].astype(int)
                X_crop   = probs[:, :, keep_idx].compute()

            samples_X.append(X_crop)
            samples_A.append(amb_cols)
            samples_y.append(np.full(n_frames, label, dtype=np.int8))
            samples_grp.append(np.full(n_frames, enzyme_id, dtype=np.int32))

# align columns across samples (make seq_len identical)
# master set = intersection of Ambler bases present in every sample
A_master = samples_A[0].copy()
for A_i in samples_A[1:]:
    A_master = np.intersect1d(A_master, A_i, assume_unique=False)
# ensure sorted ascending
A_master = np.sort(A_master)


# slice each sample to A_master in the same order
X_list, y_list, g_list = [], [], []
for X_i, A_i, y_i, g_i in zip(samples_X, samples_A, samples_y, samples_grp):
    # map base -> column index in this sample
    idx_map = {int(a): j for j, a in enumerate(A_i)}
    col_idx = np.array([idx_map[a] for a in A_master], dtype=int)
    X_list.append(X_i[:, :, col_idx])
    y_list.append(y_i)
    g_list.append(g_i)

# final tensors
X = np.concatenate(X_list, axis=0)            # (total_frames, 20, L)
A = A_master                                  # (L,) Ambler bases for columns
y = np.concatenate(y_list, axis=0)
groups = np.concatenate(g_list, axis=0)

print("Final dataset:", X.shape, y.shape, groups.shape)


Final dataset: (120000, 20, 252) (120000,) (120000,)


In [None]:
X_t = torch.from_numpy(X).float()
y_t = torch.from_numpy(y).long()
dataset = TensorDataset(X_t, y_t)

In [None]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
if device.type == "cuda":
    torch.cuda.manual_seed_all(42)

class SmallCNN(nn.Module):
    def __init__(self, n_channels=20, n_classes=2):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv1d(n_channels, 32, kernel_size=7, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),   # [B, 128, 1]
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),              # [B, 128]
            nn.Dropout(0.7),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.7),
            nn.Linear(64, n_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [None]:
import torch
import torch.nn as nn

seq_len = X.shape[-1]

class CNNTransformerSimple(nn.Module):
    def __init__(self, n_channels=20, d_model=64, nhead=4, n_classes=2, seq_len=seq_len):
        super().__init__()
        # CNN to project 20 to d_model features per residue
        self.conv_block = nn.Sequential(
            nn.Conv1d(n_channels, d_model, kernel_size=7, padding=3),
            nn.BatchNorm1d(d_model),
            nn.ReLU(inplace=True),
            nn.Conv1d(d_model, d_model, kernel_size=5, padding=2),
            nn.BatchNorm1d(d_model),
            nn.ReLU(inplace=True),
        )
        # Learnable positional embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, d_model))

        # Single-layer transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model*2,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)

        # Global pooling + classifier
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),  # put seq_dim to 1
            nn.Flatten(),             # [B, d_model]
            nn.Dropout(0.3),
            nn.Linear(d_model, d_model//2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(d_model//2, n_classes)
        )

    def forward(self, x):
        # CNN block
        h = self.conv_block(x)
        # transpose for transformer: [B, seq_len, d_model]
        h = h.permute(0,2,1) + self.pos_embed
        # Transformer
        h = self.transformer(h)           # [B, seq_len, d_model]
        # back to [B, d_model, seq_len] and classify
        h = h.permute(0,2,1)
        return self.classifier(h)         # [B, n_classes]


In [None]:
for test_eid in range(8):
  print(f"→ Testing on enzyme {test_eid}: {enzyme_names[test_eid]}")
  # split frame‐indices
  train_idx = np.where(groups != test_eid)[0]
  test_idx  = np.where(groups == test_eid)[0]

  # build DataLoaders
  BATCH = 256
  train_loader = DataLoader(Subset(dataset, train_idx),
                            batch_size=BATCH, shuffle=True)
  test_loader  = DataLoader(Subset(dataset, test_idx),
                            batch_size=BATCH)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  # set seed for reproducibility
  torch.manual_seed(0)
  if device.type == "cuda":
      torch.cuda.manual_seed_all(0)

  # lossf = nn.CrossEntropyLoss()

  model = CNNTransformerSimple().to(device)
  # model = SmallCNN().to(device)

  criterion = nn.CrossEntropyLoss()
  opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

  EPOCHS = 10
  for epoch in range(1, EPOCHS+1):
      model.train()
      acc_sum = loss_sum = n = 0
      for xb, yb in train_loader:
          xb, yb = xb.to(device), yb.to(device)
          opt.zero_grad()
          logits = model(xb)
          loss   = criterion(logits, yb)
          loss.backward()
          opt.step()

          loss_sum += loss.item() * xb.size(0)
          acc_sum  += (logits.argmax(1) == yb).sum().item()
          n       += xb.size(0)

      # print(f"Epoch {epoch:02d}  "
      #       f"Train Loss: {loss_sum/n:.4f}, "
      #       f"Train Acc:  {acc_sum/n:.4f}")

  # final test eval
  model.eval()
  correct = total = 0
  with torch.no_grad():
      for xb, yb in test_loader:
          xb, yb = xb.to(device), yb.to(device)
          preds = model(xb).argmax(1)
          correct += (preds == yb).sum().item()
          total   += xb.size(0)

  print(f"Test on {enzyme_names[test_eid]}  Acc: {correct/total:.4f}")

NameError: name 'enzyme_names' is not defined

In [None]:
from torch.utils.data import DataLoader

# Wrap entire dataset
full_loader = DataLoader(dataset, batch_size=256, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# set seed for reproducibility
torch.manual_seed(0)
if device.type == "cuda":
    torch.cuda.manual_seed_all(0)

model = CNNTransformerSimple().to(device)
# model = SmallCNN().to(device)
lossf = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Train for N epochs
EPOCHS = 10
for epoch in range(1, EPOCHS+1):
    model.train()
    total_loss = total_correct = total = 0
    for xb, yb in full_loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad()
        logits = model(xb)
        loss   = lossf(logits, yb)
        loss.backward()
        opt.step()

        total_loss   += loss.item() * xb.size(0)
        total_correct+= (logits.argmax(1) == yb).sum().item()
        total        += xb.size(0)

    print(f"Epoch {epoch:02d} | Loss: {total_loss/total:.4f} | Acc: {total_correct/total:.4f}")

Epoch 01 | Loss: 0.0251 | Acc: 0.9925
Epoch 02 | Loss: 0.0003 | Acc: 0.9999
Epoch 03 | Loss: 0.0001 | Acc: 1.0000
Epoch 04 | Loss: 0.0000 | Acc: 1.0000
Epoch 05 | Loss: 0.0000 | Acc: 1.0000
Epoch 06 | Loss: 0.0000 | Acc: 1.0000
Epoch 07 | Loss: 0.0000 | Acc: 1.0000
Epoch 08 | Loss: 0.0000 | Acc: 1.0000
Epoch 09 | Loss: 0.0000 | Acc: 1.0000
Epoch 10 | Loss: 0.0000 | Acc: 1.0000


In [None]:
from captum.attr import (
    Saliency,
    IntegratedGradients,
    NoiseTunnel,
    DeepLift,
    LayerGradCam
)

In [None]:
model.eval()

# Prepare data loader
dataset = TensorDataset(X_t, y_t)
loader  = DataLoader(dataset, batch_size=256, shuffle=True)

# Captum attributions
saliency    = Saliency(model)
ig          = IntegratedGradients(model)
nt          = NoiseTunnel(saliency)
deeplift    = DeepLift(model)

# Grad‑CAM on the second conv layer
target_layer = model.conv_block[3]
gradcam      = LayerGradCam(model, target_layer)

# baseline for IG/DeepLIFT: uniform distribution
baseline = torch.zeros(1, 20, seq_len, device=device) + (1.0/20.0)

# Compute attributions
methods = {
    'saliency': lambda inp: saliency.attribute(inp, target=1),
    'smooth':   lambda inp: nt.attribute(inp, target=1,
                        nt_type='smoothgrad', stdevs=0.02, nt_samples=25),
    'ig':       lambda inp: ig.attribute(inp, baselines=baseline,
                        target=1, n_steps=100),
    'deeplift': lambda inp: deeplift.attribute(inp, baselines=baseline, target=1),
    'gradcam':  lambda inp: gradcam.attribute(inp, target=1)
}

scores = {name: [] for name in methods}
max_frames = 500
count = 0

for X_batch, _ in loader:
    X_batch = X_batch.to(device)
    for frame in X_batch:
        inp = frame.unsqueeze(0).requires_grad_(True)
        for name, fn in methods.items():
            attr = fn(inp)
            if name == 'gradcam':
                # [1,d_model,seq_len] mean over d_model
                res = attr.mean(dim=1).squeeze(0)
            else:
                # [1,20,seq_len] sum over AAs
                res = attr.abs().sum(dim=1).squeeze(0)
            scores[name].append(res.cpu().detach())
        count += 1
        if count >= max_frames:
            break
    if count >= max_frames:
        break

# Average and normalise top residues
top_k = 10
for name, arr in scores.items():
    stacked = torch.stack(arr)               # [max_frames, seq_len]
    avg_imp = stacked.mean(dim=0)
    norm_imp = (avg_imp - avg_imp.min()) / (avg_imp.max() - avg_imp.min() + 1e-8)
    vals, idxs = torch.topk(norm_imp, k=top_k)
    print(f"\n{name.upper()} Top Ambler residues:", A[idxs.cpu().numpy()].tolist())
    print(f"{name.upper()} Scores: {vals.tolist()}")


SALIENCY Top Ambler residues (0‑based): [238, 239, 241, 240, 237, 236, 242, 235, 243, 234]
SALIENCY Scores: [1.0, 0.9371141195297241, 0.811948835849762, 0.7939862608909607, 0.6581747531890869, 0.6459794640541077, 0.6450086832046509, 0.6336749196052551, 0.554413914680481, 0.5061284899711609]

SMOOTH Top Ambler residues (0‑based): [238, 239, 241, 240, 237, 242, 236, 235, 243, 234]
SMOOTH Scores: [1.0, 0.9396723508834839, 0.8151755332946777, 0.7927833199501038, 0.6589153409004211, 0.6503751277923584, 0.6456523537635803, 0.6320083141326904, 0.557887852191925, 0.5089046359062195]

IG Top Ambler residues (0‑based): [241, 238, 240, 236, 239, 244, 242, 235, 60, 237]
IG Scores: [1.0, 0.9314240217208862, 0.9175805449485779, 0.5742488503456116, 0.41367408633232117, 0.363450288772583, 0.36319872736930847, 0.2897765338420868, 0.2771315276622772, 0.2668672800064087]

DEEPLIFT Top Ambler residues (0‑based): [240, 241, 238, 236, 239, 242, 244, 235, 243, 237]
DEEPLIFT Scores: [1.0, 0.9810715317726135,

-----------------------