In [22]:
from tqdm import tqdm
import json
import torch

In [36]:
RES_ATOM14 = [
    [''] * 14,
    ['N', 'CA', 'C', 'O', 'CB', '',    '',    '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'NE',  'CZ',  'NH1', 'NH2', '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'ND2', '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'OD2', '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'SG',  '',    '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'NE2', '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'OE2', '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', '',   '',    '',    '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'ND1', 'CD2', 'CE1', 'NE2', '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'CE',  'NZ',  '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'SD',  'CE',  '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'OG',  '',    '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  'OH',  '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '',    '',    '',    '',    '',    '',    ''],
]

In [37]:
ALPHABET = ['#', 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']

In [38]:
RESTYPE_1to3 = {
    "A": "ALA",
    "R": "ARG",
    "N": "ASN",
    "D": "ASP",
    "C": "CYS",
    "Q": "GLN",
    "E": "GLU",
    "G": "GLY",
    "H": "HIS",
    "I": "ILE",
    "L": "LEU",
    "K": "LYS",
    "M": "MET",
    "F": "PHE",
    "P": "PRO",
    "S": "SER",
    "T": "THR",
    "W": "TRP",
    "Y": "TYR",
    "V": "VAL"
}

In [39]:
ATOM_TYPES = [
    '', 'N', 'CA', 'C', 'O', 'CB', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
    'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
    'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
    'CZ3', 'NZ', 'OXT'
]

In [40]:
jsonl_file = 'data/sample.json'
cdr_type = '3'
L_target = 20

In [52]:
def full_square_dist(X, Y, XA, YA, contact=False, remove_diag=False):
    B, N, M, L = X.size(0), X.size(1), Y.size(1), Y.size(2)
    print(B, N, M, L)
    X = X.view(B, N * L, 3)
    Y = Y.view(B, M * L, 3)
    dxy = X.unsqueeze(2) - Y.unsqueeze(1)  # [B, NL, 1, 3] - [B, 1, ML, 3]
    D = torch.sum(dxy ** 2, dim=-1)
    D = D.view(B, N, L, M, L)
    D = D.transpose(2, 3).reshape(B, N, M, L*L)

    xmask = XA.clamp(max=1).float().view(B, N * L)
    ymask = YA.clamp(max=1).float().view(B, M * L)
    mask = xmask.unsqueeze(2) * ymask.unsqueeze(1)  # [B, NL, 1] x [B, 1, ML]
    mask = mask.view(B, N, L, M, L)
    mask = mask.transpose(2, 3).reshape(B, N, M, L*L)
    if remove_diag:
        mask = mask * (1 - torch.eye(N)[None,:,:,None]).to(mask)

    if contact:
        D = D + 1e6 * (1 - mask)
        return D.amin(dim=-1), mask.amax(dim=-1)
    else:
        return D, mask

In [53]:
data = []
with open(jsonl_file) as f:
    all_lines = f.readlines()
    for line in tqdm(all_lines):
        entry = json.loads(line)
        assert len(entry['antibody_coords']) == len(entry['antibody_seq'])
        assert len(entry['antigen_coords']) == len(entry['antigen_seq'])
        if entry['antibody_cdr'].count(cdr_type) <= 4:
            continue

        # paratope region
        surface = torch.tensor(
                [i for i,v in enumerate(entry['antibody_cdr']) if v in cdr_type]
        )
        entry['binder_surface'] = surface

        entry['binder_seq'] = ''.join([entry['antibody_seq'][i] for i in surface.tolist()])
        entry['binder_coords'] = torch.tensor(entry['antibody_coords'])[surface]
        entry['binder_atypes'] = torch.tensor(
                [[ATOM_TYPES.index(a) for a in RES_ATOM14[ALPHABET.index(s)]] for s in entry['binder_seq']]
        )
        mask = (entry['binder_coords'].norm(dim=-1) > 1e-6).long()
        entry['binder_atypes'] *= mask

        # Create target
        entry['target_seq'] = entry['antigen_seq']
        entry['target_coords'] = torch.tensor(entry['antigen_coords'])
        entry['target_atypes'] = torch.tensor(
                [[ATOM_TYPES.index(a) for a in RES_ATOM14[ALPHABET.index(s)]] for s in entry['target_seq']]
        )
        mask = (entry['target_coords'].norm(dim=-1) > 1e-6).long()
        entry['target_atypes'] *= mask

        # Find target surface
        dist, _ = full_square_dist(
                entry['target_coords'][None,...], 
                entry['binder_coords'][None,...], 
                entry['target_atypes'][None,...], 
                entry['binder_atypes'][None,...], 
                contact=True
        )
        K = min(len(dist[0]), L_target)
        epitope = dist[0].amin(dim=-1).topk(k=K, largest=False).indices
        entry['target_surface'] = torch.sort(epitope).values

        if len(entry['binder_coords']) > 4 and len(entry['target_coords']) > 4 and entry['antibody_cdr'].count('001') <= 1:
            data.append(entry)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 10.58it/s]

1 389 13 14





In [51]:
data[0]['binder_seq']

'ARGEDNFGSLSDY'