In [1]:
import torch
import os
from pathlib import Path


import argparse
import ast
import copy
import gc
import hashlib
import json
import logging
import os
import pickle
import sys
import time
import warnings
from collections import Counter, OrderedDict
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union

from datasets import Dataset
from torch.utils.data import IterableDataset, DataLoader


from tqdm import tqdm

from methylgpt.model.methyl_datasets import create_dataloader
from methylgpt.model.methyl_model import MethylGPTModel
from methylgpt.model.methyl_vocab import MethylVocab
from methylgpt.model.methyl_loss import masked_mse_loss
from scgpt.tokenizer import tokenize_and_pad_batch

from methylgpt.utils.plot_embeddings import plot_umap_categorical, plot_umap_numerical
from methylgpt.utils.logging import *
from methylgpt.common_setup import *

try:
    from flash_attn.flash_attention import FlashMHA

    flash_attn_available = True
except ImportError:
    import warnings

    warnings.warn("flash_attn is not installed")
    flash_attn_available = False


os.environ['CUDA_LAUNCH_BLOCKING']="1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"



  from .autonotebook import tqdm as notebook_tqdm


  warn(


  self.hub = sentry_sdk.Hub(client)


In [2]:
SAVE_DIR = Path('Embeddings')
SAVE_DIR.mkdir(parents=True, exist_ok=True)
print(f"save to {SAVE_DIR}")

PARQUET_DIR="/home/A.Y/project/MethylGPT_clean/data/pretraining/processed_type3_parquet_shuffled"

MODEL_PATH_DIR="/home/A.Y/project/MethylGPT_clean/pretrained_models/dev_pretraining_test-dataset_CpGs_type3-preprocessing_False-Sep26-10-27"
MODEL_DIR=MODEL_PATH_DIR+"/model_epoch10.pt"
CPG_LIST_DIR="/home/A.Y/project/MethylGPT_clean/data/pretraining/probe_ids_type3.csv"



# load from config file
with open(Path(MODEL_PATH_DIR+"/args.json"), "r") as f:
    config = json.load(f)

print(config)

# update config dict
config["load_model"] = True
config["batch_size"] = 32
config["model_file"] = MODEL_DIR
config["mask_ratio"] = 0
config["probe_id_dir"] = CPG_LIST_DIR



pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]

mask_ratio = config["mask_ratio"]
mask_value = -1
pad_value = -2

# number of highly variable CpG sites
n_hvg = config["n_hvg"]  
max_seq_len = n_hvg + 1

per_seq_batch_sample = False
DSBN = True  # Domain-spec batchnorm
explicit_zero_prob = False  # whether explicit bernoulli for zeros





save to Embeddings
{'seed': 42, 'input_type': 'CpGs_type3', 'parquet_dir': '../data/pretraining/processed_type3_parquet_shuffled', 'probe_id_dir': '../data/pretraining/probe_ids_type3.csv', 'qced_data_table': '../data/pretraining/QCed_samples_type3.csv', 'compiled_data_dir': '/home/A.Y/project/MethylGPT_clean/data/pretraining/compiled_metadata.csv', 'valid_ratio': 0.1, 'n_hvg': 49156, 'max_fi': 500000, 'do_train': True, 'pretrained_file': None, 'mask_ratio': 0.3, 'GEPC': True, 'dab_weight': 1.0, 'pretraining_dataset_name': 'CpGs_type3', 'epochs': 100, 'ecs_thres': 0.0, 'lr': 0.001, 'batch_size': 32, 'layer_size': 64, 'nlayers': 6, 'nhead': 4, 'dropout': 0.1, 'schedule_ratio': 0.9, 'save_eval_interval': 10, 'log_interval': 1000, 'fast_transformer': True, 'pre_norm': False, 'amp': True, 'pad_token': '<pad>', 'special_tokens': ['<pad>', '<cls>', '<eoc>'], 'mask_value': -1, 'pad_value': -2, 'explicit_zero_prob': False, 'max_seq_len': 49157, 'per_seq_batch_sample': False}


In [3]:
parquet_dirs = [
    os.path.join(PARQUET_DIR, f) for f in os.listdir(PARQUET_DIR)
]

valid_dataloader = create_dataloader([parquet_dirs[0]], config["batch_size"])



In [4]:
methyl_vocab = MethylVocab(config["probe_id_dir"], config["pad_token"], config["special_tokens"], save_dir=None)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MethylGPTModel(config, methyl_vocab)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    model.load_state_dict(torch.load(MODEL_DIR, map_location="cpu"))
    print(f"Loading all model params from {MODEL_DIR}")
except:
    # only load params that are in the model and match the size
    model_dict = model.state_dict()
    pretrained_dict = torch.load(MODEL_DIR, map_location="cpu")
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if k in model_dict and v.shape == model_dict[k].shape
    }
    for k, v in pretrained_dict.items():
        print(f"Loading params {k} with shape {v.shape}")
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

model.eval()  # Switch to evaluation mode (turns off dropout, etc.)
model.to(device)
model.half()

for name, param in model.named_parameters():
    print(name, param.dtype)
    

Loading all model params from /home/A.Y/project/MethylGPT_clean/pretrained_models/dev_pretraining_test-dataset_CpGs_type3-preprocessing_False-Sep26-10-27/model_epoch10.pt


encoder.embedding.weight torch.float16
encoder.enc_norm.weight torch.float16
encoder.enc_norm.bias torch.float16
value_encoder.linear1.weight torch.float16
value_encoder.linear1.bias torch.float16
value_encoder.linear2.weight torch.float16
value_encoder.linear2.bias torch.float16
value_encoder.norm.weight torch.float16
value_encoder.norm.bias torch.float16
transformer_encoder.layers.0.self_attn.Wqkv.weight torch.float16
transformer_encoder.layers.0.self_attn.Wqkv.bias torch.float16
transformer_encoder.layers.0.self_attn.out_proj.weight torch.float16
transformer_encoder.layers.0.self_attn.out_proj.bias torch.float16
transformer_encoder.layers.0.linear1.weight torch.float16
transformer_encoder.layers.0.linear1.bias torch.float16
transformer_encoder.layers.0.linear2.weight torch.float16
transformer_encoder.layers.0.linear2.bias torch.float16
transformer_encoder.layers.0.norm1.weight torch.float16
transformer_encoder.layers.0.norm1.bias torch.float16
transformer_encoder.layers.0.norm2.weig

In [5]:

def generate_cell_embeddings(model, data_loader, device, vocab, max_seq_len, config, mask_value, pad_value, pad_token):
    """
    Generate cell embeddings using the provided model and data loader.

    Args:
        model (torch.nn.Module): The model to generate embeddings.
        data_loader (torch.utils.data.DataLoader): DataLoader containing the dataset.
        device (torch.device): The device to run the model on.
        vocab (dict): Vocabulary dictionary.
        max_seq_len (int): Maximum sequence length.
        config (object): Configuration object containing mask_ratio.
        mask_value (float): Value used for masking.
        pad_value (float): Value used for padding.
        pad_token (str): Token used for padding.

    Returns:
        np.ndarray: Array of cell embeddings.

    Raises:
        RuntimeError: If there's an error during embedding generation.
    """
    logger = logging.getLogger(__name__)
    cell_embs = []
    cell_ids = []
    
    logger.info("Generating embedding...")
    
    
    with torch.no_grad():
        for i, batch in enumerate(tqdm(data_loader, desc="Processing batches")):
            # Prepare data

            if i==100:
                break
            batch_data = model.prepare_data(batch)
            
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device).half()
            
            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token]).to(device)
            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                MVC=config["GEPC"],
                ECS=config["ecs_thres"] > 0,
            )
            output_values = output_dict["cell_emb"].cpu().numpy()
            cell_embs.append(output_values)
            cell_ids.append(batch["id"])
            
            logger.debug(f"Batch embedding shape: {output_values.shape}")

            
    
    cell_emb = np.concatenate(cell_embs, axis=0)
    cell_list = np.concatenate(cell_ids, axis=0)
    logger.info(f"Validset embedding shape: {cell_emb.shape}")
    return cell_emb, cell_list

    #except Exception as e:
    #    logger.error(f"Error generating cell embeddings: {str(e)}")
    #    raise RuntimeError("Failed to generate cell embeddings") from e


valid_cell_emb, valid_cell_list = generate_cell_embeddings(
    model, 
    valid_dataloader, 
    device,
    methyl_vocab,
    max_seq_len,
    config,
    mask_value,
    pad_value,
    pad_token
)
valid_emb_path = SAVE_DIR / "cell_emb.pt"
with open(valid_emb_path, "wb") as file:
    pickle.dump({"cell_emb": valid_cell_emb, "cell_list": valid_cell_list}, file)



Processing batches: 0it [00:00, ?it/s]

Too many dataloader workers: 24 (max is dataset.n_shards=1). Stopping 23 dataloader workers.


Processing batches: 1it [00:14, 14.02s/it]

Processing batches: 2it [00:20,  9.36s/it]

Processing batches: 3it [00:26,  7.87s/it]

Processing batches: 4it [00:32,  7.16s/it]

Processing batches: 5it [00:38,  6.77s/it]

Processing batches: 6it [00:43,  6.27s/it]

Processing batches: 7it [00:49,  6.22s/it]

Processing batches: 8it [00:55,  6.19s/it]

Processing batches: 9it [01:02,  6.17s/it]

Processing batches: 10it [01:08,  6.16s/it]

Processing batches: 11it [01:14,  6.15s/it]

Processing batches: 12it [01:20,  6.14s/it]

Processing batches: 13it [01:26,  6.14s/it]

Processing batches: 14it [01:32,  6.14s/it]

Processing batches: 15it [01:38,  6.14s/it]