In [1]:
from glob import glob 
import os 
from utils_tbox.utils_tbox import read_pklz
import lightning as Lit
from parse import parse
import numpy as np
import matplotlib

from utils.trainer import lightningdmsEVE#lightningEVE,
from utils.models import dmsEVE#lightningEVE,
import torch
from torch.utils.data import DataLoader
import warnings

# Suppress the specific warning about NVML initialization
warnings.filterwarnings("ignore", message="Can't initialize NVML")


def read_fun(fname):
    try:
        out = read_pklz(fname)
    except:
        out = None
    return out

# Find the last checkpoint
def get_last_checkpoint(retrain,log_dir,exp_name):
    latest_checkpoint = None
    if not retrain: # If retrain is true, skip this and return None
        earlier_runs = glob(os.path.join(log_dir, exp_name, "version*"))
        #print(earlier_runs)
        if len(earlier_runs)>0:
            latest_run = sorted(earlier_runs, key= lambda s: parse("version_{:d}", os.path.basename(s))[0])[-1]
            earlier_runs_checkpoints = glob(os.path.join(latest_run,"checkpoints","*.ckpt"))
            if len(earlier_runs_checkpoints) > 0:
                latest_checkpoint = earlier_runs_checkpoints[0]
    return latest_checkpoint
def load_checkpoint(fname):
    print("Loading checkpoint=",fname)
    state_dict = torch.load(fname, map_location=torch.device('cpu'),weights_only=True)
    
    model=dmsEVE(state_dict["hyper_parameters"]["model_parameters"])
    
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    
    for key, value in state_dict["state_dict"].items():
        key = ".".join(key.split(".")[1:])
        new_state_dict[key] = value
    model.load_state_dict(new_state_dict)
    
    model = lightningdmsEVE(model=model, hparams=state_dict["hyper_parameters"])
    
    return model

# Load checkpoint

In [2]:
# Lightning log folder
log_dir = "../lightning_logs_cp2c9"

# Select fold to reproduce (out of 5)
fold_idx = 0

# path to the experiment in the log folder
exp_name = "matVAEMOG1/6/fold{}".format(fold_idx)

chkpt_fname = get_last_checkpoint(False,log_dir,exp_name)
model = load_checkpoint( chkpt_fname )
matplotlib.pyplot.close()

protein_name = model.hparams["model_parameters"]["protein_name"]
print("Protein name=",protein_name)

L = model.hparams["model_parameters"]["L"]

# Sanity run with random data
with torch.no_grad():
    logits, latent_output=model.model(torch.randn(2,L,20))
logits

Loading checkpoint= ../lightning_logs_cp2c9/matVAEMOG1/6/fold0/version_0/checkpoints/last.ckpt




Protein name= CP2C9_HUMAN


No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda-12.2'


tensor([[[ 2.2044, -0.2780, -5.1872,  ...,  3.9678, -1.1283, -1.0135],
         [ 2.5616,  0.9089, -4.2803,  ...,  4.8001, -0.3317, -1.4157],
         [ 5.4316,  1.0941, -4.6579,  ...,  2.8694,  0.5187, -0.9395],
         ...,
         [ 0.6366, -1.0008, -0.8636,  ...,  2.4807, -1.5415, -0.5702],
         [-0.0988, -1.2598,  0.0924,  ...,  0.8259, -1.1127, -0.7352],
         [ 1.7639,  0.3939, -4.0544,  ...,  3.4743,  0.0687,  1.4525]],

        [[ 2.2052, -0.2779, -5.1857,  ...,  3.9712, -1.1285, -1.0133],
         [ 2.5626,  0.9094, -4.2811,  ...,  4.8032, -0.3344, -1.4147],
         [ 5.4316,  1.0911, -4.6561,  ...,  2.8678,  0.5195, -0.9377],
         ...,
         [ 0.6306, -0.9992, -0.8631,  ...,  2.4786, -1.5436, -0.5702],
         [-0.1025, -1.2608,  0.0933,  ...,  0.8172, -1.1070, -0.7368],
         [ 1.7658,  0.3951, -4.0531,  ...,  3.4753,  0.0695,  1.4510]]])

# Load datasets

In [3]:
data_dir = "../data/preprocessed"
all_files = glob(os.path.join(data_dir,"{}.pklz".format(protein_name)))

fname = all_files[0]

msa_dataset, dms_datasets = read_pklz(fname)

# Reproduce data split, select fold and create dataloaders

In [4]:
from utils.data import prepare_dms_dataloaders#, prepare_msa_dataloaders

train_DMS_prot_datasets = {protein_name:dms_datasets}
test_DMS_prot_datasets = {}

random_state = 12345
n_folds=5
num_workers=torch.get_num_threads()
pin_memory=False
verbose=0
batch_size=8

from sklearn.model_selection import KFold

# Seeded random split
train_val_idxes = {protein_name: {dataset_name: 
                                  list(KFold(n_folds, shuffle=True, random_state=random_state).split(np.arange(len(v[dataset_name]))))
                                for dataset_name in v.keys()
                                }
                for protein_name, v in train_DMS_prot_datasets.items()
                }

train_DMS_prot_dataloaders, test_DMS_prot_dataloaders = \
    prepare_dms_dataloaders(train_DMS_prot_datasets, test_DMS_prot_datasets, num_workers, pin_memory, batch_size,
    fold_idx=fold_idx, train_val_idxes=train_val_idxes, verbose=verbose)

In [5]:
Lit.Trainer().validate(model,test_DMS_prot_dataloaders)

/home/anthon@ad.cmm.se/projects/EVE/pyenv/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/anthon@ad.cmm.se/projects/EVE/pyenv/lib/python ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Validation: |                                                                                                 …

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                      Validate metric                                               DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 binary_accuracy/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv                     0.48744112253189087
binary_accuracy/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv/d                     0.5125588774681091
                           msELBO
binary_accuracy/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv/l                     0.5125588774681091
                           atent
 binary_accuracy/dms_CP2C9_HUMAN_Amorosi_2021_activity.csv                       0.4808787703514099
binary_accuracy/dms_CP2C9_HUMAN_Amorosi_2021_activity.csv/dm                     0.5256305932998657
                           sELBO
binary_accuracy/dms_CP2C9_HUMAN_Amorosi_2021_activity.csv/la 

[{'r2/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv': -4.55610466003418,
  'spearmanr/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv': 0.4879545271396637,
  'binary_auroc/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv': 0.7410267981544365,
  'binary_auprc/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv': 0.7130352854728699,
  'binary_precision/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv': 0.0,
  'binary_recall/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv': 0.0,
  'binary_f1_score/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv': 0.0,
  'binary_accuracy/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv': 0.48744112253189087,
  'r2/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv/latent': -3.8618874549865723,
  'spearmanr/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv/latent': 0.0780782401561737,
  'binary_auroc/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv/latent': 0.5450836348033232,
  'binary_auprc/dms_CP2C9_HUMAN_Amorosi_2021_abundance.csv/latent': 0.544508695602417,
  'binary_precision/dms_CP2C9_HUMAN_Amorosi_2021_abundance.c