In [1]:
from glob import glob 
import os 
from utils_tbox.utils_tbox import read_pklz
import lightning as Lit
from parse import parse
import numpy as np

from utils.trainer import lightningdmsEVE#lightningEVE,
from utils.models import dmsEVE#lightningEVE,
import torch
from torch.utils.data import DataLoader

def read_fun(fname):
    try:
        out = read_pklz(fname)
    except:
        out = None
    return out

# Find the last checkpoint
def get_last_checkpoint(retrain,log_dir,exp_name):
    latest_checkpoint = None
    if not retrain:
        earlier_runs = glob(os.path.join(log_dir, exp_name, "version*"))
        #print(earlier_runs)
        if len(earlier_runs)>0:
            latest_run = sorted(earlier_runs, key= lambda s: parse("version_{:d}", os.path.basename(s))[0])[-1]
            earlier_runs_checkpoints = glob(os.path.join(latest_run,"checkpoints","*.ckpt"))
            if len(earlier_runs_checkpoints) > 0:
                latest_checkpoint = earlier_runs_checkpoints[0]
    return latest_checkpoint
def load_checkpoint(fname):
    print("Loading checkpoint=",fname)
    state_dict = torch.load(fname, map_location=torch.device('cpu'))
    
    model=dmsEVE(state_dict["hyper_parameters"]["model_parameters"])
    
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    
    for key, value in state_dict["state_dict"].items():
        key = ".".join(key.split(".")[1:])
        new_state_dict[key] = value
    model.load_state_dict(new_state_dict)
    
    model = lightningdmsEVE(model=model, hparams=state_dict["hyper_parameters"])
    
    return model

# Load checkpoint

In [2]:
# Lightning log folder
log_dir = "../lightning_logs_complete"

# Select fold to reproduce (out of 5)
fold_idx = 0

# path to the experiment in the log folder
exp_name = "indivlatenttodmssimple/11/fold{}".format(fold_idx)

chkpt_fname = get_last_checkpoint(False,log_dir,exp_name)
model = load_checkpoint( chkpt_fname )

protein_name = model.hparams["model_parameters"]["protein_name"]
print("Protein name=",protein_name)

L = model.hparams["model_parameters"]["L"]

# Sanity run with random data
with torch.no_grad():
    logits, latent_output=model.model(torch.randn(2,L,20))
logits

Loading checkpoint= ../lightning_logs_complete/indivlatenttodmssimple/11/fold0/version_4/checkpoints/epoch=9089-step=99990.ckpt
Protein name= MK01_HUMAN


tensor([[[-1.6425e-01, -1.0845e-01,  3.9820e-01,  ..., -4.5161e-01,
           4.7603e-01, -3.3526e-01],
         [-1.3580e+00, -7.7614e-01,  1.3165e-01,  ..., -8.3608e-01,
          -1.6409e-01, -2.7871e-01],
         [-4.2374e-01,  6.0560e-01, -4.4876e-01,  ..., -1.3471e-01,
           2.2159e-01,  4.0547e-01],
         ...,
         [ 6.3283e-02,  4.7008e-01,  4.6019e-01,  ...,  4.2702e-01,
           3.8697e-01, -1.1613e+00],
         [-4.3329e-01, -6.4604e-02,  6.9149e-01,  ...,  6.1604e-01,
           2.6715e-02, -2.3357e-01],
         [-5.8444e-01, -5.8697e-01, -3.2510e-01,  ..., -3.1421e-01,
          -1.0505e+00,  4.2110e-01]],

        [[-3.2535e-01, -1.9387e-01,  3.9797e-01,  ...,  4.2012e-04,
           6.9392e-01, -2.2490e-01],
         [-7.4271e-01, -6.7622e-01, -1.4302e-01,  ..., -7.4481e-01,
          -1.0207e-01, -2.5053e-01],
         [-3.8598e-01,  6.5250e-01, -6.5422e-01,  ..., -2.9394e-01,
          -4.5420e-02,  3.6937e-01],
         ...,
         [-1.5160e-01,  4

# Load datasets

In [3]:
data_dir = "../data/preprocessed"
all_files = glob(os.path.join(data_dir,"{}.pklz".format(protein_name)))

fname = all_files[0]

msa_dataset, dms_datasets = read_pklz(fname)

# Reproduce data split, select fold and create dataloaders

In [4]:
from utils.data import prepare_dms_dataloaders#, prepare_msa_dataloaders

train_DMS_prot_datasets = {protein_name:dms_datasets}
test_DMS_prot_datasets = {}

random_state = 12345
n_folds=5
num_workers=torch.get_num_threads()
pin_memory=False
verbose=0
batch_size=8

from sklearn.model_selection import KFold

# Seeded random split
train_val_idxes = {protein_name: {dataset_name: 
                                  list(KFold(n_folds, shuffle=True, random_state=random_state).split(np.arange(len(v[dataset_name]))))
                                for dataset_name in v.keys()
                                }
                for protein_name, v in train_DMS_prot_datasets.items()
                }

train_DMS_prot_dataloaders, test_DMS_prot_dataloaders = \
    prepare_dms_dataloaders(train_DMS_prot_datasets, test_DMS_prot_datasets, num_workers, pin_memory, batch_size,
    fold_idx=fold_idx, train_val_idxes=train_val_idxes, verbose=verbose)

In [5]:
Lit.Trainer().validate(model,test_DMS_prot_dataloaders)

/home/anthon@ad.cmm.se/virtualenvs/pyfactorial/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/anthon@ad.cmm.se/virtualenvs/pyfactorial/lib/p ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Validation DataLoader 0: 100%|████████████████| 171/171 [00:09<00:00, 18.94it/s]


[{'r2/dms_MK01_HUMAN_Brenan_2016.csv': -50.52971649169922,
  'spearmanr/dms_MK01_HUMAN_Brenan_2016.csv': 0.11204680055379868,
  'binary_auroc/dms_MK01_HUMAN_Brenan_2016.csv': 0.5451215591331724,
  'binary_auprc/dms_MK01_HUMAN_Brenan_2016.csv': 0.5268720984458923,
  'binary_precision/dms_MK01_HUMAN_Brenan_2016.csv': 0.0,
  'binary_recall/dms_MK01_HUMAN_Brenan_2016.csv': 0.0,
  'binary_f1_score/dms_MK01_HUMAN_Brenan_2016.csv': 0.0,
  'binary_accuracy/dms_MK01_HUMAN_Brenan_2016.csv': 0.5132158398628235,
  'r2/dms_MK01_HUMAN_Brenan_2016.csv/latent': 0.7769248485565186,
  'spearmanr/dms_MK01_HUMAN_Brenan_2016.csv/latent': 0.8767576217651367,
  'binary_auroc/dms_MK01_HUMAN_Brenan_2016.csv/latent': 0.9697887738786501,
  'binary_auprc/dms_MK01_HUMAN_Brenan_2016.csv/latent': 0.9718227982521057,
  'binary_precision/dms_MK01_HUMAN_Brenan_2016.csv/latent': 0.0,
  'binary_recall/dms_MK01_HUMAN_Brenan_2016.csv/latent': 0.0,
  'binary_f1_score/dms_MK01_HUMAN_Brenan_2016.csv/latent': 0.0,
  'binary_ac