In [1]:
import gc
import logging
from functools import partial

import lightning as L
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.utils.data
import wandb
from hydra.utils import instantiate
from matplotlib import pyplot as plt
from minlora import LoRAParametrization
from minlora.model import add_lora_by_name
from omegaconf import DictConfig, OmegaConf

from get_model.config.config import *
from get_model.dataset.zarr_dataset import (
    InferenceRegionDataset,
    InferenceRegionMotifDataset,
    RegionDataset,
    RegionMotifDataset,
    get_gencode_obj,
)
from get_model.model.model import *
from get_model.model.modules import *
from get_model.run import LitModel, run_shared
from get_model.utils import (
    extract_state_dict,
    load_checkpoint,
    load_state_dict,
    recursive_detach,
    recursive_numpy,
    rename_state_dict,
)




In [2]:
from get_model.run_region import *

In [3]:
from get_model.config.config import export_config, load_config_from_yaml
#from get_model.run_region import run_zarr as run

In [29]:
cfg = load_config_from_yaml("/gpfs/home/asun/jin_lab/get/debug/interpret_gex_finetune_config.yaml")
cfg.finetune.checkpoint = cfg.finetune.resume_ckpt

In [30]:
model = RegionLitModel(cfg)


Load ckpt from /gpfs/home/asun/jin_lab/get/3_aggr_m23_finetune/output/finetune_aggr_multiome/training_from_null_m23_L6-IT_Astro_no_chr_split_binary_atac/checkpoints/best.ckpt
Load state_dict by model_key = state_dict


In [22]:
for name, param in model.model.named_parameters():
    print(f"{name}: {param.shape}, mean={param.data.mean():.4f}, std={param.data.std():.4f}")

cls_token: torch.Size([1, 1, 768]), mean=-0.0084, std=0.0649
region_embed.embed.weight: torch.Size([768, 283]), mean=-0.0008, std=0.0686
region_embed.embed.bias: torch.Size([768]), mean=0.0052, std=0.0515
encoder.blocks.0.norm1.weight: torch.Size([768]), mean=0.6558, std=0.1529
encoder.blocks.0.norm1.bias: torch.Size([768]), mean=0.0231, std=0.3234
encoder.blocks.0.attn.q_bias: torch.Size([768]), mean=0.0057, std=0.1651
encoder.blocks.0.attn.v_bias: torch.Size([768]), mean=-0.0008, std=0.0363
encoder.blocks.0.attn.qkv.weight: torch.Size([2304, 768]), mean=0.0001, std=0.0265
encoder.blocks.0.attn.proj.weight: torch.Size([768, 768]), mean=0.0001, std=0.0254
encoder.blocks.0.attn.proj.bias: torch.Size([768]), mean=0.0067, std=0.0544
encoder.blocks.0.norm2.weight: torch.Size([768]), mean=0.9401, std=0.6096
encoder.blocks.0.norm2.bias: torch.Size([768]), mean=-0.0101, std=0.0633
encoder.blocks.0.mlp.fc1.weight: torch.Size([3072, 768]), mean=-0.0004, std=0.0364
encoder.blocks.0.mlp.fc1.bias:

In [13]:
trainable = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in model.model.parameters() if not p.requires_grad)
print(f"Trainable params: {trainable:,}")
print(f"Frozen params: {frozen:,}")

Trainable params: 85,267,202
Frozen params: 0


In [None]:
model = RegionLitModel(cfg)
logging.debug(OmegaConf.to_yaml(cfg))
dm = RegionDataModule(cfg)

In [14]:
pretrain_cfg = load_config_from_yaml("/gpfs/home/asun/jin_lab/get/3_aggr_m23_finetune/gex_batac_config.yaml")
pretrain_model = RegionLitModel(cfg)


Load ckpt from /gpfs/home/asun/jin_lab/get/3_aggr_m23_finetune/checkpoint-best.pth
Load state_dict by model_key = model


In [15]:
for name, param in pretrain_model.model.named_parameters():
    print(f"{name}: {param.shape}, mean={param.data.mean():.4f}, std={param.data.std():.4f}")

cls_token: torch.Size([1, 1, 768]), mean=-0.0087, std=0.0610
region_embed.embed.weight: torch.Size([768, 283]), mean=0.0010, std=0.0705
region_embed.embed.bias: torch.Size([768]), mean=0.0068, std=0.0521
encoder.blocks.0.norm1.weight: torch.Size([768]), mean=0.6721, std=0.1583
encoder.blocks.0.norm1.bias: torch.Size([768]), mean=0.0234, std=0.3330
encoder.blocks.0.attn.q_bias: torch.Size([768]), mean=0.0054, std=0.1649
encoder.blocks.0.attn.v_bias: torch.Size([768]), mean=-0.0008, std=0.0364
encoder.blocks.0.attn.qkv.weight: torch.Size([2304, 768]), mean=0.0001, std=0.0285
encoder.blocks.0.attn.proj.weight: torch.Size([768, 768]), mean=0.0001, std=0.0290
encoder.blocks.0.attn.proj.bias: torch.Size([768]), mean=0.0076, std=0.0518
encoder.blocks.0.norm2.weight: torch.Size([768]), mean=0.9471, std=0.6040
encoder.blocks.0.norm2.bias: torch.Size([768]), mean=-0.0100, std=0.0643
encoder.blocks.0.mlp.fc1.weight: torch.Size([3072, 768]), mean=-0.0004, std=0.0379
encoder.blocks.0.mlp.fc1.bias: 

In [21]:
diffs = {}
for (name1, param1), (name2, param2) in zip(model.named_parameters(), pretrain_model.named_parameters()):
    assert name1 == name2, f"Layer mismatch: {name1} vs {name2}"
    if not torch.allclose(param1, param2, atol=1e-6):
        diffs[name1] = (param1.detach().cpu(), param2.detach().cpu())

print(f"Found {len(diffs)} layers with differences:")
for name, (curr, pre) in diffs.items():
    print(f"{name}: mean diff = {(curr - pre).abs().mean():.6f}")

Found 163 layers with differences:
model.cls_token: mean diff = 0.009539
model.region_embed.embed.weight: mean diff = 0.008947
model.region_embed.embed.bias: mean diff = 0.010705
model.encoder.blocks.0.norm1.weight: mean diff = 0.018320
model.encoder.blocks.0.norm1.bias: mean diff = 0.013663
model.encoder.blocks.0.attn.q_bias: mean diff = 0.003213
model.encoder.blocks.0.attn.v_bias: mean diff = 0.000879
model.encoder.blocks.0.attn.qkv.weight: mean diff = 0.004645
model.encoder.blocks.0.attn.proj.weight: mean diff = 0.005826
model.encoder.blocks.0.attn.proj.bias: mean diff = 0.007022
model.encoder.blocks.0.norm2.weight: mean diff = 0.009941
model.encoder.blocks.0.norm2.bias: mean diff = 0.008237
model.encoder.blocks.0.mlp.fc1.weight: mean diff = 0.006207
model.encoder.blocks.0.mlp.fc1.bias: mean diff = 0.005597
model.encoder.blocks.0.mlp.fc2.weight: mean diff = 0.006096
model.encoder.blocks.0.mlp.fc2.bias: mean diff = 0.003280
model.encoder.blocks.1.norm1.weight: mean diff = 0.010385
mo

In [23]:
def scramble_weights(model, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    for name, param in model.named_parameters():
        if param.requires_grad:
            # Replace with random values of the same shape
            param.data = torch.randn_like(param.data)
    print("Weights scrambled!")

In [24]:
scramble_weights(model)

Weights scrambled!


In [25]:
for name, param in model.model.named_parameters():
    print(f"{name}: {param.shape}, mean={param.data.mean():.4f}, std={param.data.std():.4f}")

cls_token: torch.Size([1, 1, 768]), mean=-0.0399, std=0.9752
region_embed.embed.weight: torch.Size([768, 283]), mean=-0.0020, std=1.0020
region_embed.embed.bias: torch.Size([768]), mean=-0.0280, std=0.9830
encoder.blocks.0.norm1.weight: torch.Size([768]), mean=-0.0419, std=1.0177
encoder.blocks.0.norm1.bias: torch.Size([768]), mean=0.0102, std=0.9588
encoder.blocks.0.attn.q_bias: torch.Size([768]), mean=-0.0578, std=1.0110
encoder.blocks.0.attn.v_bias: torch.Size([768]), mean=0.0172, std=1.0471
encoder.blocks.0.attn.qkv.weight: torch.Size([2304, 768]), mean=-0.0006, std=0.9999
encoder.blocks.0.attn.proj.weight: torch.Size([768, 768]), mean=0.0011, std=0.9995
encoder.blocks.0.attn.proj.bias: torch.Size([768]), mean=0.0100, std=0.9812
encoder.blocks.0.norm2.weight: torch.Size([768]), mean=-0.0178, std=0.9605
encoder.blocks.0.norm2.bias: torch.Size([768]), mean=-0.0198, std=0.9767
encoder.blocks.0.mlp.fc1.weight: torch.Size([3072, 768]), mean=-0.0003, std=1.0003
encoder.blocks.0.mlp.fc1.b

In [26]:
zarr_path = cfg.dataset.zarr_path

In [27]:
zarr_path

'/gpfs/home/asun/jin_lab/get/3_aggr_m23_finetune/aggr_multiome_m23.zarr'

In [29]:
import zarr

In [28]:
# try split by comma
if isinstance(zarr_path, str):
    zarr_path = zarr_path.split(",")
self_zarr_path = zarr_path


In [30]:
# Get available celltypes from zarr paths by checking atpm group
self_available_celltypes = []
for zarr_path in self_zarr_path:
    zarr_root = zarr.open(zarr_path, mode='r')
    if 'atpm' in zarr_root:
        self_available_celltypes.extend(list(zarr_root['atpm'].keys()))
self_available_celltypes = list(set(self_available_celltypes))

In [31]:
self_available_celltypes

['L2',
 'L5 IT',
 'Pvalb',
 'L5 ET',
 'L6b',
 'VLMC',
 'L5',
 'Astro',
 'Micro-PVM',
 'OPC',
 'L6 IT',
 'Sst',
 'L6 CT',
 'Oligo',
 'Vip',
 'Endo',
 'Peri',
 'Lamp5',
 'Meis2']

In [32]:
celltypes = cfg.dataset.celltypes

In [34]:
if celltypes is not None:
    requested_celltypes = celltypes.split(",")
    self_celltypes = [ct for ct in requested_celltypes if ct in self_available_celltypes]
    if len(self_celltypes) < len(requested_celltypes):
        missing = set(requested_celltypes) - set(self_celltypes)
        print(f"Some requested celltypes were not found in the zarr files: {missing}")
else:
    self_celltypes = self_available_celltypes

Some requested celltypes were not found in the zarr files: {'L5/6 NP', 'L2/3 IT'}


In [8]:
dm = RegionZarrDataModule(cfg)

In [26]:
len(cfg.task.layer_names)

0

In [11]:
def show_batch(batch, name="batch"):
    print(f"{name} type: {type(batch)}")
    
    if isinstance(batch, dict):
        for k, v in batch.items():
            if hasattr(v, "shape"):
                print(f"  {k}: shape={tuple(v.shape)}, dtype={v.dtype}")
            else:
                print(f"  {k}: {type(v)}")
                
    elif isinstance(batch, (list, tuple)):
        for i, v in enumerate(batch):
            if hasattr(v, "shape"):
                print(f"  [{i}]: shape={tuple(v.shape)}, dtype={v.dtype}")
            else:
                print(f"  [{i}]: {type(v)} {v}")
                
    else:
        print("Unknown batch format:", batch)

In [9]:
dm

<get_model.run_region.RegionZarrDataModule at 0x7f2c55f42d20>

In [22]:
dm.cfg.task.test_mode

'interpret'

In [21]:
dm.setup("predict")
batch = next(iter(dm.predict_dataloader()))
print("Batch keys:", batch.keys())
show_batch(batch)

  0%|          | 0/2 [00:00<?, ?it/s]

Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']


 50%|█████     | 1/2 [00:04<00:04,  4.97s/it]

Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']


100%|██████████| 2/2 [00:10<00:00,  5.09s/it]


Batch keys: dict_keys(['region_motif', 'mask', 'gene_name', 'tss_peak', 'chromosome', 'peak_coord', 'all_tss_peak', 'strand', 'exp_label', 'celltype'])
batch type: <class 'dict'>
  region_motif: shape=(8, 200, 283), dtype=torch.float32
  mask: shape=(8, 200, 2), dtype=torch.int8
  gene_name: <class 'list'>
  tss_peak: shape=(8,), dtype=torch.int64
  chromosome: <class 'list'>
  peak_coord: shape=(8, 200, 2), dtype=torch.int64
  all_tss_peak: shape=(8, 200), dtype=torch.int64
  strand: shape=(8,), dtype=torch.int64
  exp_label: shape=(8, 200, 2), dtype=torch.float32
  celltype: <class 'list'>


In [32]:
layer_names = ['region_embed', 'encoder.blocks.0', 'encoder.blocks.1', 'encoder.blocks.2', 'encoder.blocks.3', 'encoder.blocks.4', 
                        'encoder.blocks.5', 'encoder.blocks.6', 'encoder.blocks.7', 'encoder.blocks.8', 'encoder.blocks.9', 'encoder.blocks.10',
                        'encoder.blocks.11', 'encoder.norm']

for layer_name in layer_names:
    layer = model.model.get_submodule(layer_name)
    print(layer)

RegionEmbed(
  (embed): Linear(in_features=283, out_features=768, bias=True)
)
Block(
  (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (attn): Attention(
    (qkv): Linear(in_features=768, out_features=2304, bias=False)
    (attn_drop): Dropout(p=0, inplace=False)
    (proj): Linear(in_features=768, out_features=768, bias=True)
    (proj_drop): Dropout(p=0, inplace=False)
  )
  (drop_path): Identity()
  (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (mlp): Mlp(
    (fc1): Linear(in_features=768, out_features=3072, bias=True)
    (act): GELU(approximate='none')
    (fc2): Linear(in_features=3072, out_features=768, bias=True)
    (drop1): Dropout(p=0, inplace=False)
    (drop2): Dropout(p=0, inplace=False)
  )
)
Block(
  (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (attn): Attention(
    (qkv): Linear(in_features=768, out_features=2304, bias=False)
    (attn_drop): Dropout(p=0, inplace=False)
    (proj): Linear(in_features=768, out_

In [10]:
dm.setup("fit")
batch = next(iter(dm.train_dataloader()))
print("Batch keys:", batch.keys())
for k, v in batch.items():
    print(k, v.shape, v.dtype)

  0%|          | 0/15 [00:00<?, ?it/s]

Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']


 13%|█▎        | 2/15 [00:00<00:02,  4.54it/s]

Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']
Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']


 27%|██▋       | 4/15 [00:00<00:02,  4.83it/s]

Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']
Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']


 40%|████      | 6/15 [00:01<00:01,  4.96it/s]

Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']
Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']


 53%|█████▎    | 8/15 [00:01<00:01,  5.05it/s]

Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']
Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']


 60%|██████    | 9/15 [00:01<00:01,  5.06it/s]

Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']


 73%|███████▎  | 11/15 [00:02<00:00,  5.10it/s]

Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']
Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']


 87%|████████▋ | 13/15 [00:02<00:00,  5.05it/s]

Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']
Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']


100%|██████████| 15/15 [00:03<00:00,  4.95it/s]


Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']


  0%|          | 0/2 [00:00<?, ?it/s]

Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']


100%|██████████| 2/2 [00:00<00:00,  4.43it/s]


Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chrX']
Batch keys: dict_keys(['region_motif', 'mask', 'atpm', 'chromosome', 'peak_coord', 'exp_label', 'celltype'])
region_motif torch.Size([8, 200, 283]) torch.float32
mask torch.Size([8, 200, 2]) torch.int8
atpm torch.Size([8, 200, 1]) torch.float32


AttributeError: 'list' object has no attribute 'shape'

In [12]:
import numpy as np

# Load a single array from a .npy file
array = np.load('/gpfs/home/asun/jin_lab/get/web/get_figure_code/lentiMPRA/k562_count_10/k562_count_10.tss.npy')

# Check the contents
print(array.shape)
print(array.dtype)

(203553, 2)
bool


In [13]:
print(array[:5])

[[ True False]
 [False False]
 [ True  True]
 [ True False]
 [False False]]


In [14]:
data = np.load("/gpfs/home/asun/jin_lab/get/web/get_figure_code/lentiMPRA/k562_count_10/k562_count_10.watac.npz")
print(data.files)

['indices', 'indptr', 'format', 'shape', 'data']


In [18]:
print(data["data"].shape)

(2197871,)
