In [3]:
%load_ext autoreload
%autoreload 2

In [142]:
import json, time, os, sys, glob
import shutil
import warnings
import numpy as np
import pandas as pd
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split, Subset
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
import subprocess
from tqdm import tqdm
from omegaconf import OmegaConf
import pytorch_lightning as pl
from terrace.batch import DataLoader
from typing import Sequence

from fireprot_dataset import *
from training.model_utils import featurize, loss_smoothed, loss_nll, get_std_opt, ProteinMPNN
from training.utils import worker_init_fn, get_pdbs, loader_pdb, build_training_clusters, PDB_dataset, StructureDataset, StructureLoader
from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN
from kaggle_dataset import KaggleTrainDataset
from transfer_model import *
from train import *
from combo_dataset import *

In [6]:
cfg = OmegaConf.load("config.yaml")

In [145]:
combo = ComboDataset(cfg, "val")

In [147]:
combo[10][1]

{'seq_chain_A': 'SMVKIYAPASIGNVSVGFDVLGAAVSPIDGTLLGDCVSVTAAERFSLHNEGRFVSKLPDDPKQNIVYQCWERFCQEMGKEIPVAMVLEKNMPIGSGLGSSACSVVAGLMAMNEFCGQPLDKVTLLGMMGELEGRVSGSIHFDNVAPCYLGGMQLILEQEGYISQDVPGFSDWLWVMAYPGIKVSTAEARAILPAQYRRQDCITHGRNLAGFIHACHTQQPDLAAKMMKDVIAEPYRTQLLPGFAAARQAAQDIGALACGISGSGPTLFAVCNDQATAQRMAGWLQNHYLQNDEGFVHICRLDTAGARLLG',
 'coords_chain_A': {'N_chain_A': [[34.39699935913086,
    75.13400268554688,
    60.108001708984375],
   [35.4900016784668, 74.37300109863281, 63.236000061035156],
   [35.43899917602539, 71.01499938964844, 64.4229965209961],
   [34.604000091552734, 68.93199920654297, 67.0270004272461],
   [34.96500015258789, 65.92900085449219, 68.87899780273438],
   [33.595001220703125, 63.52000045776367, 71.1760025024414],
   [34.564998626708984, 60.53499984741211, 72.3479995727539],
   [33.520999908447266, 57.779998779296875, 74.01499938964844],
   [34.49800109863281, 55.60100173950195, 76.80400085449219],
   [34.57699966430664, 51.98099899291992, 76.73500061035156],
   [33.37

In [33]:
@cache(lambda cfg, p: "")
def train_clusters_cached(cfg, params):
    return build_training_clusters(params, False)

data_path = cfg.platform.pdb_dir
params = {
    "LIST"    : f"{data_path}/list.csv", 
    "VAL"     : f"{data_path}/valid_clusters.txt",
    "TEST"    : f"{data_path}/test_clusters.txt",
    "DIR"     : f"{data_path}",
    "DATCUT"  : "2030-Jan-01",
    "RESCUT"  : 3.5, #resolution cutoff for PDBs (3.5)
    "HOMO"    : 0.70 #min seq.id. to detect homo chains
}

train, valid, test = train_clusters_cached(cfg, params)
train_set = PDB_dataset(list(train.keys()), loader_pdb, train, params)
valid_set = PDB_dataset(list(valid.keys()), loader_pdb, valid, params)

In [106]:
train_loader = torch.utils.data.DataLoader(train_set)
train_pdbs = get_pdbs(train_loader, num_units=500)

  3%|▎         | 680/23349 [03:06<1:43:22,  3.65it/s] 


In [108]:
max_protein_length = 10000
dataset_train = StructureDataset(train_pdbs, truncate=None, max_length=max_protein_length)

In [127]:
X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize([dataset_train[0]], 'cpu')

In [136]:
t_model = TransferModel(cfg)
_, log_probs = t_model(dataset_train[0], [], False)

In [141]:
_, loss_av_smoothed = loss_smoothed(S, log_probs, chain_M)
loss_av_smoothed

tensor(3.0908, grad_fn=<DivBackward0>)

In [130]:
model(X, S, mask, chain_M, residue_idx, chain_encoding_all, None)

([tensor([[[ 0.0000,  0.0000,  0.0000,  ..., -0.0000,  0.0000, -0.0000],
           [ 0.0000,  0.0000,  0.0000,  ..., -0.0000,  0.0000, -0.0000],
           [ 0.0000,  0.0000,  0.0000,  ..., -0.0000,  0.0000, -0.0000],
           ...,
           [ 0.0541,  0.0136,  0.0954,  ..., -0.0786,  0.2617, -0.0942],
           [ 0.0000,  0.0000,  0.0000,  ..., -0.0000,  0.0000, -0.0000],
           [ 0.0000,  0.0000,  0.0000,  ..., -0.0000,  0.0000, -0.0000]]]),
  tensor([[[ 0.0000,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
           [ 0.0000,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
           [ 0.0000,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
           ...,
           [ 0.1444, -0.1354,  0.0470,  ..., -0.1241,  0.1332, -0.2393],
           [ 0.0000,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
           [ 0.0000,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000]]]),
  tensor([[[ 0.0000, -0.0000, -0.0000,  ...,  0.0000, -0.0000, -0.0000],
           [ 

In [117]:
model = get_protein_mpnn()

In [44]:
train_pdbs = get_pdbs(train_set)

  0%|          | 11/23349 [00:01<40:54,  9.51it/s]
