In [1]:
import os
import json
import copy
import pickle
import sys
import torch
sys.path.insert(0,'/orcd/home/002/swans/relE_paper/relE_binder_design/TERMinator_sscore')
from terminator.utils.model.default_hparams import DEFAULT_MODEL_HPARAMS, DEFAULT_TRAIN_HPARAMS
from terminator.models.TERMinator import TERMinator

# import importlib
# importlib.reload(TERMinator)

dev = 'cpu'


def load_hparams(model_path, default_hparams, output_name):
    print("loading params")
    # load hparams
    hparams_path = os.path.join(model_path, output_name)
    hparams = json.load(open(hparams_path, 'r'))
    for hparam in default_hparams:
        if hparam not in hparams:
            # print(f"{hparam} = {default_hparams[hparam]}")
            hparams[hparam] = default_hparams[hparam]
    return hparams

# Load the model configuration parameters
model_dir = '/orcd/pool/008/swans/terminator_stuff/priority/230310_finetuneCOORDinator_sscore_mpnodeupdate'
model_hparams = load_hparams(model_dir, DEFAULT_MODEL_HPARAMS, 'model_hparams.json')
run_hparams = load_hparams(model_dir, DEFAULT_TRAIN_HPARAMS, 'run_hparams.json')

# backwards compatibility
if "cov_features" not in model_hparams.keys():
    model_hparams["cov_features"] = False
if "term_use_mpnn" not in model_hparams.keys():
    model_hparams["term_use_mpnn"] = False
if "matches" not in model_hparams.keys():
    model_hparams["matches"] = "resnet"
if "struct2seq_linear" not in model_hparams.keys():
    model_hparams['struct2seq_linear'] = False
if "energies_gvp" not in model_hparams.keys():
    model_hparams['energies_gvp'] = False
if "num_sing_stats" not in model_hparams.keys():
    model_hparams['num_sing_stats'] = 0
if "num_pair_stats" not in model_hparams.keys():
    model_hparams['num_pair_stats'] = 0
if "contact_idx" not in model_hparams.keys():
    model_hparams['contact_idx'] = False
if "fe_dropout" not in model_hparams.keys():
    model_hparams['fe_dropout'] = 0.1
if "fe_max_len" not in model_hparams.keys():
    model_hparams['fe_max_len'] = 1000
if "cie_dropout" not in model_hparams.keys():
    model_hparams['cie_dropout'] = 0.1

if "num_ensembles" in run_hparams.keys():
    model_hparams['num_ensembles'] = run_hparams['num_ensembles']
else:
    run_hparams['num_ensembles'] = 1
    model_hparams['num_ensembles'] = 1

if "use_flex" not in model_hparams.keys():
    model_hparams["use_flex"] = False
    model_hparams["flex_type"] = ""

print(model_hparams)

# Initialize the model
terminator = TERMinator(hparams=model_hparams, device=dev)

# Load weights from the best checkpoint during training
best_checkpoint_state = torch.load(os.path.join(model_dir, 'net_best_checkpoint.pt'), map_location=dev)
best_checkpoint = best_checkpoint_state['state_dict']
terminator.load_state_dict(best_checkpoint)
terminator.to(dev)
terminator.eval()
torch.set_grad_enabled(False)

  from pkg_resources import packaging  # type: ignore[attr-defined]
  from .autonotebook import tqdm as notebook_tqdm


loading params
loading params
{'cov_features': 'all_raw', 'cov_compress': 'project', 'term_use_mpnn': True, 'matches': 'transformer', 'energies_use_mpnn': True, 'energies_full_graph': True, 'contact_idx': True, 'energies_encoder_layers': 3, 'energies_hidden_dim': 128, 'resnet_linear': True, 'matches_linear': True, 'transformer_linear': True, 'term_mpnn_linear': True, 'use_terms': False, 'res_embed_linear': True, 'sscore_from_embedding': 'node', 'sscore_module': 'aggnode', 'model': 'multichain', 'term_hidden_dim': 32, 'flex_hidden_dim': 1, 'flex_type': '', 'gradient_checkpointing': True, 'num_pair_stats': 28, 'num_sing_stats': 0, 'resnet_blocks': 4, 'term_layers': 4, 'flex_layers': 4, 'term_heads': 4, 'conv_filter': 3, 'matches_layers': 4, 'matches_num_heads': 4, 'k_neighbors': 30, 'k_cutoff': None, 'cie_dropout': 0.1, 'cie_scaling': 500, 'cie_offset': 0, 'transformer_dropout': 0.1, 'energies_protein_features': 'full', 'energies_augment_eps': 0, 'energies_dropout': 0.1, 'energies_output

<torch.autograd.grad_mode.set_grad_enabled at 0x2aaab6368070>

# Pick one from complex features

In [2]:
with open('02_scoreSeeds_4/packaged_binder_complex_data_0.pkl', 'rb') as file:
    complex_data = pickle.load(file)
complex_data.keys()

dict_keys(['msas', 'features', 'ppoe', 'seq_lens', 'focuses', 'contact_idxs', 'src_key_mask', 'term_lens', 'X', 'x_mask', 'seqs', 'ids', 'chain_idx', 'gvp_data', 'sortcery_seqs', 'sortcery_nrgs', 'sscore', 'flex', 'flex_mask', 'chain_lens', 'res_info', 'binder_chain_id'])

In [3]:
complex_data['ids']

['seed-4FXE-ARG81-relax-noHyd-48-66__0_connector-9353-311-12_seed-4FXE-ARG81-relax-noHyd_E_B_E7_19558_264_255_connector-659-302-7_seed-4FXE-ARG81-relax-noHyd_E_B_E27_18950_417_380']

In [4]:
complex_data['seq_lens']

tensor([117])

In [5]:
etab, E_idx, sscore = terminator.forward(complex_data,119) # 2 more to match... 
print('etab: ',etab.shape)
print('E_idx:',E_idx.shape)
print('sscore: ',sscore.shape)

graph_features.py,V[0] tensor(-166.5430)
graph_features.py,E[0] tensor(20691.4746)
graph_features.py,norm_nodes,V[0] tensor(-2.0591)
graph_features.py,norm_edges,E[0] tensor(2654.4138)
graph_features.py,E_idx tensor(196757)
s2s.py,features tensor(2654.4138)
s2s.py,h_E tensor(3585.4736)
s2s.py,h_E tensor(43841.7500)
s2s.py,h_E tensor(17.0627)
s2s.py,h_E 2 tensor(17.0627)
s2s.py,h_E 3 tensor(-7095.2451)
etab:  torch.Size([1, 117, 30, 400])
E_idx: torch.Size([1, 117, 30])
sscore:  torch.Size([1, 117])


In [6]:
etab[0].sum()

tensor(-7095.2451)

In [7]:
sscore.sum()

tensor(-41.4683)

In [8]:
sscore

tensor([[ 0.5538,  0.4566, -0.2898,  0.5131, -1.7901, -1.0002,  0.2060, -0.4437,
         -1.4581, -1.4870, -0.9295, -1.4690, -1.6663, -0.9656, -0.8693, -0.1027,
          0.0870,  0.3496, -1.0537, -1.2963, -1.3552, -0.9164, -2.0327, -2.3308,
         -1.6737, -1.8763, -1.7431, -1.6423, -1.6521, -1.1342, -1.3505, -1.3356,
         -0.1254,  0.2404,  2.3544,  0.9284,  2.0889,  1.7822,  0.3518,  0.2076,
          1.0170,  0.0795, -0.0313,  1.1173,  0.5626,  0.1428,  0.4099,  1.4403,
         -0.8369, -0.7473, -1.2389, -0.7886,  0.0877,  0.2057,  0.9700,  0.8374,
          1.6231,  1.2166,  0.3529, -0.8243, -1.3245, -1.6623, -1.2803, -1.2738,
         -0.9465,  0.0219,  0.8731,  0.4879,  0.5666,  1.8051, -0.7221, -0.9014,
         -1.6169, -1.4536, -1.1342, -1.1532, -1.1610,  1.5692,  0.7064,  1.0152,
          2.1404,  1.6110,  1.4763, -1.3345, -1.2082, -1.3633, -1.5493, -1.8237,
         -2.0297, -1.6370, -1.8856, -2.0903, -2.1478, -1.7402, -1.9713, -2.1510,
         -2.2693, -1.7532, -

# Pick one from seed features

In [5]:
with open('02_scoreSeeds_2/packaged_binder_complex_data_0.pkl', 'rb') as file:
    seed_data = pickle.load(file)
seed_data.keys()

dict_keys(['msas', 'features', 'ppoe', 'seq_lens', 'focuses', 'contact_idxs', 'src_key_mask', 'term_lens', 'X', 'x_mask', 'seqs', 'ids', 'chain_idx', 'gvp_data', 'sortcery_seqs', 'sortcery_nrgs', 'sscore', 'flex', 'flex_mask', 'chain_lens', 'res_info', 'binder_chain_id'])

In [10]:
seed_data['ids']

['4FXE-ARG81-relax-noHyd_E__seed-4FXE-ARG81-relax-noHyd-48-66__0_connector-9353-311-12_seed-4FXE-ARG81-relax-noHyd_E_B_E7_19558_264_255_connector-659-302-7_seed-4FXE-ARG81-relax-noHyd_E_B_E27_18950_417_380',
 '4FXE-ARG81-relax-noHyd_E__seed-4FXE-ARG81-relax-noHyd_E_B_E41_12262_133_404_connector-13667-708-8_seed-4FXE-ARG81-relax-noHyd-48-66__0_connector-9218-393-12_seed-4FXE-ARG81-relax-noHyd_E_B_E5_18353_195_247',
 '4FXE-ARG81-relax-noHyd_E__seed-4FXE-ARG81-relax-noHyd_E_B_E41_12262_133_404_connector-14590-287-8_seed-4FXE-ARG81-relax-noHyd-48-66__0_connector-9218-393-12_seed-4FXE-ARG81-relax-noHyd_E_B_E5_18353_195_247']

In [11]:
seed_data['seq_lens']

tensor([117, 119, 119])

In [12]:
seed_data.keys()

dict_keys(['msas', 'features', 'ppoe', 'seq_lens', 'focuses', 'contact_idxs', 'src_key_mask', 'term_lens', 'X', 'x_mask', 'seqs', 'ids', 'chain_idx', 'gvp_data', 'sortcery_seqs', 'sortcery_nrgs', 'sscore', 'flex', 'flex_mask', 'chain_lens', 'res_info', 'binder_chain_id'])

In [13]:
seed_data['X'].shape

torch.Size([3, 119, 4, 3])

In [14]:
etab, E_idx, sscore = terminator.forward(seed_data,119)
print('etab: ',etab.shape)
print('E_idx:',E_idx.shape)
print('sscore: ',sscore.shape)

graph_features.py,V[0] tensor(239.2742)
graph_features.py,E[0] tensor(20728.9219)
graph_features.py,norm_nodes,V[0] tensor(252.9310)
graph_features.py,norm_edges,E[0] tensor(2516.5020)
graph_features.py,E_idx tensor(201227)
s2s.py,features tensor(2516.5020)
s2s.py,h_E tensor(3080.1035)
s2s.py,h_E tensor(37908.4375)
s2s.py,h_E tensor(61.8003)
s2s.py,h_E 2 tensor(61.8003)
s2s.py,h_E 3 tensor(-14266.7539)
etab:  torch.Size([3, 119, 30, 400])
E_idx: torch.Size([3, 119, 30])
sscore:  torch.Size([3, 119])


In [15]:
etab[0].sum()

tensor(-14266.7539)

In [16]:
etab[0][:,:117].sum()

tensor(-14266.7539)

In [17]:
sscore[0].sum()

tensor(-27.7327)

In [18]:
sscore[0]

tensor([ 1.5394,  1.1915,  0.8377, -0.1450, -0.9447, -0.7021, -0.1820,  0.0126,
        -1.3416, -1.3933, -0.5332, -1.5383, -1.3804, -1.1283, -0.3320, -0.3744,
         0.0053, -0.2845, -0.7209, -1.2913, -1.4419, -1.4303, -1.7831, -1.8484,
        -2.1231, -1.6720, -1.9486, -1.8690, -2.0439, -1.8172, -1.9183, -1.8367,
        -1.1584,  0.1013,  0.5450,  0.1707,  0.6443,  0.6067, -0.6841,  0.7227,
         0.7128,  0.5847,  0.6119,  0.4468,  1.4508,  0.7115,  1.0253,  0.9272,
        -0.3332, -0.8500, -0.9390, -0.3843, -0.3582,  0.4714,  1.0243,  0.2834,
         1.3332,  0.8829,  0.2884, -0.9039, -1.2832, -1.1332, -1.4346, -0.8863,
        -0.3604,  0.3048,  0.3585,  0.2698,  0.4793,  0.6587, -0.8517, -0.9693,
        -1.9786, -0.9806, -0.9232, -0.5740,  0.1119,  0.3913,  0.9269,  0.9262,
         1.6362,  2.1416,  0.8424,  0.6694, -0.5234, -0.7421, -1.0333, -1.2932,
        -1.1789, -1.1357, -1.7448, -1.8785, -1.7027, -1.5194, -1.7861, -1.5157,
        -1.3950, -1.5974, -0.8301, -1.12

# Reshape the seed input so it is like the complex input

In [6]:
seed_data['X'][0].shape

torch.Size([119, 4, 3])

In [7]:
seed_data['X'].shape

torch.Size([3, 119, 4, 3])

In [8]:
seed_data['ids']

['4FXE-ARG81-relax-noHyd_E__seed-4FXE-ARG81-relax-noHyd-48-66__0_connector-9353-311-12_seed-4FXE-ARG81-relax-noHyd_E_B_E7_19558_264_255_connector-659-302-7_seed-4FXE-ARG81-relax-noHyd_E_B_E27_18950_417_380',
 '4FXE-ARG81-relax-noHyd_E__seed-4FXE-ARG81-relax-noHyd_E_B_E41_12262_133_404_connector-13667-708-8_seed-4FXE-ARG81-relax-noHyd-48-66__0_connector-9218-393-12_seed-4FXE-ARG81-relax-noHyd_E_B_E5_18353_195_247',
 '4FXE-ARG81-relax-noHyd_E__seed-4FXE-ARG81-relax-noHyd_E_B_E41_12262_133_404_connector-14590-287-8_seed-4FXE-ARG81-relax-noHyd-48-66__0_connector-9218-393-12_seed-4FXE-ARG81-relax-noHyd_E_B_E5_18353_195_247']

In [9]:
[seed_data['ids'][0]]

['4FXE-ARG81-relax-noHyd_E__seed-4FXE-ARG81-relax-noHyd-48-66__0_connector-9353-311-12_seed-4FXE-ARG81-relax-noHyd_E_B_E7_19558_264_255_connector-659-302-7_seed-4FXE-ARG81-relax-noHyd_E_B_E27_18950_417_380']

In [10]:
def repack_features(x,idx):#,max_len):
    feat = dict()
    for k in x.keys():
        # print(k)
        if x[k] == None:
            feat[k] = None
        elif x[k] == []:
            feat[k] = []
        elif isinstance(x[k],list):
            feat[k] = [x[k][0]]
        else:
            feat[k] = x[k][[0]]
    return feat

seed_data_0 = repack_features(seed_data,0)#,117)
# seed_data_0

In [11]:
seed_data['x_mask'][0,-2:]

tensor([0., 0.])

In [12]:
seed_data['X'][0,-2:]

tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])

In [13]:
seed_data['X'].shape

torch.Size([3, 119, 4, 3])

In [14]:
seed_data['X'][[0]].shape

torch.Size([1, 119, 4, 3])

In [15]:
seed_data_0['X'].shape

torch.Size([1, 119, 4, 3])

In [25]:
seed_data_0_manual = copy.deepcopy(seed_data_0)
seed_data_0_manual['X'] = seed_data_0_manual['X'][:,:117]
seed_data_0_manual['msas'] = seed_data_0_manual['msas'][:,:117]
seed_data_0_manual['ppoe'] = seed_data_0_manual['ppoe'][:,:117]
seed_data_0_manual['focuses'] = seed_data_0_manual['focuses'][:,:117]
seed_data_0_manual['contact_idxs'] = seed_data_0_manual['contact_idxs'][:,:117]
seed_data_0_manual['src_key_mask'] = seed_data_0_manual['src_key_mask'][:,:117]
seed_data_0_manual['term_lens'] = seed_data_0_manual['term_lens'][:,:117]
seed_data_0_manual['x_mask'] = seed_data_0_manual['x_mask'][:,:117]
seed_data_0_manual['seqs'] = seed_data_0_manual['seqs'][:,:117]
seed_data_0_manual['chain_idx'] = seed_data_0_manual['chain_idx'][:,:117]
seed_data_0_manual.keys()

dict_keys(['msas', 'features', 'ppoe', 'seq_lens', 'focuses', 'contact_idxs', 'src_key_mask', 'term_lens', 'X', 'x_mask', 'seqs', 'ids', 'chain_idx', 'gvp_data', 'sortcery_seqs', 'sortcery_nrgs', 'sscore', 'flex', 'flex_mask', 'chain_lens', 'res_info', 'binder_chain_id'])

In [17]:
seed_data_0_manual['seq_lens']

tensor([117])

In [18]:
len(seed_data_0_manual['X'][0])

117

In [19]:
len(seed_data_0_manual['res_info'][0])

117

In [20]:
seed_data['X'][0].sum()

tensor(-11417.0859)

In [21]:
seed_data_0['X'][0].sum()

tensor(-11417.0859)

In [22]:
etab, E_idx, sscore = terminator.forward(seed_data,117)
print('etab: ',etab.shape)
print('E_idx:',E_idx.shape)
print('sscore: ',sscore.shape)

dihedrals,dX tensor(-6.3530)
dihedrals,U tensor(-23.3408)
dihedrals,D tensor(-0.8880)
graph_features.py,V[0] tensor(239.2742)
graph_features.py,V[0] tensor(231.0600)
graph_features.py,E[0] tensor(20728.9219)
graph_features.py,E[0] tensor(20213.9570)
graph_features.py,norm_nodes,V[0] tensor(252.9310)
graph_features.py,norm_nodes,V[0] tensor(245.9031)
graph_features.py,norm_edges,E[0] tensor(2516.5020)
graph_features.py,norm_edges,E[0] tensor(2336.1362)
graph_features.py,E_idx tensor(201227)
s2s.py,features tensor(2516.5020)
s2s.py,h_E tensor(3080.1035)
s2s.py,h_E tensor(37908.4375)
s2s.py,h_E tensor(61.8003)
s2s.py,h_E 2 tensor(61.8003)
s2s.py,h_E 3 tensor(-14266.7539)
etab:  torch.Size([3, 119, 30, 400])
E_idx: torch.Size([3, 119, 30])
sscore:  torch.Size([3, 119])


In [23]:
etab, E_idx, sscore = terminator.forward(seed_data_0,117)
print('etab: ',etab.shape)
print('E_idx:',E_idx.shape)
print('sscore: ',sscore.shape)

dihedrals,dX tensor(-6.3530)
dihedrals,U tensor(-23.3408)
dihedrals,D tensor(-55.0380)
graph_features.py,V[0] tensor(-162.2719)
graph_features.py,V[0] tensor(-167.8247)
graph_features.py,E[0] tensor(20899.9551)
graph_features.py,E[0] tensor(20384.9883)
graph_features.py,norm_nodes,V[0] tensor(4.7151)
graph_features.py,norm_nodes,V[0] tensor(-1.5670)
graph_features.py,norm_edges,E[0] tensor(2620.5737)
graph_features.py,norm_edges,E[0] tensor(2440.2107)
graph_features.py,E_idx tensor(201227)
s2s.py,features tensor(2620.5737)
s2s.py,h_E tensor(3499.0522)
s2s.py,h_E tensor(43833.2969)
s2s.py,h_E tensor(2.8068)
s2s.py,h_E 2 tensor(2.8068)
s2s.py,h_E 3 tensor(-7018.6328)
etab:  torch.Size([1, 119, 30, 400])
E_idx: torch.Size([1, 119, 30])
sscore:  torch.Size([1, 119])


In [24]:
etab, E_idx, sscore = terminator.forward(seed_data_0_manual,117)
print('etab: ',etab.shape)
print('E_idx:',E_idx.shape)
print('sscore: ',sscore.shape)

dihedrals,dX tensor(-27.3240)
dihedrals,U tensor(-24.0776)
dihedrals,D tensor(-53.7111)
graph_features.py,V[0] tensor(-166.5430)
graph_features.py,V[0] tensor(-166.2110)
graph_features.py,E[0] tensor(20691.4746)
graph_features.py,E[0] tensor(20224.8848)
graph_features.py,norm_nodes,V[0] tensor(-2.0591)
graph_features.py,norm_nodes,V[0] tensor(-6.3345)
graph_features.py,norm_edges,E[0] tensor(2654.4138)
graph_features.py,norm_edges,E[0] tensor(2483.7185)
graph_features.py,E_idx tensor(196757)
s2s.py,features tensor(2654.4138)
s2s.py,h_E tensor(3585.4736)
s2s.py,h_E tensor(43841.7500)
s2s.py,h_E tensor(17.0627)
s2s.py,h_E 2 tensor(17.0627)
s2s.py,h_E 3 tensor(-7095.2451)
etab:  torch.Size([1, 117, 30, 400])
E_idx: torch.Size([1, 117, 30])
sscore:  torch.Size([1, 117])


In [18]:
E_idx[:,:-2].sum()

tensor(196757)

In [19]:
etab.sum()

tensor(-7018.6328)

In [20]:
sscore.sum()

tensor(-43.9246)

In [21]:
sscore[:,:-2].sum()

tensor(-43.8393)

In [22]:
sscore.shape

torch.Size([1, 119])

In [23]:
etab, E_idx, sscore = terminator.forward(seed_data_0_manual,117)
print('etab: ',etab.shape)
print('E_idx:',E_idx.shape)
print('sscore: ',sscore.shape)

graph_features.py,V[0] tensor(-166.5430)
graph_features.py,V[0] tensor(-166.2110)
graph_features.py,E[0] tensor(20691.4746)
graph_features.py,E[0] tensor(20224.8848)
graph_features.py,norm_nodes,V[0] tensor(-2.0591)
graph_features.py,norm_nodes,V[0] tensor(-6.3345)
graph_features.py,norm_edges,E[0] tensor(2654.4138)
graph_features.py,norm_edges,E[0] tensor(2483.7185)
graph_features.py,E_idx tensor(196757)
s2s.py,features tensor(2654.4138)
s2s.py,h_E tensor(3585.4736)
s2s.py,h_E tensor(43841.7500)
s2s.py,h_E tensor(17.0627)
s2s.py,h_E 2 tensor(17.0627)
s2s.py,h_E 3 tensor(-7095.2451)
etab:  torch.Size([1, 117, 30, 400])
E_idx: torch.Size([1, 117, 30])
sscore:  torch.Size([1, 117])


In [24]:
etab[0].sum()

tensor(-7095.2451)

In [35]:
sscore.sum()

tensor(-41.4683)

In [36]:
sscore.shape

torch.Size([1, 117])

Okay, so the difference is in the batching

# Reshape to 2 equal length complexes

In [48]:
def repack_features(x):#,max_len):
    feat = dict()
    for k in x.keys():
        # print(k)
        if x[k] == None:
            feat[k] = None
        elif x[k] == []:
            feat[k] = []
        elif isinstance(x[k],list):
            feat[k] = [x[k][1:]]
        else:
            feat[k] = x[k][1:]
    return feat

seed_data_12 = repack_features(seed_data)#,117)
# seed_data_0

In [49]:
etab, E_idx, sscore = terminator.forward(seed_data_12,119)
print('etab: ',etab.shape)
print('E_idx:',E_idx.shape)
print('sscore: ',sscore.shape)

graph_features.py,V[0] tensor(-167.1116)
graph_features.py,V[0] tensor(-166.9531)
graph_features.py,E[0] tensor(21104.3047)
graph_features.py,E[0] tensor(20639.7148)
graph_features.py,norm_nodes,V[0] tensor(-5.9293)
graph_features.py,norm_nodes,V[0] tensor(-10.3853)
graph_features.py,norm_edges,E[0] tensor(2971.6599)
graph_features.py,norm_edges,E[0] tensor(2752.7664)
graph_features.py,E_idx tensor(203097)
s2s.py,features tensor(2971.6599)
s2s.py,h_E tensor(3685.5835)
s2s.py,h_E tensor(44640.2422)
s2s.py,h_E tensor(-13.0487)
s2s.py,h_E 2 tensor(-13.0487)
s2s.py,h_E 3 tensor(-7049.0918)
etab:  torch.Size([2, 119, 30, 400])
E_idx: torch.Size([2, 119, 30])
sscore:  torch.Size([2, 119])


In [50]:
sscore[0].sum()

tensor(-30.0201)

In [51]:
def repack_features(x):#,max_len):
    feat = dict()
    for k in x.keys():
        # print(k)
        if x[k] == None:
            feat[k] = None
        elif x[k] == []:
            feat[k] = []
        elif isinstance(x[k],list):
            feat[k] = [x[k][1]]
        else:
            feat[k] = x[k][[1]]
    return feat

seed_data_1 = repack_features(seed_data)#,117)

In [53]:
etab, E_idx, sscore = terminator.forward(seed_data_1,119)
print('etab: ',etab.shape)
print('E_idx:',E_idx.shape)
print('sscore: ',sscore.shape)

graph_features.py,V[0] tensor(-167.1116)
graph_features.py,V[0] tensor(-166.9531)
graph_features.py,E[0] tensor(21104.3047)
graph_features.py,E[0] tensor(20639.7148)
graph_features.py,norm_nodes,V[0] tensor(-5.9293)
graph_features.py,norm_nodes,V[0] tensor(-10.3853)
graph_features.py,norm_edges,E[0] tensor(2971.6599)
graph_features.py,norm_edges,E[0] tensor(2752.7664)
graph_features.py,E_idx tensor(203097)
s2s.py,features tensor(2971.6599)
s2s.py,h_E tensor(3685.5835)
s2s.py,h_E tensor(44640.2422)
s2s.py,h_E tensor(-13.0487)
s2s.py,h_E 2 tensor(-13.0487)
s2s.py,h_E 3 tensor(-7049.0918)
etab:  torch.Size([1, 119, 30, 400])
E_idx: torch.Size([1, 119, 30])
sscore:  torch.Size([1, 119])


In [54]:
sscore[0].sum()

tensor(-30.0201)

# Are the inputs really the same?

In [148]:
def compare_out(a,b):
    for k in a:
        try:
            if isinstance(a[k],torch.Tensor):
                if not torch.equal(a[k],b[k]):
                    print(k,'is not equal')
                else:
                    print(k,'is equal')
            else:
                if not a[k] == b[k]:
                    print(k,'is not equal!!!')
                else:
                    print(k,'is equal')
        except Exception as E:
            print('woops',k,E)
            if a[k] == b:
                print('but they are the same')
compare_out(complex_data,seed_data_0_manual)

msas is equal
features is equal
ppoe is equal
seq_lens is equal
focuses is equal
contact_idxs is equal
src_key_mask is equal
term_lens is equal
X is equal
x_mask is equal
seqs is equal
ids is not equal!!!
chain_idx is equal
gvp_data is not equal!!!
sortcery_seqs is equal
sortcery_nrgs is equal
sscore is equal
flex is equal
flex_mask is equal
chain_lens is equal
res_info is equal
binder_chain_id is equal


In [149]:
complex_data['src_key_mask']

tensor([[False]])

In [175]:
seed_data_0_manual['src_key_mask']

tensor([[False]])

In [176]:
complex_data['x_mask']

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [177]:
seed_data_0['x_mask']

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.]])

In [180]:
torch.equal(complex_data['X'],seed_data_0_manual['X'])

True

In [146]:
complex_data

{'msas': tensor([[[0]]]),
 'features': tensor([[[[0.]]]]),
 'ppoe': tensor([[[[0.]]]]),
 'seq_lens': tensor([117]),
 'focuses': tensor([[[0.]]], dtype=torch.float64),
 'contact_idxs': tensor([[[0.]]], dtype=torch.float64),
 'src_key_mask': tensor([[False]]),
 'term_lens': tensor([[1, 0]]),
 'X': tensor([[[[ -8.7290,  -3.6990,  18.7810],
           [ -9.3140,  -4.6870,  17.8840],
           [-10.5750,  -5.3110,  18.4730],
           [-10.7480,  -5.3630,  19.6920]],
 
          [[-11.4510,  -5.7860,  17.5930],
           [-12.6890,  -6.4290,  18.0000],
           [-12.3770,  -7.8010,  18.5760],
           [-11.3670,  -8.4160,  18.2290]],
 
          [[-13.2370,  -8.2860,  19.4620],
           [-13.1010,  -9.6460,  19.9590],
           [-13.3660, -10.6370,  18.8310],
           [-14.1670, -10.3710,  17.9380]],
 
          ...,
 
          [[ -6.9930, -19.4860,  13.3170],
           [ -6.4550, -20.7780,  12.9310],
           [ -7.0820, -21.9090,  13.8140],
           [ -7.4180, -22.9660,  

# Try to fix with batching

In [138]:
with open('02_scoreSeeds_2/packaged_binder_complex_data_0.pkl', 'rb') as file:
    seed_data = pickle.load(file)
seed_data.keys()

dict_keys(['msas', 'features', 'ppoe', 'seq_lens', 'focuses', 'contact_idxs', 'src_key_mask', 'term_lens', 'X', 'x_mask', 'seqs', 'ids', 'chain_idx', 'gvp_data', 'sortcery_seqs', 'sortcery_nrgs', 'sscore', 'flex', 'flex_mask', 'chain_lens', 'res_info', 'binder_chain_id'])

In [139]:
seed_data['ids']

['4FXE-ARG81-relax-noHyd_E__seed-4FXE-ARG81-relax-noHyd-48-66__0_connector-9353-311-12_seed-4FXE-ARG81-relax-noHyd_E_B_E7_19558_264_255_connector-659-302-7_seed-4FXE-ARG81-relax-noHyd_E_B_E27_18950_417_380',
 '4FXE-ARG81-relax-noHyd_E__seed-4FXE-ARG81-relax-noHyd_E_B_E41_12262_133_404_connector-13667-708-8_seed-4FXE-ARG81-relax-noHyd-48-66__0_connector-9218-393-12_seed-4FXE-ARG81-relax-noHyd_E_B_E5_18353_195_247',
 '4FXE-ARG81-relax-noHyd_E__seed-4FXE-ARG81-relax-noHyd_E_B_E41_12262_133_404_connector-14590-287-8_seed-4FXE-ARG81-relax-noHyd-48-66__0_connector-9218-393-12_seed-4FXE-ARG81-relax-noHyd_E_B_E5_18353_195_247']

In [140]:
seed_data['seq_lens']

tensor([117, 119, 119])

In [141]:
# suggested fix 1
seed_data['scatter_idx'] = torch.arange(len(seed_data['seq_lens']))
etab, E_idx, sscore = terminator.forward(seed_data,119)
print('etab: ',etab.shape)
print('E_idx:',E_idx.shape)
print('sscore: ',sscore.shape)

etab:  torch.Size([3, 119, 30, 400])
E_idx: torch.Size([3, 119, 30])
sscore:  torch.Size([3, 119])


In [142]:
etab[0].sum()

tensor(-14266.7539)

In [143]:
sscore[0].sum()

tensor(-27.7327)

In [144]:
sscore[0]

tensor([ 1.5394,  1.1915,  0.8377, -0.1450, -0.9447, -0.7021, -0.1820,  0.0126,
        -1.3416, -1.3933, -0.5332, -1.5383, -1.3804, -1.1283, -0.3320, -0.3744,
         0.0053, -0.2845, -0.7209, -1.2913, -1.4419, -1.4303, -1.7831, -1.8484,
        -2.1231, -1.6720, -1.9486, -1.8690, -2.0439, -1.8172, -1.9183, -1.8367,
        -1.1584,  0.1013,  0.5450,  0.1707,  0.6443,  0.6067, -0.6841,  0.7227,
         0.7128,  0.5847,  0.6119,  0.4468,  1.4508,  0.7115,  1.0253,  0.9272,
        -0.3332, -0.8500, -0.9390, -0.3843, -0.3582,  0.4714,  1.0243,  0.2834,
         1.3332,  0.8829,  0.2884, -0.9039, -1.2832, -1.1332, -1.4346, -0.8863,
        -0.3604,  0.3048,  0.3585,  0.2698,  0.4793,  0.6587, -0.8517, -0.9693,
        -1.9786, -0.9806, -0.9232, -0.5740,  0.1119,  0.3913,  0.9269,  0.9262,
         1.6362,  2.1416,  0.8424,  0.6694, -0.5234, -0.7421, -1.0333, -1.2932,
        -1.1789, -1.1357, -1.7448, -1.8785, -1.7027, -1.5194, -1.7861, -1.5157,
        -1.3950, -1.5974, -0.8301, -1.12