In [1]:
import os
import sys
import importlib
import numpy as np
import torch as pt
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from glob import glob
from matplotlib import rcParams

import src as sp
import runtime as rt
import mdtraj_utils as mdu
from theme import colors

# font parameters
rcParams['font.family'] = 'sans-serif'
rcParams['font.sans-serif'] = ['Arial']
rcParams['font.size'] = 14

In [2]:
# parameters
device = pt.device("cuda")

# model parameters
# r6
save_path = "model/save/s_v6_4_2022-09-16_11-51"  # virtual Cb & partial
#save_path = "model/save/s_v6_5_2022-09-16_11-52"  # virtual Cb, partial & noise

# r7
#save_path = "model/save/s_v7_0_2023-04-25"  # partial chain
#save_path = "model/save/s_v7_1_2023-04-25"  # partial secondary structure
#save_path = "model/save/s_v7_2_2023-04-25"  # partial chain high coverage

# create models
model = rt.SequenceModel(save_path, "model.pt", device=device)

# create confidence mapping
#conf = rt.ConfidenceMap("results/{}_cdf.csv".format(os.path.basename(save_path)))

In [3]:
# parameters
pdbids = ["1JTG","1CLV","1Z0K","1AK4","1R6Q","1D6R","2I25","3F1P","1R0R","1E96","1GPW","1RKE","1FLE","2O3B","3SGQ","1ZHH","1CGI","2UUY","2HQS","2OOB"]
mdids = ["uR", "uL", "bR", "bL", "C"]

# setup data connector
dc = mdu.data.DataConnector("database")

In [4]:
# parameters
n_skip = 100

# process trajectorie
results = {}
for pdbid in pdbids:
    # for each md
    for mdid in mdids:
        # check if results exist
        pt_filepath = "results/{}:{}.pt".format(pdbid, mdid)
        if not os.path.exists(pt_filepath):
            # debug print
            print(pdbid, mdid)

            # load trajectory
            dc.load_trajectory(pdbid, mdid)

            # convert to structure
            structure = rt.traj_to_struct(dc[pdbid][mdid]['traj'])
            
            # skips frames
            structure['xyz'] = structure['xyz'][:,::n_skip]

            # apply model on full structure
            _, P, y = model(structure)

            # save results
            pt.save(pt.cat([y.unsqueeze(0),P], dim=0), pt_filepath)

            # unload data
            dc.unload_md(pdbid, mdid)
        
    # unload data
    dc.unload_pdb(pdbid)