### Vision Transformer
##### This notebook contains all modules for training and evaluating the ViT model.

In [None]:
import os
import torch
from bconformer import embed, utils
import numpy as np
import math
import sys
import gc
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F


from torch.utils.data import Dataset, DataLoader
from Bio import SeqIO
from Bio.PDB import PDBParser, is_aa, Polypeptide
from Bio.PDB.NeighborSearch import NeighborSearch
from Bio.PDB.Selection import unfold_entities 
from bconformer.embed import Alphabet
from typing import List
from tqdm import tqdm
from functools import partial
from torch.nn.init import trunc_normal_
from sklearn.metrics import (
    roc_auc_score, average_precision_score, matthews_corrcoef,
    f1_score, precision_score, recall_score, accuracy_score,
    brier_score_loss, log_loss
)
from typing import Iterable, Optional
from timm.layers import DropPath
from timm.data import Mixup
from timm.utils import accuracy, ModelEma
from ptflops import get_model_complexity_info

In [None]:
three_to_one_dict = {
    'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E',
    'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
    'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N',
    'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S',
    'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
    'SEC': 'U', 'PYL': 'O', 'ASX': 'B', 'GLX': 'Z', 'UNK': 'X'
}

In [None]:
fasta_files = "..." # directory containing training (or evaluating) fastas
pdb_files = "..." # directory containing training (or evaluating) pdbs

num_fasta = len([f for f in os.listdir(fasta_files) if f.endswith('.fasta')])
num_pdb = len([f for f in os.listdir(pdb_files) if f.endswith('.pdb')])
num_fasta, num_pdb

### 1. Data

In [4]:
def parse_chains_from_fasta_name(fasta_name):
    base = fasta_name.replace('.fasta', '')
    parts = base.split('_')
    ag_idx = parts.index('ag')
    ab_idx = parts.index('ab')
    antigen_chains = parts[ag_idx+1:ab_idx]
    antibody_chains = parts[ab_idx+1:]
    return antigen_chains, antibody_chains

In [5]:
def get_atoms(chains):
    return [atom for chain in chains for atom in unfold_entities(chain, 'A') if atom.element != 'H']

In [6]:
def get_epitope_labels(antigen_chain_objs, antibody_chain_objs):
    antibody_atoms = get_atoms(antibody_chain_objs)
    ns = NeighborSearch(antibody_atoms)
    epitope_residues = set()

    for chain in antigen_chain_objs:
        for res in chain.get_residues():
            if not is_aa(res):
                continue
            for atom in res:
                if ns.search(atom.coord, 4):
                    epitope_residues.add((chain.id, res.id))
                    break

    labels = []
    for chain in antigen_chain_objs:
        for res in chain.get_residues():
            if not is_aa(res):
                continue
            label_val = 1 if (chain.id, res.id) in epitope_residues else 0
            labels.append(label_val)
    return torch.tensor(labels, dtype=torch.long)

In [7]:
def esm_embed_sequences(sequences, model, alphabet, device):
    embeddings = []
    for seq in sequences:
        batch = alphabet.get_batch_converter()([("protein", seq)])
        batch_labels, batch_strs, batch_tokens = batch
        batch_tokens = batch_tokens.to(device)
        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[33], return_contacts=False)
        token_embeddings = results["representations"][33]
        # Remove BOS and EOS tokens
        seq_embedding = token_embeddings[0, 1:-1].cpu()
        embeddings.append(seq_embedding)
    return torch.cat(embeddings, dim=0)

In [8]:
class EpitopeDataset(Dataset):
    def __init__(self, fasta_dir, pdb_dir, esm_model, esm_alphabet, device):
        TOTAL_ANTIGEN_CHAINS = 0
        TOTAL_ANTIBODY_CHAINS = 0
        
        self.fasta_dir = fasta_dir
        self.pdb_dir = pdb_dir
        self.esm_model = esm_model
        self.esm_alphabet = esm_alphabet
        self.device = device

        self.fasta_files = sorted([f for f in os.listdir(fasta_dir) if f.endswith('.fasta')])
        self.pdb_files = sorted([f for f in os.listdir(pdb_dir) if f.endswith('.pdb')])

        self.antigen_len_cache = {}

        total_ag = 0
        total_ab = 0
        for fasta_file in self.fasta_files:
            ag_chains, ab_chains = parse_chains_from_fasta_name(fasta_file)
            total_ag += len(ag_chains)
            total_ab += len(ab_chains)

        EpitopeDataset.TOTAL_ANTIGEN_CHAINS = total_ag
        EpitopeDataset.TOTAL_ANTIBODY_CHAINS = total_ab

    def __len__(self):
        return len(self.fasta_files)

    def __getitem__(self, idx):
        if idx in self.antigen_len_cache:
            antigen_length = self.antigen_len_cache[idx]
        else:
            antigen_length = None

        fasta_name = self.fasta_files[idx]
        fasta_id = os.path.splitext(fasta_name)[0]

        matched_pdb_file = None
        for f in self.pdb_files:
            if fasta_id in f:
                matched_pdb_file = os.path.join(self.pdb_dir, f)
                break

        if matched_pdb_file is None:
            raise ValueError(f"No matching pdb file found for {fasta_name}")

        antigen_chains, antibody_chains = parse_chains_from_fasta_name(fasta_name)

        parser = PDBParser(QUIET=True)
        structure = parser.get_structure("protein", matched_pdb_file)
        model = structure[0]

        sorted_chain_ids = sorted([chain.id for chain in model])
        assert len(sorted_chain_ids) == len(antigen_chains) + len(antibody_chains)

        antigen_chain_ids = sorted_chain_ids[:len(antigen_chains)]
        antibody_chain_ids = sorted_chain_ids[len(antigen_chains):]

        antigen_chains_objs = [model[c] for c in antigen_chain_ids]
        antibody_chains_objs = [model[c] for c in antibody_chain_ids]

        # Antigen length
        if antigen_length is None:
            length = 0
            for chain in antigen_chains_objs:
                for residue in chain.get_residues():
                    if is_aa(residue):
                        length += 1
            self.antigen_len_cache[idx] = length
            antigen_length = length

        # Antigen sequence
        antigen_sequences = []
        for chain in antigen_chains_objs:
            seq = ""
            for residue in chain.get_residues():
                if is_aa(residue):
                    try:
                        resname = residue.get_resname()
                        aa = three_to_one_dict.get(resname, 'X')
                        seq += aa
                    except KeyError:
                        continue
            antigen_sequences.append(seq)

        embedding = esm_embed_sequences(
            antigen_sequences, self.esm_model, self.esm_alphabet, self.device
        )

        labels = get_epitope_labels(antigen_chains_objs, antibody_chains_objs)
        mask = torch.ones(labels.shape[0], dtype=torch.bool)

        return {
            'embedding': embedding,
            'labels': labels,
            'mask': mask,
            'antigen_length': antigen_length
        }

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm_model = esm_model.to(device)
esm_model.eval()

dataset = EpitopeDataset(fasta_files, pdb_files, esm_model, esm_alphabet, device)

print("Number of antigen chains:", EpitopeDataset.TOTAL_ANTIGEN_CHAINS)
print("Number of antibody chains:", EpitopeDataset.TOTAL_ANTIBODY_CHAINS)

In [10]:
max_seq_len = 1024

def collate_fn_padding(batch):
    batch_embeddings = []
    batch_labels = []
    batch_masks = []
    attn_masks = []

    for item in batch:
        L = item['embedding'].shape[0]
        pad_len = max_seq_len - L
        if pad_len < 0:
            continue

        embedding = F.pad(item['embedding'], (0, 0, 0, pad_len), value=0)
        labels = F.pad(item['labels'], (0, pad_len), value=-100)
        mask = F.pad(item['mask'], (0, pad_len), value=0)
        attn_mask = torch.cat([torch.ones(L), torch.zeros(pad_len)])

        batch_embeddings.append(embedding)
        batch_labels.append(labels)
        batch_masks.append(mask)
        attn_masks.append(attn_mask)

    batch_embeddings = torch.stack(batch_embeddings)
    batch_labels = torch.stack(batch_labels)
    batch_masks = torch.stack(batch_masks)
    attn_masks = torch.stack(attn_masks)

    return {
        "embedding": batch_embeddings,
        "labels": batch_labels,
        "mask": batch_masks,
        "attention_mask": attn_masks
    }


In [None]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn_padding)

for i, batch in enumerate(dataloader):
    embedding = batch["embedding"]        # shape: [B, max_len, 1280]
    labels = batch["labels"]              # shape: [B, max_len]
    mask = batch["mask"]                  # shape: [B, max_len]
    attention_mask = batch["attention_mask"]  # shape: [B, max_len]

    print(f"Embedding shape: {embedding.shape}")
    print(f"Labels shape: {labels.shape}")
    print(f"Mask shape: {mask.shape}")
    print(f"Attention mask shape: {attention_mask.shape}")


    print("\n=== A sequence sample ===")
    print(f"Embedding shape: {embedding[0].shape}")  # [max_len, 1280]
    print(f"Embedding:\n{embedding[0]}")
    print(f"Labels:\n{labels[0]}")
    print(f"Mask:\n{mask[0]}")

    break

### 2. Model

In [12]:
class PatchEmbedding1D(nn.Module):
    def __init__(self, in_dim=1280, patch_size=1, embed_dim=1536, seq_len=1024):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = seq_len // patch_size
        self.proj = nn.Linear(in_dim * patch_size, embed_dim)

    def forward(self, x):
        B, L, C = x.shape  # [B, 1024, 1280]
        # 切patch, reshape成 [B, num_patches, patch_size * C]
        x = x.unfold(dimension=1, size=self.patch_size, step=self.patch_size)  # [B, num_patches, patch_size, C]
        x = x.contiguous().view(B, self.num_patches, -1)  # flatten patch: [B, num_patches, patch_size * C]
        x = self.proj(x)  # [B, num_patches, embed_dim]
        return x

In [13]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [14]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=6, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)  # each [B, heads, N, head_dim]

        attn = (q @ k.transpose(-2, -1)) * self.scale  # [B, heads, N, N]
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # [B, N, C]
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [15]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim, num_heads=6, mlp_ratio=2., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads, qkv_bias, attn_drop, drop)
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(dim, mlp_hidden_dim, drop=drop)
        self.drop_path = nn.Identity()

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

In [16]:
class ViT_1D(nn.Module):
    def __init__(self, in_dim=1280, seq_len=1024, patch_size=1, embed_dim=1536,
                 depth=9, num_heads=12, num_classes=2, mlp_ratio=2.):
        super().__init__()
        self.patch_embed = PatchEmbedding1D(in_dim, patch_size, embed_dim, seq_len)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, seq_len // patch_size + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=0.)

        self.blocks = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, mlp_ratio)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        nn.init.trunc_normal_(self.cls_token, std=.02)
        nn.init.trunc_normal_(self.head.weight, std=.02)
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)  # [B, num_patches, embed_dim]
        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B, 1, embed_dim]
        x = torch.cat((cls_tokens, x), dim=1)  # [B, 1 + num_patches, embed_dim]
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)  # [B, 1 + num_patches, embed_dim]

        token_features = x[:, 1:]  # Remove cls token: [B, num_patches, embed_dim]
        out = self.head(token_features)  # [B, num_patches, num_classes]
        out = out.permute(0, 2, 1)       # [B, num_classes, num_patches]
        return out


### 3. Train and Evaluate

In [17]:
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler=None, max_norm: float = 0,
                    model_ema: Optional[object] = None, mixup_fn=None,
                    set_training_mode=True):
    model.train(set_training_mode)
    if hasattr(criterion, 'train'):
        criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = f"Epoch: [{epoch}]"
    print_freq = 10

    for batch in metric_logger.log_every(data_loader, print_freq, header):
        samples = batch['embedding'].to(device)
        targets = batch['labels'].to(device)
        mask = batch['mask'].to(device).bool()
        # samples = samples.transpose(1, 2)
        
        with torch.cuda.amp.autocast(dtype=torch.float32):
            output = model(samples)  # [B, num_classes, L]
            loss = sequence_loss(output, targets, mask)

        loss_value = loss.item()
        if not math.isfinite(loss_value):
            sys.exit(1)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    metric_logger.synchronize_between_processes()
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

In [18]:
@torch.no_grad()
def evaluate(data_loader, model, device, threshold=0.3):
    model.eval()
    true_positives = 0
    union_positives = 0

    all_probs = []
    all_preds = []
    all_targets = []
    sample_ious = []

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = "Eval:"
    print_freq = 10
    metric_logger.add_meter("mean_iou", utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    for batch in metric_logger.log_every(data_loader, print_freq, header):
        samples = batch['embedding'].to(device)
        targets = batch['labels'].to(device)
        mask = batch['mask'].to(device).bool()
        samples = samples.transpose(1, 2)

        with torch.cuda.amp.autocast():
            output = model(samples)
            probs = torch.softmax(output, dim=1)[:, 1, :]
            preds = (probs > threshold).long()
            preds = preds.masked_fill(~mask, 0)

            for i in range(samples.shape[0]):
                pred_i = preds[i]
                target_i = targets[i]
                mask_i = mask[i]

                tp_i = ((pred_i == 1) & (target_i == 1) & mask_i).sum().item()
                union_i = (((pred_i == 1) | (target_i == 1)) & mask_i).sum().item()
                iou_i = tp_i / union_i if union_i > 0 else 0.0
                sample_ious.append(iou_i)

            tp = ((preds == 1) & (targets == 1) & mask).sum().item()
            union = (((preds == 1) | (targets == 1)) & mask).sum().item()
            true_positives += tp
            union_positives += union

            all_probs.append(probs[mask].cpu())
            all_preds.append(preds[mask].cpu())
            all_targets.append(targets[mask].cpu())

            mean_iou_so_far = sum(sample_ious) / len(sample_ious)
            metric_logger.update(mean_iou=mean_iou_so_far)

    metric_logger.synchronize_between_processes()

    all_probs = torch.cat(all_probs).numpy()
    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()

    agiou = true_positives / union_positives if union_positives > 0 else 0.0

    try:
        auc = roc_auc_score(all_targets, all_probs)
    except:
        auc = float('nan')

    try:
        pr_auc = average_precision_score(all_targets, all_probs)
    except:
        pr_auc = float('nan')

    try:
        pcc = np.corrcoef(all_probs, all_targets)[0, 1]
    except:
        pcc = float('nan')

    try:
        brier = brier_score_loss(all_targets, all_probs)
    except:
        brier = float('nan')

    try:
        bce = log_loss(all_targets, all_probs, labels=[0, 1])
    except:
        bce = float('nan')

    results = {
        "AgIoU": round(agiou, 4),
        "Precision": round(precision_score(all_targets, all_preds, zero_division=0), 4),
        "Recall": round(recall_score(all_targets, all_preds, zero_division=0), 4),
        "F1": round(f1_score(all_targets, all_preds, zero_division=0), 4),
        "MCC": round(matthews_corrcoef(all_targets, all_preds), 4),
        "Accuracy": round(accuracy_score(all_targets, all_preds), 4),
        "AUC": round(auc, 4),
        "PR-AUC": round(pr_auc, 4),
        "PCC": round(pcc, 4),
        "Brier": round(brier, 4),
        "BCE": round(bce, 4)
    }

    print("\nEvaluation Results:")
    for k, v in results.items():
        print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

    return results, sample_ious

In [19]:
def sequence_loss(pred, target, mask):
    """
    pred: [B, C, L]
    target: [B, L]
    mask: [B, L] (bool)
    """
    B, C, L = pred.shape
    pred = pred.transpose(1, 2).reshape(-1, C)      # [B*L, C]
    target = target.reshape(-1)                     # [B*L]
    mask = mask.reshape(-1)                         # [B*L], bool

    loss = F.cross_entropy(pred, target, reduction='none')  # [B*L]
    loss = loss[mask].mean()  # only valid positions
    return loss

In [20]:
def criterion(output, target, mask):
    return sequence_loss(output, target, mask)

### 3.1 Train

#### Model

In [21]:
# ViT-9
model = ViT_1D(depth=9)
# ViT-12
# model = ViT_1D(depth=12)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

input_shape = (1024, 1280)
with torch.cuda.device(0):
    macs, params = get_model_complexity_info(model, input_shape, as_strings=True,
                                             print_per_layer_stat=False, verbose=False)

print(f"Params: {params}")
print(f"MACs: {macs}")

Params: 173.57 M
MACs: 176.32 GMac


In [22]:
# training on pretrained models
# model_name = "..." # model name
# model_path = os.path.join("...", model_name) # directory + model name
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = Conformer(in_chans=1280, num_classes=2)
# state = torch.load(model_path, map_location=device, weights_only=False)
# model.load_state_dict(state['model_state_dict'])
# model.to(device)

#### Training with checkpoints

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler()
threshold = 0.3
epochs = 150
all_metrics = []

save_dir = "..." # directory saving models
os.makedirs(save_dir, exist_ok=True)

for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    train_stats = train_one_epoch(model, criterion, dataloader, optimizer, device, epoch, scaler)
    val_stats, _ = evaluate(dataloader, model, device, threshold)
    
    all_metrics.append(val_stats)
    
    print(f"Train loss: {train_stats['loss']:.4f}\n")
    
    if 50 <= epoch + 1 <= 150:
        save_path = os.path.join(save_dir, f"model_epoch{epoch+1}_AgIoU{val_stats['AgIoU']:.4f}.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'all_metrics': all_metrics,
        }, save_path)

In [None]:
# save training process (metrics at each epoch) to a csv.
df = pd.DataFrame(all_metrics)
df.insert(0, "Epoch", range(1, len(df) + 1))

df.to_csv("....csv", index=False)
print("Successfully saved.")

### 3.2 Evaluate

In [None]:
# model_name = "..." # model name
# model_path = os.path.join("...", model_name) # checkpoints directory + model name
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = Conformer(in_chans=1280, num_classes=2)
# state = torch.load(model_path, map_location=device, weights_only=False)
# model.load_state_dict(state['model_state_dict'])
# model.to(device)
# thresholds = np.linspace(0.28, 0.32, 40)
# collected_metrics = {}

# for threshold in thresholds:
#     with torch.no_grad():
#         metrics, _ = evaluate_get_sample_iou(dataloader, model, device, threshold)
#         for k, v in metrics.items():
#             collected_metrics.setdefault(k, []).append(v)

# # metrics (mean ± std)
# results_summary = {}
# for k, v_list in collected_metrics.items():
#     v_array = np.array(v_list)
#     mean = np.mean(v_array)
#     std = np.std(v_array)
#     results_summary[k] = f"{mean:.3f} ± {std:.3f}"

# for k, v in results_summary.items():
#     print(f"{k}: {v}")