In [1]:
from Bio.PDB import PDBParser, PDBIO
from Bio.PDB.Structure import Structure as BStructure
from Bio.PDB.Model import Model as BModel
from Bio.PDB.Chain import Chain as BChain
from Bio.PDB.Residue import Residue as BResidue
from Bio.PDB.Atom import Atom as BAtom

In [2]:
import json
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
from multiprocessing import Pool
from tqdm import tqdm

from utils import *
from dataset import *

Random seed set as 42
Random seed set as 42


# egnn

In [3]:
import egnn_clean as eg
import torch

# Dummy parameters
batch_size = 8
n_nodes = 4
n_feat = 1
x_dim = 3

# Dummy variables h, x and fully connected edges
h = torch.ones(batch_size * n_nodes, n_feat)
x = torch.ones(batch_size * n_nodes, x_dim)
edges, edge_attr = eg.get_edges_batch(n_nodes, batch_size)

# Initialize EGNN
egnn = eg.EGNN(in_node_nf=n_feat, hidden_nf=32, out_node_nf=1, in_edge_nf=1)

# Run EGNN
h, x = egnn(h, x, edges, edge_attr)

In [4]:
h.shape, x.shape

(torch.Size([32, 1]), torch.Size([32, 3]))

# eghn

In [None]:
from eghn import *



# dict data EDA

In [5]:
data = pickle.load(open("./data/data.json", "rb"))
len(data)

5374

In [6]:
data[0]

{'pdb': '7t17',
 'Hchain': 'J',
 'Lchain': 'K',
 'Achain': 'C',
 'Hseq': ['QVQLQESGPGLVKPSQTLSLTCAVSGGSISSGDSYWSWIRQHPGKGLEWIGSIYYSGSTYYNPSLKSRVTIPIDTSKNQFSLKLSSVTAADTAVYYCARHVGDLRVNDAFDIWGQGTMVTVSS'],
 'Lseq': ['QSVLTQPPSVSAAPGQKVTISCSGSSSNIGNNFVSWYQRLPGTPPKLLIYDSDKRPSGIPDRFSGSKSGTSATLGITGLQTGDEGDYYCGTWDRSLSVVVFGGGTKLTVL'],
 'Aseq': ['IRCIGVSNRDFVEGMSGGTWVDVVLEHGGCVTVMAQDKPTVDIELVTTTVSNMAEVRSYCYEASISDMASDSRCPTQGEAYLDKQSDTQYVCKRTLVDRGWGNGCGLFGKGSLVTCAKFACSKKMTGKSIQPENLEYRIMLSVHGSQHSGMIVNDTGHETDENRAKVEITPNSPRAEATLGGFGSLGLDCEPRTGLDFSDLYYLTMNNKHWLVHKEWFHDIPLPWHAGADTGTPHWNNKEALVEFKDAHAKRQTVVVLGSQEGAVHTALAGALEAEMDGAKGRLSSGHLKCRLKMDKLRLKGVSYSLCTAAFTFTKIPAETLHGTVTVEVQYAGTDGPCKVPAQMAVDMQTLTPVGRLITANPVITESTENSKMMLELDPPFGDSYIVIGVG'],
 'L1': 'SSNIGNNF',
 'L2': 'DSD',
 'L3': 'GTWDRSLSVVV',
 'H1': 'GGSISSGDSY',
 'H2': 'IYYSGST',
 'H3': 'ARHVGDLRVNDAFDI',
 'Hpos': [array([[ 57.69 , -27.721, 213.695],
         [ 56.713, -27.938, 214.755],
         [ 56.848, -26.882, 215.848],
         [ 57.955, -26.4

In [7]:
data[0].keys()

dict_keys(['pdb', 'Hchain', 'Lchain', 'Achain', 'Hseq', 'Lseq', 'Aseq', 'L1', 'L2', 'L3', 'H1', 'H2', 'H3', 'Hpos', 'Lpos', 'Apos'])

In [8]:
len(data[0]["Aseq"][0]), len(data[0]["Apos"][0])

(392, 392)

In [9]:
len(data[0]["Hseq"][0]), len(data[0]["Hpos"])

(123, 123)

In [10]:
def get_mask(seq, substrs, is_Hchain=True):
    # seq: [Hseq]
    # substrs: [H1, H2, H3]
    # return [mask_H1, mask_H2, mask_H3]
    mask = [0] * len(seq)
    m = 1
    for substr in substrs:
        start = seq.find(substr)
        end = start+len(substr)
        for idx in range(len(mask)):
            if idx>=start and idx<end:
                mask[idx] = m
        m += 1
    
    return mask

In [11]:
data[0]["Hseq"]

['QVQLQESGPGLVKPSQTLSLTCAVSGGSISSGDSYWSWIRQHPGKGLEWIGSIYYSGSTYYNPSLKSRVTIPIDTSKNQFSLKLSSVTAADTAVYYCARHVGDLRVNDAFDIWGQGTMVTVSS']

In [12]:
i = 0
Hmask = get_mask(seq=data[i]["Hseq"][0], substrs=[data[i]["H1"], data[i]["H2"], data[i]["H3"]])
# mask

In [13]:
"".join([data[0]["Hseq"][0][idx] if Hmask[idx]==1 else "" for idx in range(len(Hmask))])==data[0]["H1"], data[0]["H1"]

(True, 'GGSISSGDSY')

In [14]:
"".join([data[0]["Hseq"][0][idx] if Hmask[idx]==2 else "" for idx in range(len(Hmask))])==data[0]["H2"], data[0]["H2"]

(True, 'IYYSGST')

In [15]:
"".join([data[0]["Hseq"][0][idx] if Hmask[idx]==3 else "" for idx in range(len(Hmask))])==data[0]["H3"], data[0]["H3"]

(True, 'ARHVGDLRVNDAFDI')

In [16]:
len(data[0]["Hpos"]), data[0]["Hpos"][0]

(123,
 array([[ 57.69 , -27.721, 213.695],
        [ 56.713, -27.938, 214.755],
        [ 56.848, -26.882, 215.848],
        [ 57.955, -26.464, 216.188]], dtype=float32))

In [17]:
data[0]["Hpos"][0:3]

[array([[ 57.69 , -27.721, 213.695],
        [ 56.713, -27.938, 214.755],
        [ 56.848, -26.882, 215.848],
        [ 57.955, -26.464, 216.188]], dtype=float32),
 array([[ 55.708, -26.454, 216.394],
        [ 55.725, -25.461, 217.46 ],
        [ 56.083, -26.137, 218.776],
        [ 55.473, -27.138, 219.173]], dtype=float32),
 array([[ 57.085, -25.589, 219.46 ],
        [ 57.561, -26.15 , 220.715],
        [ 57.666, -25.045, 221.754],
        [ 58.003, -23.902, 221.438]], dtype=float32)]

In [18]:
a = [[1,2,3], [0,0,0]]
b = [[4,5,6], [6,5,4]]

a + [[0,0,0]] + b

[[1, 2, 3], [0, 0, 0], [0, 0, 0], [4, 5, 6], [6, 5, 4]]

In [19]:
a = ["a", "b"]
b = ["c", "d"]

a + ["zzzz"] + b

['a', 'b', 'zzzz', 'c', 'd']

In [20]:
Hpos = [data[i]["Hpos"][idx] if Hmask[idx]==1 else np.zeros((4,3)) for idx in range(len(Hmask))]
Hpos

[array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 array([[0

In [21]:
data[0]["Hpos"][0].shape

(4, 3)

In [22]:
len(data[0]["Lseq"][0]), len(data[0]["Lpos"])

(110, 110)

In [23]:
data[0]["Lpos"][0]

array([[ 37.846, -27.322, 244.186],
       [ 36.957, -28.308, 243.585],
       [ 37.441, -28.704, 242.194],
       [ 36.769, -29.451, 241.482]], dtype=float32)

In [24]:
len(data[0]["Aseq"][0]), len(data[0]["Apos"])

(392, 1)

In [25]:
len(data[0]["Apos"][0]), data[0]["Apos"][0]

(392,
 [array([[ 81.858,  -7.199, 197.021],
         [ 81.735,  -6.449, 198.265],
         [ 80.284,  -6.421, 198.728],
         [ 79.74 ,  -5.362, 199.036]], dtype=float32),
  array([[ 79.66 ,  -7.596, 198.767],
         [ 78.278,  -7.837, 199.157],
         [ 77.269,  -7.526, 198.055],
         [ 76.075,  -7.745, 198.261]], dtype=float32),
  array([[ 77.695,  -7.022, 196.897],
         [ 76.793,  -6.774, 195.785],
         [ 77.323,  -7.296, 194.458],
         [ 76.538,  -7.469, 193.52 ]], dtype=float32),
  array([[ 78.63 ,  -7.552, 194.352],
         [ 79.196,  -8.074, 193.112],
         [ 78.916,  -9.554, 192.915],
         [ 78.921, -10.028, 191.772]], dtype=float32),
  array([[ 78.667, -10.296, 193.99 ],
         [ 78.411, -11.718, 193.885],
         [ 76.953, -12.063, 193.664],
         [ 76.593, -13.242, 193.593]], dtype=float32),
  array([[ 76.106, -11.047, 193.552],
         [ 74.68 , -11.249, 193.323],
         [ 74.461, -11.614, 191.862],
         [ 75.008, -10.971, 190.958

In [26]:
Apos = [np.zeros((4,3)).astype(np.float32) if idx%2==1 else ele for idx,ele in enumerate(data[0]["Apos"])]

In [27]:
x = [np.ones((4,3)), np.ones((4,3)), np.ones((4,3)), np.ones((4,3))]
[np.zeros((4,3)).astype(np.float32) if idx%2==1 else ele for idx,ele in enumerate(x)]

[array([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]),
 array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=float32),
 array([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]),
 array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=float32)]

In [28]:
data[0]["Lpos"][0]

array([[ 37.846, -27.322, 244.186],
       [ 36.957, -28.308, 243.585],
       [ 37.441, -28.704, 242.194],
       [ 36.769, -29.451, 241.482]], dtype=float32)

# function "process" demo

In [5]:
dict_data = pickle.load(open("./data/data.json", "rb"))
len(dict_data)

5374

In [6]:
dict_data[0].keys()

dict_keys(['pdb', 'Hchain', 'Lchain', 'Achain', 'Hseq', 'Lseq', 'Aseq', 'L1', 'L2', 'L3', 'H1', 'H2', 'H3', 'Hpos', 'Lpos', 'Apos'])

In [7]:
keys = [dict_data[i]["pdb"] for i in range(len(dict_data))]

def countElement(sample_list, element):
    return sample_list.count(element)

non_repeated = 0
for i in range(len(keys)):
    element = keys[i]
    t = countElement(keys, element)
    print('{} has occurred {} times'.format(element, t))
    
    if t==1:
        non_repeated += 1
non_repeated

7t17 has occurred 2 times
7t17 has occurred 2 times
4o51 has occurred 4 times
5w08 has occurred 6 times
7lg6 has occurred 6 times
7t6s has occurred 1 times
5mi0 has occurred 1 times
2ny3 has occurred 1 times
4yo0 has occurred 2 times
7ugq has occurred 6 times
3qa3 has occurred 4 times
6x1t has occurred 6 times
1nca has occurred 1 times
4ypg has occurred 2 times
7mxl has occurred 3 times
6hga has occurred 1 times
6wj1 has occurred 3 times
3dvn has occurred 2 times
6mph has occurred 9 times
2vxs has occurred 4 times
6oor has occurred 1 times
7v4w has occurred 1 times
5bk2 has occurred 2 times
7bh8 has occurred 2 times
5hys has occurred 4 times
6x87 has occurred 7 times
7rgp has occurred 1 times
6o2b has occurred 4 times
2ny0 has occurred 1 times
7rlw has occurred 2 times
6vi0 has occurred 3 times
7um3 has occurred 2 times
6myy has occurred 3 times
7t4r has occurred 4 times
6xox has occurred 1 times
6bdz has occurred 1 times
6qb3 has occurred 1 times
7dc8 has occurred 2 times
7yqz has occ

7ws0 has occurred 3 times
7re7 has occurred 1 times
7tl0 has occurred 6 times
6pds has occurred 2 times
2ck0 has occurred 1 times
5w6d has occurred 2 times
4xpa has occurred 1 times
1n5y has occurred 1 times
4m48 has occurred 1 times
5ea0 has occurred 1 times
4f15 has occurred 4 times
2oqj has occurred 4 times
4qww has occurred 2 times
2itc has occurred 1 times
7f4f has occurred 1 times
6fy1 has occurred 2 times
7e9b has occurred 1 times
7lbf has occurred 2 times
3hi6 has occurred 2 times
4okv has occurred 2 times
7rly has occurred 3 times
6jhr has occurred 1 times
5czx has occurred 2 times
7umm has occurred 3 times
7wyb has occurred 1 times
5ggs has occurred 2 times
6ban has occurred 4 times
5t5n has occurred 5 times
3ogc has occurred 1 times
5ig7 has occurred 4 times
4kxz has occurred 4 times
6ubi has occurred 2 times
3cxd has occurred 1 times
5anm has occurred 3 times
1u93 has occurred 1 times
4oii has occurred 2 times
1v7n has occurred 4 times
6vmk has occurred 16 times
7lja has oc

7m7e has occurred 1 times
7eo0 has occurred 1 times
6hjp has occurred 3 times
5l0q has occurred 2 times
2b1a has occurred 1 times
6fxn has occurred 6 times
4lqf has occurred 1 times
6pa0 has occurred 1 times
4qnp has occurred 2 times
7ru5 has occurred 2 times
6pe8 has occurred 2 times
5xbm has occurred 2 times
4tnv has occurred 10 times
6b0s has occurred 1 times
6utk has occurred 2 times
3bsz has occurred 2 times
5fgb has occurred 2 times
7ugn has occurred 6 times
7ucy has occurred 2 times
6x1t has occurred 6 times
6mu8 has occurred 2 times
7vad has occurred 1 times
6iea has occurred 1 times
7vih has occurred 1 times
6ppg has occurred 2 times
6vn1 has occurred 3 times
1oak has occurred 1 times
1pkq has occurred 2 times
4yc2 has occurred 2 times
6k68 has occurred 4 times
7sjo has occurred 3 times
1v7m has occurred 2 times
3nig has occurred 2 times
4xpf has occurred 1 times
3csy has occurred 4 times
5tlk has occurred 4 times
3ehb has occurred 1 times
4ojf has occurred 1 times
1i9r has oc

1660

In [8]:
train, val, test = process(data=dict_data)

Lcdr disalignment: 7dc8 at position 37
Lcdr disalignment: 7v05 at position 183
Hcdr disalignment: 3csy at position 229
Lcdr disalignment: 7cec at position 253
Lcdr disalignment: 7cn2 at position 285
Hcdr disalignment: 2znx at position 357
Lcdr disalignment: 2znx at position 357
Lcdr disalignment: 7tuf at position 400
Lcdr disalignment: 6mph at position 475
Hcdr disalignment: 7t9b at position 484
Lcdr disalignment: 5c7x at position 492
Lcdr disalignment: 4lvn at position 556
Hcdr disalignment: 4r4n at position 598
Lcdr disalignment: 4tnv at position 603
Lcdr disalignment: 3mac at position 671
Hcdr disalignment: 2znx at position 788
Lcdr disalignment: 2znx at position 788
Hcdr disalignment: 6vmk at position 798
Lcdr disalignment: 6vmk at position 798
Lcdr disalignment: 7cn2 at position 918
Hcdr disalignment: 3u4e at position 962
Hcdr disalignment: 7t74 at position 1020
Lcdr disalignment: 7cn2 at position 1066
Hcdr disalignment: 6t9e at position 1069
Lcdr disalignment: 6t9e at position 10

In [48]:
def get_mask(seq, substrs):
    # seq: [Hseq]
    # substrs: [H1, H2, H3]
    # return [mask_H1, mask_H2, mask_H3]
    mask = [0] * len(seq)
    m = 1
    span = []
    for substr in substrs:
        start = seq.find(substr)
        end = start+len(substr)
        span.append((start, end))
        for idx in range(len(mask)):
            if idx>=start and idx<end:
                mask[idx] = m
        m += 1
    
    return mask, span

for i in range(len(dict_data)):
    if dict_data[i]["pdb"]=="7dc8":
        break

Hseq = dict_data[i]["Hseq"][0]
Lseq = dict_data[i]["Lseq"][0]

H1 = dict_data[i]["H1"]
H2 = dict_data[i]["H2"]
H3 = dict_data[i]["H3"]
L1 = dict_data[i]["L1"]
L2 = dict_data[i]["L2"]
L3 = dict_data[i]["L3"]

lm = get_mask(seq=Lseq, substrs=[L1, L2, L3])

print("Lseq ")
print(Lseq)
print()
print("L1, L2, L3")
print(L1, L2, L3)

"".join(list(map(str, lm[0])))

Lseq 
QSALTQPPSASGSPGQTVTISCTGTSTDVGDYAYVSWYQQHPGKAPKLMIYYVSKKPDGVPDRFSGSKSGNTASLTVSGLQAEDEADYFCSLRSPGPYPLFGGGTKLTVLGQPKAAPSVTLFPPSSEELQANKATLVCLISDFYPGAVTVAWKADSSPVKAGVETTTPSKQSNNKYAASSYLSLTPEQWKSHRSYSCQVTHEGSTVEKTVAPTECS

L1, L2, L3
STDVGDYAY YVS SLRSPGPYPL


'000000000000000000000000011111111222000000000000000000000000000000000000000000000000000000333333333300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'

In [49]:
len("QSALTQPPSASGSPGQTVTISCTGT"), len("0000000000000000000000000")

(25, 25)

In [18]:
for i in range(len(dict_data)):
    if dict_data[i]["H1"]=="GFTFSSYD" and dict_data[i]["H2"]=="ISSGGSYS" and dict_data[i]["H3"]=="ARQGDYAWFAY" and \
       dict_data[i]["L1"]=="QTIGTW" and dict_data[i]["L2"]=="AAT" and dict_data[i]["L3"]=="QQFYSTPFT":
        print(i, dict_data[i]["pdb"], dict_data[i]["Hchain"], dict_data[i]["Lchain"], dict_data[i]["Achain"])
    if "/".join([dict_data[i]["H1"], dict_data[i]["H2"], dict_data[i]["H3"]]) == 'GFTFSSYD/ISSGGSYS/ARQGDYAWFAY/QTIGTW/AAT/QQFYSTPFT':
        print(i, dict_data[i]["pdb"], dict_data[i]["Hchain"], dict_data[i]["Lchain"], dict_data[i]["Achain"])

2154 6tys C D B
3740 6tys F G E
4842 6tys H L A


In [10]:
train[0].keys()

dict_keys(['X', 'S', 'mask'])

In [11]:
train[0]["mask"]

[[0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  2,
  2,
  2,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,

In [12]:
train[0]["S"]

'GFTFSSYD/ISSGGSYS/ARQGDYAWFAY/QTIGTW/AAT/QQFYSTPFT'

In [13]:
train[0]["X"]

[[array([[155.544, 215.725, 168.924],
         [156.595, 215.841, 167.886],
         [157.611, 214.694, 167.795],
         [158.506, 214.707, 166.948]], dtype=float32),
  array([[157.477, 213.71 , 168.642],
         [158.381, 212.581, 168.667],
         [157.612, 211.439, 169.248],
         [156.508, 211.643, 169.756]], dtype=float32),
  array([[158.146, 210.235, 169.191],
         [157.377, 209.186, 169.789],
         [157.718, 209.046, 171.237],
         [158.807, 208.639, 171.635]], dtype=float32),
  array([[156.719, 209.334, 172.038],
         [156.797, 209.365, 173.477],
         [157.249, 208.032, 173.982],
         [158.147, 207.93 , 174.814]], dtype=float32),
  array([[156.658, 206.995, 173.433],
         [156.86 , 205.633, 173.868],
         [158.278, 205.088, 173.635],
         [158.587, 203.954, 174.014]], dtype=float32),
  array([[159.111, 205.835, 172.898],
         [160.477, 205.419, 172.615],
         [161.548, 206.009, 173.548],
         [162.717, 205.641, 173.42 ]], dt