# nano-DyMEAN

## Description
Personal modifications of the dyMEAN framework tailored for Nanobody analysis.
We start with DyMEAN paper implimentation, and modify that scripts to avoid the use of lighchain explicly.
the input for dyMEAN is just the json files(train.json, valid.json and test.json) and all_structure folders which has all pdb structures

In [1]:

########### Import your packages below ##########

import os
import requests
import numpy as np
import pickle
import re
import random
import sys
import datetime
import json
import argparse
from random import random
import random

from tqdm import tqdm
from math import cos, pi, log, exp, nan
from copy import copy, deepcopy
from typing import Dict, List, Tuple

from scipy.spatial.transform import Rotation


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_mean, scatter_std, scatter_softmax, scatter_sum
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter



from Bio.PDB import PDBParser, PDBIO
from Bio.PDB.Structure import Structure as BStructure
from Bio.PDB.Model import Model as BModel
from Bio.PDB.Chain import Chain as BChain
from Bio.PDB.Residue import Residue as BResidue
from Bio.PDB.Atom import Atom as BAtom


In [2]:

SEED = 12


class Config:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

config = Config(
    backbone_only=False,
    batch_size=1,
    bind_dist_cutoff=6.6,
    cdr=['H3'],
    embed_dim=64,
    final_lr=0.0001,
    fix_channel_weights=False,
    gpus=[0],
    # gpus=[0, 1],
    grad_clip=1.0,
    hidden_size=128,
    iter_round=3,
    k_neighbors=9,
    local_rank=-1,
    lr=0.001,
    max_epoch=200,
    model_type='dyMEAN',
    n_layers=3,
    no_memory=False,
    no_pred_edge_dist=False,
    num_workers=4,
    paratope='H3',
    patience=1000,
    save_dir='./all_data/RAbD/models_single_cdr_design',
    save_topk=10,
    seq_warmup=0,
    shuffle=True,
    struct_only=False,
    train_set='./all_data/RAbD/train.json',
    valid_set='./all_data/RAbD/valid.json',
    warmup=0
)



In [3]:
LEVELS = ['TRACE', 'DEBUG', 'INFO', 'WARN', 'ERROR']
LEVELS_MAP = None


def init_map():
    global LEVELS_MAP, LEVELS
    LEVELS_MAP = {}
    for idx, level in enumerate(LEVELS):
        LEVELS_MAP[level] = idx


def get_prio(level):
    global LEVELS_MAP
    if LEVELS_MAP is None:
        init_map()
    return LEVELS_MAP[level.upper()]


def print_log(s, level='INFO', end='\n', no_prefix=False):
    pth_prio = get_prio(os.getenv('LOG', 'INFO'))
    prio = get_prio(level)
    if prio >= pth_prio:
        if not no_prefix:
            now = datetime.datetime.now()
            prefix = now.strftime("%Y-%m-%d %H:%M:%S") + f'::{level.upper()}::'
            print(prefix, end='')
        print(s, end=end)
        sys.stdout.flush()

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True


setup_seed(SEED)



In [4]:


'''
Four parts:
1. basic variables
2. benchmark definitions and configs for data processing
3. definitions for antibody numbering system
4. optional dependencies for pipelines
'''

# 1. basic variables
PROJ_DIR = './'
RENUMBER = os.path.join(PROJ_DIR, 'utils', 'renumber.py')
# FoldX
FOLDX_BIN = './foldx5/foldx_20231231'
# DockQ 
# IMPORTANT: change it to your path to DockQ project)
DOCKQ_DIR = './DockQ'
# cache directory
CACHE_DIR = os.path.join(PROJ_DIR, '__cache__')
if not os.path.exists(CACHE_DIR):
    os.makedirs(CACHE_DIR)


# 2. configs related to data process
AG_TYPES = ['protein', 'peptide']
RAbD_PDB = ['1a14', '1a2y', '1fe8', '1ic7', '1iqd', '1n8z', '1ncb', '1osp', '1uj3', '1w72', '2adf', '2b2x', '2cmr', '2dd8', '2ghw', '2vxt', '2xqy', '2xwt', '2ypv', '3bn9', '3cx5', '3ffd', '3h3b', '3hi6', '3k2u', '3l95', '3mxw', '3nid', '3o2d', '3rkd', '3s35', '3uzq', '3w9e', '4cmh', '4dtg', '4dvr', '4etq', '4ffv', '4fqj', '4g6j', '4g6m', '4h8w', '4ki5', '4lvn', '4ot1', '4qci', '4xnq', '4ydk', '5b8c', '5bv7', '5d93', '5d96', '5en2', '5f9o', '5ggs', '5hi4', '5j13', '5l6y', '5mes', '5nuz']
IGFOLD_TEST_PDB = ['7rdm', '6xsw', '7e9b', '7cj2', '7o33', '7ora', '7mzh', '7mzj', '7ahu', '7s13', '7mzk', '7e72', '7n3c', '7n3f', '7rdl', '7mzg', '7n4j', '7r9d', '7e3o', '7rah', '7or9', '7oo2', '6xsn', '7arn', '7n3e', '7o30', '7kf0', '7lfa', '7nx3', '7keo', '7mzf', '7o31', '7e5o', '7daa', '7s11', '7aj6', '7m2i', '7kf1', '7jkm', '7n4i', '6xm2', '7doh', '7o2z', '7kba', '7s3m', '7mzm', '7rks', '7n3h', '7lyw', '7rco', '7bg1', '7coe', '7n3g', '7kkz', '7kyo', '7s4s', '7rnj', '7bbg', '7l7e', '7n0u', '6xlz', '7mf7', '6xp6', '7lfb', '7kn3', '7rdk', '7s0b', '7kez', '7n3d', '7o4y']
SKEMPI_PDB = ['1ahw', '1dvf', '1vfb', '2vis', '2vir', '1kiq', '1kip', '1kir', '2jel', '1nca', '1dqj', '1jrh', '1nmb', '3hfm', '1yy9', '4gxu', '3lzf', '1n8z', '3g6d', '1xgu', '1xgp', '1xgq', '1xgr', '1xgt', '3n85', '4i77', '3l5x', '4jpk', '1bj1', '1cz8', '1mhp', '2b2x', '1mlc', '3bdy', '3be1', '2ny7', '3idx', '2nyy', '2nz9', '3ngb', '2bdn', '3w2d', '4krl', '4kro', '4krp', '4nm8', '4u6h', '4zs6', '5c6t', '5dwu', '3se8', '3se9', '1yqv']
CONTACT_DIST = 6.6  # 6.6 A between one pair of atoms means the two residues are interacting


# 3. antibody numbering, [start, end] of residue id, both start & end are included
# 3.1 IMGT numbering definition
class IMGT:
    # heavy chain
    HFR1 = (1, 26)
    HFR2 = (39, 55)
    HFR3 = (66, 104)
    HFR4 = (118, 129)

    H1 = (27, 38)
    H2 = (56, 65)
    H3 = (105, 117)

    # light chain
    LFR1 = (1, 26)
    LFR2 = (39, 55)
    LFR3 = (66, 104)
    LFR4 = (118, 129)

    L1 = (27, 38)
    L2 = (56, 65)
    L3 = (105, 117)

    Hconserve = {
        23: ['CYS'],
        41: ['TRP'],
        104: ['CYS']
    }

    Lconserve = {
        23: ['CYS'],
        41: ['TRP'],
        104: ['CYS']
    }

    @classmethod
    def renumber(cls, pdb, out_pdb):
        code = os.system(f'python {RENUMBER} {pdb} {out_pdb} imgt 0')
        return code

# 3.2 Chothia numbering definition
class Chothia:
    # heavy chain
    HFR1 = (1, 25)
    HFR2 = (33, 51)
    HFR3 = (57, 94)
    HFR4 = (103, 113)

    H1 = (26, 32)
    H2 = (52, 56)
    H3 = (95, 102)

    # light chain
    LFR1 = (1, 23)
    LFR2 = (35, 49)
    LFR3 = (57, 88)
    LFR4 = (98, 107)

    L1 = (24, 34)
    L2 = (50, 56)
    L3 = (89, 97)

    Hconserve = {
        92: ['CYS']
    }

    Lconserve = {
        88: ['CYS']
    }

    @classmethod
    def renumber(cls, pdb, out_pdb):
        code = os.system(f'python {RENUMBER} {pdb} {out_pdb} chothia 0')
        return code


# (Optional) 4. dependencies for pipelines
# 4.1 structure prediction
IGFOLD_DIR = './IgFold'
IGFOLD_CKPTS = None  # 'None' means using the default checkpoints
# 4.2 docking
HDOCK_DIR = './HDOCKlite-v1.1'
# 4.3 CDR generation models
MEAN_DIR = './MEAN'
Rosetta_DIR = './rosetta/rosetta.binary.linux.release-315/main/source/bin'
DiffAb_DIR = './diffab'

In [5]:




class AminoAcid:
    def __init__(self, symbol: str, abrv: str, sidechain: List[str], idx=0):
        self.symbol = symbol
        self.abrv = abrv
        self.idx = idx
        self.sidechain = sidechain

    def __str__(self):
        return f'{self.idx} {self.symbol} {self.abrv} {self.sidechain}'


class AminoAcidVocab:

    MAX_ATOM_NUMBER = 14   # 4 backbone atoms + up to 10 sidechain atoms

    def __init__(self):
        self.backbone_atoms = ['N', 'CA', 'C', 'O']
        self.PAD, self.MASK = '#', '*'
        self.BOA, self.BOH, self.BOL = '&', '+', '-' # begin of antigen, heavy chain, light chain
        specials = [# special added
                (self.PAD, 'PAD'), (self.MASK, 'MASK'), # mask for masked / unknown residue
                (self.BOA, '<X>'), (self.BOH, '<H>'), (self.BOL, '<L>')
            ]
        aas = [
                ('G', 'GLY'), ('A', 'ALA'), ('V', 'VAL'), ('L', 'LEU'),
                ('I', 'ILE'), ('F', 'PHE'), ('W', 'TRP'), ('Y', 'TYR'),
                ('D', 'ASP'), ('H', 'HIS'), ('N', 'ASN'), ('E', 'GLU'),
                ('K', 'LYS'), ('Q', 'GLN'), ('M', 'MET'), ('R', 'ARG'),
                ('S', 'SER'), ('T', 'THR'), ('C', 'CYS'), ('P', 'PRO') # 20 aa
                # ('U', 'SEC') # 21 aa for eukaryote
            ]

        # max number of sidechain atoms: 10        
        self.atom_pad, self.atom_mask = 'p', 'm'
        self.atom_pos_mask, self.atom_pos_bb, self.atom_pos_pad = 'm', 'b', 'p'
        sidechain_map = {
            'G': [],   # -H
            'A': ['CB'],  # -CH3
            'V': ['CB', 'CG1', 'CG2'],  # -CH-(CH3)2
            'L': ['CB', 'CG', 'CD1', 'CD2'],  # -CH2-CH(CH3)2
            'I': ['CB', 'CG1', 'CG2', 'CD1'], # -CH(CH3)-CH2-CH3
            'F': ['CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ'],  # -CH2-C6H5
            'W': ['CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],  # -CH2-C8NH6
            'Y': ['CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH'],  # -CH2-C6H4-OH
            'D': ['CB', 'CG', 'OD1', 'OD2'],  # -CH2-COOH
            'H': ['CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2'],  # -CH2-C3H3N2
            'N': ['CB', 'CG', 'OD1', 'ND2'],  # -CH2-CONH2
            'E': ['CB', 'CG', 'CD', 'OE1', 'OE2'],  # -(CH2)2-COOH
            'K': ['CB', 'CG', 'CD', 'CE', 'NZ'],  # -(CH2)4-NH2
            'Q': ['CB', 'CG', 'CD', 'OE1', 'NE2'],  # -(CH2)-CONH2
            'M': ['CB', 'CG', 'SD', 'CE'],  # -(CH2)2-S-CH3
            'R': ['CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2'],  # -(CH2)3-NHC(NH)NH2
            'S': ['CB', 'OG'],  # -CH2-OH
            'T': ['CB', 'OG1', 'CG2'],  # -CH(CH3)-OH
            'C': ['CB', 'SG'],  # -CH2-SH
            'P': ['CB', 'CG', 'CD'],  # -C3H6
        }

        self.chi_angles_atoms = {
            "ALA": [],
            # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
            "ARG": [
                ["N", "CA", "CB", "CG"],
                ["CA", "CB", "CG", "CD"],
                ["CB", "CG", "CD", "NE"],
                ["CG", "CD", "NE", "CZ"],
            ],
            "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
            "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
            "CYS": [["N", "CA", "CB", "SG"]],
            "GLN": [
                ["N", "CA", "CB", "CG"],
                ["CA", "CB", "CG", "CD"],
                ["CB", "CG", "CD", "OE1"],
            ],
            "GLU": [
                ["N", "CA", "CB", "CG"],
                ["CA", "CB", "CG", "CD"],
                ["CB", "CG", "CD", "OE1"],
            ],
            "GLY": [],
            "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
            "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
            "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
            "LYS": [
                ["N", "CA", "CB", "CG"],
                ["CA", "CB", "CG", "CD"],
                ["CB", "CG", "CD", "CE"],
                ["CG", "CD", "CE", "NZ"],
            ],
            "MET": [
                ["N", "CA", "CB", "CG"],
                ["CA", "CB", "CG", "SD"],
                ["CB", "CG", "SD", "CE"],
            ],
            "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
            "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
            "SER": [["N", "CA", "CB", "OG"]],
            "THR": [["N", "CA", "CB", "OG1"]],
            "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
            "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
            "VAL": [["N", "CA", "CB", "CG1"]],
        }

        self.sidechain_bonds = {
            "ALA": { "CA": ["CB"] },
            "GLY": {},
            "VAL": {
                "CA": ["CB"],
                "CB": ["CG1", "CG2"]
            },
            "LEU": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["CD2", "CD1"]
            },
            "ILE": {
                "CA": ["CB"],
                "CB": ["CG1", "CG2"],
                "CG1": ["CD1"]
            },
            "MET": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["SD"],
                "SD": ["CE"],
            },
            "PHE": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["CD1", "CD2"],
                "CD1": ["CE1"],
                "CD2": ["CE2"],
                "CE1": ["CZ"]
            },
            "TRP": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["CD1", "CD2"],
                "CD1": ["NE1"],
                "CD2": ["CE2", "CE3"],
                "CE2": ["CZ2"],
                "CZ2": ["CH2"],
                "CE3": ["CZ3"]
            },
            "PRO": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["CD"]
            },
            "SER": {
                "CA": ["CB"],
                "CB": ["OG"]
            },
            "THR": {
                "CA": ["CB"],
                "CB": ["OG1", "CG2"]
            },
            "CYS": {
                "CA": ["CB"],
                "CB": ["SG"]
            },
            "TYR": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["CD1", "CD2"],
                "CD1": ["CE1"],
                "CD2": ["CE2"],
                "CE1": ["CZ"],
                "CZ": ["OH"]
            },
            "ASN": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["OD1", "ND2"]
            },
            "GLN": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["CD"],
                "CD": ["OE1", "NE2"]
            },
            "ASP": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["OD1", "OD2"]
            },
            "GLU": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["CD"],
                "CD": ["OE1", "OE2"]
            },
            "LYS": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["CD"],
                "CD": ["CE"],
                "CE": ["NZ"]
            },
            "ARG": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["CD"],
                "CD": ["NE"],
                "NE": ["CZ"],
                "CZ": ["NH1", "NH2"]
            },
            "HIS": {
                "CA": ["CB"],
                "CB": ["CG"],
                "CG": ["ND1", "CD2"],
                "ND1": ["CE1"],
                "CD2": ["NE2"]
            }
        }
        

        _all = aas + specials
        self.amino_acids = [AminoAcid(symbol, abrv, sidechain_map.get(symbol, [])) for symbol, abrv in _all]
        self.symbol2idx, self.abrv2idx = {}, {}
        for i, aa in enumerate(self.amino_acids):
            self.symbol2idx[aa.symbol] = i
            self.abrv2idx[aa.abrv] = i
            aa.idx = i
        self.special_mask = [0 for _ in aas] + [1 for _ in specials]

        # atom level vocab
        self.idx2atom = [self.atom_pad, self.atom_mask, 'C', 'N', 'O', 'S']
        self.idx2atom_pos = [self.atom_pos_pad, self.atom_pos_mask, self.atom_pos_bb, 'B', 'G', 'D', 'E', 'Z', 'H']
        self.atom2idx, self.atom_pos2idx = {}, {}
        for i, atom in enumerate(self.idx2atom):
            self.atom2idx[atom] = i
        for i, atom_pos in enumerate(self.idx2atom_pos):
            self.atom_pos2idx[atom_pos] = i
    
    def abrv_to_symbol(self, abrv):
        idx = self.abrv_to_idx(abrv)
        return None if idx is None else self.amino_acids[idx].symbol

    def symbol_to_abrv(self, symbol):
        idx = self.symbol_to_idx(symbol)
        return None if idx is None else self.amino_acids[idx].abrv

    def abrv_to_idx(self, abrv):
        abrv = abrv.upper()
        return self.abrv2idx.get(abrv, None)

    def symbol_to_idx(self, symbol):
        symbol = symbol.upper()
        return self.symbol2idx.get(symbol, None)
    
    def idx_to_symbol(self, idx):
        return self.amino_acids[idx].symbol

    def idx_to_abrv(self, idx):
        return self.amino_acids[idx].abrv

    def get_pad_idx(self):
        return self.symbol_to_idx(self.PAD)

    def get_mask_idx(self):
        return self.symbol_to_idx(self.MASK)
    
    def get_special_mask(self):
        return copy(self.special_mask)

    def get_atom_type_mat(self):
        atom_pad = self.get_atom_pad_idx()
        mat = []
        for i, aa in enumerate(self.amino_acids):
            atoms = [atom_pad for _ in range(self.MAX_ATOM_NUMBER)]
            if aa.symbol == self.PAD:
                pass
            elif self.special_mask[i] == 1:  # specials
                atom_mask = self.get_atom_mask_idx()
                atoms = [atom_mask for _ in range(self.MAX_ATOM_NUMBER)]
            else:
                for aidx, atom in enumerate(self.backbone_atoms + aa.sidechain):
                    atoms[aidx] = self.atom_to_idx(atom[0])
            mat.append(atoms)
        return mat

    def get_atom_pos_mat(self):
        atom_pos_pad = self.get_atom_pos_pad_idx()
        mat = []
        for i, aa in enumerate(self.amino_acids):
            aps = [atom_pos_pad for _ in range(self.MAX_ATOM_NUMBER)]
            if aa.symbol == self.PAD:
                pass
            elif self.special_mask[i] == 1:
                atom_pos_mask = self.get_atom_pos_mask_idx()
                aps = [atom_pos_mask for _ in range(self.MAX_ATOM_NUMBER)]
            else:
                aidx = 0
                for _ in self.backbone_atoms:
                    aps[aidx] = self.atom_pos_to_idx(self.atom_pos_bb)
                    aidx += 1
                for atom in aa.sidechain:
                    aps[aidx] = self.atom_pos_to_idx(atom[1])
                    aidx += 1
            mat.append(aps)
        return mat

    def get_sidechain_info(self, symbol):
        idx = self.symbol_to_idx(symbol)
        return copy(self.amino_acids[idx].sidechain)
    
    def get_sidechain_geometry(self, symbol):
        abrv = self.symbol_to_abrv(symbol)
        chi_angles_atoms = copy(self.chi_angles_atoms[abrv])
        sidechain_bonds = self.sidechain_bonds[abrv]
        return (chi_angles_atoms, sidechain_bonds)
    
    def get_atom_pad_idx(self):
        return self.atom2idx[self.atom_pad]
    
    def get_atom_mask_idx(self):
        return self.atom2idx[self.atom_mask]
    
    def get_atom_pos_pad_idx(self):
        return self.atom_pos2idx[self.atom_pos_pad]

    def get_atom_pos_mask_idx(self):
        return self.atom_pos2idx[self.atom_pos_mask]
    
    def idx_to_atom(self, idx):
        return self.idx2atom[idx]

    def atom_to_idx(self, atom):
        return self.atom2idx[atom]

    def idx_to_atom_pos(self, idx):
        return self.idx2atom_pos[idx]
    
    def atom_pos_to_idx(self, atom_pos):
        return self.atom_pos2idx[atom_pos]

    def get_num_atom_type(self):
        return len(self.idx2atom)
    
    def get_num_atom_pos(self):
        return len(self.idx2atom_pos)

    def get_num_amino_acid_type(self):
        return len(self.special_mask) - sum(self.special_mask)

    def __len__(self):
        return len(self.symbol2idx)


VOCAB = AminoAcidVocab()


def format_aa_abrv(abrv):  # special cases
    if abrv == 'MSE':
        return 'MET' # substitue MSE with MET
    return abrv


class Residue:
    def __init__(self, symbol: str, coordinate: Dict, _id: Tuple):
        self.symbol = symbol
        self.coordinate = coordinate
        self.sidechain = VOCAB.get_sidechain_info(symbol)
        self.id = _id  # (residue_number, insert_code)

    def get_symbol(self):
        return self.symbol

    def get_coord(self, atom_name):
        return copy(self.coordinate[atom_name])

    def get_coord_map(self) -> Dict[str, List]:
        return deepcopy(self.coordinate)

    def get_backbone_coord_map(self) -> Dict[str, List]:
        coord = { atom: self.coordinate[atom] for atom in self.coordinate if atom in VOCAB.backbone_atoms }
        return coord

    def get_sidechain_coord_map(self) -> Dict[str, List]:
        coord = {}
        for atom in self.sidechain:
            if atom in self.coordinate:
                coord[atom] = self.coordinate[atom]
        return coord

    def get_atom_names(self):
        return list(self.coordinate.keys())

    def get_id(self):
        return self.id

    def set_symbol(self, symbol):
        assert VOCAB.symbol_to_abrv(symbol) is not None, f'{symbol} is not an amino acid'
        self.symbol = symbol

    def set_coord(self, coord):
        self.coordinate = deepcopy(coord)

    def dist_to(self, residue):  # measured by nearest atoms
        xa = np.array(list(self.get_coord_map().values()))
        xb = np.array(list(residue.get_coord_map().values()))
        if len(xa) == 0 or len(xb) == 0:
            return nan
        dist = np.linalg.norm(xa[:, None, :] - xb[None, :, :], axis=-1)
        return np.min(dist)

    def to_bio(self):
        _id = (' ', self.id[0], self.id[1])
        residue = BResidue(_id, VOCAB.symbol_to_abrv(self.symbol), '    ')
        atom_map = self.coordinate
        for i, atom in enumerate(atom_map):
            fullname = ' ' + atom
            while len(fullname) < 4:
                fullname += ' '
            bio_atom = BAtom(
                name=atom,
                coord=np.array(atom_map[atom], dtype=np.float32),
                bfactor=0,
                occupancy=1.0,
                altloc=' ',
                fullname=fullname,
                serial_number=i,
                element=atom[0]  # not considering symbols with 2 chars (e.g. FE, MG)
            )
            residue.add(bio_atom)
        return residue

    def __iter__(self):
        return iter([(atom_name, self.coordinate[atom_name]) for atom_name in self.coordinate])


class Peptide:
    def __init__(self, _id, residues: List[Residue]):
        self.residues = residues
        self.seq = ''
        self.id = _id
        for residue in residues:
            self.seq += residue.get_symbol()

    def set_id(self, _id):
        self.id = _id

    def get_id(self):
        return self.id

    def get_seq(self):
        return self.seq

    def get_span(self, i, j):  # [i, j)
        i, j = max(i, 0), min(j, len(self.seq))
        if j <= i:
            return None
        else:
            residues = deepcopy(self.residues[i:j])
            return Peptide(self.id, residues)

    def get_residue(self, i):
        return deepcopy(self.residues[i])
    
    def get_ca_pos(self, i):
        return copy(self.residues[i].get_coord('CA'))

    def get_cb_pos(self, i):
        return copy(self.residues[i].get_coord('CB'))

    def set_residue_coord(self, i, coord):
        self.residues[i].set_coord(coord)

    def set_residue_translation(self, i, vec):
        coord = self.residues[i].get_coord_map()
        for atom in coord:
            ori_vec = coord[atom]
            coord[atom] = [a + b for a, b in zip(ori_vec, vec)]
        self.set_residue_coord(i, coord)

    def set_residue_symbol(self, i, symbol):
        self.residues[i].set_symbol(symbol)
        self.seq = self.seq[:i] + symbol + self.seq[i+1:]

    def set_residue(self, i, symbol, coord):
        self.set_residue_symbol(i, symbol)
        self.set_residue_coord(i, coord)

    def to_bio(self):
        chain = BChain(id=self.id)
        for residue in self.residues:
            chain.add(residue.to_bio())
        return chain

    def __iter__(self):
        return iter(self.residues)

    def __len__(self):
        return len(self.seq)

    def __str__(self):
        return self.seq


class Protein:
    def __init__(self, pdb_id, peptides):
        self.pdb_id = pdb_id
        self.peptides = peptides

    @classmethod
    def from_pdb(cls, pdb_path):
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure('anonym', pdb_path)
        pdb_id = structure.header['idcode'].upper().strip()
        if pdb_id == '':
            # deduce from file name
            pdb_id = os.path.split(pdb_path)[1].split('.')[0] + '(filename)'

        peptides = {}
        for chain in structure.get_chains():
            _id = chain.get_id()
            residues = []
            has_non_residue = False
            for residue in chain:
                abrv = residue.get_resname()
                hetero_flag, res_number, insert_code = residue.get_id()
                if hetero_flag != ' ':
                    continue   # residue from glucose or water
                symbol = VOCAB.abrv_to_symbol(abrv)
                if symbol is None:
                    has_non_residue = True
                    # print(f'has non residue: {abrv}')
                    break
                # filter Hs because not all data include them
                atoms = { atom.get_id(): atom.get_coord() for atom in residue if atom.element != 'H' }
                residues.append(Residue(
                    symbol, atoms, (res_number, insert_code)
                ))
            if has_non_residue or len(residues) == 0:  # not a peptide
                continue
            peptides[_id] = Peptide(_id, residues)
        return cls(pdb_id, peptides)

    def get_id(self):
        return self.pdb_id

    def num_chains(self):
        return len(self.peptides)

    def get_chain(self, name):
        if name in self.peptides:
            return deepcopy(self.peptides[name])
        else:
            return None

    def get_chain_names(self):
        return list(self.peptides.keys())

    def to_bio(self):
        structure = BStructure(id=self.pdb_id)
        model = BModel(id=0)
        for name in self.peptides:
            model.add(self.peptides[name].to_bio())
        structure.add(model)
        return structure

    def to_pdb(self, path, atoms=None):
        if atoms is None:
            bio_structure = self.to_bio()
        else:
            prot = deepcopy(self)
            for _, chain in prot:
                for residue in chain:
                    coordinate = {}
                    for atom in atoms:
                        if atom in residue.coordinate:
                            coordinate[atom] = residue.coordinate[atom]
                    residue.coordinate = coordinate
            bio_structure = prot.to_bio()
        io = PDBIO()
        io.set_structure(bio_structure)
        io.save(path)

    def __iter__(self):
        return iter([(c, self.peptides[c]) for c in self.peptides])

    def __eq__(self, other):
        if not isinstance(other, Protein):
            raise TypeError('Cannot compare other type to Protein')
        for key in self.peptides:
            if key in other.peptides and self.peptides[key].seq == other.peptides[key].seq:
                continue
            else:
                return False
        return True

    def __str__(self):
        res = self.pdb_id + '\n'
        for seq_name in self.peptides:
            res += f'\t{seq_name}: {self.peptides[seq_name]}\n'
        return res


class AgAbComplex:

    num_interface_residues = 48  # from PNAS (view as epitope)

    def __init__(self, antigen: Protein, antibody: Protein, heavy_chain: str, light_chain: str,
                 numbering: str='imgt', skip_epitope_cal=False, skip_validity_check=False) -> None:
        self.heavy_chain = heavy_chain
        self.light_chain = light_chain
        self.numbering = numbering

        self.antigen = antigen
        if skip_validity_check:
            self.antibody, self.cdr_pos = antibody, None
        else:
            self.antibody, self.cdr_pos = self._extract_antibody_info(antibody, numbering)
        self.pdb_id = antigen.get_id()

        if skip_epitope_cal:
            self.epitope = None
        else:
            self.epitope = self._cal_epitope()
    
    @classmethod
    def from_pdb(cls, pdb_path: str, heavy_chain: str, light_chain: str, antigen_chains: List[str],
                 numbering: str='imgt', skip_epitope_cal=False, skip_validity_check=False):
        protein = Protein.from_pdb(pdb_path)
        pdb_id = protein.get_id()
        # print('skipping light chain',protein.get_chain(light_chain))
        ab_peptides = {
            heavy_chain: protein.get_chain(heavy_chain)
            # ,light_chain: protein.get_chain(light_chain)
        }
        ag_peptides = { chain: protein.get_chain(chain) for chain in antigen_chains if protein.get_chain(chain) is not None }
        for chain in antigen_chains:
            assert chain in ag_peptides, f'Antigen chain {chain} has something wrong!'

        antigen = Protein(pdb_id, ag_peptides)
        antibody = Protein(pdb_id, ab_peptides)


        # Print the contents of the dictionaries
        # not necessary
        # print("Antibody Peptides:")
        # for chain_name, peptide in ab_peptides.items():
        #     print(f"Chain: {chain_name}")
        #     print(f"Peptide: {peptide}")

        # print("Antigen Peptides:")
        # for chain_name, peptide in ag_peptides.items():
        #     print(f"Chain: {chain_name}")
        #     print(f"Peptide: {peptide}")

        return cls(antigen, antibody, heavy_chain, light_chain, numbering, skip_epitope_cal, skip_validity_check)

    def _extract_antibody_info(self, antibody: Protein, numbering: str):
        # calculating cdr pos according to number scheme (type_mapping and conserved residues)
        numbering = numbering.lower()
        if numbering == 'imgt':
            _scheme = IMGT
        elif numbering.lower() == 'chothia':
            _scheme = Chothia
            # for i in list(range(1, 27)) + list(range(39, 56)) + list(range(66, 105)) + list(range(118, 130)):
            #     type_mapping[i] = '0'
            # for i in range(27, 39):     # cdr1
            #     type_mapping[i] = '1'
            # for i in range(56, 66):     # cdr2
            #     type_mapping[i] = '2'
            # for i in range(105, 118):   # cdr3
            #     type_mapping[i] = '3'
            # conserved = {
            #     23: ['CYS'],
            #     41: ['TRP'],
            #     104: ['CYS'],
            #     # 118: ['PHE', 'TRP']
            # }
        else:
            raise NotImplementedError(f'Numbering scheme {numbering} not implemented')

        # get cdr/frame denotes
        h_type_mapping, l_type_mapping = {}, {}  # - for non-Fv region, 0 for framework, 1/2/3 for cdr1/2/3

        for lo, hi in [_scheme.HFR1, _scheme.HFR2, _scheme.HFR3, _scheme.HFR4]:
            for i in range(lo, hi + 1):
                h_type_mapping[i] = '0'
        for cdr, (lo, hi) in zip(['1', '2', '3'], [_scheme.H1, _scheme.H2, _scheme.H3]):
            for i in range(lo, hi + 1):
                h_type_mapping[i] = cdr
        h_conserved = _scheme.Hconserve

        for lo, hi in [_scheme.LFR1, _scheme.LFR2, _scheme.LFR3, _scheme.LFR4]:
            for i in range(lo, hi + 1):
                l_type_mapping[i] = '0'
        for cdr, (lo, hi) in zip(['1', '2', '3'], [_scheme.L1, _scheme.L2, _scheme.L3]):
            for i in range(lo, hi + 1):
                l_type_mapping[i] = cdr
        l_conserved = _scheme.Lconserve

        # get variable domain and cdr positions
        selected_peptides, cdr_pos = {}, {}
        # for c, chain_name in zip(['H', 'L'], [self.heavy_chain, self.light_chain]):
        for c, chain_name in zip(['H'], [self.heavy_chain]):
            chain = antibody.get_chain(chain_name)
            # Note: possbly two chains are different segments of a same chain
            assert chain is not None, f'Chain {chain_name} not found in the antibody'
            type_mapping = h_type_mapping if c == 'H' else l_type_mapping
            conserved = h_conserved if c == 'H' else l_conserved
            res_type = ''
            for i in range(len(chain)):
                residue = chain.get_residue(i)
                residue_number = residue.get_id()[0]
                if residue_number in type_mapping:
                    res_type += type_mapping[residue_number]
                    if residue_number in conserved:
                        hit, symbol = False, residue.get_symbol()
                        for conserved_residue in conserved[residue_number]:
                            if symbol == VOCAB.abrv_to_symbol(conserved_residue):
                                hit = True
                                break
                        assert hit, f'Not {conserved[residue_number]} at {residue_number}'
                else:
                    res_type += '-'
            if '0' not in res_type:
                print(self.heavy_chain, self.light_chain, antibody.pdb_id, res_type)
            start, end = res_type.index('0'), res_type.rindex('0')
            for cdr in ['1', '2', '3']:
                cdr_start, cdr_end = res_type.find(cdr), res_type.rfind(cdr)
                assert cdr_start != -1, f'cdr {c}{cdr} not found, residue type: {res_type}'
                start, end = min(start, cdr_start), max(end, cdr_end)
                cdr_pos[f'CDR-{c}{cdr}'] = (cdr_start, cdr_end)
            for cdr in ['1', '2', '3']:
                cdr = f'CDR-{c}{cdr}'
                cdr_start, cdr_end = cdr_pos[cdr]
                cdr_pos[cdr] = (cdr_start - start, cdr_end - start)
            chain = chain.get_span(start, end + 1)  # the length may exceed 130 because of inserted amino acids
            chain.set_id(chain_name)
            selected_peptides[chain_name] = chain

        antibody = Protein(antibody.get_id(), selected_peptides)

        return antibody, cdr_pos

    def _cal_epitope(self):
        ag_rids, ag_xs, ab_xs = [], [], []
        ag_mask, ab_mask = [], []
        cdrh3 = self.get_cdr('H3')
        for _type, protein in zip(['ag', 'ab'], [self.antigen, [('A', cdrh3)]]):
            is_ag = _type == 'ag'
            rids = []
            if is_ag: 
                xs, masks = ag_xs, ag_mask
            else:
                xs, masks = ab_xs, ab_mask
            for chain_name, chain in protein:
                for i, residue in enumerate(chain):
                    bb_coord = residue.get_backbone_coord_map()
                    sc_coord = residue.get_sidechain_coord_map()
                    coord = {}
                    coord.update(bb_coord)
                    coord.update(sc_coord)
                    num_pad = VOCAB.MAX_ATOM_NUMBER - len(coord)
                    x = [coord[key] for key in coord] + [[0, 0, 0] for _ in range(num_pad)]
                    mask = [1 for _ in coord] + [0 for _ in range(num_pad)]
                    rids.append((chain_name, i))
                    xs.append(x)
                    masks.append(mask)
            if is_ag:
                ag_rids = rids
        assert len(ag_xs) != 0, 'No antigen structure!'
        # calculate distance
        ag_xs, ab_xs = np.array(ag_xs), np.array(ab_xs)  # [Nag/ab, M, 3], M == MAX_ATOM_NUM
        ag_mask, ab_mask = np.array(ag_mask).astype('bool'), np.array(ab_mask).astype('bool')  # [Nag/ab, M]
        dist = np.linalg.norm(ag_xs[:, None] - ab_xs[None, :], axis=-1)  # [Nag, Nab, M]
        dist = dist + np.logical_not(ag_mask[:, None] * ab_mask[None, :]) * 1e6  # [Nag, Nab, M]
        min_dists = np.min(np.min(dist, axis=-1), axis=-1)  # [ag_len]
        topk = min(len(min_dists), self.num_interface_residues)
        ind = np.argpartition(-min_dists, -topk)[-topk:]
        epitope = []
        for idx in ind:
            chain_name, i = ag_rids[idx]
            residue = self.antigen.peptides[chain_name].get_residue(i)
            epitope.append((residue, chain_name, i))
        return epitope

    def get_id(self) -> str:
        return self.antibody.pdb_id

    def get_antigen(self) -> Protein:
        return deepcopy(self.antigen)

    def get_epitope(self, cdrh3_pos=None) -> List[Tuple[Residue, str, int]]:
        if cdrh3_pos is not None:
            backup = self.cdr_pos
            self.cdr_pos = {'CDR-H3': [cdrh3_pos[0], cdrh3_pos[1]]}
            epitope = self._cal_epitope()
            self.cdr_pos = backup
            return deepcopy(epitope)
        if self.epitope is None:
            self.epitope = self._cal_epitope()
        return deepcopy(self.epitope)

    def get_interacting_residues(self, dist_cutoff=5) -> Tuple[List[int], List[int]]:
        ag_rids, ag_xs, ab_xs = [], [], []
        for chain_name in self.antigen.get_chain_names():
            chain = self.antigen.get_chain(chain_name)
            for i in range(len(chain)):
                try:
                    x = chain.get_ca_pos(i)
                except KeyError:  # CA position is missing
                    continue
                ag_rids.append((chain_name, i))
                ag_xs.append(x)
        for chain_name in self.antibody.get_chain_names():
            chain = self.antibody.get_chain(chain_name)
            for i in range(len(chain)):
                try:
                    x = chain.get_ca_pos(i)
                except KeyError:
                    continue
                ab_xs.append(x)
        assert len(ag_xs) != 0, 'No antigen structure!'
        # calculate distance
        ag_xs, ab_xs = np.array(ag_xs), np.array(ab_xs)
        dist = np.linalg.norm(ag_xs[:, None, :] - ab_xs[None, :, :], axis=-1)
        min_dists = np.min(dist, axis=1)  # [ag_len]
        topk = min(len(min_dists), self.num_interface_residues)
        ind = np.argpartition(-min_dists, -topk)[-topk:]
        epitope = []
        for idx in ind:
            chain_name, i = ag_rids[idx]
            residue = self.antigen.peptides[chain_name].get_residue(i)
            epitope.append((residue, chain_name, i))
        return

    def get_heavy_chain(self) -> Peptide:
        return self.antibody.get_chain(self.heavy_chain)

    def get_light_chain(self) -> Peptide:
        return self.antibody.get_chain(self.light_chain)

    def get_framework(self, fr):  # H/L + FR + 1/2/3/4
        seg_id = int(fr[-1])
        chain = self.get_heavy_chain() if fr[0] == 'H' else self.get_light_chain()
        begin, end = -1, -1
        if seg_id == 1:
            begin, end = 0, self.get_cdr_pos(fr[0] + str(seg_id))[0]
        elif seg_id == 4:
            begin, end = self.get_cdr_pos(fr[0] + '3')[-1] + 1, len(chain)
        else:
            begin = self.get_cdr_pos(fr[0] + str(seg_id - 1))[-1] + 1
            end = self.get_cdr_pos(fr[0] + str(seg_id))[0]
        return chain.get_span(begin, end)

    def get_cdr_pos(self, cdr='H3'):  # H/L + 1/2/3, return [begin, end] position
        cdr = f'CDR-{cdr}'.upper()
        if cdr in self.cdr_pos:
            return self.cdr_pos[cdr]
        else:
            return None

    def get_cdr(self, cdr='H3'):
        cdr = cdr.upper()
        pos = self.get_cdr_pos(cdr)
        if pos is None:
            return None
        chain = self.get_heavy_chain() if 'H' in cdr else self.get_light_chain()
        return chain.get_span(pos[0], pos[1] + 1)

    def to_pdb(self, path, atoms=None):
        peptides = {}
        for name in self.antigen.get_chain_names():
            peptides[name] = self.antigen.get_chain(name)
        for name in self.antibody.get_chain_names():
            peptides[name] = self.antibody.get_chain(name)
        protein = Protein(self.get_id(), peptides)
        protein.to_pdb(path, atoms)
    
    def __str__(self):
        pdb_info = f'PDB ID: {self.pdb_id}'
        antibody_info = f'Antibody H-{self.heavy_chain} ({len(self.get_heavy_chain())}), ' + \
                        f'L-{self.light_chain} ({len(self.get_light_chain())})'
        antigen_info = f'Antigen Chains: {[(ag, len(self.antigen.get_chain(ag))) for ag in self.antigen.get_chain_names()]}'
        cdr_info = f'CDRs: \n'
        for name in self.cdr_pos:
            chain = self.get_heavy_chain() if 'H' in name else self.get_light_chain()
            start, end = self.cdr_pos[name]
            cdr_info += f'\t{name}: [{start}, {end}], {chain.seq[start:end + 1]}\n'
        epitope_info = f'Epitope: \n'
        residue_map = {}
        for _, chain_name, i in self.get_epitope():
            if chain_name not in residue_map:
                residue_map[chain_name] = []
            residue_map[chain_name].append(i)
        for chain_name in residue_map:
            epitope_info += f'\t{chain_name}: {sorted(residue_map[chain_name])}\n'

        sep = '\n' + '=' * 20 + '\n'
        return sep + pdb_info + '\n' + antibody_info + '\n' + cdr_info + '\n' + antigen_info + '\n' + epitope_info + sep


def merge_to_one_chain(protein: Protein):
    residues = []
    chain_order = sorted(protein.get_chain_names())
    for chain_name in chain_order:
        chain = protein.get_chain(chain_name)
        for _, residue in enumerate(chain.residues):
            residue.id = (len(residues), ' ')
            residues.append(residue)
    return Protein(protein.get_id(), {'A': Peptide('A', residues)})


def fetch_from_pdb(identifier):
    # example identifier: 1FBI
    url = 'https://data.rcsb.org/rest/v1/core/entry/' + identifier
    res = requests.get(url)
    if res.status_code != 200:
        return None
    url = f'https://files.rcsb.org/download/{identifier}.pdb'
    text = requests.get(url)
    data = res.json()
    data['pdb'] = text.text
    return data


VOCAB = AminoAcidVocab()


In [6]:



# from https://github.com/charnley/rmsd/blob/master/rmsd/calculate_rmsd.py
def kabsch_rotation(P, Q):
    """
    Using the Kabsch algorithm with two sets of paired point P and Q, centered
    around the centroid. Each vector set is represented as an NxD
    matrix, where D is the the dimension of the space.
    The algorithm works in three steps:
    - a centroid translation of P and Q (assumed done before this function
      call)
    - the computation of a covariance matrix C
    - computation of the optimal rotation matrix U
    For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm
    Parameters
    ----------
    P : array
        (N,D) matrix, where N is points and D is dimension.
    Q : array
        (N,D) matrix, where N is points and D is dimension.
    Returns
    -------
    U : matrix
        Rotation matrix (D,D)
    """

    # Computation of the covariance matrix
    C = np.dot(np.transpose(P), Q)

    # Computation of the optimal rotation matrix
    # This can be done using singular value decomposition (SVD)
    # Getting the sign of the det(V)*(W) to decide
    # whether we need to correct our rotation matrix to ensure a
    # right-handed coordinate system.
    # And finally calculating the optimal rotation matrix U
    # see http://en.wikipedia.org/wiki/Kabsch_algorithm
    V, S, W = np.linalg.svd(C)
    d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0

    if d:
        S[-1] = -S[-1]
        V[:, -1] = -V[:, -1]

    # Create Rotation matrix U
    U = np.dot(V, W)

    return U


# have been validated with kabsch from RefineGNN
def kabsch(a, b):
    # find optimal rotation matrix to transform a into b
    # a, b are both [N, 3]
    # a_aligned = aR + t
    a, b = np.array(a), np.array(b)
    a_mean = np.mean(a, axis=0)
    b_mean = np.mean(b, axis=0)
    a_c = a - a_mean
    b_c = b - b_mean

    rotation = kabsch_rotation(a_c, b_c)
    # a_aligned = np.dot(a_c, rotation)
    # t = b_mean - np.mean(a_aligned, axis=0)
    # a_aligned += t
    t = b_mean - np.dot(a_mean, rotation)
    a_aligned = np.dot(a, rotation) + t

    return a_aligned, rotation, t
    

# a: [N, 3], b: [N, 3]
def compute_rmsd(a, b, aligned=False):  # amino acids level rmsd
    if aligned:
        a_aligned = a
    else:
        a_aligned, _, _ = kabsch(a, b)
    dist = np.sum((a_aligned - b) ** 2, axis=-1)
    rmsd = np.sqrt(dist.sum() / a.shape[0])
    return float(rmsd)


def kabsch_torch(A, B, requires_grad=False):
    """
    See: https://en.wikipedia.org/wiki/Kabsch_algorithm
    2-D or 3-D registration with known correspondences.
    Registration occurs in the zero centered coordinate system, and then
    must be transported back.
        Args:
        -    A: Torch tensor of shape (N,D) -- Point Cloud to Align (source)
        -    B: Torch tensor of shape (N,D) -- Reference Point Cloud (target)
        Returns:
        -    R: optimal rotation
        -    t: optimal translation
    Test on rotation + translation and on rotation + translation + reflection
        >>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float)
        >>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float)
        >>> B = (R0.mm(A.T)).T
        >>> t0 = torch.tensor([3., 3.])
        >>> B += t0
        >>> R, t = find_rigid_alignment(A, B)
        >>> A_aligned = (R.mm(A.T)).T + t
        >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
        >>> rmsd
        tensor(3.7064e-07)
        >>> B *= torch.tensor([-1., 1.])
        >>> R, t = find_rigid_alignment(A, B)
        >>> A_aligned = (R.mm(A.T)).T + t
        >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
        >>> rmsd
        tensor(3.7064e-07)
    """
    a_mean = A.mean(axis=0)
    b_mean = B.mean(axis=0)
    A_c = A - a_mean
    B_c = B - b_mean
    # Covariance matrix
    H = A_c.T.mm(B_c)
    # U, S, V = torch.svd(H)
    if requires_grad:  # try more times to find a stable solution
        assert not torch.isnan(H).any()
        U, S, Vt = torch.linalg.svd(H)
        num_it = 0
        while torch.min(S) < 1e-3 or torch.min(torch.abs((S**2).view(1,3) - (S**2).view(3,1) + torch.eye(3).to(S.device))) < 1e-2:
            H = H + torch.rand(3,3).to(H.device) * torch.eye(3).to(H.device)
            U, S, Vt = torch.linalg.svd(H)
            num_it += 1

            if num_it > 10:
                raise RuntimeError('SVD consistently numerically unstable! Exitting ... ')
    else:
        U, S, Vt = torch.linalg.svd(H)
    V = Vt.T
    # rms
    d = (torch.linalg.det(U) * torch.linalg.det(V)) < 0.0
    if d:
        SS = torch.diag(torch.tensor([1. for _ in range(len(U) - 1)] + [-1.], device=U.device, dtype=U.dtype))
        U = U @ SS
        # U[:, -1] = -U[:, -1]
    # Rotation matrix
    R = V.mm(U.T)
    # Translation vector
    t = b_mean[None, :] - R.mm(a_mean[None, :].T).T
    t = (t.T).squeeze()
    return R.mm(A.T).T + t, R, t


def batch_kabsch_torch(A, B):
    '''
    A: [B, N, 3]
    B: [B, N, 3]
    '''
    a_mean = A.mean(dim=1, keepdims=True)
    b_mean = B.mean(dim=1, keepdims=True)
    A_c = A - a_mean
    B_c = B - b_mean
    # Covariance matrix
    H = torch.bmm(A_c.transpose(1,2), B_c)  # [B, 3, 3]
    U, S, Vt = torch.linalg.svd(H)  # [B, 3, 3]
    V = Vt.transpose(1, 2)
    # rms
    d = ((torch.linalg.det(U) * torch.linalg.det(V)) < 0.0).long()  # [B]
    nSS = torch.diag(torch.tensor([1. for _ in range(len(U))], device=U.device, dtype=U.dtype))
    SS = torch.diag(torch.tensor([1. for _ in range(len(U) - 1)] + [-1.], device=U.device, dtype=U.dtype))
    bSS = torch.stack([nSS, SS], dim=0)[d]  # [B, 3, 3]
    U = torch.bmm(U, bSS)
    # Rotation matrix
    R = torch.bmm(V, U.transpose(1,2))  # [B, 3, 3]
    # Translation vector
    t = b_mean - torch.bmm(R, a_mean.transpose(1,2)).transpose(1,2)
    A_aligned = torch.bmm(R, A.transpose(1,2)).transpose(1,2) + t
    return A_aligned, R, t

In [7]:


def singleton(cls):
    _instance = {}

    def inner(*args, **kwargs):
        if cls not in _instance:
            _instance[cls] = cls(*args, **kwargs)
        return _instance[cls]
    return inner

In [8]:


@singleton
class ConserveTemplateGenerator:
    def __init__(self, json_path=None):
        if json_path is None:

            folder = 'data/'
            json_path = os.path.join(folder, 'template.json')
            # print(json_path)
        with open(json_path, 'r') as fin:
            self.template_map = json.load(fin)
    
    def _chain_template(self, cplx: AgAbComplex, poses, n_channel, heavy=True):
        chain = cplx.get_heavy_chain() if heavy else cplx.get_light_chain()
        chain_name = 'H' if heavy else 'L'
        hit_map = { pos: False for pos in poses }
        X, hit_index = [], []
        for i, residue in enumerate(chain):
            pos, _ = residue.get_id()
            pos = str(pos)
            if pos in hit_map:
                coord = self.template_map[chain_name][pos]  # N, CA, C, O
                ca, num_sc = coord[1], n_channel - len(coord)
                coord.extend([ca for _ in range(num_sc)])
                hit_index.append(i)
                coord = np.array(coord)
            else:
                coord = [[0, 0, 0] for _ in range(n_channel)]
            X.append(coord)
        # uniform distribution between residues and extension at two ends
        for left_i, right_i in zip(hit_index[:-1], hit_index[1:]):
            left, right = X[left_i], X[right_i]
            span, index_span = right - left, right_i - left_i
            span = span / index_span
            for i in range(left_i + 1, right_i):
                X[i] = X[i - 1] + span
        # start and end
        if hit_index[0] != 0:
            left_i = hit_index[0]
            span = X[left_i] - X[left_i + 1]
            for i in reversed(range(0, left_i)):
                X[i] = X[i + 1] + span
        if hit_index[-1] != len(X) - 1:
            right_i = hit_index[-1]
            span = X[right_i] - X[right_i - 1]
            for i in range(right_i + 1, len(X)):
                X[i] = X[i - 1] + span
        return X, hit_index

    def construct_template(self, cplx: AgAbComplex, n_channel=VOCAB.MAX_ATOM_NUMBER, align=True):
        hc, hc_hit = self._chain_template(cplx, self.template_map['H'], n_channel, heavy=True)
        # lc, lc_hit = self._chain_template(cplx, self.template_map['L'], n_channel, heavy=False)
        template = np.array(hc)  # [N, n_channel, 3]
        if align:
            # align (will be dropped in the future)
            true_X_bb, temp_X_bb = [], []
            chains = [cplx.get_heavy_chain(), cplx.get_light_chain()]
            temps, hits = [hc], [hc_hit]
            for chain, temp, hit in zip(chains, temps, hits):
                for i, residue_temp in zip(hit, temp):
                    residue = chain.get_residue(i)
                    bb = residue.get_backbone_coord_map()
                    for ai, atom in enumerate(VOCAB.backbone_atoms):
                        if atom not in bb:
                            continue
                        true_X_bb.append(bb[atom])
                        temp_X_bb.append(residue_temp[ai])
            true_X_bb, temp_X_bb = np.array(true_X_bb), np.array(temp_X_bb)
            _, Q, t = kabsch(temp_X_bb, true_X_bb)
            template = np.dot(template, Q) + t
        return template


In [9]:




def _generate_chain_data(residues, start):
    backbone_atoms = VOCAB.backbone_atoms
    # Coords, Sequence, residue positions, mask for loss calculation (exclude missing coordinates)
    X, S, res_pos, xloss_mask = [], [], [], []
    # global node
    # coordinates will be set to the center of the chain
    X.append([[0, 0, 0] for _ in range(VOCAB.MAX_ATOM_NUMBER)])  
    S.append(VOCAB.symbol_to_idx(start))
    res_pos.append(0)
    xloss_mask.append([0 for _ in range(VOCAB.MAX_ATOM_NUMBER)])
    # other nodes
    for residue in residues:
        residue_xloss_mask = [0 for _ in range(VOCAB.MAX_ATOM_NUMBER)]
        bb_atom_coord = residue.get_backbone_coord_map()
        sc_atom_coord = residue.get_sidechain_coord_map()
        if 'CA' not in bb_atom_coord:
            for atom in bb_atom_coord:
                ca_x = bb_atom_coord[atom]
                print_log(f'no ca, use {atom}', level='DEBUG')
                break
        else:
            ca_x = bb_atom_coord['CA']
        x = [ca_x for _ in range(VOCAB.MAX_ATOM_NUMBER)]
        
        i = 0
        for atom in backbone_atoms:
            if atom in bb_atom_coord:
                x[i] = bb_atom_coord[atom]
                residue_xloss_mask[i] = 1
            i += 1
        for atom in residue.sidechain:
            if atom in sc_atom_coord:
                x[i] = sc_atom_coord[atom]
                residue_xloss_mask[i] = 1
            i += 1

        X.append(x)
        S.append(VOCAB.symbol_to_idx(residue.get_symbol()))
        res_pos.append(residue.get_id()[0])
        xloss_mask.append(residue_xloss_mask)
    X = np.array(X)
    center = np.mean(X[1:].reshape(-1, 3), axis=0)
    X[0] = center  # set center
    if start == VOCAB.BOA:  # epitope does not have position encoding
        res_pos = [0 for _ in res_pos]
    data = {'X': X, 'S': S, 'residue_pos': res_pos, 'xloss_mask': xloss_mask}
    return data


# use this class to splice the dataset and maintain only one part of it in RAM
# Antibody-Antigen Complex dataset
class E2EDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, save_dir=None, cdr=None, paratope='H3', full_antigen=False, num_entry_per_file=-1, random=False):
        '''
        file_path: path to the dataset
        save_dir: directory to save the processed data
        cdr: which cdr to generate (L1/2/3, H1/2/3) (can be list), None for all including framework
        paratope: which cdr to use as paratope (L1/2/3, H1/2/3) (can be list)
        full_antigen: whether to use the full antigen information
        num_entry_per_file: number of entries in a single file. -1 to save all data into one file 
                            (In-memory dataset)
        '''
        super().__init__()
        self.cdr = cdr
        self.paratope = paratope
        self.full_antigen = full_antigen
        if save_dir is None:
            if not os.path.isdir(file_path):
                save_dir = os.path.split(file_path)[0]
            else:
                save_dir = file_path
            prefix = os.path.split(file_path)[1]
            if '.' in prefix:
                prefix = prefix.split('.')[0]
            save_dir = os.path.join(save_dir, f'{prefix}_processed')
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        metainfo_file = os.path.join(save_dir, '_metainfo')
        self.data: List[AgAbComplex] = []  # list of ABComplex

        # try loading preprocessed files
        need_process = False
        try:
            with open(metainfo_file, 'r') as fin:
                metainfo = json.load(fin)
                self.num_entry = metainfo['num_entry']
                self.file_names = metainfo['file_names']
                self.file_num_entries = metainfo['file_num_entries']
        except FileNotFoundError:
            print_log('No meta-info file found, start processing', level='INFO')
            need_process = True
        except Exception as e:
            print_log(f'Faild to load file {metainfo_file}, error: {e}', level='WARN')
            need_process = True

        if need_process:
            # preprocess
            self.file_names, self.file_num_entries = [], []
            self.preprocess(file_path, save_dir, num_entry_per_file)
            self.num_entry = sum(self.file_num_entries)

            metainfo = {
                'num_entry': self.num_entry,
                'file_names': self.file_names,
                'file_num_entries': self.file_num_entries
            }
            with open(metainfo_file, 'w') as fout:
                json.dump(metainfo, fout)

        self.random = random
        self.cur_file_idx, self.cur_idx_range = 0, (0, self.file_num_entries[0])  # left close, right open
        self._load_part()

        # user defined variables
        self.idx_mapping = [i for i in range(self.num_entry)]
        self.mode = '101'  # H/L/Antigen, 1 for include, 0 for exclude

    def _save_part(self, save_dir, num_entry):
        file_name = os.path.join(save_dir, f'part_{len(self.file_names)}.pkl')
        print_log(f'Saving {file_name} ...')
        file_name = os.path.abspath(file_name)
        if num_entry == -1:
            end = len(self.data)
        else:
            end = min(num_entry, len(self.data))
        with open(file_name, 'wb') as fout:
            pickle.dump(self.data[:end], fout)
        self.file_names.append(file_name)
        self.file_num_entries.append(end)
        self.data = self.data[end:]

    def _load_part(self):
        f = self.file_names[self.cur_file_idx]
        print_log(f'Loading preprocessed file {f}, {self.cur_file_idx + 1}/{len(self.file_names)}')
        with open(f, 'rb') as fin:
            del self.data
            self.data = pickle.load(fin)
        self.access_idx = [i for i in range(len(self.data))]
        if self.random:
            np.random.shuffle(self.access_idx)

    def _check_load_part(self, idx):
        if idx < self.cur_idx_range[0]:
            while idx < self.cur_idx_range[0]:
                end = self.cur_idx_range[0]
                self.cur_file_idx -= 1
                start = end - self.file_num_entries[self.cur_file_idx]
                self.cur_idx_range = (start, end)
            self._load_part()
        elif idx >= self.cur_idx_range[1]:
            while idx >= self.cur_idx_range[1]:
                start = self.cur_idx_range[1]
                self.cur_file_idx += 1
                end = start + self.file_num_entries[self.cur_file_idx]
                self.cur_idx_range = (start, end)
            self._load_part()
        idx = self.access_idx[idx - self.cur_idx_range[0]]
        return idx
     
    def __len__(self):
        return self.num_entry

    ########### load data from file_path and add to self.data ##########
    def preprocess(self, file_path, save_dir, num_entry_per_file):
        '''
        Load data from file_path and add processed data entries to self.data.
        Remember to call self._save_data(num_entry_per_file) to control the number
        of items in self.data (this function will save the first num_entry_per_file
        data and release them from self.data) e.g. call it when len(self.data) reaches
        num_entry_per_file.
        '''
        with open(file_path, 'r') as fin:
            lines = fin.read().strip().split('\n')
        # line_id = 0
        for line in tqdm(lines):
            # if line_id < 206:
            #     line_id += 1
            #     continue
            item = json.loads(line)
            try:
                # print('making AgABComplex')
                cplx = AgAbComplex.from_pdb(
                    item['pdb_data_path'], item['heavy_chain'], item['light_chain'],
                    item['antigen_chains'])
            except AssertionError as e:
                print_log(e, level='ERROR')
                print_log(f'parse {item["pdb"]} pdb failed, skip', level='ERROR')
                continue

            self.data.append(cplx)
            if num_entry_per_file > 0 and len(self.data) >= num_entry_per_file:
                self._save_part(save_dir, num_entry_per_file)
        if len(self.data):
            self._save_part(save_dir, num_entry_per_file)

    ########## override get item ##########
    def __getitem__(self, idx):
        '''
        an example of the returned data
        {
            'X': [n, n_channel, 3],
            'S': [n],
            'cmask': [n],
            'smask': [n],
            'paratope_mask': [n],
            'xloss_mask': [n, n_channel],
            'template': [n, n_channel, 3]
        }
        '''
        idx = self.idx_mapping[idx]
        # print('idx is',idx)

        idx = self._check_load_part(idx)
        # print('load_part idx ',idx)

        item = self.data[idx]
        # print('item is ',item)

        # antigen
        ag_residues = []

        if self.full_antigen:
            # get antigen residues
            ag = item.get_antigen()
            for chain in ag.get_chain_names():
                chain = ag.get_chain(chain)
                for i in range(len(chain)):
                    residue = chain.get_residue(i)
                    ag_residues.append(residue)
        else:
            # get antigen residues (epitope only)
            for residue, chain, i in item.get_epitope():
                # print(residue, chain, i)
                ag_residues.append(residue)
        # print('ag resiues are',ag_residues)
        # generate antigen data
        ag_data = _generate_chain_data(ag_residues, VOCAB.BOA)
        # print('ag_data is ',ag_data)
        hc, lc = item.get_heavy_chain(), item.get_light_chain()
        hc_residues, lc_residues = [], []

        # generate heavy chain data
        for i in range(len(hc)):
            hc_residues.append(hc.get_residue(i))
        hc_data = _generate_chain_data(hc_residues, VOCAB.BOH)
        # print('hc data is',hc_data)
        # generate light chain data
        # for i in range(len(lc)):
        #     lc_residues.append(lc.get_residue(i))
        # lc_data = _generate_chain_data(lc_residues, VOCAB.BOL)
        # print('lc data is',lc_data)

        data = { key: np.concatenate([ag_data[key], hc_data[key]], axis=0) \
                 for key in hc_data}
        # print('data is ',data)

        # smask (sequence) and cmask (coordinates): 0 for fixed, 1 for generate
        # not generate coordinates of global node and antigen 
        cmask = [0 for _ in ag_data['S']] + [0] + [1 for _ in hc_data['S'][1:]]
        # according to the setting of cdr
        if self.cdr is None:
            smask = cmask
        else:
            smask = [0 for _ in range(len(ag_data['S']) + len(hc_data['S']) )]
            cdrs = [self.cdr] if type(self.cdr) == str else self.cdr
            for cdr in cdrs:
                cdr_range = item.get_cdr_pos(cdr)
                offset = len(ag_data['S']) + 1 + (0 if cdr[0] == 'H' else len(hc_data['S']))
                for idx in range(offset + cdr_range[0], offset + cdr_range[1] + 1):
                    smask[idx] = 1

        data['cmask'], data['smask'] = cmask, smask
        # print('masks are ',data['cmask'], data['smask'])
        paratope_mask = [0 for _ in range(len(ag_data['S']) + len(hc_data['S']) )]
        paratope = [self.paratope] if type(self.paratope) == str else self.paratope
        for cdr in paratope:
            cdr_range = item.get_cdr_pos(cdr)
            offset = len(ag_data['S']) + 1 + (0 if cdr[0] == 'H' else len(hc_data['S']))
            for idx in range(offset + cdr_range[0], offset + cdr_range[1] + 1):
                paratope_mask[idx] = 1
        data['paratope_mask'] = paratope_mask
        # print('paratope masks are ',data['cmask'], data['smask'])


        template = ConserveTemplateGenerator().construct_template(item, align=False)
        data['template'] = template

        return data

    @classmethod
    def collate_fn(cls, batch):
        keys = ['X', 'S', 'smask', 'cmask', 'paratope_mask', 'residue_pos', 'template', 'xloss_mask']
        types = [torch.float, torch.long, torch.bool, torch.bool, torch.bool, torch.long, torch.float, torch.bool]
        res = {}
        for key, _type in zip(keys, types):
            val = []
            for item in batch:
                val.append(torch.tensor(item[key], dtype=_type))
            res[key] = torch.cat(val, dim=0)
        lengths = [len(item['S']) for item in batch]
        res['lengths'] = torch.tensor(lengths, dtype=torch.long)
        return res


In [10]:

class TrainConfig:
    def __init__(self, save_dir, lr, max_epoch, warmup=0,
                 metric_min_better=True, patience=3,
                 grad_clip=None, save_topk=-1,  # -1 for save all
                 **kwargs):
        self.save_dir = save_dir
        self.lr = lr
        self.max_epoch = max_epoch
        self.warmup = warmup
        self.metric_min_better = metric_min_better
        self.patience = patience
        self.grad_clip = grad_clip
        self.save_topk = save_topk
        self.__dict__.update(kwargs)

    def add_parameter(self, **kwargs):
        self.__dict__.update(kwargs)

    def __str__(self):
        return str(self.__class__) + ': ' + str(self.__dict__)


class Trainer:
    def __init__(self, model, train_loader, valid_loader, config):
        self.model = model
        self.config = config
        self.optimizer = self.get_optimizer()
        sched_config = self.get_scheduler(self.optimizer)
        if sched_config is None:
            sched_config = {
                'scheduler': None,
                'frequency': None
            }
        self.scheduler = sched_config['scheduler']
        self.sched_freq = sched_config['frequency']
        self.train_loader = train_loader
        self.valid_loader = valid_loader

        # distributed training
        self.local_rank = -1

        # log
        self.version = self._get_version()
        self.config.save_dir = os.path.join(self.config.save_dir, f'version_{self.version}')
        self.model_dir = os.path.join(self.config.save_dir, 'checkpoint')
        self.writer = None  # initialize right before training
        self.writer_buffer = {}

        # training process recording
        self.global_step = 0
        self.valid_global_step = 0
        self.epoch = 0
        self.last_valid_metric = None
        self.topk_ckpt_map = []  # smaller index means better ckpt
        self.patience = self.config.patience

    @classmethod
    def to_device(cls, data, device):
        if isinstance(data, dict):
            for key in data:
                data[key] = cls.to_device(data[key], device)
        elif isinstance(data, list) or isinstance(data, tuple):
            res = [cls.to_device(item, device) for item in data]
            data = type(data)(res)
        elif hasattr(data, 'to'):
            data = data.to(device)
        return data

    def _is_main_proc(self):
        return self.local_rank == 0 or self.local_rank == -1

    def _get_version(self):
        version, pattern = -1, r'version_(\d+)'
        if os.path.exists(self.config.save_dir):
            for fname in os.listdir(self.config.save_dir):
                ver = re.findall(pattern, fname)
                if len(ver):
                    version = max(int(ver[0]), version)
        return version + 1

    def _train_epoch(self, device):
        if self.train_loader.sampler is not None and self.local_rank != -1:  # distributed
            self.train_loader.sampler.set_epoch(self.epoch)
        t_iter = tqdm(self.train_loader) if self._is_main_proc() else self.train_loader
        for batch in t_iter:
            # print(batch,device)
            batch = self.to_device(batch, device)
            # print(batch,device,self.global_step)
            loss = self.train_step(batch, self.global_step)
            # print(loss)
            self.optimizer.zero_grad()
            loss.backward()
            if self.config.grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
            self.optimizer.step()
            if hasattr(t_iter, 'set_postfix'):
                t_iter.set_postfix(loss=loss.item(), version=self.version)
            self.global_step += 1
            if self.sched_freq == 'batch':
                self.scheduler.step()
        if self.sched_freq == 'epoch':
            self.scheduler.step()
    
    def _valid_epoch(self, device):
        metric_arr = []
        self.model.eval()
        with torch.no_grad():
            t_iter = tqdm(self.valid_loader) if self._is_main_proc() else self.valid_loader
            for batch in t_iter:
                batch = self.to_device(batch, device)
                metric = self.valid_step(batch, self.valid_global_step)
                metric_arr.append(metric.cpu().item())
                self.valid_global_step += 1
        self.model.train()
        # judge
        valid_metric = np.mean(metric_arr)
        if self._metric_better(valid_metric):
            self.patience = self.config.patience
            if self._is_main_proc():
                save_path = os.path.join(self.model_dir, f'epoch{self.epoch}_step{self.global_step}.ckpt')
                module_to_save = self.model.module if self.local_rank == 0 else self.model
                torch.save(module_to_save, save_path)
                self._maintain_topk_checkpoint(valid_metric, save_path)
        else:
            self.patience -= 1
        self.last_valid_metric = valid_metric
        # write valid_metric
        for name in self.writer_buffer:
            value = np.mean(self.writer_buffer[name])
            self.log(name, value, self.epoch)
        self.writer_buffer = {}
    
    def _metric_better(self, new):
        old = self.last_valid_metric
        if old is None:
            return True
        if self.config.metric_min_better:
            return new < old
        else:
            return old < new

    def _maintain_topk_checkpoint(self, valid_metric, ckpt_path):
        topk = self.config.save_topk
        if self.config.metric_min_better:
            better = lambda a, b: a < b
        else:
            better = lambda a, b: a > b
        insert_pos = len(self.topk_ckpt_map)
        for i, (metric, _) in enumerate(self.topk_ckpt_map):
            if better(valid_metric, metric):
                insert_pos = i
                break
        self.topk_ckpt_map.insert(insert_pos, (valid_metric, ckpt_path))

        # maintain topk
        if topk > 0:
            while len(self.topk_ckpt_map) > topk:
                last_ckpt_path = self.topk_ckpt_map[-1][1]
                os.remove(last_ckpt_path)
                self.topk_ckpt_map.pop()

        # save map
        topk_map_path = os.path.join(self.model_dir, 'topk_map.txt')
        with open(topk_map_path, 'w') as fout:
            for metric, path in self.topk_ckpt_map:
                fout.write(f'{metric}: {path}\n')

    def train(self, device_ids, local_rank):
        # set local rank
        self.local_rank = local_rank
        # init writer
        if self._is_main_proc():
            self.writer = SummaryWriter(self.config.save_dir)
            if not os.path.exists(self.model_dir):
                os.makedirs(self.model_dir)
            with open(os.path.join(self.config.save_dir, 'namespace.json'), 'w') as fout:
                json.dump(self.config.__dict__, fout, indent=2)
        # main device
        main_device_id = local_rank if local_rank != -1 else device_ids[0]
        device = torch.device('cpu' if main_device_id == -1 else f'cuda:{main_device_id}')
        self.model.to(device)
        if local_rank != -1:
            print_log(f'Using data parallel, local rank {local_rank}, all {device_ids}')
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model, device_ids=[local_rank], output_device=local_rank
            )
        else:
            print_log(f'training on {device_ids}')
        for _ in range(self.config.max_epoch):
            print_log(f'epoch{self.epoch} starts') if self._is_main_proc() else 1
            self._train_epoch(device)
            print_log(f'validating ...') if self._is_main_proc() else 1
            self._valid_epoch(device)
            self.epoch += 1
            if self.patience <= 0:
                break

    def log(self, name, value, step, val=False):
        if self._is_main_proc():
            if isinstance(value, torch.Tensor):
                value = value.cpu().item()
            if val:
                if name not in self.writer_buffer:
                    self.writer_buffer[name] = []
                self.writer_buffer[name].append(value)
            else:
                self.writer.add_scalar(name, value, step)

    ########## Overload these functions below ##########
    # define optimizer
    def get_optimizer(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr)
        return optimizer

    # scheduler example: linear. Return None if no scheduler is needed.
    def get_scheduler(self, optimizer):
        lam = lambda epoch: 1 / (epoch + 1)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lam)
        return {
            'scheduler': scheduler,
            'frequency': 'epoch' # or batch
        }

    # train step, note that batch should be dict/list/tuple/instance. Objects with .to(device) attribute will be automatically moved to the same device as the model
    def train_step(self, batch, batch_idx):
        print(batch, batch_idx)
        loss = self.model(batch)
        self.log('Loss/train', loss, batch_idx)
        return loss

    # validation step
    def valid_step(self, batch, batch_idx):
        loss = self.model(batch)
        self.log('Loss/validation', loss, batch_idx, val=True)
        return loss


In [11]:

class AMEGNN(nn.Module):

    def __init__(self, in_node_nf, hidden_nf, out_node_nf, n_channel, channel_nf,
                 radial_nf, in_edge_nf=0, act_fn=nn.SiLU(), n_layers=4,
                 residual=True, dropout=0.1, dense=False):
        super().__init__()
        '''
        :param in_node_nf: Number of features for 'h' at the input
        :param hidden_nf: Number of hidden features
        :param out_node_nf: Number of features for 'h' at the output
        :param n_channel: Number of channels of coordinates
        :param in_edge_nf: Number of features for the edge features
        :param act_fn: Non-linearity
        :param n_layers: Number of layer for the EGNN
        :param residual: Use residual connections, we recommend not changing this one
        :param dropout: probability of dropout
        :param dense: if dense, then context states will be concatenated for all layers,
                      coordination will be averaged
        '''
        self.hidden_nf = hidden_nf
        self.n_layers = n_layers

        self.dropout = nn.Dropout(dropout)

        self.linear_in = nn.Linear(in_node_nf, self.hidden_nf)

        self.dense = dense
        if dense:
            self.linear_out = nn.Linear(self.hidden_nf * (n_layers + 1), out_node_nf)
        else:
            self.linear_out = nn.Linear(self.hidden_nf, out_node_nf)

        for i in range(0, n_layers):
            self.add_module(f'gcl_{i}', AM_E_GCL(
                self.hidden_nf, self.hidden_nf, self.hidden_nf, n_channel, channel_nf, radial_nf,
                edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual, dropout=dropout
            ))
        self.out_layer = AM_E_GCL(
            self.hidden_nf, self.hidden_nf, self.hidden_nf, n_channel, channel_nf,
            radial_nf, edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual
        )
    
    def forward(self, h, x, edges, channel_attr, channel_weights, ctx_edge_attr=None):
        h = self.linear_in(h)
        h = self.dropout(h)

        ctx_states, ctx_coords = [], []
        for i in range(0, self.n_layers):
            h, x = self._modules[f'gcl_{i}'](
                h, edges, x, channel_attr, channel_weights,
                edge_attr=ctx_edge_attr)
            ctx_states.append(h)
            ctx_coords.append(x)

        h, x = self.out_layer(
            h, edges, x, channel_attr, channel_weights,
            edge_attr=ctx_edge_attr)
        ctx_states.append(h)
        ctx_coords.append(x)
        if self.dense:
            h = torch.cat(ctx_states, dim=-1)
            x = torch.mean(torch.stack(ctx_coords), dim=0)
        h = self.dropout(h)
        h = self.linear_out(h)
        return h, x

'''
Below are the implementation of the adaptive multi-channel message passing mechanism
'''

@singleton
class RollerPooling(nn.Module):
    '''
    Adaptive average pooling for the adaptive scaler
    '''
    def __init__(self, n_channel) -> None:
        super().__init__()
        self.n_channel = n_channel
        with torch.no_grad():
            pool_matrix = []
            ones = torch.ones((n_channel, n_channel), dtype=torch.float)
            for i in range(n_channel):
                # i start from 0 instead of 1 !!! (less readable but higher implemetation efficiency)
                window_size = n_channel - i
                mat = torch.triu(ones) - torch.triu(ones, diagonal=window_size)
                pool_matrix.append(mat / window_size)
            self.pool_matrix = torch.stack(pool_matrix)
    
    def forward(self, hidden, target_size):
        '''
        :param hidden: [n_edges, n_channel]
        :param target_size: [n_edges]
        '''
        pool_mat = self.pool_matrix.to(hidden.device).type(hidden.dtype)
        pool_mat = pool_mat[target_size - 1]  # [n_edges, n_channel, n_channel]
        hidden = hidden.unsqueeze(-1)  # [n_edges, n_channel, 1]
        return torch.bmm(pool_mat, hidden)  # [n_edges, n_channel, 1]


class AM_E_GCL(nn.Module):
    '''
    Adaptive Multi-Channel E(n) Equivariant Convolutional Layer
    '''

    def __init__(self, input_nf, output_nf, hidden_nf, n_channel, channel_nf, radial_nf,
                 edges_in_d=0, node_attr_d=0, act_fn=nn.SiLU(), residual=True, attention=False,
                 normalize=False, coords_agg='mean', tanh=False, dropout=0.1):
        super(AM_E_GCL, self).__init__()

        input_edge = input_nf * 2
        self.residual = residual
        self.attention = attention
        self.normalize = normalize
        self.coords_agg = coords_agg
        self.tanh = tanh
        self.epsilon = 1e-8

        self.dropout = nn.Dropout(dropout)

        input_edge = input_nf * 2
        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + radial_nf + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)
        self.radial_linear = nn.Linear(channel_nf ** 2, radial_nf)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf + node_attr_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        layer = nn.Linear(hidden_nf, n_channel, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        coord_mlp = []
        coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
        coord_mlp.append(act_fn)
        coord_mlp.append(layer)
        if self.tanh:
            coord_mlp.append(nn.Tanh())
        self.coord_mlp = nn.Sequential(*coord_mlp)

        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

    def edge_model(self, source, target, radial, edge_attr):
        '''
        :param source: [n_edge, input_size]
        :param target: [n_edge, input_size]
        :param radial: [n_edge, d, d]
        :param edge_attr: [n_edge, edge_dim]
        '''
        radial = radial.reshape(radial.shape[0], -1)  # [n_edge, d ^ 2]

        if edge_attr is None:  # Unused.
            out = torch.cat([source, target, radial], dim=1)
        else:
            out = torch.cat([source, target, radial, edge_attr], dim=1)
        out = self.edge_mlp(out)
        out = self.dropout(out)

        if self.attention:
            att_val = self.att_mlp(out)
            out = out * att_val
        return out

    def node_model(self, x, edge_index, edge_attr, node_attr):
        '''
        :param x: [bs * n_node, input_size]
        :param edge_index: list of [n_edge], [n_edge]
        :param edge_attr: [n_edge, hidden_size], refers to message from i to j
        :param node_attr: [bs * n_node, node_dim]
        '''
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0))  # [bs * n_node, hidden_size]
        # print_log(f'agg1, {torch.isnan(agg).sum()}', level='DEBUG')
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)  # [bs * n_node, input_size + hidden_size]
        # print_log(f'agg, {torch.isnan(agg).sum()}', level='DEBUG')
        out = self.node_mlp(agg)  # [bs * n_node, output_size]
        # print_log(f'out, {torch.isnan(out).sum()}', level='DEBUG')
        out = self.dropout(out)
        if self.residual:
            out = x + out
        return out, agg

    def coord_model(self, coord, edge_index, coord_diff, edge_feat, channel_weights):
        '''
        coord: [bs * n_node, n_channel, d]
        edge_index: list of [n_edge], [n_edge]
        coord_diff: [n_edge, n_channel, d]
        edge_feat: [n_edge, hidden_size]
        channel_weights: [N, n_channel]
        '''
        row, col = edge_index

        # first pooling, then element-wise multiply
        n_channel = channel_weights.shape[-1]
        edge_feat = self.coord_mlp(edge_feat)  # [n_edge, n_channel]
        channel_sum = (channel_weights != 0).long().sum(-1)  # [N]
        pooled_edge_feat = RollerPooling(n_channel)(edge_feat, channel_sum[row])  # [n_edge, n_channel, 1]
        trans = coord_diff * pooled_edge_feat  # [n_edge, n_channel, d]

        # aggregate
        if self.coords_agg == 'sum':
            agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
        elif self.coords_agg == 'mean':
            agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))  # [bs * n_node, n_channel, d]
        else:
            raise Exception('Wrong coords_agg parameter' % self.coords_agg)
        coord = coord + agg
        return coord

    def forward(self, h, edge_index, coord, channel_attr, channel_weights,
                edge_attr=None, node_attr=None):
        '''
        h: [bs * n_node, hidden_size]
        edge_index: list of [n_row] and [n_col] where n_row == n_col (with no cutoff, n_row == bs * n_node * (n_node - 1))
        coord: [bs * n_node, n_channel, d]
        channel_attr: [bs * n_node, n_channel, channel_nf]
        channel_weights: [bs * n_node, n_channel]
        '''
        row, col = edge_index

        radial, coord_diff = coord2radial(edge_index, coord, channel_attr, channel_weights, self.radial_linear)

        edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)  # [n_edge, hidden_size]
        coord = self.coord_model(coord, edge_index, coord_diff, edge_feat, channel_weights)    # [bs * n_node, n_channel, d]
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)

        return h, coord


def unsorted_segment_sum(data, segment_ids, num_segments):
    '''
    :param data: [n_edge, *dimensions]
    :param segment_ids: [n_edge]
    :param num_segments: [bs * n_node]
    '''
    expand_dims = tuple(data.shape[1:])
    result_shape = (num_segments, ) + expand_dims
    for _ in expand_dims:
        segment_ids = segment_ids.unsqueeze(-1)
    segment_ids = segment_ids.expand(-1, *expand_dims)
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    result.scatter_add_(0, segment_ids, data)
    return result


def unsorted_segment_mean(data, segment_ids, num_segments):
    '''
    :param data: [n_edge, *dimensions]
    :param segment_ids: [n_edge]
    :param num_segments: [bs * n_node]
    '''
    expand_dims = tuple(data.shape[1:])
    result_shape = (num_segments, ) + expand_dims
    for _ in expand_dims:
        segment_ids = segment_ids.unsqueeze(-1)
    segment_ids = segment_ids.expand(-1, *expand_dims)
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    count = data.new_full(result_shape, 0)
    result.scatter_add_(0, segment_ids, data)
    count.scatter_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)


CONSTANT = 1
NUM_SEG = 1  # if you do not have enough memory or you have large attr_size, increase this parameter

def coord2radial(edge_index, coord, attr, channel_weights, linear_map):
    '''
    :param edge_index: tuple([n_edge], [n_edge]) which is tuple of (row, col)
    :param coord: [N, n_channel, d]
    :param attr: [N, n_channel, attr_size], attribute embedding of each channel
    :param channel_weights: [N, n_channel], weights of different channels
    :param linear_map: nn.Linear, map features to d_out
    :param num_seg: split row/col into segments to reduce memory cost
    '''
    row, col = edge_index
    
    radials = []

    seg_size = (len(row) + NUM_SEG - 1) // NUM_SEG

    for i in range(NUM_SEG):
        start = i * seg_size
        end = min(start + seg_size, len(row))
        if end <= start:
            break
        seg_row, seg_col = row[start:end], col[start:end]

        coord_msg = torch.norm(
            coord[seg_row].unsqueeze(2) - coord[seg_col].unsqueeze(1),  # [n_edge, n_channel, n_channel, d]
            dim=-1, keepdim=False)  # [n_edge, n_channel, n_channel]
        
        coord_msg = coord_msg * torch.bmm(
            channel_weights[seg_row].unsqueeze(2),
            channel_weights[seg_col].unsqueeze(1)
            )  # [n_edge, n_channel, n_channel]
        
        radial = torch.bmm(
            attr[seg_row].transpose(-1, -2),  # [n_edge, attr_size, n_channel]
            coord_msg)  # [n_edge, attr_size, n_channel]
        radial = torch.bmm(radial, attr[seg_col])  # [n_edge, attr_size, attr_size]
        radial = radial.reshape(radial.shape[0], -1)  # [n_edge, attr_size * attr_size]
        radial_norm = torch.norm(radial, dim=-1, keepdim=True) + CONSTANT  # post norm
        radial = linear_map(radial) / radial_norm # [n_edge, d_out]

        radials.append(radial)
    
    radials = torch.cat(radials, dim=0)  # [N_edge, d_out]

    # generate coord_diff by first mean src then minused by dst
    # message passed from col to row
    channel_mask = (channel_weights != 0).long()  # [N, n_channel]
    channel_sum = channel_mask.sum(-1)  # [N]
    pooled_col_coord = (coord[col] * channel_mask[col].unsqueeze(-1)).sum(1)  # [n_edge, d]
    pooled_col_coord = pooled_col_coord / channel_sum[col].unsqueeze(-1)  # [n_edge, d], denominator cannot be 0 since no pad node exists
    coord_diff = coord[row] - pooled_col_coord.unsqueeze(1)  # [n_edge, n_channel, d]

    return radials, coord_diff

In [12]:

class AMEncoder(nn.Module):

    def __init__(self, in_node_nf, hidden_nf, out_node_nf, n_channel, channel_nf,
                 radial_nf, in_edge_nf=0, act_fn=nn.SiLU(), n_layers=4,
                 residual=True, dropout=0.1, dense=False):
        super().__init__()
        '''
        :param in_node_nf: Number of features for 'h' at the input
        :param hidden_nf: Number of hidden features
        :param out_node_nf: Number of features for 'h' at the output
        :param n_channel: Number of channels of coordinates
        :param in_edge_nf: Number of features for the edge features
        :param act_fn: Non-linearity
        :param n_layers: Number of layer for the EGNN
        :param residual: Use residual connections, we recommend not changing this one
        :param dropout: probability of dropout
        :param dense: if dense, then context states will be concatenated for all layers,
                      coordination will be averaged
        '''
        self.hidden_nf = hidden_nf
        self.n_layers = n_layers

        self.dropout = nn.Dropout(dropout)

        self.linear_in = nn.Linear(in_node_nf, self.hidden_nf)

        self.dense = dense
        if dense:
            self.linear_out = nn.Linear(self.hidden_nf * (n_layers + 1), out_node_nf)
        else:
            self.linear_out = nn.Linear(self.hidden_nf, out_node_nf)

        for i in range(0, n_layers):
            self.add_module(f'ctx_gcl_{i}', AM_E_GCL(
                self.hidden_nf, self.hidden_nf, self.hidden_nf, n_channel, channel_nf, radial_nf,
                edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual, dropout=dropout
            ))
            self.add_module(f'inter_gcl_{i}', AM_E_GCL(
                self.hidden_nf, self.hidden_nf, self.hidden_nf, n_channel, channel_nf, radial_nf,
                edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual, dropout=dropout
            ))
        self.out_layer = AM_E_GCL(
            self.hidden_nf, self.hidden_nf, self.hidden_nf, n_channel, channel_nf,
            radial_nf, edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual
        )
    
    def forward(self, h, x, ctx_edges, inter_mask, inter_x, inter_edges, update_mask, inter_update_mask, channel_attr, channel_weights,
                ctx_edge_attr=None):
        h = self.linear_in(h)
        h = self.dropout(h)
        inter_h = h[inter_mask]
        inter_channel_attr = channel_attr[inter_mask]
        inter_channel_weights = channel_weights[inter_mask]

        ctx_states, ctx_coords, inter_coords = [], [], []
        for i in range(0, self.n_layers):
            h, x = self._modules[f'ctx_gcl_{i}'](
                h, ctx_edges, x, channel_attr, channel_weights,
                edge_attr=ctx_edge_attr)
            # synchronization of the shadow paratope (native -> shadow)
            inter_h = inter_h.clone()
            inter_h[inter_update_mask] = h[update_mask]
            inter_h, inter_x = self._modules[f'inter_gcl_{i}'](
                inter_h, inter_edges, inter_x, inter_channel_attr, inter_channel_weights
            )
            # synchronization of the shadow paratope (shadow -> native)
            h = h.clone()
            h[inter_mask] = inter_h
            ctx_states.append(h)
            ctx_coords.append(x)
            inter_coords.append(inter_x)

        h, x = self.out_layer(
            h, ctx_edges, x, channel_attr, channel_weights,
            edge_attr=ctx_edge_attr)
        ctx_states.append(h)
        ctx_coords.append(x)
        if self.dense:
            h = torch.cat(ctx_states, dim=-1)
            x = torch.mean(torch.stack(ctx_coords), dim=0)
            inter_x = torch.mean(torch.stack(inter_coords), dim=0)
        h = self.dropout(h)
        h = self.linear_out(h)
        return h, x, inter_x

In [13]:


def sequential_and(*tensors):
    res = tensors[0]
    for mat in tensors[1:]:
        res = torch.logical_and(res, mat)
    return res


def sequential_or(*tensors):
    res = tensors[0]
    for mat in tensors[1:]:
        res = torch.logical_or(res, mat)
    return res


def graph_to_batch(tensor, batch_id, padding_value=0, mask_is_pad=True):
    '''
    :param tensor: [N, D1, D2, ...]
    :param batch_id: [N]
    :param mask_is_pad: 1 in the mask indicates padding if set to True
    '''
    lengths = scatter_sum(torch.ones_like(batch_id), batch_id)  # [bs]
    bs, max_n = lengths.shape[0], torch.max(lengths)
    batch = torch.ones((bs, max_n, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) * padding_value
    # generate pad mask: 1 for pad and 0 for data
    pad_mask = torch.zeros((bs, max_n + 1), dtype=torch.long, device=tensor.device)
    pad_mask[(torch.arange(bs, device=tensor.device), lengths)] = 1
    pad_mask = (torch.cumsum(pad_mask, dim=-1)[:, :-1]).bool()
    data_mask = torch.logical_not(pad_mask)
    # fill data
    batch[data_mask] = tensor
    mask = pad_mask if mask_is_pad else data_mask
    return batch, mask


def _knn_edges(X, AP, src_dst, atom_pos_pad_idx, k_neighbors, batch_info, given_dist=None):
    '''
    :param X: [N, n_channel, 3], coordinates
    :param AP: [N, n_channel], atom position with pad type need to be ignored
    :param src_dst: [Ef, 2], full possible edges represented in (src, dst)
    :param given_dist: [Ef], given distance of edges
    '''
    offsets, batch_id, max_n, gni2lni = batch_info

    BIGINT = 1e10  # assign a large distance to invalid edges
    N = X.shape[0]
    if given_dist is None:
        dist = X[src_dst]  # [Ef, 2, n_channel, 3]
        dist = dist[:, 0].unsqueeze(2) - dist[:, 1].unsqueeze(1)  # [Ef, n_channel, n_channel, 3]
        dist = torch.norm(dist, dim=-1)  # [Ef, n_channel, n_channel]
        pos_pad = AP[src_dst] == atom_pos_pad_idx # [Ef, 2, n_channel]
        pos_pad = torch.logical_or(pos_pad[:, 0].unsqueeze(2), pos_pad[:, 1].unsqueeze(1))  # [Ef, n_channel, n_channel]
        dist = dist + pos_pad * BIGINT  # [Ef, n_channel, n_channel]
        del pos_pad  # release memory
        dist = torch.min(dist.reshape(dist.shape[0], -1), dim=1)[0]  # [Ef]
    else:
        dist = given_dist
    src_dst = src_dst.transpose(0, 1)  # [2, Ef]

    dist_mat = torch.ones(N, max_n, device=dist.device, dtype=dist.dtype) * BIGINT  # [N, max_n]
    dist_mat[(src_dst[0], gni2lni[src_dst[1]])] = dist
    del dist
    dist_neighbors, dst = torch.topk(dist_mat, k_neighbors, dim=-1, largest=False)  # [N, topk]

    src = torch.arange(0, N, device=dst.device).unsqueeze(-1).repeat(1, k_neighbors)
    src, dst = src.flatten(), dst.flatten()
    dist_neighbors = dist_neighbors.flatten()
    is_valid = dist_neighbors < BIGINT
    src = src.masked_select(is_valid)
    dst = dst.masked_select(is_valid)

    dst = dst + offsets[batch_id[src]]  # mapping from local to global node index

    edges = torch.stack([src, dst])  # message passed from dst to src
    return edges  # [2, E]


class EdgeConstructor:
    def __init__(self, boa_idx, boh_idx, bol_idx, atom_pos_pad_idx, ag_seg_id) -> None:
        self.boa_idx, self.boh_idx, self.bol_idx = boa_idx, boh_idx, bol_idx
        self.atom_pos_pad_idx = atom_pos_pad_idx
        self.ag_seg_id = ag_seg_id

        # buffer
        self._reset_buffer()

    def _reset_buffer(self):
        self.row = None
        self.col = None
        self.row_global = None
        self.col_global = None
        self.row_seg = None
        self.col_seg = None
        self.offsets = None
        self.max_n = None
        self.gni2lni = None
        self.not_global_edges = None

    def get_batch_edges(self, batch_id):
        # construct tensors to map between global / local node index
        lengths = scatter_sum(torch.ones_like(batch_id), batch_id)  # [bs]
        N, max_n = batch_id.shape[0], torch.max(lengths)
        offsets = F.pad(torch.cumsum(lengths, dim=0)[:-1], pad=(1, 0), value=0)  # [bs]
        # global node index to local index. lni2gni can be implemented as lni + offsets[batch_id]
        gni = torch.arange(N, device=batch_id.device)
        gni2lni = gni - offsets[batch_id]  # [N]

        # all possible edges (within the same graph)
        # same bid (get rid of self-loop and none edges)
        same_bid = torch.zeros(N, max_n, device=batch_id.device)
        same_bid[(gni, lengths[batch_id] - 1)] = 1
        same_bid = 1 - torch.cumsum(same_bid, dim=-1)
        # shift right and pad 1 to the left
        same_bid = F.pad(same_bid[:, :-1], pad=(1, 0), value=1)
        same_bid[(gni, gni2lni)] = 0  # delete self loop
        row, col = torch.nonzero(same_bid).T  # [2, n_edge_all]
        col = col + offsets[batch_id[row]]  # mapping from local to global node index
        return (row, col), (offsets, max_n, gni2lni)

    def _prepare(self, S, batch_id, segment_ids) -> None:
        (row, col), (offsets, max_n, gni2lni) = self.get_batch_edges(batch_id)

        # not global edges
        is_global = sequential_or(S == self.boa_idx, S == self.boh_idx, S == self.bol_idx) # [N]
        row_global, col_global = is_global[row], is_global[col]
        not_global_edges = torch.logical_not(torch.logical_or(row_global, col_global))
        
        # segment ids
        row_seg, col_seg = segment_ids[row], segment_ids[col]

        # add to buffer
        self.row, self.col = row, col
        self.offsets, self.max_n, self.gni2lni = offsets, max_n, gni2lni
        self.row_global, self.col_global = row_global, col_global
        self.not_global_edges = not_global_edges
        self.row_seg, self.col_seg = row_seg, col_seg

    def _construct_inner_edges(self, X, batch_id, k_neighbors, atom_pos):
        row, col = self.row, self.col
        # all possible ctx edges: same seg, not global
        select_edges = torch.logical_and(self.row_seg == self.col_seg, self.not_global_edges)
        ctx_all_row, ctx_all_col = row[select_edges], col[select_edges]
        # ctx edges
        inner_edges = _knn_edges(
            X, atom_pos, torch.stack([ctx_all_row, ctx_all_col]).T,
            self.atom_pos_pad_idx, k_neighbors,
            (self.offsets, batch_id, self.max_n, self.gni2lni))
        return inner_edges

    def _construct_outer_edges(self, X, batch_id, k_neighbors, atom_pos):
        row, col = self.row, self.col
        # all possible inter edges: not same seg, not global
        select_edges = torch.logical_and(self.row_seg != self.col_seg, self.not_global_edges)
        inter_all_row, inter_all_col = row[select_edges], col[select_edges]
        outer_edges = _knn_edges(
            X, atom_pos, torch.stack([inter_all_row, inter_all_col]).T,
            self.atom_pos_pad_idx, k_neighbors,
            (self.offsets, batch_id, self.max_n, self.gni2lni))
        return outer_edges

    def _construct_global_edges(self):
        row, col = self.row, self.col
        # edges between global and normal nodes
        select_edges = torch.logical_and(self.row_seg == self.col_seg, torch.logical_not(self.not_global_edges))
        global_normal = torch.stack([row[select_edges], col[select_edges]])  # [2, nE]
        # edges between global and global nodes
        select_edges = torch.logical_and(self.row_global, self.col_global) # self-loop has been deleted
        global_global = torch.stack([row[select_edges], col[select_edges]])  # [2, nE]
        return global_normal, global_global

    def _construct_seq_edges(self):
        row, col = self.row, self.col
        # add additional edge to neighbors in 1D sequence (except epitope)
        select_edges = sequential_and(
            torch.logical_or((row - col) == 1, (row - col) == -1),  # adjacent in the graph
            self.not_global_edges,  # not global edges (also ensure the edges are in the same segment)
            self.row_seg != self.ag_seg_id  # not epitope
        )
        seq_adj = torch.stack([row[select_edges], col[select_edges]])  # [2, nE]
        return seq_adj

    @torch.no_grad()
    def construct_edges(self, X, S, batch_id, k_neighbors, atom_pos, segment_ids):
        '''
        Memory efficient with complexity of O(Nn) where n is the largest number of nodes in the batch
        '''
        # prepare inputs
        self._prepare(S, batch_id, segment_ids)

        ctx_edges, inter_edges = [], []

        # edges within chains
        inner_edges = self._construct_inner_edges(X, batch_id, k_neighbors, atom_pos)
        # edges between global nodes and normal/global nodes
        global_normal, global_global = self._construct_global_edges()
        # edges on the 1D sequence
        seq_edges = self._construct_seq_edges()

        # construct context edges
        ctx_edges = torch.cat([inner_edges, global_normal, global_global, seq_edges], dim=1)  # [2, E]

        # construct interaction edges
        inter_edges = self._construct_outer_edges(X, batch_id, k_neighbors, atom_pos)

        self._reset_buffer()
        return ctx_edges, inter_edges


class GMEdgeConstructor(EdgeConstructor):
    '''
    Edge constructor for graph matching (kNN internel edges and all bipartite edges)
    '''
    def _construct_inner_edges(self, X, batch_id, k_neighbors, atom_pos):
        row, col = self.row, self.col
        # all possible ctx edges: both in ag or ab, not global
        row_is_ag = self.row_seg == self.ag_seg_id
        col_is_ag = self.col_seg == self.ag_seg_id
        select_edges = torch.logical_and(row_is_ag == col_is_ag, self.not_global_edges)
        ctx_all_row, ctx_all_col = row[select_edges], col[select_edges]
        # ctx edges
        inner_edges = _knn_edges(
            X, atom_pos, torch.stack([ctx_all_row, ctx_all_col]).T,
            self.atom_pos_pad_idx, k_neighbors,
            (self.offsets, batch_id, self.max_n, self.gni2lni))
        return inner_edges

    def _construct_global_edges(self):
        row, col = self.row, self.col
        # edges between global and normal nodes
        select_edges = torch.logical_and(self.row_seg == self.col_seg, torch.logical_not(self.not_global_edges))
        global_normal = torch.stack([row[select_edges], col[select_edges]])  # [2, nE]
        # edges between global and global nodes
        row_is_ag = self.row_seg == self.ag_seg_id
        col_is_ag = self.col_seg == self.ag_seg_id
        select_edges = sequential_and(
            self.row_global, self.col_global, # self-loop has been deleted
            row_is_ag == col_is_ag)  # only inter-ag or inter-ab globals
        global_global = torch.stack([row[select_edges], col[select_edges]])  # [2, nE]
        return global_normal, global_global

    def _construct_outer_edges(self, X, batch_id, k_neighbors, atom_pos):
        row, col = self.row, self.col
        # all possible inter edges: one in ag and one in ab, not global
        row_is_ag = self.row_seg == self.ag_seg_id
        col_is_ag = self.col_seg == self.ag_seg_id
        select_edges = torch.logical_and(row_is_ag != col_is_ag, self.not_global_edges)
        inter_all_row, inter_all_col = row[select_edges], col[select_edges]
        return torch.stack([inter_all_row, inter_all_col])  # [2, E]


class SinusoidalPositionEmbedding(nn.Module):
    """
    Sin-Cos Positional Embedding
    """
    def __init__(self, output_dim):
        super(SinusoidalPositionEmbedding, self).__init__()
        self.output_dim = output_dim

    def forward(self, position_ids):
        device = position_ids.device
        position_ids = position_ids[None] # [1, N]
        indices = torch.arange(self.output_dim // 2, device=device, dtype=torch.float)
        indices = torch.pow(10000.0, -2 * indices / self.output_dim)
        embeddings = torch.einsum('bn,d->bnd', position_ids, indices)
        embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
        embeddings = embeddings.reshape(-1, self.output_dim)
        return embeddings

# embedding of amino acids. (default: concat residue embedding and atom embedding to one vector)
class AminoAcidEmbedding(nn.Module):
    '''
    [residue embedding + position embedding, mean(atom embeddings + atom position embeddings)]
    '''
    def __init__(self, num_res_type, num_atom_type, num_atom_pos, res_embed_size, atom_embed_size,
                 atom_pad_id=VOCAB.get_atom_pad_idx(), relative_position=True, max_position=192):  # max position (with IMGT numbering)
        super().__init__()
        self.residue_embedding = nn.Embedding(num_res_type, res_embed_size)
        if relative_position:
            self.res_pos_embedding = SinusoidalPositionEmbedding(res_embed_size)  # relative positional encoding
        else:
            self.res_pos_embedding = nn.Embedding(max_position, res_embed_size)  # absolute position encoding
        self.atom_embedding = nn.Embedding(num_atom_type, atom_embed_size)
        self.atom_pos_embedding = nn.Embedding(num_atom_pos, atom_embed_size)
        self.atom_pad_id = atom_pad_id
        self.eps = 1e-10  # for mean of atom embedding (some residues have no atom at all)
    
    def forward(self, S, RP, A, AP):
        '''
        :param S: [N], residue types
        :param RP: [N], residue positions
        :param A: [N, n_channel], atom types
        :param AP: [N, n_channel], atom positions
        '''
        res_embed = self.residue_embedding(S) + self.res_pos_embedding(RP)  # [N, res_embed_size]
        atom_embed = self.atom_embedding(A) + self.atom_pos_embedding(AP)   # [N, n_channel, atom_embed_size]
        atom_not_pad = (AP != self.atom_pad_id)  # [N, n_channel]
        denom = torch.sum(atom_not_pad, dim=-1, keepdim=True) + self.eps
        atom_embed = torch.sum(atom_embed * atom_not_pad.unsqueeze(-1), dim=1) / denom  # [N, atom_embed_size]
        return torch.cat([res_embed, atom_embed], dim=-1)  # [N, res_embed_size + atom_embed_size]


class AminoAcidFeature(nn.Module):
    def __init__(self, embed_size, relative_position=True, edge_constructor=EdgeConstructor, backbone_only=False) -> None:
        super().__init__()

        self.backbone_only = backbone_only

        # number of classes
        self.num_aa_type = len(VOCAB)
        self.num_atom_type = VOCAB.get_num_atom_type()
        self.num_atom_pos = VOCAB.get_num_atom_pos()

        # atom-level special tokens
        self.atom_mask_idx = VOCAB.get_atom_mask_idx()
        self.atom_pad_idx = VOCAB.get_atom_pad_idx()
        self.atom_pos_mask_idx = VOCAB.get_atom_pos_mask_idx()
        self.atom_pos_pad_idx = VOCAB.get_atom_pos_pad_idx()
        
        # embedding
        self.aa_embedding = AminoAcidEmbedding(
            self.num_aa_type, self.num_atom_type, self.num_atom_pos,
            embed_size, embed_size, self.atom_pad_idx, relative_position)

        # global nodes and mask nodes
        self.boa_idx = VOCAB.symbol_to_idx(VOCAB.BOA)
        self.boh_idx = VOCAB.symbol_to_idx(VOCAB.BOH)
        self.bol_idx = VOCAB.symbol_to_idx(VOCAB.BOL)
        self.mask_idx = VOCAB.get_mask_idx()

        # segment ids
        self.ag_seg_id, self.hc_seg_id, self.lc_seg_id = 1, 2, 3

        # atoms encoding
        residue_atom_type, residue_atom_pos = [], []
        backbone = [VOCAB.atom_to_idx(atom[0]) for atom in VOCAB.backbone_atoms]
        n_channel = VOCAB.MAX_ATOM_NUMBER if not backbone_only else 4
        special_mask = VOCAB.get_special_mask()
        for i in range(len(VOCAB)):
            if i == self.boa_idx or i == self.boh_idx or i == self.bol_idx or i == self.mask_idx:
                # global nodes
                residue_atom_type.append([self.atom_mask_idx for _ in range(n_channel)])
                residue_atom_pos.append([self.atom_pos_mask_idx for _ in range(n_channel)])
            elif special_mask[i] == 1:
                # other special token (pad)
                residue_atom_type.append([self.atom_pad_idx for _ in range(n_channel)])
                residue_atom_pos.append([self.atom_pos_pad_idx for _ in range(n_channel)])
            else:
                # normal amino acids
                sidechain_atoms = VOCAB.get_sidechain_info(VOCAB.idx_to_symbol(i))
                atom_type = backbone
                atom_pos = [VOCAB.atom_pos_to_idx(VOCAB.atom_pos_bb) for _ in backbone]
                if not backbone_only:
                    sidechain_atoms = VOCAB.get_sidechain_info(VOCAB.idx_to_symbol(i))
                    atom_type = atom_type + [VOCAB.atom_to_idx(atom[0]) for atom in sidechain_atoms]
                    atom_pos = atom_pos + [VOCAB.atom_pos_to_idx(atom[1]) for atom in sidechain_atoms]
                num_pad = n_channel - len(atom_type)
                residue_atom_type.append(atom_type + [self.atom_pad_idx for _ in range(num_pad)])
                residue_atom_pos.append(atom_pos + [self.atom_pos_pad_idx for _ in range(num_pad)])
        
        # mapping from residue to atom types and positions
        self.residue_atom_type = nn.parameter.Parameter(
            torch.tensor(residue_atom_type, dtype=torch.long),
            requires_grad=False)
        self.residue_atom_pos = nn.parameter.Parameter(
            torch.tensor(residue_atom_pos, dtype=torch.long),
            requires_grad=False)

        # sidechain geometry
        if not backbone_only:
            sc_bonds, sc_bonds_mask = [], []
            sc_chi_atoms, sc_chi_atoms_mask = [], []
            for i in range(len(VOCAB)):
                if special_mask[i] == 1:
                    sc_bonds.append([])
                    sc_chi_atoms.append([])
                else:
                    symbol = VOCAB.idx_to_symbol(i)
                    atom_type = VOCAB.backbone_atoms + VOCAB.get_sidechain_info(symbol)
                    atom2channel = { atom: i for i, atom in enumerate(atom_type) }
                    chi_atoms, bond_atoms = VOCAB.get_sidechain_geometry(symbol)
                    sc_chi_atoms.append(
                        [[atom2channel[atom] for atom in atoms] for atoms in chi_atoms]
                    )
                    bonds = []
                    for src_atom in bond_atoms:
                        for dst_atom in bond_atoms[src_atom]:
                            bonds.append((atom2channel[src_atom], atom2channel[dst_atom]))
                    sc_bonds.append(bonds)
            max_num_chis = max([len(chis) for chis in sc_chi_atoms])
            max_num_bonds = max([len(bonds) for bonds in sc_bonds])
            for i in range(len(VOCAB)):
                num_chis, num_bonds = len(sc_chi_atoms[i]), len(sc_bonds[i])
                num_pad_chis, num_pad_bonds = max_num_chis - num_chis, max_num_bonds - num_bonds
                sc_chi_atoms_mask.append(
                    [1 for _ in range(num_chis)] + [0 for _ in range(num_pad_chis)]
                )
                sc_bonds_mask.append(
                    [1 for _ in range(num_bonds)] + [0 for _ in range(num_pad_bonds)]
                )
                sc_chi_atoms[i].extend([[-1, -1, -1, -1] for _ in range(num_pad_chis)])
                sc_bonds[i].extend([(-1, -1) for _ in range(num_pad_bonds)])

            # mapping residues to their sidechain chi angle atoms and bonds
            self.sidechain_chi_angle_atoms = nn.parameter.Parameter(
                torch.tensor(sc_chi_atoms, dtype=torch.long),
                requires_grad=False)
            self.sidechain_chi_mask = nn.parameter.Parameter(
                torch.tensor(sc_chi_atoms_mask, dtype=torch.bool),
                requires_grad=False
            )
            self.sidechain_bonds = nn.parameter.Parameter(
                torch.tensor(sc_bonds, dtype=torch.long),
                requires_grad=False
            )
            self.sidechain_bonds_mask = nn.parameter.Parameter(
                torch.tensor(sc_bonds_mask, dtype=torch.bool),
                requires_grad=False
            )

        # edge constructor
        self.edge_constructor = edge_constructor(self.boa_idx, self.boh_idx, self.bol_idx, self.atom_pos_pad_idx, self.ag_seg_id)

    def _is_global(self, S):
        return sequential_or(S == self.boa_idx, S == self.boh_idx, S == self.bol_idx)  # [N]

    def _construct_residue_pos(self, S):
        # construct residue position. global node is 1, the first residue is 2, ... (0 for padding)
        glbl_node_mask = self._is_global(S)
        glbl_node_idx = torch.nonzero(glbl_node_mask).flatten()  # [batch_size * 3] (boa, boh, bol)
        shift = F.pad(glbl_node_idx[:-1] - glbl_node_idx[1:] + 1, (1, 0), value=1) # [batch_size * 3]
        residue_pos = torch.ones_like(S)
        residue_pos[glbl_node_mask] = shift
        residue_pos = torch.cumsum(residue_pos, dim=0)
        return residue_pos

    def _construct_segment_ids(self, S):
        # construct segment ids. 1/2/3 for antigen/heavy chain/light chain
        glbl_node_mask = self._is_global(S)
        glbl_nodes = S[glbl_node_mask]
        boa_mask, boh_mask, bol_mask = (glbl_nodes == self.boa_idx), (glbl_nodes == self.boh_idx), (glbl_nodes == self.bol_idx)
        glbl_nodes[boa_mask], glbl_nodes[boh_mask], glbl_nodes[bol_mask] = self.ag_seg_id, self.hc_seg_id, self.lc_seg_id
        segment_ids = torch.zeros_like(S)
        segment_ids[glbl_node_mask] = glbl_nodes - F.pad(glbl_nodes[:-1], (1, 0), value=0)
        segment_ids = torch.cumsum(segment_ids, dim=0)
        return segment_ids

    def _construct_atom_type(self, S):
        # construct atom types
        return self.residue_atom_type[S]
    
    def _construct_atom_pos(self, S):
        # construct atom positions
        return self.residue_atom_pos[S]

    @torch.no_grad()
    def get_sidechain_chi_angles_atoms(self, S):
        chi_angles_atoms = self.sidechain_chi_angle_atoms[S]  # [N, max_num_chis, 4]
        chi_mask = self.sidechain_chi_mask[S]  # [N, max_num_chis]
        return chi_angles_atoms, chi_mask

    @torch.no_grad()
    def get_sidechain_bonds(self, S):
        bonds = self.sidechain_bonds[S]  # [N, max_num_bond, 2]
        bond_mask = self.sidechain_bonds_mask[S]
        return bonds, bond_mask

    def update_globel_coordinates(self, X, S, atom_pos=None):
        X = X.clone()

        if atom_pos is None:  # [N, n_channel]
            atom_pos = self._construct_atom_pos(S)

        glbl_node_mask = self._is_global(S)
        chain_id = glbl_node_mask.long()
        chain_id = torch.cumsum(chain_id, dim=0)  # [N]
        chain_id[glbl_node_mask] = 0    # set global nodes to 0
        chain_id = chain_id.unsqueeze(-1).repeat(1, atom_pos.shape[-1])  # [N, n_channel]
        
        not_global = torch.logical_not(glbl_node_mask)
        not_pad = (atom_pos != self.atom_pos_pad_idx)[not_global]
        flatten_coord = X[not_global][not_pad]  # [N_atom, 3]
        flatten_chain_id = chain_id[not_global][not_pad]

        global_x = scatter_mean(
            src=flatten_coord, index=flatten_chain_id,
            dim=0, dim_size=glbl_node_mask.sum() + 1)  # because index start from 1
        X[glbl_node_mask] = global_x[1:].unsqueeze(1)

        return X

    def embedding(self, S, residue_pos=None, atom_type=None, atom_pos=None):
        '''
        :param S: [N], residue types
        '''
        if residue_pos is None:  # Residue positions in the chain
            residue_pos = self._construct_residue_pos(S)  # [N]

        if atom_type is None:  # Atom types in each residue
            atom_type = self.residue_atom_type[S]  # [N, n_channel]

        if atom_pos is None:   # Atom position in each residue
            atom_pos = self.residue_atom_pos[S]     # [N, n_channel]

        H = self.aa_embedding(S, residue_pos, atom_type, atom_pos)
        return H, (residue_pos, atom_type, atom_pos)

    @torch.no_grad()
    def construct_edges(self, X, S, batch_id, k_neighbors, atom_pos=None, segment_ids=None):

        # prepare inputs
        if atom_pos is None:  # Atom position in each residue (pad need to be ignored)
            atom_pos = self.residue_atom_pos[S]
        
        if segment_ids is None:
            segment_ids = self._construct_segment_ids(S)

        ctx_edges, inter_edges = self.edge_constructor.construct_edges(
            X, S, batch_id, k_neighbors, atom_pos, segment_ids)

        return ctx_edges, inter_edges

    def forward(self, X, S, batch_id, k_neighbors):
        H, (_, _, atom_pos) = self.embedding(S)
        ctx_edges, inter_edges = self.construct_edges(
            X, S, batch_id, k_neighbors, atom_pos=atom_pos)
        return H, (ctx_edges, inter_edges)


class SeparatedAminoAcidFeature(AminoAcidFeature):
    '''
    Separate embeddings of atoms and residues
    '''
    def __init__(self, embed_size, atom_embed_size, relative_position=True, edge_constructor=EdgeConstructor, fix_atom_weights=False, backbone_only=False) -> None:
        super().__init__(embed_size, relative_position=relative_position, edge_constructor=edge_constructor, backbone_only=backbone_only)
        atom_weights_mask = self.residue_atom_type == self.atom_pad_idx
        self.register_buffer('atom_weights_mask', atom_weights_mask)
        self.fix_atom_weights = fix_atom_weights
        if fix_atom_weights:
            atom_weights = torch.ones_like(self.residue_atom_type, dtype=torch.float)
        else:
            atom_weights = torch.randn_like(self.residue_atom_type, dtype=torch.float)
        atom_weights[atom_weights_mask] = 0
        self.atom_weight = nn.parameter.Parameter(atom_weights, requires_grad=not fix_atom_weights)
        self.zero_atom_weight = nn.parameter.Parameter(torch.zeros_like(atom_weights), requires_grad=False)
        
        # override
        self.aa_embedding = AminoAcidEmbedding(
            self.num_aa_type, self.num_atom_type, self.num_atom_pos,
            embed_size, atom_embed_size, self.atom_pad_idx, relative_position)
    
    def get_atom_weights(self, residue_types):
        weights = torch.where(
            self.atom_weights_mask,
            self.zero_atom_weight,
            self.atom_weight
        )  # [num_aa_classes, max_atom_number(n_channel)]
        if not self.fix_atom_weights:
            weights = F.normalize(weights, dim=-1)
        return weights[residue_types]

    def forward(self, X, S, batch_id, k_neighbors, residue_pos=None, smooth_prob=None, smooth_mask=None):
        if residue_pos is None:
            residue_pos = self._construct_residue_pos(S)  # [N]
        atom_type = self.residue_atom_type[S]  # [N, n_channel]
        atom_pos = self.residue_atom_pos[S]     # [N, n_channel]

        # residue embedding
        pos_embedding = self.aa_embedding.res_pos_embedding(residue_pos)
        H = self.aa_embedding.residue_embedding(S)
        if smooth_prob is not None:
            res_embeddings = self.aa_embedding.residue_embedding(
                torch.arange(smooth_prob.shape[-1], device=S.device, dtype=S.dtype)
            )  # [num_aa_type, embed_size]
            H[smooth_mask] = smooth_prob.mm(res_embeddings)
        H = H + pos_embedding

        # atom embedding
        atom_embedding = self.aa_embedding.atom_embedding(atom_type) +\
                         self.aa_embedding.atom_pos_embedding(atom_pos)
        atom_weights = self.get_atom_weights(S)
        
        ctx_edges, inter_edges = self.construct_edges(
            X, S, batch_id, k_neighbors, atom_pos=atom_pos)
        return H, (ctx_edges, inter_edges), (atom_embedding, atom_weights)


class ProteinFeature:
    def __init__(self, backbone_only=False):
        self.backbone_only = backbone_only

    def _cal_sidechain_bond_lengths(self, S, X, aa_feature: AminoAcidFeature):
        bonds, bonds_mask = aa_feature.get_sidechain_bonds(S)
        n = torch.nonzero(bonds_mask)[:, 0]  # [Nbonds]
        src, dst = bonds[bonds_mask].T
        src_X, dst_X = X[(n, src)], X[(n, dst)]  # [Nbonds, 3]
        bond_lengths = torch.norm(dst_X - src_X, dim=-1)
        return bond_lengths

    def _cal_sidechain_chis(self, S, X, aa_feature: AminoAcidFeature):
        chi_atoms, chi_mask = aa_feature.get_sidechain_chi_angles_atoms(S)
        n = torch.nonzero(chi_mask)[:, 0]  # [Nchis]
        a0, a1, a2, a3 = chi_atoms[chi_mask].T  # [Nchis]
        x0, x1, x2, x3 = X[(n, a0)], X[(n, a1)], X[(n, a2)], X[(n, a3)]  # [Nchis, 3]
        u_0, u_1, u_2 = (x1 - x0), (x2 - x1), (x3 - x2)  # [Nchis, 3]
        # normals of the two planes
        n_1 = F.normalize(torch.cross(u_0, u_1), dim=-1)  # [Nchis, 3]
        n_2 = F.normalize(torch.cross(u_1, u_2), dim=-1)  # [Nchis, 3]
        cosChi = (n_1 * n_2).sum(-1)  # [Nchis]
        eps = 1e-7
        cosChi = torch.clamp(cosChi, -1 + eps, 1 - eps)
        return cosChi

    def _cal_backbone_bond_lengths(self, X, seg_id):
        # loss of backbone (...N-CA-C(O)-N...) bond length
        # N-CA, CA-C, C=O
        bl1 = torch.norm(X[:, 1:4] - X[:, :3], dim=-1)  # [N, 3], (N-CA), (CA-C), (C=O)
        # C-N
        bl2 = torch.norm(X[1:, 0] - X[:-1, 2], dim=-1)  # [N-1]
        same_chain_mask = seg_id[1:] == seg_id[:-1]
        bl2 = bl2[same_chain_mask]
        bl = torch.cat([bl1.flatten(), bl2], dim=0)
        return bl

    def _cal_angles(self, X, seg_id):
        ori_X = X
        X = X[:, :3].reshape(-1, 3)  # [N * 3, 3], N, CA, C
        U = F.normalize(X[1:] - X[:-1], dim=-1)  # [N * 3 - 1, 3]

        # 1. dihedral angles
        u_2, u_1, u_0 = U[:-2], U[1:-1], U[2:]   # [N * 3 - 3, 3]
        # backbone normals
        n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
        n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)
        # angle between normals
        eps = 1e-7
        cosD = (n_2 * n_1).sum(-1)  # [(N-1) * 3]
        cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
        # D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
        seg_id_atom = seg_id.repeat(1, 3).flatten()  # [N * 3]
        same_chain_mask = sequential_and(
            seg_id_atom[:-3] == seg_id_atom[1:-2],
            seg_id_atom[1:-2] == seg_id_atom[2:-1],
            seg_id_atom[2:-1] == seg_id_atom[3:]
        )  # [N * 3 - 3]
        # D = D[same_chain_mask]
        cosD = cosD[same_chain_mask]

        # 2. bond angles (C_{n-1}-N, N-CA), (N-CA, CA-C), (CA-C, C=O), (CA-C, C-N_{n+1}), (O=C, C-Nn)
        u_0, u_1 = U[:-1], U[1:]  # [N*3 - 2, 3]
        cosA1 = ((-u_0) * u_1).sum(-1)  # [N*3 - 2], (C_{n-1}-N, N-CA), (N-CA, CA-C), (CA-C, C-N_{n+1})
        same_chain_mask = sequential_and(
            seg_id_atom[:-2] == seg_id_atom[1:-1],
            seg_id_atom[1:-1] == seg_id_atom[2:]
        )
        cosA1 = cosA1[same_chain_mask]  # [N*3 - 2 * num_chain]
        u_co = F.normalize(ori_X[:, 3] - ori_X[:, 2], dim=-1)  # [N, 3], C=O
        u_cca = -U[1::3]  # [N, 3], C-CA
        u_cn = U[2::3] # [N-1, 3], C-N_{n+1}
        cosA2 = (u_co * u_cca).sum(-1)  # [N], (C=O, C-CA)
        cosA3 = (u_co[:-1] * u_cn).sum(-1)  # [N-1], (C=O, C-N_{n+1})
        same_chain_mask = (seg_id[:-1] == seg_id[1:]) # [N-1]
        cosA3 = cosA3[same_chain_mask]
        cosA = torch.cat([cosA1, cosA2, cosA3], dim=-1)
        cosA = torch.clamp(cosA, -1 + eps, 1 - eps)

        return cosD, cosA

    def coord_loss(self, pred_X, true_X, batch_id, atom_mask, reference=None):
        pred_bb, true_bb = pred_X[:, :4], true_X[:, :4]
        bb_mask = atom_mask[:, :4]
        true_X = true_X.clone()
        ops = []

        align_obj = pred_bb if reference is None else reference[:, :4]

        for i in range(torch.max(batch_id) + 1):
            is_cur_graph = batch_id == i
            cur_bb_mask = bb_mask[is_cur_graph]
            _, R, t = kabsch_torch(
                true_bb[is_cur_graph][cur_bb_mask],
                align_obj[is_cur_graph][cur_bb_mask],
                requires_grad=True)
            true_X[is_cur_graph] = torch.matmul(true_X[is_cur_graph], R.T) + t
            ops.append((R.detach(), t.detach()))

        xloss = F.smooth_l1_loss(
            pred_X[atom_mask], true_X[atom_mask],
            reduction='sum') / atom_mask.sum()  # atom-level loss
        bb_rmsd = torch.sqrt(((pred_X[:, :4] - true_X[:, :4]) ** 2).sum(-1).mean(-1))  # [N]
        return xloss, bb_rmsd, ops

    def structure_loss(self, pred_X, true_X, S, cmask, batch_id, xloss_mask, aa_feature, full_profile=False, reference=None):
        atom_pos = aa_feature._construct_atom_pos(S)[cmask]
        seg_id = aa_feature._construct_segment_ids(S)[cmask]
        atom_mask = atom_pos != aa_feature.atom_pos_pad_idx
        atom_mask = torch.logical_and(atom_mask, xloss_mask[cmask])

        pred_X, true_X, batch_id = pred_X[cmask], true_X[cmask], batch_id[cmask]

        # loss of absolute coordinates
        xloss, bb_rmsd, ops = self.coord_loss(pred_X, true_X, batch_id, atom_mask, reference)

        # loss of backbone (...N-CA-C(O)-N...) bond length
        true_bl = self._cal_backbone_bond_lengths(true_X, seg_id)
        pred_bl = self._cal_backbone_bond_lengths(pred_X, seg_id)
        bond_loss = F.smooth_l1_loss(pred_bl, true_bl)

        # loss of backbone dihedral angles
        if full_profile:
            true_cosD, true_cosA = self._cal_angles(true_X, seg_id)
            pred_cosD, pred_cosA = self._cal_angles(pred_X, seg_id)
            angle_loss = F.smooth_l1_loss(pred_cosD, true_cosD)
            bond_angle_loss = F.smooth_l1_loss(pred_cosA, true_cosA)

        S = S[cmask]
        if self.backbone_only:
            sc_bond_loss, sc_chi_loss = 0, 0
        else:
            # loss of sidechain bonds
            true_sc_bl = self._cal_sidechain_bond_lengths(S, true_X, aa_feature)
            pred_sc_bl = self._cal_sidechain_bond_lengths(S, pred_X, aa_feature)
            sc_bond_loss = F.smooth_l1_loss(pred_sc_bl, true_sc_bl)

            # loss of sidechain chis
            if full_profile:
                true_sc_chi = self._cal_sidechain_chis(S, true_X, aa_feature)
                pred_sc_chi = self._cal_sidechain_chis(S, pred_X, aa_feature)
                sc_chi_loss = F.smooth_l1_loss(pred_sc_chi, true_sc_chi)

        # exerting constraints on bond lengths only is sufficient
        violation_loss = bond_loss + sc_bond_loss
        loss = xloss + violation_loss

        if full_profile:
            details = (xloss, bond_loss, bond_angle_loss, angle_loss, sc_bond_loss, sc_chi_loss)
        else:
            details = (xloss, bond_loss, sc_bond_loss)

        return loss, details, bb_rmsd, ops


class SeperatedCoordNormalizer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.mean = torch.tensor(0)
        self.std = torch.tensor(10)
        self.mean = nn.parameter.Parameter(self.mean, requires_grad=False)
        self.std = nn.parameter.Parameter(self.std, requires_grad=False)
        self.boa_idx = VOCAB.symbol_to_idx(VOCAB.BOA)

    def normalize(self, X):
        X = (X - self.mean) / self.std
        return X

    def unnormalize(self, X):
        X = X * self.std + self.mean
        return X

    def centering(self, X, S, batch_id, aa_feature: AminoAcidFeature):
        # centering antigen and antibody separatedly
        segment_ids = aa_feature._construct_segment_ids(S)
        not_bol = S != aa_feature.bol_idx
        tmp_S = S[not_bol]
        tmp_X = aa_feature.update_globel_coordinates(X[not_bol], tmp_S)
        self.ag_centers = tmp_X[tmp_S == aa_feature.boa_idx][:, 0]
        self.ab_centers = tmp_X[tmp_S == aa_feature.boh_idx][:, 0]

        is_ag = segment_ids == aa_feature.ag_seg_id
        is_ab = torch.logical_not(is_ag)

        # compose centers
        centers = torch.zeros(X.shape[0], X.shape[-1], dtype=X.dtype, device=X.device)
        centers[is_ag] = self.ag_centers[batch_id[is_ag]]
        centers[is_ab] = self.ab_centers[batch_id[is_ab]]
        X = X - centers.unsqueeze(1)
        self.is_ag, self.is_ab = is_ag, is_ab
        return X

    def uncentering(self, X, batch_id, _type=1):
        if _type == 0:
            # type 0: [N, 3]
            X = X.unsqueeze(1) # then it is type 1
        
        if _type == 0 or _type == 1:
            # type 1: [N, n_channel, 3]
            centers = torch.zeros(X.shape[0], X.shape[-1], dtype=X.dtype, device=X.device)
            centers[self.is_ag] = self.ag_centers[batch_id[self.is_ag]]
            centers[self.is_ab] = self.ab_centers[batch_id[self.is_ab]]
            X = X + centers.unsqueeze(1)
        elif _type == 2:
            # type 2: [2, bs, K, 3], X[0] for antigen, X[1] for antibody
            centers = torch.stack([self.ag_centers, self.ab_centers], dim=0)  # [2, bs, 3]
            X = X + centers.unsqueeze(-2)
        elif _type == 3:
            # type 3: [2, Ef, 3], X[0] for antigen, X[1] for antibody
            centers = torch.stack([self.ag_centers[batch_id], self.ab_centers[batch_id]], dim=0)
            X = X + centers
        elif _type == 4:
            # type 4: [N, n_channel, 3], but all uncentering to the center of antigen
            centers = self.ag_centers[batch_id]
            X = X + centers.unsqueeze(1)
        else:
            raise NotImplementedError(f'uncentering for type {_type} not implemented')

        if _type == 0:
            X = X.squeeze(1)
        return X

    def clear_cache(self):
        self.ag_centers, self.ab_centers, self.is_ag, self.is_ab = None, None, None, None

In [14]:
########### load your train / valid set ###########
if len(config.gpus) >= 1:
    print_log(config)
    print_log(f'CDR type: {config.cdr}')
    print_log(f'Paratope: {config.paratope}')
    print_log('structure only' if config.struct_only else 'sequence & structure codesign')

print('till here')
train_set = E2EDataset(config.train_set, cdr=config.cdr, paratope=config.paratope)

########## set your collate_fn ##########
collate_fn = train_set.collate_fn

2023-10-10 01:15:10::INFO::<__main__.Config object at 0x151f2ab7d0d0>
2023-10-10 01:15:10::INFO::CDR type: ['H3']
2023-10-10 01:15:10::INFO::Paratope: H3
2023-10-10 01:15:10::INFO::sequence & structure codesign
till here
2023-10-10 01:15:10::INFO::No meta-info file found, start processing


 32%|███▏      | 832/2638 [03:54<11:07,  2.70it/s]

2023-10-10 01:19:05::ERROR::Antigen chain K has something wrong!
2023-10-10 01:19:05::ERROR::parse 6ulf pdb failed, skip
2023-10-10 01:19:05::ERROR::Antigen chain K has something wrong!
2023-10-10 01:19:05::ERROR::parse 6vln pdb failed, skip


 38%|███▊      | 1013/2638 [04:30<03:46,  7.18it/s]

2023-10-10 01:19:41::ERROR::Antigen chain G has something wrong!
2023-10-10 01:19:41::ERROR::parse 4cc8 pdb failed, skip


 89%|████████▉ | 2355/2638 [10:55<01:41,  2.79it/s]

2023-10-10 01:26:06::ERROR::Antigen chain E has something wrong!
2023-10-10 01:26:06::ERROR::parse 6qd8 pdb failed, skip


 93%|█████████▎| 2461/2638 [11:26<00:21,  8.32it/s]

2023-10-10 01:26:39::ERROR::Antigen chain E has something wrong!
2023-10-10 01:26:39::ERROR::parse 6qd7 pdb failed, skip


100%|██████████| 2638/2638 [12:20<00:00,  3.56it/s]

2023-10-10 01:27:31::INFO::Saving ./all_data/RAbD/train_processed/part_0.pkl ...





2023-10-10 01:28:23::INFO::Loading preprocessed file /home/dagaa/Projects/dyMEAN/all_data/RAbD/train_processed/part_0.pkl, 1/1


In [15]:
example = train_set.__getitem__(10)

In [16]:
config.train_set

'./all_data/RAbD/train.json'

In [17]:
example

{'X': array([[[194.68936648, -64.90005667, 161.78872256],
         [194.68936648, -64.90005667, 161.78872256],
         [194.68936648, -64.90005667, 161.78872256],
         ...,
         [194.68936648, -64.90005667, 161.78872256],
         [194.68936648, -64.90005667, 161.78872256],
         [194.68936648, -64.90005667, 161.78872256]],
 
        [[187.2230072 , -82.16100311, 154.39599609],
         [188.30700684, -81.2990036 , 153.94099426],
         [188.4940033 , -81.40899658, 152.42500305],
         ...,
         [188.30700684, -81.2990036 , 153.94099426],
         [188.30700684, -81.2990036 , 153.94099426],
         [188.30700684, -81.2990036 , 153.94099426]],
 
        [[185.14500427, -64.29699707, 159.91000366],
         [185.47200012, -62.92200089, 159.57000732],
         [184.27099609, -62.        , 159.71800232],
         ...,
         [185.47200012, -62.92200089, 159.57000732],
         [185.47200012, -62.92200089, 159.57000732],
         [185.47200012, -62.92200089, 159.5700

In [18]:
example.keys()

dict_keys(['X', 'S', 'residue_pos', 'xloss_mask', 'cmask', 'smask', 'paratope_mask', 'template'])

In [19]:
print(example['X'].shape)
print(example['S'].shape)
print(example['residue_pos'].shape)
print(example['xloss_mask'].shape)
print(len(example['cmask']))
print(example['template'].shape)

(174, 14, 3)
(174,)
(174,)
(174, 14)
174
(124, 14, 3)


In [20]:

########### load your train / valid set ###########
if len(config.gpus) >= 1:
    print_log(config)
    print_log(f'CDR type: {config.cdr}')
    print_log(f'Paratope: {config.paratope}')
    print_log('structure only' if config.struct_only else 'sequence & structure codesign')

print('till here')
train_set = E2EDataset(config.train_set, cdr=config.cdr, paratope=config.paratope)
valid_set = E2EDataset(config.valid_set, cdr=config.cdr, paratope=config.paratope)

########## set your collate_fn ##########
collate_fn = train_set.collate_fn

########## define your model/trainer/trainconfig #########
config = TrainConfig(**vars(config))

if config.model_type == 'dyMEAN':


    class dyMEANTrainer(Trainer):

        ########## Override start ##########

        def __init__(self, model, train_loader, valid_loader, config):
            self.global_step = 0
            self.epoch = 0
            self.max_step = config.max_epoch * config.step_per_epoch
            self.log_alpha = log(config.final_lr / config.lr) / self.max_step
            super().__init__(model, train_loader, valid_loader, config)

        def get_optimizer(self):
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr)
            return optimizer

        def get_scheduler(self, optimizer):
            log_alpha = self.log_alpha
            lr_lambda = lambda step: exp(log_alpha * (step + 1))  # equal to alpha^{step}
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
            return {
                'scheduler': scheduler,
                'frequency': 'batch'
            }

        def train_step(self, batch, batch_idx):
            batch['context_ratio'] = self.get_context_ratio()
            return self.share_step(batch, batch_idx, val=False)

        def valid_step(self, batch, batch_idx):
            batch['context_ratio'] = 0
            return self.share_step(batch, batch_idx, val=True)

        ########## Override end ##########

        def get_context_ratio(self):
            step = self.global_step
            ratio = 0.5 * (cos(step / self.max_step * pi) + 1) * 0.9  # scale to [0, 0.9]
            return ratio

        def share_step(self, batch, batch_idx, val=False):
            loss, seq_detail, structure_detail, dock_detail, pdev_detail = self.model(**batch)
            snll, aar = seq_detail
            struct_loss, xloss, bond_loss, sc_bond_loss = structure_detail
            dock_loss, interface_loss, ed_loss, r_ed_losses = dock_detail
            pdev_loss, prmsd_loss = pdev_detail

            log_type = 'Validation' if val else 'Train'

            self.log(f'Overall/Loss/{log_type}', loss, batch_idx, val)

            self.log(f'Seq/SNLL/{log_type}', snll, batch_idx, val)
            self.log(f'Seq/AAR/{log_type}', aar, batch_idx, val)

            self.log(f'Struct/StructLoss/{log_type}', struct_loss, batch_idx, val)
            self.log(f'Struct/XLoss/{log_type}', xloss, batch_idx, val)
            self.log(f'Struct/BondLoss/{log_type}', bond_loss, batch_idx, val)
            self.log(f'Struct/SidechainBondLoss/{log_type}', sc_bond_loss, batch_idx, val)

            self.log(f'Dock/DockLoss/{log_type}', dock_loss, batch_idx, val)
            self.log(f'Dock/SPLoss/{log_type}', interface_loss, batch_idx, val)
            self.log(f'Dock/EDLoss/{log_type}', ed_loss, batch_idx, val)
            for i, l in enumerate(r_ed_losses):
                self.log(f'Dock/edloss{i}/{log_type}', l, batch_idx, val)

            if pdev_loss is not None:
                self.log(f'PDev/PDevLoss/{log_type}', pdev_loss, batch_idx, val)
                self.log(f'PDev/PRMSDLoss/{log_type}', prmsd_loss, batch_idx, val)

            if not val:
                lr = self.config.lr if self.scheduler is None else self.scheduler.get_last_lr()
                lr = lr[0]
                self.log('lr', lr, batch_idx, val)
                self.log('context_ratio', batch['context_ratio'], batch_idx, val)
            return loss
    class dyMEANModel(nn.Module):
        def __init__(self, embed_size, hidden_size, n_channel, num_classes,
                    mask_id=VOCAB.get_mask_idx(), k_neighbors=9, bind_dist_cutoff=6,
                    n_layers=3, iter_round=3, dropout=0.1, struct_only=False,
                    backbone_only=False, fix_channel_weights=False, pred_edge_dist=True,
                    keep_memory=True, cdr_type='H3', paratope='H3', relative_position=False) -> None:
            super().__init__()
            self.mask_id = mask_id
            self.num_classes = num_classes
            self.bind_dist_cutoff = bind_dist_cutoff
            self.k_neighbors = k_neighbors
            self.round = iter_round
            self.struct_only = struct_only

            # options
            self.backbone_only = backbone_only
            self.fix_channel_weights = fix_channel_weights
            self.pred_edge_dist = pred_edge_dist
            self.keep_memory = keep_memory
            if self.backbone_only:
                n_channel = 4
            self.cdr_type = cdr_type
            self.paratope = paratope

            atom_embed_size = embed_size // 4
            self.aa_feature = SeparatedAminoAcidFeature(
                embed_size, atom_embed_size,
                relative_position=relative_position,
                edge_constructor=GMEdgeConstructor,
                fix_atom_weights=fix_channel_weights,
                backbone_only=backbone_only
            )
            self.protein_feature = ProteinFeature(backbone_only=backbone_only)
            if keep_memory:
                self.memory_ffn = nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(hidden_size, hidden_size),
                    nn.SiLU(),
                    nn.Linear(hidden_size, embed_size)
                )
            if self.pred_edge_dist:  # use predicted dist for KNN-graph at the interface
                if self.keep_memory:  # this ffn acts on the memory
                    self.edge_H_ffn = nn.Sequential(
                        nn.SiLU(),
                        nn.Linear(hidden_size, hidden_size),
                        nn.SiLU(),
                        nn.Linear(hidden_size, hidden_size)
                    )
                self.edge_dist_ffn = nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(2 * hidden_size, hidden_size),
                    nn.SiLU(),
                    nn.Linear(hidden_size, 1)
                )
                # this GNN encodes the initial hidden states for initial edge distance prediction
                self.init_gnn = AMEGNN(
                    embed_size, hidden_size, hidden_size, n_channel,
                    channel_nf=atom_embed_size, radial_nf=hidden_size,
                    in_edge_nf=0, n_layers=n_layers, residual=True,
                    dropout=dropout, dense=False)
            if not struct_only:
                self.ffn_residue = nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(hidden_size, hidden_size),
                    nn.SiLU(),
                    nn.Linear(hidden_size, self.num_classes)
                )
            else:
                self.prmsd_ffn = nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(hidden_size, hidden_size),
                    nn.SiLU(),
                    nn.Linear(hidden_size, 1)
                )
            self.gnn = AMEncoder(
                embed_size, hidden_size, hidden_size, n_channel,
                channel_nf=atom_embed_size, radial_nf=hidden_size,
                in_edge_nf=0, n_layers=n_layers, residual=True,
                dropout=dropout, dense=False)
            
            self.normalizer = SeperatedCoordNormalizer()

            # training related cache
            self.batch_constants = {}

        def init_mask(self, X, S, cmask, smask, template):
            if not self.struct_only:
                S[smask] = self.mask_id
            X[cmask] = template
            return X, S

        def message_passing(self, X, S, residue_pos, interface_X, paratope_mask, batch_id, t, memory_H=None, smooth_prob=None, smooth_mask=None):
            # embeddings
            H_0, (ctx_edges, inter_edges), (atom_embeddings, atom_weights) = self.aa_feature(X, S, batch_id, self.k_neighbors, residue_pos, smooth_prob=smooth_prob, smooth_mask=smooth_mask)

            if not self.keep_memory:
                memory_H = None

            if memory_H is not None:
                H_0 = H_0 + self.memory_ffn(memory_H)

            if self.pred_edge_dist:
                if memory_H is not None:
                    edge_H = self.edge_H_ffn(memory_H)
                else:
                    # replace the MLP with gnn for initial edge distance prediction
                    edge_H, dumb_X = self.init_gnn(H_0, X, ctx_edges,
                                        channel_attr=atom_embeddings,
                                        channel_weights=atom_weights)
                    X = X + dumb_X * 0  # to cheat the autograd check

            # update coordination of the global node
            X = self.aa_feature.update_globel_coordinates(X, S)

            # prepare local complex
            local_mask = self.batch_constants['local_mask']
            local_is_ab = self.batch_constants['local_is_ab']
            local_batch_id = self.batch_constants['local_batch_id']
            local_X = X[local_mask].clone()
            # prepare local complex edges
            local_ctx_edges = self.batch_constants['local_ctx_edges']  # [2, Ec]
            local_inter_edges = self.batch_constants['local_inter_edges']  # [2, Ei]
            atom_pos = self.aa_feature._construct_atom_pos(S[local_mask])
            offsets, max_n, gni2lni = self.batch_constants['local_edge_infos']
            # for context edges, use edges in the native paratope
            local_ctx_edges = _knn_edges(
                local_X, atom_pos, local_ctx_edges.T,
                self.aa_feature.atom_pos_pad_idx, self.k_neighbors,
                (offsets, local_batch_id, max_n, gni2lni))
            # for interative edges, use edges derived from the predicted distance
            local_X[local_is_ab] = interface_X
            if self.pred_edge_dist:
                local_H = edge_H[local_mask]
                src_H, dst_H = local_H[local_inter_edges[0]], local_H[local_inter_edges[1]]
                p_edge_dist = self.edge_dist_ffn(torch.cat([src_H, dst_H], dim=-1)) +\
                            self.edge_dist_ffn(torch.cat([dst_H, src_H], dim=-1))  # perm-invariant
                p_edge_dist = p_edge_dist.squeeze()
            else:
                p_edge_dist = None
            local_inter_edges = _knn_edges(
                local_X, atom_pos, local_inter_edges.T,
                self.aa_feature.atom_pos_pad_idx, self.k_neighbors,
                (offsets, local_batch_id, max_n, gni2lni), given_dist=p_edge_dist)
            local_edges = torch.cat([local_ctx_edges, local_inter_edges], dim=1)

            # message passing
            H, pred_X, pred_local_X = self.gnn(H_0, X, ctx_edges,
                                            local_mask, local_X, local_edges,
                                            paratope_mask, local_is_ab,
                                            channel_attr=atom_embeddings,
                                            channel_weights=atom_weights)
            interface_X = pred_local_X[local_is_ab]
            pred_logits = None if self.struct_only else self.ffn_residue(H)

            return pred_logits, pred_X, interface_X, H, p_edge_dist  # [N, num_classes], [N, n_channel, 3], [Ncdr, n_channel, 3], [N, hidden_size]
        
        @torch.no_grad()
        def init_interface(self, X, S, paratope_mask, batch_id, init_noise=None):
            ag_centers = X[S == self.aa_feature.boa_idx][:, 0]  # [bs, 3]
            init_local_X = torch.zeros_like(X[paratope_mask])
            init_local_X = init_local_X + ag_centers[batch_id[paratope_mask]].unsqueeze(1)
            noise = torch.randn_like(init_local_X) if init_noise is None else init_noise
            ca_noise = noise[:, 1]
            noise = noise / 10  + ca_noise.unsqueeze(1) # scale other atoms
            noise[:, 1] = ca_noise
            init_local_X = init_local_X + noise
            return init_local_X

        @torch.no_grad()
        def _prepare_batch_constants(self, S, paratope_mask, lengths):
            # generate batch id
            batch_id = torch.zeros_like(S)  # [N]
            batch_id[torch.cumsum(lengths, dim=0)[:-1]] = 1
            batch_id.cumsum_(dim=0)  # [N], item idx in the batch
            self.batch_constants['batch_id'] = batch_id
            self.batch_constants['batch_size'] = torch.max(batch_id) + 1

            segment_ids = self.aa_feature._construct_segment_ids(S)
            self.batch_constants['segment_ids'] = segment_ids

            # interface relatd
            is_ag = segment_ids == self.aa_feature.ag_seg_id
            not_ag_global = S != self.aa_feature.boa_idx
            local_mask = torch.logical_or(
                paratope_mask, torch.logical_and(is_ag, not_ag_global)
            )
            local_segment_ids = segment_ids[local_mask]
            local_is_ab = local_segment_ids != self.aa_feature.ag_seg_id
            local_batch_id = batch_id[local_mask]
            self.batch_constants['is_ag'] = is_ag
            self.batch_constants['local_mask'] = local_mask
            self.batch_constants['local_is_ab'] = local_is_ab
            self.batch_constants['local_batch_id'] = local_batch_id
            self.batch_constants['local_segment_ids'] = local_segment_ids
            # interface local edges
            (row, col), (offsets, max_n, gni2lni) = self.aa_feature.edge_constructor.get_batch_edges(local_batch_id)
            row_segment_ids, col_segment_ids = local_segment_ids[row], local_segment_ids[col]
            is_ctx = row_segment_ids == col_segment_ids
            is_inter = torch.logical_not(is_ctx)

            self.batch_constants['local_ctx_edges'] = torch.stack([row[is_ctx], col[is_ctx]])  # [2, Ec]
            self.batch_constants['local_inter_edges'] = torch.stack([row[is_inter], col[is_inter]])  # [2, Ei]
            self.batch_constants['local_edge_infos'] = (offsets, max_n, gni2lni)

            interface_batch_id = batch_id[paratope_mask]
            self.batch_constants['interface_batch_id'] = interface_batch_id
        
        def _clean_batch_constants(self):
            self.batch_constants = {}

        @torch.no_grad()
        def _get_inter_edge_dist(self, X, S):
            local_mask = self.batch_constants['local_mask']
            atom_pos = self.aa_feature._construct_atom_pos(S[local_mask])
            src_dst = self.batch_constants['local_inter_edges'].T
            dist = X[local_mask][src_dst]  # [Ef, 2, n_channel, 3]
            dist = dist[:, 0].unsqueeze(2) - dist[:, 1].unsqueeze(1)  # [Ef, n_channel, n_channel, 3]
            dist = torch.norm(dist, dim=-1)  # [Ef, n_channel, n_channel]
            pos_pad = atom_pos[src_dst] == self.aa_feature.atom_pos_pad_idx # [Ef, 2, n_channel]
            pos_pad = torch.logical_or(pos_pad[:, 0].unsqueeze(2), pos_pad[:, 1].unsqueeze(1))  # [Ef, n_channel, n_channel]
            dist = dist + pos_pad * 1e10  # [Ef, n_channel, n_channel]
            dist = torch.min(dist.reshape(dist.shape[0], -1), dim=1)[0]  # [Ef]
            return dist
            is_binding = dist <= self.bind_dist_cutoff
            return is_binding

        def _forward(self, X, S, cmask, smask, paratope_mask, residue_pos, template, lengths, init_noise=None):
            batch_id = self.batch_constants['batch_id']

            # mask sequence and initialize coordinates with template
            X, S = self.init_mask(X, S, cmask, smask, template)

            # normalize
            X = self.normalizer.centering(X, S, batch_id, self.aa_feature)
            X = self.normalizer.normalize(X)

            # update center
            X = self.aa_feature.update_globel_coordinates(X, S)

            # prepare initial interface
            interface_X = self.init_interface(X, S, paratope_mask, batch_id, init_noise)

            # sequence and structure loss
            r_pred_S_logits, pred_S_dist, = [], None
            r_interface_X = [interface_X.clone()]  # init
            r_edge_dist = []
            memory_H = None
            # message passing
            for t in range(self.round):
                pred_S_logits, pred_X, interface_X, H, edge_dist = self.message_passing(X, S, residue_pos, interface_X, paratope_mask, batch_id, t, memory_H, pred_S_dist, smask)
                memory_H = H
                r_interface_X.append(interface_X.clone())
                r_pred_S_logits.append((pred_S_logits, smask))
                r_edge_dist.append(edge_dist)
                # 1. update X
                X = X.clone()
                X[cmask] = pred_X[cmask]
                X = self.aa_feature.update_globel_coordinates(X, S)

                if not self.struct_only:
                    # 2. update S
                    S = S.clone()
                    if t == self.round - 1:
                        S[smask] = torch.argmax(pred_S_logits[smask], dim=-1)
                    else:
                        pred_S_dist = torch.softmax(pred_S_logits[smask], dim=-1)

            interface_batch_id = self.batch_constants['interface_batch_id']

            if self.struct_only:
                # predicted rmsd
                prmsd = self.prmsd_ffn(H[cmask]).squeeze()  # [N_ab]
            else:
                prmsd = None

            # uncentering and unnormalize
            pred_X = self.normalizer.unnormalize(pred_X)
            pred_X = self.normalizer.uncentering(pred_X, batch_id)
            for i, interface_X in enumerate(r_interface_X):
                interface_X = self.normalizer.unnormalize(interface_X)
                interface_X = self.normalizer.uncentering(interface_X, interface_batch_id, _type=4)
                r_interface_X[i] = interface_X
            self.normalizer.clear_cache()

            return H, S, r_pred_S_logits, pred_X, r_interface_X,  r_edge_dist, prmsd

        def forward(self, X, S, cmask, smask, paratope_mask, residue_pos, template, lengths, xloss_mask, context_ratio=0):
            '''
            :param X: [N, n_channel, 3], Cartesian coordinates
            :param context_ratio: float, rate of context provided in masked sequence, should be [0, 1) and anneal to 0 in training
            '''
            if self.backbone_only:
                X, template = X[:, :4], template[:, :4]  # backbone
                xloss_mask = xloss_mask[:, :4]
            # clone ground truth coordinates, sequence
            true_X, true_S = X.clone(), S.clone()

            # prepare constants
            self._prepare_batch_constants(S, paratope_mask, lengths)
            batch_id = self.batch_constants['batch_id']

            # provide some ground truth for annealing sequence training
            if context_ratio > 0:
                not_ctx_mask = torch.rand_like(smask, dtype=torch.float) >= context_ratio
                smask = torch.logical_and(smask, not_ctx_mask)

            # get results
            H, pred_S, r_pred_S_logits, pred_X, r_interface_X, r_edge_dist, prmsd = self._forward(X, S, cmask, smask, paratope_mask, residue_pos, template, lengths)

            # sequence negtive log likelihood
            snll, total = 0, 0
            if not self.struct_only:
                for logits, mask in r_pred_S_logits:
                    snll = snll + F.cross_entropy(logits[mask], true_S[mask], reduction='sum')
                    total = total + mask.sum()
                snll = snll / total

            # structure loss
            struct_loss, struct_loss_details, bb_rmsd, ops = self.protein_feature.structure_loss(pred_X, true_X, true_S, cmask, batch_id, xloss_mask, self.aa_feature)

            # docking loss
            gt_interface_X = true_X[paratope_mask]
            # 1. interface loss (shadow paratope)
            interface_atom_pos = self.aa_feature._construct_atom_pos(true_S[paratope_mask])
            interface_atom_mask = interface_atom_pos != self.aa_feature.atom_pos_pad_idx
            interface_loss = F.smooth_l1_loss(
                r_interface_X[-1][interface_atom_mask],
                gt_interface_X[interface_atom_mask])
            # 2. edge dist loss
            if self.pred_edge_dist:
                gt_edge_dist = self._get_inter_edge_dist(self.normalizer.normalize(true_X), true_S)
                ed_loss, r_ed_losses = 0, []
                for edge_dist in r_edge_dist:
                    r_ed_loss = F.smooth_l1_loss(edge_dist, gt_edge_dist)
                    ed_loss = ed_loss + r_ed_loss
                    r_ed_losses.append(r_ed_loss)
            else:
                r_ed_losses = [0 for _ in range(self.round)]
                ed_loss = 0
            dock_loss = interface_loss + ed_loss

            if self.struct_only:
                # predicted rmsd
                prmsd_loss = F.smooth_l1_loss(prmsd, bb_rmsd)
                pdev_loss = prmsd_loss
            else:
                pdev_loss, prmsd_loss = None, None

            # comprehensive loss
            loss = snll + struct_loss + dock_loss + (0 if pdev_loss is None else pdev_loss)

            self._clean_batch_constants()

            # AAR
            with torch.no_grad():
                aa_hit = pred_S[smask] == true_S[smask]
                aar = aa_hit.long().sum() / aa_hit.shape[0]

            return loss, (snll, aar), (struct_loss, *struct_loss_details), (dock_loss, interface_loss, ed_loss, r_ed_losses), (pdev_loss, prmsd_loss)

        def sample(self, X, S, cmask, smask, paratope_mask, residue_pos, template, lengths, init_noise=None, return_hidden=False):
            if self.backbone_only:
                X, template = X[:, :4], template[:, :4]  # backbone
            gen_X, gen_S = X.clone(), S.clone()
            
            # prepare constants
            self._prepare_batch_constants(S, paratope_mask, lengths)

            batch_id = self.batch_constants['batch_id']
            batch_size = self.batch_constants['batch_size']
            segment_ids = self.batch_constants['segment_ids']
            interface_batch_id = self.batch_constants['interface_batch_id']
            is_ab = segment_ids != self.aa_feature.ag_seg_id
            s_batch_id = batch_id[smask]

            best_metric = torch.ones(batch_size, dtype=torch.float, device=X.device) * 1e10
            interface_cmask = paratope_mask[cmask]

            n_tries = 10 if self.struct_only else 1
            for i in range(n_tries):
            
                # generate
                H, pred_S, r_pred_S_logits, pred_X, r_interface_X, _, prmsd = self._forward(X, S, cmask, smask, paratope_mask, residue_pos, template, lengths, init_noise)

                # PPL or PRMSD
                if not self.struct_only:
                    S_logits = r_pred_S_logits[-1][0][smask]
                    S_probs = torch.max(torch.softmax(S_logits, dim=-1), dim=-1)[0]
                    nlls = -torch.log(S_probs)
                    metric = scatter_mean(nlls, s_batch_id)  # [batch_size]
                else:
                    metric = scatter_mean(prmsd[interface_cmask], interface_batch_id)  # [batch_size]

                update = metric < best_metric
                cupdate = cmask & update[batch_id]
                supdate = smask & update[batch_id]
                # update metric history
                best_metric[update] = metric[update]

                # 1. set generated part
                gen_X[cupdate] = pred_X[cupdate]
                if not self.struct_only:
                    gen_S[supdate] = pred_S[supdate]
            
                interface_X = r_interface_X[-1]
                # 2. align by cdr
                for i in range(batch_size):
                    if not update[i]:
                        continue
                    # 1. align CDRH3
                    is_cur_graph = batch_id == i
                    cdrh3_cur_graph = torch.logical_and(is_cur_graph, paratope_mask)
                    ori_cdr = gen_X[cdrh3_cur_graph][:, :4]  # backbone
                    pred_cdr = interface_X[interface_batch_id == i][:, :4]
                    _, R, t = kabsch_torch(ori_cdr.reshape(-1, 3), pred_cdr.reshape(-1, 3))

                    # 2. tranform antibody
                    is_cur_ab = is_cur_graph & is_ab
                    ab_X = torch.matmul(gen_X[is_cur_ab], R.T) + t
                    gen_X[is_cur_ab] = ab_X

            self._clean_batch_constants()

            if return_hidden:
                return gen_X, gen_S, metric, H
            return gen_X, gen_S, metric
    model = dyMEANModel(config.embed_dim, config.hidden_size, VOCAB.MAX_ATOM_NUMBER,
                VOCAB.get_num_amino_acid_type(), VOCAB.get_mask_idx(),
                config.k_neighbors, bind_dist_cutoff=config.bind_dist_cutoff,
                n_layers=config.n_layers, struct_only=config.struct_only,
                iter_round=config.iter_round,
                backbone_only=config.backbone_only,
                fix_channel_weights=config.fix_channel_weights,
                pred_edge_dist=not config.no_pred_edge_dist,
                keep_memory=not config.no_memory,
                cdr_type=config.cdr, paratope=config.paratope)
    
elif config.model_type == 'dyMEANOpt':
    # from trainer import dyMEANOptTrainer


    class dyMEANOptTrainer(Trainer):

        ########## Override start ##########

        def __init__(self, model, train_loader, valid_loader, config):
            self.global_step = 0
            self.epoch = 0
            self.max_step = config.max_epoch * config.step_per_epoch
            self.log_alpha = log(config.final_lr / config.lr) / self.max_step
            self.seq_warmup = config.seq_warmup
            super().__init__(model, train_loader, valid_loader, config)

        def get_optimizer(self):
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr)
            return optimizer

        def get_scheduler(self, optimizer):
            log_alpha = self.log_alpha
            lr_lambda = lambda step: exp(log_alpha * (step + 1))  # equal to alpha^{step}
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
            return {
                'scheduler': scheduler,
                'frequency': 'batch'
            }

        def train_step(self, batch, batch_idx):
            # batch['seq_alpha'] = min((self.epoch + 1) / (self.seq_warmup + 1), 1) # linear
            batch['seq_alpha'] = 1.0 - 1.0 * self.epoch / self.config.max_epoch
            return self.share_step(batch, batch_idx, val=False)

        def valid_step(self, batch, batch_idx):
            batch['seq_alpha'] = 1
            return self.share_step(batch, batch_idx, val=True)

        ########## Override end ##########

        def get_context_ratio(self):
            ratio = random() * 0.9
            return ratio

        def share_step(self, batch, batch_idx, val=False):
            del batch['paratope_mask']
            del batch['template']
            batch['context_ratio'] = self.get_context_ratio()
            loss, seq_detail, structure_detail, pdev_detail = self.model(**batch)
            snll, aar = seq_detail
            struct_loss, xloss, bond_loss, sc_bond_loss = structure_detail
            pdev_loss, prmsd_loss = pdev_detail

            log_type = 'Validation' if val else 'Train'

            self.log(f'Overall/Loss/{log_type}', loss, batch_idx, val)

            self.log(f'Seq/SNLL/{log_type}', snll, batch_idx, val)
            self.log(f'Seq/AAR/{log_type}', aar, batch_idx, val)

            self.log(f'Struct/StructLoss/{log_type}', struct_loss, batch_idx, val)
            self.log(f'Struct/XLoss/{log_type}', xloss, batch_idx, val)
            self.log(f'Struct/BondLoss/{log_type}', bond_loss, batch_idx, val)
            self.log(f'Struct/SidechainBondLoss/{log_type}', sc_bond_loss, batch_idx, val)

            if pdev_loss is not None:
                self.log(f'PDev/PDevLoss/{log_type}', pdev_loss, batch_idx, val)
                self.log(f'PDev/PRMSDLoss/{log_type}', prmsd_loss, batch_idx, val)

            if not val:
                lr = self.config.lr if self.scheduler is None else self.scheduler.get_last_lr()
                lr = lr[0]
                self.log('lr', lr, batch_idx, val)
                self.log('context_ratio', batch['context_ratio'], batch_idx, val)
                self.log('seq_alpha', batch['seq_alpha'], batch_idx, val)
            return loss
    
    '''
    Masked 1D & 3D language model
    Add noise to ground truth 3D coordination
    Add mask to 1D sequence
    '''
    class dyMEANOptModel(nn.Module):
        def __init__(self, embed_size, hidden_size, n_channel, num_classes,
                    mask_id=VOCAB.get_mask_idx(), k_neighbors=9, bind_dist_cutoff=6,
                    n_layers=3, iter_round=3, dropout=0.1, struct_only=False,
                    fix_atom_weights=False, cdr_type='H3', relative_position=False) -> None:
            super().__init__()
            self.mask_id = mask_id
            self.num_classes = num_classes
            self.bind_dist_cutoff = bind_dist_cutoff
            self.k_neighbors = k_neighbors
            self.round = iter_round
            self.cdr_type = cdr_type  # only to indicate the usage of the model

            atom_embed_size = embed_size // 4
            self.aa_feature = SeparatedAminoAcidFeature(
                embed_size, atom_embed_size,
                relative_position=relative_position,
                edge_constructor=EdgeConstructor,
                fix_atom_weights=fix_atom_weights)
            self.protein_feature = ProteinFeature()
            
            self.memory_ffn = nn.Sequential(
                nn.SiLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.SiLU(),
                nn.Linear(hidden_size, embed_size)
            )
            self.struct_only = struct_only
            if not struct_only:
                self.ffn_residue = nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(hidden_size, hidden_size),
                    nn.SiLU(),
                    nn.Linear(hidden_size, self.num_classes)
                )
            else:
                self.prmsd_ffn = nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(hidden_size, hidden_size),
                    nn.SiLU(),
                    nn.Linear(hidden_size, 1)
                )
            self.gnn = AMEGNN(
                embed_size, hidden_size, hidden_size, n_channel,
                channel_nf=atom_embed_size, radial_nf=hidden_size,
                in_edge_nf=0, n_layers=n_layers, residual=True,
                dropout=dropout, dense=False)
            
            # training related cache
            self.start_seq_training = False
            self.batch_constants = {}

        def init_mask(self, X, S, cmask, smask, init_noise):
            if not self.struct_only:
                S[smask] = self.mask_id
            coords = X[cmask]
            noise = torch.randn_like(coords) if init_noise is None else init_noise
            X = X.clone()
            X[cmask] = coords + noise
            return X, S

        def message_passing(self, X, S, residue_pos, batch_id, t, memory_H=None, smooth_prob=None, smooth_mask=None):
            # embeddings
            H_0, (ctx_edges, inter_edges), (atom_embeddings, atom_weights) = self.aa_feature(X, S, batch_id, self.k_neighbors, residue_pos, smooth_prob=smooth_prob, smooth_mask=smooth_mask)
            inter_edges = self._get_binding_edges(X, S, inter_edges)
            edges = torch.cat([ctx_edges, inter_edges], dim=1)

            if memory_H is not None:
                H_0 = H_0 + self.memory_ffn(memory_H)

            # update coordination of the global node
            X = self.aa_feature.update_globel_coordinates(X, S)

            H, pred_X = self.gnn(H_0, X, edges,
                                channel_attr=atom_embeddings,
                                channel_weights=atom_weights)


            pred_logits = None if self.struct_only else self.ffn_residue(H)

            return pred_logits, pred_X, H # [N, num_classes], [N, n_channel, 3], [N, hidden_size]
        
        @torch.no_grad()
        def _prepare_batch_constants(self, S, lengths):
            # generate batch id
            batch_id = torch.zeros_like(S)  # [N]
            batch_id[torch.cumsum(lengths, dim=0)[:-1]] = 1
            batch_id.cumsum_(dim=0)  # [N], item idx in the batch
            self.batch_constants['batch_id'] = batch_id
            self.batch_constants['batch_size'] = torch.max(batch_id) + 1

            segment_ids = self.aa_feature._construct_segment_ids(S)
            self.batch_constants['segment_ids'] = segment_ids

            # interface relatd
            is_ag = segment_ids == self.aa_feature.ag_seg_id
            self.batch_constants['is_ag'] = is_ag
        
        @torch.no_grad()
        def _get_binding_edges(self, X, S, inter_edges):
            atom_pos = self.aa_feature._construct_atom_pos(S)
            src_dst = inter_edges.T
            dist = X[src_dst]  # [Ef, 2, n_channel, 3]
            dist = dist[:, 0].unsqueeze(2) - dist[:, 1].unsqueeze(1)  # [Ef, n_channel, n_channel, 3]
            dist = torch.norm(dist, dim=-1)  # [Ef, n_channel, n_channel]
            pos_pad = atom_pos[src_dst] == self.aa_feature.atom_pos_pad_idx # [Ef, 2, n_channel]
            pos_pad = torch.logical_or(pos_pad[:, 0].unsqueeze(2), pos_pad[:, 1].unsqueeze(1))  # [Ef, n_channel, n_channel]
            dist = dist + pos_pad * 1e10  # [Ef, n_channel, n_channel]
            dist = torch.min(dist.reshape(dist.shape[0], -1), dim=1)[0]  # [Ef]
            is_binding = dist <= self.bind_dist_cutoff
            return src_dst[is_binding].T

        def _clean_batch_constants(self):
            self.batch_constants = {}

        def _forward(self, X, S, cmask, smask, residue_pos, init_noise=None):
            batch_id = self.batch_constants['batch_id']

            # mask sequence and add noise to ground truth coordinates
            X, S = self.init_mask(X, S, cmask, smask, init_noise)

            # update center
            X = self.aa_feature.update_globel_coordinates(X, S)

            # sequence and structure loss
            r_pred_S_logits, pred_S_dist = [], None
            memory_H = None
            # message passing
            for t in range(self.round):
                pred_S_logits, pred_X, H = self.message_passing(X, S, residue_pos, batch_id, t, memory_H, pred_S_dist, smask)
                r_pred_S_logits.append((pred_S_logits, smask))
                memory_H = H
                # 1. update X
                X = X.clone()
                X[cmask] = pred_X[cmask]
                X = self.aa_feature.update_globel_coordinates(X, S)

                if not self.struct_only:
                    # 2. update S
                    S = S.clone()
                    if t == self.round - 1:
                        S[smask] = torch.argmax(pred_S_logits[smask], dim=-1)
                    else:
                        pred_S_dist = torch.softmax(pred_S_logits[smask], dim=-1)

            if self.struct_only:
                # predicted rmsd
                prmsd = self.prmsd_ffn(H[cmask]).squeeze()  # [N_ab]
            else:
                prmsd = None

            return H, S, r_pred_S_logits, pred_X, prmsd

        def forward(self, X, S, cmask, smask, residue_pos, lengths, xloss_mask, context_ratio=0, seq_alpha=1):
            '''
            :param bind_ag: [N_bind], node idx of binding residues in antigen
            :param bind_ab: [N_bind], node idx of binding residues in antibody
            :param bind_ag_X: [N_bind, 3], coordinations of the midpoint of binding pairs relative to ag
            :param bind_ab_X: [N_bind, 3], coordinations of the midpoint of binding pairs relative to ab
            :param context_ratio: float, rate of context provided in masked sequence, should be [0, 1) and anneal to 0 in training
            :param seq_alpha: float, weight of SNLL, linearly increase from 0 to 1 at warmup phase
            '''
            # clone ground truth coordinates, sequence
            true_X, true_S = X.clone(), S.clone()

            # prepare constants
            self._prepare_batch_constants(S, lengths)
            batch_id = self.batch_constants['batch_id']

            # provide some ground truth for annealing sequence training
            if context_ratio > 0:
                not_ctx_mask = torch.rand_like(smask, dtype=torch.float) >= context_ratio
                smask = torch.logical_and(smask, not_ctx_mask)

            # get results
            H, pred_S, r_pred_S_logits, pred_X, prmsd = self._forward(X, S, cmask, smask, residue_pos)

            # sequence negtive log likelihood
            snll, total = 0, 0
            if not self.struct_only:
                for logits, mask in r_pred_S_logits:
                    snll = snll + F.cross_entropy(logits[mask], true_S[mask], reduction='sum')
                    total = total + mask.sum()
                snll = snll / total

            # coordination loss
            struct_loss, struct_loss_details, bb_rmsd, _ = self.protein_feature.structure_loss(pred_X, true_X, true_S, cmask, batch_id, xloss_mask, self.aa_feature)

            if self.struct_only:
                # predicted rmsd
                prmsd_loss = F.smooth_l1_loss(prmsd, bb_rmsd)
                pdev_loss = prmsd_loss# + prmsd_i_loss
            else:
                pdev_loss, prmsd_loss = None, None

            # comprehensive loss
            loss = seq_alpha * snll + struct_loss + (0 if pdev_loss is None else pdev_loss)

            self._clean_batch_constants()

            # AAR
            with torch.no_grad():
                aa_hit = pred_S[smask] == true_S[smask]
                aar = aa_hit.long().sum() / aa_hit.shape[0]

            return loss, (snll, aar), (struct_loss, *struct_loss_details), (pdev_loss, prmsd_loss)

        def sample(self, X, S, cmask, smask, residue_pos, lengths, return_hidden=False, init_noise=None):
            gen_X, gen_S = X.clone(), S.clone()
            
            # prepare constants
            self._prepare_batch_constants(S, lengths)

            batch_id = self.batch_constants['batch_id']
            batch_size = self.batch_constants['batch_size']
            s_batch_id = batch_id[smask]

            # generate
            H, pred_S, r_pred_S_logits, pred_X, _ = self._forward(X, S, cmask, smask, residue_pos, init_noise)

            # PPL
            if not self.struct_only:
                S_logits = r_pred_S_logits[-1][0][smask]
                S_dists = torch.softmax(S_logits, dim=-1)
                pred_S[smask] = torch.multinomial(S_dists, num_samples=1).squeeze()
                S_probs = S_dists[torch.arange(s_batch_id.shape[0], device=S_dists.device), pred_S[smask]]
                nlls = -torch.log(S_probs)
                ppl = scatter_mean(nlls, s_batch_id)  # [batch_size]
            else:
                ppl = torch.zeros(batch_size, device=pred_S.device)

            # 1. set generated part
            gen_X[cmask] = pred_X[cmask]
            if not self.struct_only:
                gen_S[smask] = pred_S[smask]
            
            self._clean_batch_constants()

            if return_hidden:
                return gen_X, gen_S, ppl, H
            return gen_X, gen_S, ppl

        def optimize_sample(self, X, S, cmask, smask, residue_pos, lengths, predictor, opt_steps=10, init_noise=None):
            self._prepare_batch_constants(S, lengths)
            batch_id = self.batch_constants['batch_id']
            batch_size = self.batch_constants['batch_size']
            # noise_batch_id = batch_id[smask].unsqueeze(1).repeat(1, X.shape[1] * X.shape[2]).flatten()
            noise_batch_id = batch_id[cmask].unsqueeze(1).repeat(1, X.shape[1] * X.shape[2]).flatten()

            final_X, final_S = X.clone(), S.clone()
            best_metric = torch.ones(batch_size, dtype=torch.float, device=X.device) * 1e10

            all_noise = torch.randn_like(X, requires_grad=False)
            # init_noise = torch.randn_like(X[smask], requires_grad=True)
            init_noise = torch.randn_like(X[cmask], requires_grad=True)
            optimizer = torch.optim.Adam([init_noise], lr=1.0)
            optimizer.zero_grad()
            
            for i in range(opt_steps):
                all_noise = all_noise.detach()
                X, S, cmask, smask, residue_pos, lengths = X.clone(), S.clone(), cmask.clone(), smask.clone(), residue_pos.clone(), lengths.clone()
                # all_noise[smask] = init_noise
                all_noise[cmask] = init_noise
                gen_X, gen_S, _, H = self.sample(X, S, cmask, smask, residue_pos, lengths, return_hidden=True, init_noise=all_noise[cmask])
                h = scatter_mean(H, batch_id, dim=0)
                pmetric = predictor.inference(h)

                # use KL to regularize noise
                mean = scatter_mean(init_noise.flatten(), noise_batch_id)  # [bs]
                std = scatter_std(init_noise.flatten(), noise_batch_id)
                # std, mean = torch.std_mean(init_noise.flatten())
                kl = -0.5 * (1 + 2 * torch.log(std) - std ** 2 - mean ** 2)

                (pmetric + kl).sum().backward()
                pmetric = pmetric.detach()
                optimizer.step()
                optimizer.zero_grad()

                with torch.no_grad():
                    update = pmetric < best_metric
                    cupdate = cmask & update[batch_id]
                    supdate = smask & update[batch_id]
                    # update pmetric best history
                    best_metric[update] = pmetric[update]

                    final_X[cupdate] = gen_X[cupdate].detach()
                    if not self.struct_only:
                        final_S[supdate] = gen_S[supdate].detach()
                
            return final_X, final_S, best_metric
    model = dyMEANOptModel(config.embed_dim, config.hidden_size, VOCAB.MAX_ATOM_NUMBER,
                VOCAB.get_num_amino_acid_type(), VOCAB.get_mask_idx(),
                config.k_neighbors, bind_dist_cutoff=config.bind_dist_cutoff,
                n_layers=config.n_layers, struct_only=config.struct_only,
                fix_atom_weights=config.fix_channel_weights, cdr_type=config.cdr)

else:
    raise NotImplemented(f'model {config.model_type} not implemented')

step_per_epoch = (len(train_set) + config.batch_size - 1) // config.batch_size
config.add_parameter(step_per_epoch=step_per_epoch)

if len(config.gpus) > 1:
    config.local_rank = int(-1)
    torch.cuda.set_device(config.local_rank)
    torch.distributed.init_process_group(backend='nccl', world_size=len(config.gpus))
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=config.shuffle)
    config.batch_size = int(config.batch_size / len(config.gpus))
    if config.local_rank == 0:
        print_log(f'Batch size on a single GPU: {config.batch_size}')
else:
    config.local_rank = -1
    train_sampler = None
config.local_rank = config.local_rank

if config.local_rank == 0 or config.local_rank == -1:
    print_log(f'step per epoch: {step_per_epoch}')

train_loader = DataLoader(train_set, batch_size=config.batch_size,
                            num_workers=config.num_workers,
                            shuffle=(config.shuffle and train_sampler is None),
                            sampler=train_sampler,
                            collate_fn=collate_fn)
valid_loader = DataLoader(valid_set, batch_size=config.batch_size,
                            num_workers=config.num_workers,
                            collate_fn=collate_fn)

if config.model_type == 'dyMEAN':
    trainer = dyMEANTrainer(model, train_loader, valid_loader, config)
elif config.model_type == 'dyMEANOpt':
    from trainer import dyMEANOptTrainer
    trainer = dyMEANOptTrainer(model, train_loader, valid_loader, config)
else:
    raise NotImplemented(f'model {config.model_type} not implemented')

trainer.train(config.gpus, config.local_rank)



2023-10-10 02:02:10::INFO::<__main__.Config object at 0x151f2ab7d0d0>
2023-10-10 02:02:10::INFO::CDR type: ['H3']
2023-10-10 02:02:10::INFO::Paratope: H3
2023-10-10 02:02:10::INFO::sequence & structure codesign
till here
2023-10-10 02:02:10::INFO::Loading preprocessed file /home/dagaa/Projects/dyMEAN/all_data/RAbD/train_processed/part_0.pkl, 1/1
2023-10-10 02:02:46::INFO::No meta-info file found, start processing


 14%|█▍        | 48/339 [00:25<02:37,  1.85it/s]


KeyboardInterrupt: 

In [None]:
# !GPU=0,1 bash scripts/train/train.sh scripts/train/configs/single_cdr_design.json