In [1]:
import os
import sys
import importlib
import numpy as np
import torch as pt
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from glob import glob
from scipy import signal
from matplotlib import rcParams

import src as sp
import runtime as rt
from theme import colors

# font parameters
rcParams['font.family'] = 'sans-serif'
rcParams['font.sans-serif'] = ['Arial']
rcParams['font.size'] = 14

In [2]:
# parameters
device = pt.device("cuda")

# model parameters
# r6
#save_path = "model/save/s_v6_4_2022-09-16_11-51"  # virtual Cb & partial
#save_path = "model/save/s_v6_5_2022-09-16_11-52"  # virtual Cb, partial & noise

# r7
save_path = "model/save/s_v7_0_2023-04-25"  # partial chain
#save_path = "model/save/s_v7_1_2023-04-25"  # partial chain and noise
#save_path = "model/save/s_v7_2_2023-04-25"  # partial chain high coverage
#save_path = "model/save/s_v7_3_2023-04-25"  # partial chain and noise and high coverage

# create models
model = rt.SequenceModel(save_path, "model.pt", device=device)

In [3]:
# parameters
dataset_filepath = "datasets/pdb_structures_16384_v4.h5"
sids_selection_filepath = "datasets/subunits_train_set.txt"

# load selected sids
sids_sel = np.genfromtxt(sids_selection_filepath, dtype=np.dtype('U'))
sids_sel = np.array([s.split('_')[0] for s in sids_sel])

# create dataset
dataset = rt.Dataset(dataset_filepath)

# data selection criteria
m = sp.select_by_sid(dataset, sids_sel) # select by sids
m &= sp.select_by_max_ba(dataset, model.module.config_data['max_ba'])  # select by max assembly count
m &= (dataset.sizes[:,0] <= model.module.config_data['max_size'])  # select by max size
m &= (dataset.sizes[:,1] >= model.module.config_data['min_num_res'])  # select by min size

# update dataset selection
dataset.m &= m

# debug
len(dataset)

87137

In [4]:
# parameters
N = 1024*4

# sample predictions
p_l, y_l = [], []
for i in tqdm(np.random.choice(len(dataset), N, replace=False)):
    # load structure
    _, structure = dataset[i]
    structure['chain_name'] = np.array([str(cid) for cid in structure['cid']])

    # apply model
    _, p, y = model(structure)
    
    # store results
    p_l.append(p)
    y_l.append(y)

 10%|█         | 418/4096 [01:35<14:03,  4.36it/s]


KeyboardInterrupt: 

In [None]:
# parameters
num_bins = 200

# get predictions
P = pt.cat(p_l).numpy()
Y = pt.cat(y_l).numpy()

# filter out non-amino acids
m = (np.sum(Y, axis=1) > 0.0)
P = P[m]
Y = Y[m]

# find correct predictions
ids_y_max = np.argmax(Y, axis=1)

# get confidence
C = []
for i in range(P.shape[1]):
    m = (ids_y_max == i)
    pi = P[m, ids_y_max[m]]

    h0, x = np.histogram(P[:, i], bins=num_bins, range=(0.0, 1.0))
    h1, x = np.histogram(P[m, ids_y_max[m]], bins=num_bins, range=(0.0, 1.0))
    x = 0.5*(x[1:] + x[:-1])
    C.append(h1 / h0)

# pack results
C = np.array(C)

# save prediction CDF
np.savetxt("results/{}_cdf.csv".format(os.path.basename(save_path)), np.concatenate([np.expand_dims(x,0), C]), delimiter=",")


In [None]:
# create confidence mapping
conf = rt.ConfidenceMap("results/{}_cdf.csv".format(os.path.basename(save_path)))

# remap confidence with filter
C = conf.C
x = np.linspace(0.0, 1.0, C.shape[1])

# plot cdfs
cmap = plt.cm.tab20(np.linspace(0.0, 1.0, C.shape[0]))
plt.figure(figsize=(4.5,4))
for i in range(C.shape[0]):
    plt.plot(x, C[i], '-', label=model.module.std_resnames[i], color=cmap[i])
#plt.legend(loc='upper left', bbox_to_anchor=(1.02, 1.02))
plt.legend(loc='upper left', ncol=2, prop={'size': 10})
plt.ylim(0.0, 1.0)
plt.xlim(0.0, 1.0)
plt.xlabel('prediction confidence')
plt.ylabel('correct prediction probability')
plt.tight_layout()
plt.savefig("graphs/confidence_mapping_{}.svg".format(os.path.basename(save_path)))
plt.show()