In [2]:
import torch
import numpy as np
import json
from prody import *
from sidechainnet.utils.measure import *

In [3]:
RES_ATOM14 = [
    [''] * 14,
    ['N', 'CA', 'C', 'O', 'CB', '',    '',    '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'NE',  'CZ',  'NH1', 'NH2', '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'ND2', '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'OD2', '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'SG',  '',    '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'NE2', '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'OE2', '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', '',   '',    '',    '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'ND1', 'CD2', 'CE1', 'NE2', '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'CE',  'NZ',  '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'SD',  'CE',  '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'OG',  '',    '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  'OH',  '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '',    '',    '',    '',    '',    '',    ''],
]

ALPHABET = ['#', 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']

ATOM_TYPES = [
    '', 'N', 'CA', 'C', 'O', 'CB', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
    'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
    'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
    'CZ3', 'NZ', 'OXT'
]

In [4]:
def tocdr(resseq):
    if 27 <= resseq <= 38:
        return '1'
    elif 56 <= resseq <= 65:
        return '2'
    elif 105 <= resseq <= 117:
        return '3'
    else:
        return '0'

In [5]:
hchain = parsePDB('data/1nca_imgt.pdb', model=1, chain='H')
hchain = hchain.select('not water').copy()

_, hcoords, hseq, _, _ = get_seq_coords_and_angles(hchain)
hcdr = ''.join([tocdr(res.getResnum()) for res in hchain.iterResidues()])
hcoords = hcoords.reshape((len(hseq), 14, 3)) # reshaped to (length of chain, number of types atoms (14), 3d coords)

In [6]:
achain = parsePDB('data/1nca_imgt.pdb', model=1, chain='N')
achain = achain.select('not water').copy()

_, acoords, aseq, _, _ = get_seq_coords_and_angles(achain)
acoords = acoords.reshape((len(aseq), 14, 3)) # reshaped to (length of chain, number of types atoms (14), 3d coords)

In [7]:
complex_1nca = {
    'pdb': '1cna',
    'antibody_seq': hseq, # heavy chain sequence
    'antibody_cdr': hcdr, # imgt cdr numbering
    'antibody_coords': hcoords, # heavy chain coordinates
    'antigen_seq': aseq, # antigen sequence
    'antigen_coords': acoords # antigen coordinates
}
entry = {}

In [8]:
# residue indices of the CDR3 region
surface = torch.tensor([i for i,v in enumerate(complex_1nca['antibody_cdr']) if v in '3'])
entry['binder_surface'] = surface

# FASTA sequence of the CDR3 region
entry['binder_seq'] = ''.join([complex_1nca['antibody_seq'][i] for i in surface.tolist()])

# coordinates of the cdr3 region
entry['binder_coords'] = torch.tensor(complex_1nca['antibody_coords'])[surface]

# convert to indices of atom types of the CDR3 region
entry['binder_atypes'] = torch.tensor(
                [[ATOM_TYPES.index(a) for a in RES_ATOM14[ALPHABET.index(s)]] for s in entry['binder_seq']]
)

# binary representation of the atom types (is atom rep. by a 1, no atom = 0)
mask = (entry['binder_coords'].norm(dim=-1) > 1e-6).long()

# redundancy check
entry['binder_atypes'] *= mask

In [9]:
# same things for the antigen target
entry['target_seq'] = complex_1nca['antigen_seq']
entry['target_coords'] = torch.tensor(complex_1nca['antigen_coords'])
entry['target_atypes'] = torch.tensor(
        [[ATOM_TYPES.index(a) for a in RES_ATOM14[ALPHABET.index(s)]] for s in entry['target_seq']]
)
mask = (entry['target_coords'].norm(dim=-1) > 1e-6).long()
entry['target_atypes'] *= mask

In [48]:
X = entry['target_coords'][None,...]
Y = entry['binder_coords'][None,...]
XA = entry['target_atypes'][None,...]
YA = entry['binder_atypes'][None,...]

# number of batches, length of antigen, length of cdr3, number of types atoms (14)
B, N, M, L = X.size(0), X.size(1), Y.size(1), Y.size(2)

X = X.view(B, N * L, 3) # flatten 2nd dimension (1, length of antigen * number of types atoms (14), coords)
Y = Y.view(B, M * L, 3) # flatten 2nd dimension (1, length of cdr3 * number of types atoms (14), coords)

# element wise distance between cdr3 and antigen residues
dxy = X.unsqueeze(2) - Y.unsqueeze(1)  # [B, NL, 1, 3] - [B, 1, ML, 3]

# sum of element wise squared distance
D = torch.sum(dxy ** 2, dim=-1)

# distance between each atom of each antigen residue to each atom of the cdr3
D = D.view(B, N, L, M, L)

# transposed and added
D = D.transpose(2, 3).reshape(B, N, M, L*L)

xmask = XA.clamp(max=1).float().view(B, N * L)
ymask = YA.clamp(max=1).float().view(B, M * L)
mask = xmask.unsqueeze(2) * ymask.unsqueeze(1)  # [B, NL, 1] x [B, 1, ML]
mask = mask.view(B, N, L, M, L)
mask = mask.transpose(2, 3).reshape(B, N, M, L*L)

D = D + 1e6 * (1 - mask)
dist = D.amin(dim=-1)

In [58]:
K = min(len(dist[0]), 20)
epitope = dist[0].amin(dim=-1).topk(k=K, largest=False).indices

# epitope of the antigen
entry['target_surface'] = torch.sort(epitope).values

In [60]:
entry

{'binder_surface': tensor([ 96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108]),
 'binder_seq': 'ARGEDNFGSLSDY',
 'binder_coords': tensor([[[ 58.3790,  34.3380, 105.0240],
          [ 57.4110,  34.1380, 106.0880],
          [ 58.2880,  33.8790, 107.2830],
          [ 59.4130,  33.3850, 107.1210],
          [ 56.5710,  32.8980, 105.9230],
          [  0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000]],
 
         [[ 57.8190,  34.2930, 108.4550],
          [ 58.5060,  34.0420, 109.7220],
          [ 58.0590,  32.6460, 110.0510],
          [ 56.9070,  32.2850, 109.8150],
          [ 58.0270,  34.9380, 110.8430],
          [ 58.5700,  34.7540, 112.2580],
          [ 57