In [1]:
from glob import glob 
import os 
from utils_tbox.utils_tbox import read_pklz
import lightning as Lit
from parse import parse
import numpy as np
import matplotlib

from utils.trainer import lightningdmsEVE#lightningEVE,
from utils.models import dmsEVE#lightningEVE,
import torch
from torch.utils.data import DataLoader
import warnings

# Suppress the specific warning about NVML initialization
warnings.filterwarnings("ignore", message="Can't initialize NVML")


def read_fun(fname):
    try:
        out = read_pklz(fname)
    except:
        out = None
    return out

# Find the last checkpoint
def get_last_checkpoint(retrain,log_dir,exp_name):
    latest_checkpoint = None
    if not retrain: # If retrain is true, skip this and return None
        earlier_runs = glob(os.path.join(log_dir, exp_name, "version*"))
        #print(earlier_runs)
        if len(earlier_runs)>0:
            latest_run = sorted(earlier_runs, key= lambda s: parse("version_{:d}", os.path.basename(s))[0])[-1]
            earlier_runs_checkpoints = glob(os.path.join(latest_run,"checkpoints","*.ckpt"))
            if len(earlier_runs_checkpoints) > 0:
                latest_checkpoint = earlier_runs_checkpoints[0]
    return latest_checkpoint
def load_checkpoint(fname):
    print("Loading checkpoint=",fname)
    state_dict = torch.load(fname, map_location=torch.device('cpu'),weights_only=True)
    
    model=dmsEVE(state_dict["hyper_parameters"]["model_parameters"])
    
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    
    for key, value in state_dict["state_dict"].items():
        key = ".".join(key.split(".")[1:])
        new_state_dict[key] = value
    model.load_state_dict(new_state_dict)
    
    model = lightningdmsEVE(model=model, hparams=state_dict["hyper_parameters"])
    
    return model

# Load checkpoint

In [7]:
# Lightning log folder
log_dir = "../lightning_logs_cp2c9"

# Select fold to reproduce (out of 5)
fold_idx = 0

# path to the experiment in the log folder
exp_name = "matVAEVAMP5/6/fold{}".format(fold_idx)

chkpt_fname = get_last_checkpoint(False,log_dir,exp_name)
model = load_checkpoint( chkpt_fname )
matplotlib.pyplot.close()

protein_name = model.hparams["model_parameters"]["protein_name"]
print("Protein name=",protein_name)

L = model.hparams["model_parameters"]["L"]

# Sanity run with random data
with torch.no_grad():
    logits, latent_output=model.model(torch.randn(2,L,20))
logits

Loading checkpoint= ../lightning_logs_cp2c9/matVAEVAMP5/6/fold0/version_0/checkpoints/last.ckpt
Protein name= CP2C9_HUMAN


tensor([[[ 3.5711,  3.0583, -2.9666,  ...,  5.2661, -0.1504,  0.1755],
         [-1.6088,  0.7565, -2.1933,  ...,  5.4635, -1.7634,  2.2224],
         [ 4.7438,  2.1137, -0.5398,  ...,  3.3383, -0.8061, -0.3769],
         ...,
         [ 0.4912, -1.4003, -1.6558,  ...,  2.3844, -3.6393, -1.2396],
         [-1.3261,  0.8542, -1.0990,  ..., -0.1457, -2.9468,  1.1663],
         [ 1.2296, -2.0336, -0.6887,  ...,  2.9915, -2.2282, -1.0011]],

        [[ 3.5710,  3.0587, -2.9673,  ...,  5.2657, -0.1500,  0.1758],
         [-1.6089,  0.7578, -2.1931,  ...,  5.4622, -1.7671,  2.2234],
         [ 4.7449,  2.1135, -0.5395,  ...,  3.3386, -0.8057, -0.3772],
         ...,
         [ 0.4914, -1.4009, -1.6562,  ...,  2.3840, -3.6388, -1.2388],
         [-1.3266,  0.8512, -1.0983,  ..., -0.1447, -2.9508,  1.1654],
         [ 1.2300, -2.0343, -0.6892,  ...,  2.9909, -2.2281, -0.9987]]])

# Load datasets

In [None]:
data_dir = "../data/preprocessed"
all_files = glob(os.path.join(data_dir,"{}.pklz".format(protein_name)))

fname = all_files[0]

msa_dataset, dms_datasets = read_pklz(fname)

# Reproduce data split, select fold and create dataloaders

In [None]:
from utils.data import prepare_dms_dataloaders#, prepare_msa_dataloaders

train_DMS_prot_datasets = {protein_name:dms_datasets}
test_DMS_prot_datasets = {}

random_state = 12345
n_folds=5
num_workers=torch.get_num_threads()
pin_memory=False
verbose=0
batch_size=8

from sklearn.model_selection import KFold

# Seeded random split
train_val_idxes = {protein_name: {dataset_name: 
                                  list(KFold(n_folds, shuffle=True, random_state=random_state).split(np.arange(len(v[dataset_name]))))
                                for dataset_name in v.keys()
                                }
                for protein_name, v in train_DMS_prot_datasets.items()
                }

train_DMS_prot_dataloaders, test_DMS_prot_dataloaders = \
    prepare_dms_dataloaders(train_DMS_prot_datasets, test_DMS_prot_datasets, num_workers, pin_memory, batch_size,
    fold_idx=fold_idx, train_val_idxes=train_val_idxes, verbose=verbose)

In [None]:
Lit.Trainer().validate(model,test_DMS_prot_dataloaders)

# The prototypes

In [15]:
def get_aa_dict():
    """Declare the alphabet"""
    alphabet = "ACDEFGHIKLMNPQRSTVWY"
    aa_dict = {}
    for i, aa in enumerate(alphabet):
        aa_dict[aa] = i
    return aa_dict
aa_dict = get_aa_dict()

In [16]:
aa_dict

{'A': 0,
 'C': 1,
 'D': 2,
 'E': 3,
 'F': 4,
 'G': 5,
 'H': 6,
 'I': 7,
 'K': 8,
 'L': 9,
 'M': 10,
 'N': 11,
 'P': 12,
 'Q': 13,
 'R': 14,
 'S': 15,
 'T': 16,
 'V': 17,
 'W': 18,
 'Y': 19}

In [17]:
prototypes_continuous = model.model.Prior.prototypes
prototypes_softmax = model.model.Prior.prototypes.softmax(-1)
prototypes_index = model.model.Prior.prototypes.argmax(-1)
print("The shape:", prototypes_index.shape)

tensor([[ 4, 14,  4,  ...,  7,  8, 19],
        [ 2,  0,  0,  ...,  7, 15,  1],
        [16,  0,  4,  ...,  6,  8,  6],
        [15,  8, 13,  ..., 12, 19, 19],
        [ 2,  1, 14,  ..., 12, 19,  8]])

In [19]:
# Each number is the index of a letter in aa_dict
prototypes_index

tensor([[ 4, 14,  4,  ...,  7,  8, 19],
        [ 2,  0,  0,  ...,  7, 15,  1],
        [16,  0,  4,  ...,  6,  8,  6],
        [15,  8, 13,  ..., 12, 19, 19],
        [ 2,  1, 14,  ..., 12, 19,  8]])