### Conformer
##### This notebook contains all modules for training and evaluating the Conformer model.

In [None]:
import os
from bconformer import utils
import numpy as np
import math
import sys
import torch
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F


from torch.utils.data import Dataset, DataLoader
from Bio.PDB import PDBParser, is_aa
from Bio.PDB.NeighborSearch import NeighborSearch
from Bio.PDB.Selection import unfold_entities
from functools import partial
from torch.nn.init import trunc_normal_
from sklearn.metrics import (
    roc_auc_score, average_precision_score, matthews_corrcoef,
    f1_score, precision_score, recall_score, accuracy_score,
    brier_score_loss, log_loss
)
from typing import Iterable, Optional
from timm.layers import DropPath
from ptflops import get_model_complexity_info

In [None]:
three_to_one_dict = {
    'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E',
    'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
    'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N',
    'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S',
    'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
    'SEC': 'U', 'PYL': 'O', 'ASX': 'B', 'GLX': 'Z', 'UNK': 'X'
}

In [31]:
fasta_files = "..." # directory containing training (or evaluating) fastas
pdb_files = "..." # directory containing training (or evaluating) pdbs

num_fasta = len([f for f in os.listdir(fasta_files) if f.endswith('.fasta')])
num_pdb = len([f for f in os.listdir(pdb_files) if f.endswith('.pdb')])
num_fasta, num_pdb

(1080, 1080)

### 1. Data

In [32]:
def parse_chains_from_fasta_name(fasta_name):
    base = fasta_name.replace('.fasta', '')
    parts = base.split('_')
    ag_idx = parts.index('ag')
    ab_idx = parts.index('ab')
    antigen_chains = parts[ag_idx+1:ab_idx]
    antibody_chains = parts[ab_idx+1:]
    return antigen_chains, antibody_chains

In [33]:
def get_atoms(chains):
    return [atom for chain in chains for atom in unfold_entities(chain, 'A') if atom.element != 'H']

In [34]:
def get_epitope_labels(antigen_chain_objs, antibody_chain_objs):
    antibody_atoms = get_atoms(antibody_chain_objs)
    ns = NeighborSearch(antibody_atoms)
    epitope_residues = set()

    for chain in antigen_chain_objs:
        for res in chain.get_residues():
            if not is_aa(res):
                continue
            for atom in res:
                if ns.search(atom.coord, 4):
                    epitope_residues.add((chain.id, res.id))
                    break

    labels = []
    for chain in antigen_chain_objs:
        for res in chain.get_residues():
            if not is_aa(res):
                continue
            label_val = 1 if (chain.id, res.id) in epitope_residues else 0
            labels.append(label_val)
    return torch.tensor(labels, dtype=torch.long)

In [35]:
def esm_embed_sequences(sequences, model, alphabet, device):
    embeddings = []
    for seq in sequences:
        batch = alphabet.get_batch_converter()([("protein", seq)])
        batch_labels, batch_strs, batch_tokens = batch
        batch_tokens = batch_tokens.to(device)
        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[33], return_contacts=False)
        token_embeddings = results["representations"][33]
        # Remove BOS and EOS tokens
        seq_embedding = token_embeddings[0, 1:-1].cpu()
        embeddings.append(seq_embedding)
    return torch.cat(embeddings, dim=0)

In [36]:
class EpitopeDataset(Dataset):
    def __init__(self, fasta_dir, pdb_dir, esm_model, esm_alphabet, device):
        TOTAL_ANTIGEN_CHAINS = 0
        TOTAL_ANTIBODY_CHAINS = 0
        
        self.fasta_dir = fasta_dir
        self.pdb_dir = pdb_dir
        self.esm_model = esm_model
        self.esm_alphabet = esm_alphabet
        self.device = device

        self.fasta_files = sorted([f for f in os.listdir(fasta_dir) if f.endswith('.fasta')])
        self.pdb_files = sorted([f for f in os.listdir(pdb_dir) if f.endswith('.pdb')])

        self.antigen_len_cache = {}

        total_ag = 0
        total_ab = 0
        for fasta_file in self.fasta_files:
            ag_chains, ab_chains = parse_chains_from_fasta_name(fasta_file)
            total_ag += len(ag_chains)
            total_ab += len(ab_chains)

        EpitopeDataset.TOTAL_ANTIGEN_CHAINS = total_ag
        EpitopeDataset.TOTAL_ANTIBODY_CHAINS = total_ab

    def __len__(self):
        return len(self.fasta_files)

    def __getitem__(self, idx):
        if idx in self.antigen_len_cache:
            antigen_length = self.antigen_len_cache[idx]
        else:
            antigen_length = None

        fasta_name = self.fasta_files[idx]
        fasta_id = os.path.splitext(fasta_name)[0]

        matched_pdb_file = None
        for f in self.pdb_files:
            if fasta_id in f:
                matched_pdb_file = os.path.join(self.pdb_dir, f)
                break

        if matched_pdb_file is None:
            raise ValueError(f"No matching pdb file found for {fasta_name}")

        antigen_chains, antibody_chains = parse_chains_from_fasta_name(fasta_name)

        parser = PDBParser(QUIET=True)
        structure = parser.get_structure("protein", matched_pdb_file)
        model = structure[0]

        sorted_chain_ids = sorted([chain.id for chain in model])
        assert len(sorted_chain_ids) == len(antigen_chains) + len(antibody_chains)

        antigen_chain_ids = sorted_chain_ids[:len(antigen_chains)]
        antibody_chain_ids = sorted_chain_ids[len(antigen_chains):]

        antigen_chains_objs = [model[c] for c in antigen_chain_ids]
        antibody_chains_objs = [model[c] for c in antibody_chain_ids]

        # Antigen length
        if antigen_length is None:
            length = 0
            for chain in antigen_chains_objs:
                for residue in chain.get_residues():
                    if is_aa(residue):
                        length += 1
            self.antigen_len_cache[idx] = length
            antigen_length = length

        # Antigen sequence
        antigen_sequences = []
        for chain in antigen_chains_objs:
            seq = ""
            for residue in chain.get_residues():
                if is_aa(residue):
                    try:
                        resname = residue.get_resname()
                        aa = three_to_one_dict.get(resname, 'X')
                        seq += aa
                    except KeyError:
                        continue
            antigen_sequences.append(seq)

        embedding = esm_embed_sequences(
            antigen_sequences, self.esm_model, self.esm_alphabet, self.device
        )

        labels = get_epitope_labels(antigen_chains_objs, antibody_chains_objs)
        mask = torch.ones(labels.shape[0], dtype=torch.bool)

        return {
            'embedding': embedding,
            'labels': labels,
            'mask': mask,
            'antigen_length': antigen_length
        }

In [37]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm_model = esm_model.to(device)
esm_model.eval()

dataset = EpitopeDataset(fasta_files, pdb_files, esm_model, esm_alphabet, device)

print("Number of antigen chains:", EpitopeDataset.TOTAL_ANTIGEN_CHAINS)
print("Number of antibody chains:", EpitopeDataset.TOTAL_ANTIBODY_CHAINS)

Number of antigen chains: 1338
Number of antibody chains: 2160


In [38]:
max_seq_len = 1024

def collate_fn_padding(batch):
    batch_embeddings = []
    batch_labels = []
    batch_masks = []
    attn_masks = []

    for item in batch:
        L = item['embedding'].shape[0]
        pad_len = max_seq_len - L
        if pad_len < 0:
            continue

        embedding = F.pad(item['embedding'], (0, 0, 0, pad_len), value=0)
        labels = F.pad(item['labels'], (0, pad_len), value=-100)
        mask = F.pad(item['mask'], (0, pad_len), value=0)
        attn_mask = torch.cat([torch.ones(L), torch.zeros(pad_len)])

        batch_embeddings.append(embedding)
        batch_labels.append(labels)
        batch_masks.append(mask)
        attn_masks.append(attn_mask)

    batch_embeddings = torch.stack(batch_embeddings)
    batch_labels = torch.stack(batch_labels)
    batch_masks = torch.stack(batch_masks)
    attn_masks = torch.stack(attn_masks)

    return {
        "embedding": batch_embeddings,
        "labels": batch_labels,
        "mask": batch_masks,
        "attention_mask": attn_masks
    }


In [39]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn_padding)

for i, batch in enumerate(dataloader):
    embedding = batch["embedding"]        # shape: [B, max_len, 1280]
    labels = batch["labels"]              # shape: [B, max_len]
    mask = batch["mask"]                  # shape: [B, max_len]
    attention_mask = batch["attention_mask"]  # shape: [B, max_len]

    print(f"Embedding shape: {embedding.shape}")
    print(f"Labels shape: {labels.shape}")
    print(f"Mask shape: {mask.shape}")
    print(f"Attention mask shape: {attention_mask.shape}")


    print("\n=== A sequence sample ===")
    print(f"Embedding shape: {embedding[0].shape}")  # [max_len, 1280]
    print(f"Embedding:\n{embedding[0]}")
    print(f"Labels:\n{labels[0]}")
    print(f"Mask:\n{mask[0]}")

    break

Embedding shape: torch.Size([8, 1024, 1280])
Labels shape: torch.Size([8, 1024])
Mask shape: torch.Size([8, 1024])
Attention mask shape: torch.Size([8, 1024])

=== A sequence sample ===
Embedding shape: torch.Size([1024, 1280])
Embedding:
tensor([[ 0.2433, -0.2871,  0.0368,  ...,  0.1387, -0.1156, -0.1013],
        [ 0.0637,  0.0753,  0.0817,  ...,  0.0151,  0.0629,  0.1518],
        [-0.0257,  0.0330,  0.0512,  ...,  0.0281,  0.3724,  0.1346],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])
Labels:
tensor([   0,    0,    0,  ..., -100, -100, -100])
Mask:
tensor([ True,  True,  True,  ..., False, False, False])


### 2. Model

In [40]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [41]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [42]:
class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=partial(nn.LayerNorm, eps=1e-6)):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

In [43]:
class ConvBlock(nn.Module):
    def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1,
                 norm_layer=partial(nn.BatchNorm1d, eps=1e-6), drop_block=None, drop_path=None):
        super(ConvBlock, self).__init__()

        expansion = 4
        med_planes = outplanes // expansion

        self.conv1 = nn.Conv1d(inplanes, med_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = norm_layer(med_planes)
        self.act1 = act_layer(inplace=True)

        self.conv2 = nn.Conv1d(med_planes, med_planes, kernel_size=3, stride=stride, groups=groups, padding=1, bias=False)
        self.bn2 = norm_layer(med_planes)
        self.act2 = act_layer(inplace=True)

        self.conv3 = nn.Conv1d(med_planes, outplanes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = norm_layer(outplanes)
        self.act3 = act_layer(inplace=True)

        if res_conv:
            self.residual_conv = nn.Conv1d(inplanes, outplanes, kernel_size=1, stride=stride, padding=0, bias=False)
            self.residual_bn = norm_layer(outplanes)

        self.res_conv = res_conv
        self.drop_block = drop_block
        self.drop_path = drop_path

    def zero_init_last_bn(self):
        nn.init.zeros_(self.bn3.weight)

    def forward(self, x, x_t=None, return_x_2=True):
        residual = x

        x = self.conv1(x)
        x = self.bn1(x)
        if self.drop_block is not None:
            x = self.drop_block(x)
        x = self.act1(x)

        x = self.conv2(x) if x_t is None else self.conv2(x + x_t)
        x = self.bn2(x)
        if self.drop_block is not None:
            x = self.drop_block(x)
        x2 = self.act2(x)

        x = self.conv3(x2)
        x = self.bn3(x)
        if self.drop_block is not None:
            x = self.drop_block(x)

        if self.drop_path is not None:
            x = self.drop_path(x)

        if self.res_conv:
            residual = self.residual_conv(residual)
            residual = self.residual_bn(residual)

        x += residual
        x = self.act3(x)

        if return_x_2:
            return x, x2
        else:
            return x

In [44]:
class FCUDown(nn.Module):
    """ CNN feature maps -> Transformer patch embeddings
    """

    def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU,
                 norm_layer=partial(nn.LayerNorm, eps=1e-6)):
        super(FCUDown, self).__init__()
        self.dw_stride = dw_stride

        self.conv_project = nn.Conv1d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)
        self.sample_pooling = nn.AvgPool1d(kernel_size=dw_stride, stride=dw_stride)

        self.ln = norm_layer(outplanes)
        self.act = act_layer()

    def forward(self, x, x_t):
        x = self.conv_project(x)  # [B, C, L]
        x = self.sample_pooling(x).transpose(1, 2)  # [B, L', C_embed]
        x = self.ln(x)
        x = self.act(x)

        x = torch.cat([x_t[:, 0][:, None, :], x], dim=1)  # [B, L'+1, C_embed]

        return x

In [45]:
class FCUUp(nn.Module):
    """ Transformer patch embeddings -> CNN feature maps
    """

    def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU,
                 norm_layer=partial(nn.BatchNorm1d, eps=1e-6),):
        super(FCUUp, self).__init__()

        self.up_stride = up_stride
        self.conv_project = nn.Conv1d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)
        self.bn = norm_layer(outplanes)
        self.act = act_layer()

    def forward(self, x, target_length):
        B, L, C = x.shape # x: [batch, seq_len, embed_dim]
        # [B, L, C] -> [B, L-1, C] -> [B, C, L-1]
        x_r = x[:, 1:].transpose(1, 2)
        x_r = self.act(self.bn(self.conv_project(x_r)))

        return F.interpolate(x_r, size=target_length, mode='linear', align_corners=False)

In [46]:
class Med_ConvBlock(nn.Module):
    """ special case for Convblock with down sampling, adapted to 1D conv for sequences
    """
    def __init__(self, inplanes, act_layer=nn.ReLU, groups=1, norm_layer=partial(nn.BatchNorm1d, eps=1e-6),
                 drop_block=None, drop_path=None):

        super(Med_ConvBlock, self).__init__()

        expansion = 4
        med_planes = inplanes // expansion

        self.conv1 = nn.Conv1d(inplanes, med_planes, kernel_size=1, stride=1, padding=0, bias=False)  # 1D conv
        self.bn1 = norm_layer(med_planes)  # 1D BN
        self.act1 = act_layer(inplace=True)

        self.conv2 = nn.Conv1d(med_planes, med_planes, kernel_size=3, stride=1, groups=groups, padding=1, bias=False)  # 1D conv
        self.bn2 = norm_layer(med_planes)  # 1D BN
        self.act2 = act_layer(inplace=True)

        self.conv3 = nn.Conv1d(med_planes, inplanes, kernel_size=1, stride=1, padding=0, bias=False)  # 1D conv
        self.bn3 = norm_layer(inplanes)  # 1D BN
        self.act3 = act_layer(inplace=True)

        self.drop_block = drop_block
        self.drop_path = drop_path

    def zero_init_last_bn(self):
        nn.init.zeros_(self.bn3.weight)

    def forward(self, x):
        residual = x

        x = self.conv1(x)
        x = self.bn1(x)
        if self.drop_block is not None:
            x = self.drop_block(x)
        x = self.act1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        if self.drop_block is not None:
            x = self.drop_block(x)
        x = self.act2(x)

        x = self.conv3(x)
        x = self.bn3(x)
        if self.drop_block is not None:
            x = self.drop_block(x)

        if self.drop_path is not None:
            x = self.drop_path(x)

        x += residual
        x = self.act3(x)

        return x

In [47]:
class ConvTransBlock(nn.Module):
    """
    ConvTransformer basic module, adapted for 1D sequence data.
    """

    def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads=12, mlp_ratio=4.,
                 qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
                 last_fusion=False, num_med_block=0, groups=1):

        super(ConvTransBlock, self).__init__()
        expansion = 4

        self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, stride=stride,
                                   res_conv=res_conv, groups=groups, 
                                   norm_layer=partial(nn.BatchNorm1d, eps=1e-6))  

        if last_fusion:
            self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, stride=1,
                                         res_conv=True, groups=groups,
                                         norm_layer=partial(nn.BatchNorm1d, eps=1e-6))
        else:
            self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, groups=groups,
                                         norm_layer=partial(nn.BatchNorm1d, eps=1e-6))

        if num_med_block > 0:
            self.med_block = []
            for i in range(num_med_block):
                self.med_block.append(Med_ConvBlock(inplanes=outplanes, groups=groups,
                                                    norm_layer=partial(nn.BatchNorm1d, eps=1e-6)))
            self.med_block = nn.ModuleList(self.med_block)

        self.squeeze_block = FCUDown(inplanes=outplanes // expansion, outplanes=embed_dim, dw_stride=dw_stride)

        self.expand_block = FCUUp(inplanes=embed_dim, outplanes=outplanes // expansion, up_stride=dw_stride,
                                 norm_layer=partial(nn.BatchNorm1d, eps=1e-6))

        self.trans_block = Block(
            dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate)

        self.dw_stride = dw_stride
        self.embed_dim = embed_dim
        self.num_med_block = num_med_block
        self.last_fusion = last_fusion

    def forward(self, x, x_t):
        # x shape: [B, C, L]
        x, x2 = self.cnn_block(x)  # x2 shape: [B, C_out, L_out]
        x_st = self.squeeze_block(x2, x_t)  # x_st shape: [B, L', embed_dim]

        x_t = self.trans_block(x_st + x_t)

        if self.num_med_block > 0:
            for m in self.med_block:
                x = m(x)

        x_t_r = self.expand_block(x_t, target_length=x.shape[-1])

        # print(x.shape, x_t_r.shape)

        x = self.fusion_block(x, x_t_r, return_x_2=False)

        return x, x_t

In [48]:
class Conformer(nn.Module):
    def __init__(self, patch_size=16, in_chans=1280, num_classes=2, base_channel=320, channel_ratio=2,
                 num_med_block=0, embed_dim=1536, depth=12, num_heads=12, mlp_ratio=2., qkv_bias=False,
                 qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., fusion_alpha=0.5):
        super().__init__()
        assert depth % 3 == 0
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.fusion_alpha = fusion_alpha

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        # classifier heads
        self.trans_norm = nn.LayerNorm(embed_dim)
        self.trans_token_head = nn.Linear(embed_dim, num_classes)
        self.conv_cls_head = nn.Conv1d(2560, num_classes, kernel_size=1)

        # stem
        self.conv1 = nn.Conv1d(in_chans, 320, kernel_size=7, stride=1, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(320)
        self.act1 = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=1, padding=1)

        # stage 1
        stage_1_channel = int(base_channel * channel_ratio)
        trans_dw_stride = 1
        self.conv_1 = ConvBlock(inplanes=320, outplanes=stage_1_channel, res_conv=True, stride=1)
        self.trans_patch_conv = nn.Conv1d(320, embed_dim, kernel_size=5, stride=1, padding=2)
        self.trans_1 = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                             qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                             attn_drop=attn_drop_rate, drop_path=self.trans_dpr[0])

        # stage 2~4
        init_stage = 2
        fin_stage = depth // 3 + 1
        for i in range(init_stage, fin_stage):
            self.add_module(f'conv_trans_{i}',
                ConvTransBlock(
                    stage_1_channel, stage_1_channel, False, 1, dw_stride=trans_dw_stride,
                    embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate,
                    attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i - 1],
                    num_med_block=num_med_block
                )
            )

        # stage 5~8
        stage_2_channel = stage_1_channel * 2
        for i in range(fin_stage, fin_stage + depth // 3):
            in_channel = stage_1_channel if i == fin_stage else stage_2_channel
            self.add_module(f'conv_trans_{i}',
                ConvTransBlock(
                    in_channel, stage_2_channel, i == fin_stage, 1, dw_stride=trans_dw_stride,
                    embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate,
                    attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i - 1],
                    num_med_block=num_med_block
                )
            )

        # stage 9~12
        stage_3_channel = stage_2_channel * 2
        for i in range(fin_stage + depth // 3, fin_stage + 2 * (depth // 3)):
            in_channel = stage_2_channel if i == fin_stage + depth // 3 else stage_3_channel
            self.add_module(f'conv_trans_{i}',
                ConvTransBlock(
                    in_channel, stage_3_channel, i == fin_stage + depth // 3, 1, dw_stride=trans_dw_stride,
                    embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate,
                    attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i - 1],
                    num_med_block=num_med_block,
                    last_fusion=(i == depth)
                )
            )

        self.fin_stage = fin_stage + 2 * (depth // 3)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.)
        elif isinstance(m, nn.Conv1d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.constant_(m.weight, 1.)
            nn.init.constant_(m.bias, 0.)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'cls_token'}

    def forward(self, x):
        B, _, L = x.shape
        cls_tokens = self.cls_token.expand(B, -1, -1)

        x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))     # [B, 320, L]
        x = self.conv_1(x_base, return_x_2=False)                     # [B, C, L]

        x_t = self.trans_patch_conv(x_base).transpose(1, 2)          # [B, L, embed_dim]
        x_t = torch.cat([cls_tokens, x_t], dim=1)
        x_t = self.trans_1(x_t)

        for i in range(2, self.fin_stage):
            x, x_t = getattr(self, f'conv_trans_{i}')(x, x_t)

        conv_cls = self.conv_cls_head(x)                             # [B, num_classes, L]
        x_t = self.trans_norm(x_t)
        tran_cls = self.trans_token_head(x_t[:, 1:, :]).transpose(1, 2)  # [B, num_classes, L]

        # linear fused classifier
        final_cls = self.fusion_alpha * conv_cls + (1 - self.fusion_alpha) * tran_cls
        
        return final_cls

### 3. Train and Evaluate

In [49]:
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler=None, max_norm: float = 0,
                    model_ema: Optional[object] = None, mixup_fn=None,
                    set_training_mode=True):
    model.train(set_training_mode)
    if hasattr(criterion, 'train'):
        criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = f"Epoch: [{epoch}]"
    print_freq = 10

    for batch in metric_logger.log_every(data_loader, print_freq, header):
        samples = batch['embedding'].to(device)
        targets = batch['labels'].to(device)
        mask = batch['mask'].to(device).bool()
        samples = samples.transpose(1, 2)

        with torch.cuda.amp.autocast():
            output = model(samples)  # [B, num_classes, L]
            loss = sequence_loss(output, targets, mask)

        loss_value = loss.item()
        if not math.isfinite(loss_value):
            sys.exit(1)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    metric_logger.synchronize_between_processes()
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

In [50]:
@torch.no_grad()
def evaluate(data_loader, model, device, threshold=0.3):
    model.eval()
    true_positives = 0
    union_positives = 0

    all_probs = []
    all_preds = []
    all_targets = []
    sample_ious = []

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = "Eval:"
    print_freq = 10
    metric_logger.add_meter("mean_iou", utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    for batch in metric_logger.log_every(data_loader, print_freq, header):
        samples = batch['embedding'].to(device)
        targets = batch['labels'].to(device)
        mask = batch['mask'].to(device).bool()
        samples = samples.transpose(1, 2)

        with torch.cuda.amp.autocast():
            output = model(samples)
            probs = torch.softmax(output, dim=1)[:, 1, :]
            preds = (probs > threshold).long()
            preds = preds.masked_fill(~mask, 0)

            for i in range(samples.shape[0]):
                pred_i = preds[i]
                target_i = targets[i]
                mask_i = mask[i]

                tp_i = ((pred_i == 1) & (target_i == 1) & mask_i).sum().item()
                union_i = (((pred_i == 1) | (target_i == 1)) & mask_i).sum().item()
                iou_i = tp_i / union_i if union_i > 0 else 0.0
                sample_ious.append(iou_i)

            tp = ((preds == 1) & (targets == 1) & mask).sum().item()
            union = (((preds == 1) | (targets == 1)) & mask).sum().item()
            true_positives += tp
            union_positives += union

            all_probs.append(probs[mask].cpu())
            all_preds.append(preds[mask].cpu())
            all_targets.append(targets[mask].cpu())

            mean_iou_so_far = sum(sample_ious) / len(sample_ious)
            metric_logger.update(mean_iou=mean_iou_so_far)

    metric_logger.synchronize_between_processes()

    all_probs = torch.cat(all_probs).numpy()
    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()

    agiou = true_positives / union_positives if union_positives > 0 else 0.0

    try:
        auc = roc_auc_score(all_targets, all_probs)
    except:
        auc = float('nan')

    try:
        pr_auc = average_precision_score(all_targets, all_probs)
    except:
        pr_auc = float('nan')

    try:
        pcc = np.corrcoef(all_probs, all_targets)[0, 1]
    except:
        pcc = float('nan')

    try:
        brier = brier_score_loss(all_targets, all_probs)
    except:
        brier = float('nan')

    try:
        bce = log_loss(all_targets, all_probs, labels=[0, 1])
    except:
        bce = float('nan')

    results = {
        "AgIoU": round(agiou, 4),
        "Precision": round(precision_score(all_targets, all_preds, zero_division=0), 4),
        "Recall": round(recall_score(all_targets, all_preds, zero_division=0), 4),
        "F1": round(f1_score(all_targets, all_preds, zero_division=0), 4),
        "MCC": round(matthews_corrcoef(all_targets, all_preds), 4),
        "Accuracy": round(accuracy_score(all_targets, all_preds), 4),
        "AUC": round(auc, 4),
        "PR-AUC": round(pr_auc, 4),
        "PCC": round(pcc, 4),
        "Brier": round(brier, 4),
        "BCE": round(bce, 4)
    }

    print("\nEvaluation Results:")
    for k, v in results.items():
        print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

    return results, sample_ious

In [51]:
def sequence_loss(pred, target, mask):
    """
    pred: [B, C, L]
    target: [B, L]
    mask: [B, L] (bool)
    """
    B, C, L = pred.shape
    pred = pred.transpose(1, 2).reshape(-1, C)      # [B*L, C]
    target = target.reshape(-1)                     # [B*L]
    mask = mask.reshape(-1)                         # [B*L], bool

    loss = F.cross_entropy(pred, target, reduction='none')  # [B*L]
    loss = loss[mask].mean()  # only valid positions
    return loss

In [52]:
def criterion(output, target, mask):
    return sequence_loss(output, target, mask)

### 3.1 Train

#### Model

In [53]:
# de novo training, model params and MACs
# BConformer, Conformer-12
model = Conformer(in_chans=1280, num_classes=2, depth=12)
# Conformer-9
# model = Conformer(in_chans=1280, num_classes=2, depth=9)
# Conformer-6
# model = Conformer(in_chans=1280, num_classes=2, depth=6)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

macs, params = get_model_complexity_info(model, (1280, 1024), as_strings=True, verbose=False)
print(f"Params: {params}, MACs: {macs}")

Conformer(
  302.24 M, 99.999% Params, 309.79 GMac, 99.956% MACs, 
  (trans_norm): LayerNorm(3.07 k, 0.001% Params, 1.57 MMac, 0.001% MACs, (1536,), eps=1e-05, elementwise_affine=True)
  (trans_token_head): Linear(3.07 k, 0.001% Params, 3.15 MMac, 0.001% MACs, in_features=1536, out_features=2, bias=True)
  (conv_cls_head): Conv1d(5.12 k, 0.002% Params, 5.24 MMac, 0.002% MACs, 2560, 2, kernel_size=(1,), stride=(1,))
  (conv1): Conv1d(2.87 M, 0.949% Params, 2.94 GMac, 0.947% MACs, 1280, 320, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
  (bn1): BatchNorm1d(640, 0.000% Params, 655.36 KMac, 0.000% MACs, 320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(0, 0.000% Params, 327.68 KMac, 0.000% MACs, inplace=True)
  (maxpool): MaxPool1d(0, 0.000% Params, 327.68 KMac, 0.000% MACs, kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
  (conv_1): ConvBlock(
    438.4 k, 0.145% Params, 449.9 MMac, 0.145% MACs, 
    (conv1): Conv1d(51.2 k, 0.

In [None]:
# training on pretrained models
# model_name = "..." # model name, e.g. model_epoch104_AgIoU0.6534.pth
# model_path = os.path.join("...", model_name) # directory + model name
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = Conformer(in_chans=1280, num_classes=2)
# state = torch.load(model_path, map_location=device, weights_only=False)
# model.load_state_dict(state['model_state_dict'])
# model.to(device)

#### Training with checkpoints

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler()
threshold = 0.3
epochs = 150
all_metrics = []

save_dir = "..." # directory saving models
os.makedirs(save_dir, exist_ok=True)

for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    train_stats = train_one_epoch(model, criterion, dataloader, optimizer, device, epoch, scaler)
    val_stats, _ = evaluate(dataloader, model, device, threshold)
    
    all_metrics.append(val_stats)
    
    print(f"Train loss: {train_stats['loss']:.4f}\n")
    
    if 50 <= epoch + 1 <= 150:
        save_path = os.path.join(save_dir, f"model_epoch{epoch+1}_AgIoU{val_stats['AgIoU']:.4f}.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'all_metrics': all_metrics,
        }, save_path)

In [None]:
# save training process (metrics at each epoch) to a csv.
df = pd.DataFrame(all_metrics)
df.insert(0, "Epoch", range(1, len(df) + 1))

df.to_csv("....csv", index=False)
print("Successfully saved.")

### 3.2 Evaluate

In [None]:
# model_name = "..." # model name, e.g. model_epoch96_AgIoU0.6326.pth
# model_path = os.path.join("...", model_name) # checkpoints directory + model name
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = Conformer(in_chans=1280, num_classes=2)
# state = torch.load(model_path, map_location=device, weights_only=False)
# model.load_state_dict(state['model_state_dict'])
# model.to(device)
# thresholds = np.linspace(0.28, 0.32, 40)
# collected_metrics = {}

# for threshold in thresholds:
#     with torch.no_grad():
#         metrics, _ = evaluate_get_sample_iou(dataloader, model, device, threshold)
#         for k, v in metrics.items():
#             collected_metrics.setdefault(k, []).append(v)

# # metrics (mean ± std)
# results_summary = {}
# for k, v_list in collected_metrics.items():
#     v_array = np.array(v_list)
#     mean = np.mean(v_array)
#     std = np.std(v_array)
#     results_summary[k] = f"{mean:.3f} ± {std:.3f}"

# for k, v in results_summary.items():
#     print(f"{k}: {v}")