# we will look at 2 files, TF prediction and the pooled model

In [1]:
# first the pooled model, it should be identical, so let's see if our default code works

import sys
sys.path.append('/data1/lesliec/sarthak/caduceus/evals')
from evals_utils_joint import Evals

ckpt_path = '/data1/lesliec/sarthak/caduceus/outputs/2025-07-09/12-35-55-535137/checkpoints/01-val_loss=0.27462.ckpt'
evals = Evals(ckpt_path,load_data=False, device=2)

JointMaskingDecoder: d_model=1024, d_output1=5, d_output2=1, upsample=4
JointMaskingEncoder: d_model=1024, d_input1=6, d_input2=2, joint=False, kernel_size=15, combine=True, acc_type=continuous


In [2]:
len(evals.dataset)

1937

In [3]:
a,b = evals.dataset[0]
print(a[0].shape, a[1].shape, b[0].shape, b[1].shape)

torch.Size([6, 524288]) torch.Size([2, 524288]) torch.Size([524288, 6]) torch.Size([524288, 2])


In [5]:
#just one cell type so it makes sense. let's see the outputs of each of the things
(seq,acc),(seq_unmask,acc_unmask) = evals.dataset[0]
x = seq.unsqueeze(0)
y = acc.unsqueeze(0)
x,y = x.to(evals.device), y.to(evals.device)


In [None]:
import torch
with torch.no_grad():
    x1,intermediates = evals.encoder(x,y)
    print(x1.shape)
    for k,v in intermediates.items():
        print(k, v.shape)
    x1,_ = evals.backbone(x1)
    print(x1.shape)
    x1 = evals.decoder(x1, intermediates)
    seq,acc = x1
    print(seq.shape, acc.shape)

#yeah, it looks good! clearly still downsampled!


torch.Size([1, 1024, 131072])
bin_size_1 torch.Size([1, 512, 524288])
bin_size_2 torch.Size([1, 1024, 262144])
torch.Size([1, 131072, 1024])
torch.Size([1, 524288, 5]) torch.Size([1, 524288, 1])


# now do it for the TF prediction model

In [6]:
# can't use the base evals since it  can't use that evals, but it should be fine to load the model still?
# no, it's a separate decoder and stuff
import os
import sys
sys.path.append('/data1/lesliec/sarthak/caduceus/')
import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import argparse
from src.dataloaders.datasets.general_dataset import GeneralDataset
from src.models.sequence.dna_embedding import DNAEmbeddingModelCaduceus
from src.tasks.decoders import EnformerDecoder
from src.tasks.encoders import JointCNN
from caduceus.configuration_caduceus import CaduceusConfig
import yaml
from omegaconf import OmegaConf
import os
import itertools
import inspect
import zarr
from numcodecs import Blosc
from scipy.stats import spearmanr, pearsonr
from torch.utils.data import DataLoader
import pickle

try:
    OmegaConf.register_new_resolver('eval', eval)
    OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)
except ValueError as e:
    if "Resolver already registered" in str(e):
            print("Resolver already exists, skipping registration.")

class Evals():
    def __init__(self,
                 ckpt_path,
                 dataset=None,
                 split = 'test',
                 device = None,
                 load_data=False,
                 **dataset_overrides #Don't pass None into overrides unless you intentionally want it to be None! Pass in items only that you need
                 ) -> None:
        
        #now load the cfg from the checkpoint path
        model_cfg_path = os.path.join(os.path.dirname(os.path.dirname(ckpt_path)), '.hydra', 'config.yaml')
        cfg = yaml.load(open(model_cfg_path, 'r'), Loader=yaml.FullLoader)
        cfg = OmegaConf.create(cfg)
        self.cfg = OmegaConf.to_container(cfg, resolve=True)
        
        state_dict = torch.load(ckpt_path, map_location='cpu')
        if device is not None:
            #if we are given a device, we will use that device
            self.device = torch.device(device)
        else:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.split = split

        #now set up dataset
        if dataset is None:
            dataset_args = self.cfg['dataset']
            # assert dataset_args['mlm'] == 0 and dataset_args['acc_mlm'] == 0, "MLM and acc_mlm should be 0 for the training"
            sig = inspect.signature(GeneralDataset.__init__)
            sig = {k: v for k, v in sig.parameters.items() if k != 'self'}
            to_remove = []
            for k, v in dataset_args.items():
                if k not in sig:
                    # del dataset_args[k]
                    to_remove.append(k)
            for k in to_remove:
                del dataset_args[k]
            dataset_args['split'] = split
            dataset_args['evaluating'] = True #this tells it to not do things like random shifting and rc aug, still does random masking tho, can get og sequence easily
            dataset_args['load_in'] = load_data
            
            for k, v in dataset_overrides.items():
                if k in sig:
                    dataset_args[k] = v
                    print(f"Overriding {k} with {v}")
                else:
                    print(f"Warning: {k} not in dataset args, skipping")
            
            # dataset_args['rc_aug'] = False #we don't want to do rc aug in our evaluation class!!!
            self.dataset_args = dataset_args
            # self.dataset_args['rc_aug'] = False #we don't want to do rc aug in our evaluation class!!!
            self.dataset = GeneralDataset(**dataset_args)
            
            # self.kmer_len = dataset_args['kmer_len']
            # self.dataset = enformer_dataset.EnformerDataset(split, dataset_args['max_length'], rc_aug = dataset_args['rc_aug'],
            #                                                 return_CAGE=dataset_args['return_CAGE'], cell_type=dataset_args.get('cell_type', None),
            #                                                 kmer_len=dataset_args['kmer_len']) #could use dataloader instead, but again kinda complex
        else:
            self.dataset = dataset
         
        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            state_dict["state_dict"], "model."
        )
        model_state_dict = state_dict["state_dict"]
        # need to remove torchmetrics. to remove keys, need to convert to list first
        for key in list(model_state_dict.keys()):
            if "torchmetrics" in key:
                model_state_dict.pop(key)
        # the state_dict keys slightly mismatch from Lightning..., so we fix it here
        decoder_state_dict = {}
        for key in list(model_state_dict.keys()):
            if "decoder" in key:
                decoder_state_dict[key[10:]] = model_state_dict.pop(key)
        encoder_state_dict = {}
        for key in list(model_state_dict.keys()):
            if "encoder" in key:
                encoder_state_dict[key[10:]] = model_state_dict.pop(key)
        
        cfg['model']['config'].pop('_target_')
        # cfg['model']['config']['complement_map'] = self.dataset.tokenizer.complement_map
        caduceus_cfg = CaduceusConfig(**cfg['model']['config'])
        
        self.backbone = DNAEmbeddingModelCaduceus(config=caduceus_cfg)
        self.backbone.load_state_dict(model_state_dict, strict=True)
        
        #remove self.cfg['decoder']['_name_']
        del self.cfg['decoder']['_name_']
        self.cfg['decoder']['d_model'] = self.cfg['model']['config']['d_model']
        self.decoder = EnformerDecoder(**self.cfg['decoder']) #could do with instantiating, but that is rather complex
        self.decoder.load_state_dict(decoder_state_dict, strict=True)
        
        del self.cfg['encoder']['_name_']
        self.cfg['encoder']['d_model'] = self.cfg['model']['config']['d_model']
        self.encoder = JointCNN(**self.cfg['encoder'])
        self.encoder.load_state_dict(encoder_state_dict, strict=True)
        
        self.encoder.to(self.device).eval()
        self.backbone.to(self.device).eval()
        self.decoder.to(self.device).eval()
        
    def __call__(self, idx=None, data=None):
        #now evaluate the model on one example
        if data is None:
            (seq,acc),(seq_unmask,acc_unmask, exp) = self.dataset[idx]
            
            x = seq.unsqueeze(0)
            y = acc.unsqueeze(0)
        else:
            (x,y),(seq_unmask,acc_unmask, exp) = data

            if x.dim() == 2:
                x = x.unsqueeze(0) #add batch dim
                y = y.unsqueeze(0) #add batch dim
        
        x,y = x.to(self.device), y.to(self.device)
        
        with torch.no_grad():
            x1,_ = self.encoder(x,y)
            x1,_ = self.backbone(x1)
            x1 = self.decoder(x1)
        
        return x1

ckpt_path = '/data1/lesliec/sarthak/caduceus/outputs/2025-07-07/15-25-26-980519/checkpoints/05-val_loss=0.75187.ckpt'
evals = Evals(ckpt_path,load_data=False, device=2)

JointMaskingEncoder: d_model=256, d_input1=6, d_input2=2, joint=False, kernel_size=15, combine=True, acc_type=continuous


In [7]:
len(evals.dataset) #again just one cell type

1937

In [8]:
evals.decoder #clearly seems correct, let's try it

EnformerDecoder(
  (final_pointwise): Sequential(
    (0): Rearrange('b n d -> b d n')
    (1): Sequential(
      (0): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): GELU()
      (2): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
    )
    (2): Rearrange('b d n -> b n d')
    (3): Dropout(p=0.05, inplace=False)
    (4): GELU()
  )
  (output_transform): Linear(in_features=512, out_features=162, bias=True)
  (pool): AvgPool1d(kernel_size=(1,), stride=(1,), padding=(0,))
  (softplus): Softplus(beta=1, threshold=20)
)

In [10]:

out = evals(0)

In [None]:
out.shape #looks about right! Good to see!!

torch.Size([1, 524288, 162])

In [12]:
3

3