In [53]:
from preprocess import load_pdb, ALPHABET
import torch
import torch.nn as nn
import py3Dmol
from prody import *
from sidechainnet.structure.PdbBuilder import PdbBuilder

In [39]:
entry = load_pdb('../data/1nca_imgt.pdb')

In [56]:
class AAEmbedding(nn.Module):
    def __init__(self):
        super(AAEmbedding, self).__init__()
        self.hydropathy = {'#': 0,
                           "I":4.5,
                           "V":4.2,
                           "L":3.8,
                           "F":2.8,
                           "C":2.5,
                           "M":1.9,
                           "A":1.8,
                           "W":-0.9,
                           "G":-0.4,
                           "T":-0.7,
                           "S":-0.8,
                           "Y":-1.3,
                           "P":-1.6,
                           "H":-3.2,
                           "N":-3.5,
                           "D":-3.5,
                           "Q":-3.5,
                           "E":-3.5,
                           "K":-3.9,
                           "R":-4.5}
        self.volume = {'#': 0,
                       "G":60.1,
                       "A":88.6,
                       "S":89.0,
                       "C":108.5,
                       "D":111.1,
                       "P":112.7,
                       "N":114.1,
                       "T":116.1,
                       "E":138.4,
                       "V":140.0,
                       "Q":143.8,
                       "H":153.2,
                       "M":162.9,
                       "I":166.7,
                       "L":166.7,
                       "K":168.6,
                       "R":173.4,
                       "F":189.9,
                       "Y":193.6,
                       "W":227.8}
        self.charge = {**{'R':1, 'K':1, 'D':-1, 'E':-1, 'H':0.1}, **{x:0 for x in 'ABCFGIJLMNOPQSTUVWXYZ#'}}
        self.polarity = {**{x:1 for x in 'RNDQEHKSTY'}, **{x:0 for x in "ACGILMFPWV#"}}
        self.acceptor = {**{x:1 for x in 'DENQHSTY'}, **{x:0 for x in "RKWACGILMFPV#"}}
        self.donor = {**{x:1 for x in 'RKWNQHSTY'}, **{x:0 for x in "DEACGILMFPV#"}}
        self.embedding = torch.tensor([
            [self.hydropathy[aa], self.volume[aa] / 100, self.charge[aa],
            self.polarity[aa], self.acceptor[aa], self.donor[aa]]
            for aa in ALPHABET
        ])

    def to_rbf(self, D, D_min, D_max, stride):
        D_count = int((D_max - D_min) / stride)
        D_mu = torch.linspace(D_min, D_max, D_count)
        D_mu = D_mu.view(1,1,-1)  # [1, 1, K]
        D_expand = torch.unsqueeze(D, -1)  # [B, N, 1]
        return torch.exp(-((D_expand - D_mu) / stride) ** 2)

    def transform(self, aa_vecs):
        return torch.cat([
            self.to_rbf(aa_vecs[:, :, 0], -4.5, 4.5, 0.1),
            self.to_rbf(aa_vecs[:, :, 1], 0, 2.2, 0.1),
            self.to_rbf(aa_vecs[:, :, 2], -1.0, 1.0, 0.25),
            torch.sigmoid(aa_vecs[:, :, 3:] * 6 - 3),
        ], dim=-1)

    def dim(self):
        return 90 + 22 + 8 + 3

    def forward(self, x, raw=False):
        B, N = x.size(0), x.size(1)
        aa_vecs = self.embedding[x.view(-1)].view(B, N, -1)
        rbf_vecs = self.transform(aa_vecs)
        return aa_vecs if raw else rbf_vecs

    def soft_forward(self, x):
        B, N = x.size(0), x.size(1)
        aa_vecs = torch.matmul(x.reshape(B * N, -1), self.embedding).view(B, N, -1)
        rbf_vecs = self.transform(aa_vecs)
        return rbf_vecs

In [58]:
def view_cdr3(entry):
    cdr3_len = entry['binder_coords'].shape[0]
    coords = entry['binder_coords'].reshape(cdr3_len * 14, 3)

    pdb = PdbBuilder(entry['binder_seq'], coords).get_pdb_string()
    
    view = py3Dmol.view(width=400, height=300)
    view.addModelsAsFrames(pdb)
    view.setStyle({'model': -1}, {"cartoon": {'color': 'spectrum'}})
    view.zoomTo()
    view.show()

In [68]:
true_X = entry['binder_coords'].unsqueeze(0)
true_S = torch.tensor([ALPHABET.index(a) for a in entry['binder_seq']]).unsqueeze(0)
true_A = entry['binder_atypes'].unsqueeze(0)

target_X = entry['target_coords'].unsqueeze(0)
target_S = torch.tensor([ALPHABET.index(a) for a in entry['target_seq']]).unsqueeze(0)
target_A = entry['target_atypes'].unsqueeze(0)

target_surface = entry['target_surface'].unsqueeze(0)

In [79]:
embedding = AAEmbedding()
embedding(target_S)[0][0]

tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+