In [1]:
import torch
import numpy as np
import os
from pathlib import Path

import argparse
import ast
import copy
import gc
import hashlib
import json
import logging
import os
import pickle
import sys
import time
import warnings
from collections import Counter, OrderedDict
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union

from datasets import Dataset
from torch.utils.data import IterableDataset, DataLoader

from tqdm import tqdm

from methylgpt.model.methyl_datasets import create_dataloader
from methylgpt.model.methyl_model import MethylGPTModel
from methylgpt.model.methyl_vocab import MethylVocab
from methylgpt.model.methyl_loss import masked_mse_loss
# from scgpt.tokenizer import tokenize_and_pad_batch

from methylgpt.utils.plot_embeddings import plot_umap_categorical, plot_umap_numerical
from methylgpt.utils.logging import *
# from methylgpt.common_setup import *

try:
    from flash_attn.flash_attention import FlashMHA
    flash_attn_available = True
except ImportError:
    import warnings
    warnings.warn("flash_attn is not installed")
    flash_attn_available = False


os.environ['CUDA_LAUNCH_BLOCKING']="1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [2]:
PARQUET_DIR="/Users/chan/Projects/MethylGPT/data/processed_dataset/processed_type3_parquet_shuffled"
MODEL_PATH_DIR="/Users/chan/Projects/MethylGPT/pretrained_models/dev_pretraining_test-dataset_CpGs_type3-preprocessing_False-Sep26-10-27"
MODEL_DIR=MODEL_PATH_DIR+"/model_epoch10.pt"
CPG_LIST_DIR="/Users/chan/Projects/MethylGPT/data/probe_ids_type3.csv"

# load from config file
with open(MODEL_PATH_DIR+"/args.json", "r") as f:
    config = json.load(f)
print(config)

# update config dict
config["load_model"] = True
config["model_file"] = MODEL_DIR
config["mask_ratio"] = 0
config["probe_id_dir"] = CPG_LIST_DIR

# number of highly variable CpG sites
n_hvg = config["n_hvg"]

per_seq_batch_sample = False
DSBN = True  # Domain-spec batchnorm
explicit_zero_prob = False  # whether explicit bernoulli for zeros

{'seed': 42, 'input_type': 'CpGs_type3', 'parquet_dir': '/Users/chan/Projects/MethylGPT/data/processed_dataset/processed_type3_parquet_shuffled', 'probe_id_dir': '/Users/chan/Projects/MethylGPT/data/probe_ids_type3.csv', 'qced_data_table': '', 'compiled_data_dir': '', 'valid_ratio': 0.1, 'n_hvg': 49156, 'max_fi': 500000, 'do_train': True, 'pretrained_file': None, 'mask_ratio': 0.3, 'GEPC': True, 'dab_weight': 1.0, 'pretraining_dataset_name': 'CpGs_type3', 'epochs': 100, 'ecs_thres': 0.0, 'lr': 0.001, 'batch_size': 32, 'layer_size': 64, 'nlayers': 6, 'nhead': 4, 'dropout': 0.1, 'schedule_ratio': 0.9, 'save_eval_interval': 10, 'log_interval': 1000, 'fast_transformer': True, 'pre_norm': False, 'amp': True, 'pad_token': '<pad>', 'special_tokens': ['<pad>', '<cls>', '<eoc>'], 'mask_value': -1, 'pad_value': -2, 'explicit_zero_prob': False, 'max_seq_len': 49157, 'per_seq_batch_sample': False}


In [5]:
parquet_dirs = [
    os.path.join(PARQUET_DIR, f) for f in os.listdir(PARQUET_DIR)
]
valid_dataloader = create_dataloader([parquet_dirs[0]], config["batch_size"])

In [10]:
parquet_dirs[:2]

['/Users/chan/Projects/MethylGPT/data/processed_dataset/processed_type3_parquet_shuffled/block_csv_data_5cdc5c7c5f88447f9aa6ec6cd34a502d.parquet',
 '/Users/chan/Projects/MethylGPT/data/processed_dataset/processed_type3_parquet_shuffled/block_csv_data_9bde46d2f9b37db4d7fe47a27b3a6d5c.parquet']

In [11]:
valid_dataloader = create_dataloader(parquet_dirs[:2], config["batch_size"], num_workers=1)

In [6]:
# Load a single parquet file to inspect for testing CustomDataset
import pandas as pd
file_path = "/Users/chan/Projects/MethylGPT/data/processed_dataset/processed_type3_parquet_shuffled/block_csv_data_5cdc5c7c5f88447f9aa6ec6cd34a502d.parquet"
df_chunk = pd.read_parquet(file_path)

In [39]:
print(df_chunk)
print(df_chunk.iloc[0, 1].shape)
print(df_chunk.columns)

              id                                               data
0        N-iG7_2  [0.9750000238418579, 0.953000009059906, 0.9890...
1     TWPID11773  [0.8019999861717224, 0.6110000014305115, 0.400...
2     TWPID11962  [0.7860000133514404, 0.5329999923706055, 0.354...
3     TWPID12092  [0.8410000205039978, 0.5699999928474426, 0.405...
4      TWPID1213  [0.7990000247955322, 0.5329999923706055, 0.241...
...          ...                                                ...
4995  GSM2645847  [0.8058344125747681, 0.6622937917709351, 0.470...
4996  GSM2650817  [0.8754733204841614, 0.7926859855651855, 0.009...
4997  GSM2651336  [0.8631733655929565, 0.8639513850212097, 0.858...
4998  GSM2667452  [0.42404037714004517, 0.14574101567268372, 0.0...
4999  GSM2667479  [0.909545361995697, 0.731552243232727, 0.99380...

[5000 rows x 2 columns]
(49156,)
Index(['id', 'data'], dtype='object')


In [6]:
# Create dataloaders for all parquet directories and count samples
total_samples = 0
for parquet_dir in parquet_dirs:
    dataloader = create_dataloader([parquet_dir], config["batch_size"])
    dir_samples = 0
    for batch in dataloader:
        batch_size = len(batch['id'])
        dir_samples += batch_size
    total_samples += dir_samples
    print(f"Samples in {parquet_dir}: {dir_samples}")

print(f"Total samples across all parquet directories: {total_samples}")


Samples in /Users/chan/Projects/MethylGPT/data/processed_dataset/processed_type3_parquet_shuffled/block_csv_data_5cdc5c7c5f88447f9aa6ec6cd34a502d.parquet: 5000
Samples in /Users/chan/Projects/MethylGPT/data/processed_dataset/processed_type3_parquet_shuffled/block_csv_data_9bde46d2f9b37db4d7fe47a27b3a6d5c.parquet: 5000
Samples in /Users/chan/Projects/MethylGPT/data/processed_dataset/processed_type3_parquet_shuffled/block_csv_data_0e3be7e1d9866f1ad39cb2621beeb7d8.parquet: 5000
Samples in /Users/chan/Projects/MethylGPT/data/processed_dataset/processed_type3_parquet_shuffled/block_csv_data_1e5c40614c99d1fffafc8a8e184dd0a5.parquet: 5000
Samples in /Users/chan/Projects/MethylGPT/data/processed_dataset/processed_type3_parquet_shuffled/block_csv_data_35d97d06e939d48677d7a01c1363c5f3.parquet: 5000
Samples in /Users/chan/Projects/MethylGPT/data/processed_dataset/processed_type3_parquet_shuffled/block_csv_data_8e83eae85e66eaa60ce37a2763b7db3d.parquet: 5000
Samples in /Users/chan/Projects/MethylGP

In [12]:
batch = next(iter(valid_dataloader))
print(batch)
for key, value in batch.items():
    print(f"{key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else len(value)}")

{'id': ['N-iG7_2', 'TWPID11773', 'TWPID11962', 'TWPID12092', 'TWPID1213', 'TWPID15259', 'TWPID15670', 'TWPID16808', 'TWPID2116', 'TWPID2178', 'TWPID4017', 'TWPID4413', 'TWPID5259', 'TWPID6029', 'TWPID10935', 'TWPID6865', 'TWPID6993', 'TWPID7305', 'TWPID8455', 'X3999431053_R05C01', 'X101231030036_R06C01', 'X9647450010_R02C02', 'X9647450011_R03C01', 'X9647450011_R05C01', 'X9647455151_R03C01', 'X9647450156_R03C02', 'X9647450168_R04C02', 'ENCSR312XVJ', 'ENCSR371REA', 'ENCSR408IKV', 'GSM2675612', 'GSM2675624'], 'data': tensor([[0.9750, 0.9530, 0.9890,  ..., 0.0130, 0.0550, 0.0190],
        [0.8020, 0.6110, 0.4000,  ..., 0.1890, 0.2840, 0.0480],
        [0.7860, 0.5330, 0.3540,  ..., 0.1320, 0.2430, 0.1040],
        ...,
        [0.9060, 0.8710, 0.6440,  ..., 0.0230, 0.0220, 0.0090],
        [0.9230, 0.8730, 0.8820,  ..., 0.0500, 0.0500, 0.0300],
        [0.9230, 0.8940, 0.8530,  ..., 0.0690, 0.0430, 0.0500]],
       dtype=torch.float64)}
id: <class 'list'>, shape: 32
data: <class 'torch.Ten

In [11]:
iterator = iter(valid_dataloader)
first_batch = next(iterator)
second_batch = next(iterator)
print(second_batch)
for key, value in second_batch.items():
    print(f"{key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else len(value)}")

{'id': ['GSM2675629', 'GSM2675643', 'GSM2675762', 'GSM2675782', 'GSM2675802', 'GSM2675866', 'GSM2675924', 'GSM2675961', 'GSM2675967', 'GSM2675991', 'GSM2676014', 'GSM2676045', 'GSM2676074', 'GSM2676391', 'GSM2676512', 'GSM2676545', 'GSM2676604', 'GSM2676607', 'GSM2676618', 'GSM2676656', 'GSM2676717', 'GSM2676719', 'GSM2724244', 'GSM2724372', 'GSM2724380', 'GSM2724313', 'GSM2738260', 'GSM2738185', 'GSM2746848', 'GSM2746852', 'GSM2752620', 'GSM2752648'], 'data': tensor([[0.9040, 0.8550, 0.8560,  ..., 0.0430, 0.0240, 0.0380],
        [0.8660, 0.8900, 0.8760,  ..., 0.0320, 0.0170, 0.0300],
        [0.9230, 0.9230, 0.6440,  ..., 0.0300, 0.0320, 0.0620],
        ...,
        [0.8480, 0.9100, 0.9610,  ..., 0.0760, 0.0800, 0.2090],
        [0.9420, 0.7480, 0.4000,  ..., 0.0620, 0.0640, 0.0100],
        [0.8920, 0.7040, 0.4280,  ..., 0.0440, 0.0750, 0.0190]],
       dtype=torch.float64)}
id: <class 'list'>, shape: 32
data: <class 'torch.Tensor'>, shape: torch.Size([32, 49156])


In [9]:
# attempt to autodetect device
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print(f"using device: {device}")

methyl_vocab = MethylVocab(
    config["probe_id_dir"], config["pad_token"], config["special_tokens"], save_dir=None
)
# Usage Example
index = methyl_vocab["cg00735876"]           # probe -> index
probe = methyl_vocab.lookup_tokens([index])  # index -> probe

model = MethylGPTModel(config, methyl_vocab)

try:
    model.load_state_dict(torch.load(MODEL_DIR, map_location="cpu"))
    print(f"Loading all model params from {MODEL_DIR}\n")
except Exception as e:
    print(f"Failed to load full model: {e}\n")

    # only load params that are in the model and match the size
    model_dict = model.state_dict()
    pretrained_dict = torch.load(MODEL_DIR, map_location="cpu")
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if k in model_dict and v.shape == model_dict[k].shape
    }
    print("Loading params:")
    for k, v in pretrained_dict.items():
        print(k, v.shape)
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
print("\n")

model.eval()  # Switch to evaluation mode (turns off dropout, etc.)
model.to(device)
# model.half()

for name, param in model.named_parameters():
    print(name, param.shape, param.device)

using device: mps




Failed to load full model: Error(s) in loading state_dict for MethylGPTModel:
	Missing key(s) in state_dict: "transformer_encoder.layers.0.self_attn.in_proj_weight", "transformer_encoder.layers.0.self_attn.in_proj_bias", "transformer_encoder.layers.1.self_attn.in_proj_weight", "transformer_encoder.layers.1.self_attn.in_proj_bias", "transformer_encoder.layers.2.self_attn.in_proj_weight", "transformer_encoder.layers.2.self_attn.in_proj_bias", "transformer_encoder.layers.3.self_attn.in_proj_weight", "transformer_encoder.layers.3.self_attn.in_proj_bias", "transformer_encoder.layers.4.self_attn.in_proj_weight", "transformer_encoder.layers.4.self_attn.in_proj_bias", "transformer_encoder.layers.5.self_attn.in_proj_weight", "transformer_encoder.layers.5.self_attn.in_proj_bias". 
	Unexpected key(s) in state_dict: "transformer_encoder.layers.0.self_attn.Wqkv.weight", "transformer_encoder.layers.0.self_attn.Wqkv.bias", "transformer_encoder.layers.1.self_attn.Wqkv.weight", "transformer_encoder.lay

In [12]:
print(config["probe_id_dir"])
print(config["pad_token"])
print(config["special_tokens"])
print(len(methyl_vocab))

/Users/chan/Projects/MethylGPT/data/probe_ids_type3.csv
<pad>
['<pad>', '<cls>', '<eoc>']
49159


In [11]:
print(methyl_vocab.CpG_list, len(methyl_vocab.CpG_list))
print(methyl_vocab.CpG_ids)

['cg00000109', 'cg00000292', 'cg00002033', 'cg00002426', 'cg00002719', 'cg00003298', 'cg00003994', 'cg00004105', 'cg00005619', 'cg00006081', 'cg00006414', 'cg00006815', 'cg00007076', 'cg00007981', 'cg00008033', 'cg00008493', 'cg00008629', 'cg00008671', 'cg00008713', 'cg00008800', 'cg00008945', 'cg00009088', 'cg00009196', 'cg00009407', 'cg00010078', 'cg00010445', 'cg00010672', 'cg00011200', 'cg00011459', 'cg00011891', 'cg00012199', 'cg00012386', 'cg00013618', 'cg00014085', 'cg00014837', 'cg00015770', 'cg00016255', 'cg00016522', 'cg00016783', 'cg00016968', 'cg00017489', 'cg00017842', 'cg00017931', 'cg00018198', 'cg00018261', 'cg00019275', 'cg00020052', 'cg00020474', 'cg00021275', 'cg00021527', 'cg00022866', 'cg00024396', 'cg00024404', 'cg00024471', 'cg00024516', 'cg00024812', 'cg00025138', 'cg00025991', 'cg00026033', 'cg00027083', 'cg00027674', 'cg00028013', 'cg00028935', 'cg00029246', 'cg00029931', 'cg00030047', 'cg00030296', 'cg00030432', 'cg00031162', 'cg00031235', 'cg00031346', 'cg00

In [26]:
def generate_cell_embeddings(
        model, data_loader, device, vocab, max_seq_len, config, mask_value, pad_value, pad_token
):
    """
    Generate cell embeddings using the provided model and data loader.

    Args:
        model (torch.nn.Module): The model to generate embeddings.
        data_loader (torch.utils.data.DataLoader): DataLoader containing the dataset.
        device (torch.device): The device to run the model on.
        vocab (dict): Vocabulary dictionary.
        max_seq_len (int): Maximum sequence length.
        config (object): Configuration object containing mask_ratio.
        mask_value (float): Value used for masking.
        pad_value (float): Value used for padding.
        pad_token (str): Token used for padding.

    Returns:
        np.ndarray: Array of cell embeddings.

    Raises:
        RuntimeError: If there's an error during embedding generation.
    """
    logger = logging.getLogger(__name__)
    cell_embs = []
    cell_ids = []
    
    logger.info("Generating embedding...")
    with torch.no_grad():
        for i, batch in enumerate(tqdm(data_loader, desc="Processing batches")):
            if i==100:
                break  # 3200 (5000 in total) samples for visualization
            batch_data = model.prepare_data(batch)

            # (32, 49157)
            input_gene_ids = batch_data["gene_ids"].to(device)  # [[1, 3, 4, ..., 49156, 49157, 49158], ...]
            input_values = batch_data["values"].to(device)
            # input_values = batch_data["values"].to(device).half()
            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token]).to(device)

            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                MVC=config["GEPC"],  # Gene Expression Prediction from CpG
                ECS=config["ecs_thres"] > 0,
            )
            output_values = output_dict["cell_emb"].cpu().numpy()   # (32, 64)
            cell_embs.append(output_values)
            cell_ids.append(batch["id"])
            logger.debug(f"Batch embedding shape: {output_values.shape}")

    cell_emb = np.concatenate(cell_embs, axis=0)  # (3200, 64)
    cell_list = np.concatenate(cell_ids, axis=0)  # (3200,)
    logger.info(f"Validset embedding shape: {cell_emb.shape}")
    return cell_emb, cell_list

    #except Exception as e:
    #    logger.error(f"Error generating cell embeddings: {str(e)}")
    #    raise RuntimeError("Failed to generate cell embeddings") from e


valid_cell_emb, valid_cell_list = generate_cell_embeddings(
    model, 
    valid_dataloader, 
    device,
    methyl_vocab,
    config["max_seq_len"],
    config,
    config["mask_value"],
    config["pad_value"],
    config["pad_token"],
)

SAVE_DIR = Path('Embeddings')
SAVE_DIR.mkdir(parents=True, exist_ok=True)
print(f"save to {SAVE_DIR}")
valid_emb_path = SAVE_DIR / "cell_emb_.pt"
with open(valid_emb_path, "wb") as file:
    pickle.dump({"cell_emb": valid_cell_emb, "cell_list": valid_cell_list}, file)


Processing batches: 0it [22:48, ?it/s]


KeyboardInterrupt: 