In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("/home/yufan")  # parent of `perturbgene` directory

import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm  # https://discuss.pytorch.org/t/error-while-multiprocessing-in-dataloader/46845/9
from braceexpand import braceexpand
from perturbgene.data_utils import GeneTokenizer, read_h5ad_file
from perturbgene.data_utils.data_collators import collate_fn_wrapper
from perturbgene.model import  GeneBertModel
from perturbgene.inference_utils import get_inference_config, get_gene_embedding

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# model initialization
device = "cuda:0"
model_checkpt_path = "/storage_bizon/bizon_filesystem/output_2gpu_mlm_015gene_050pheno_vastai_10000_mlm_bins10/checkpoint-305000"
model = GeneBertModel.from_pretrained(model_checkpt_path,output_hidden_states=True)
model.eval()
model.to(device)

# tokenizer loading and wrapping
# Load tokenizer first, so that we can get the config
expected_tokenizer_path = os.path.join(
    os.path.dirname(model_checkpt_path),
    "tokenizer.pkl",
)

if os.path.isfile(expected_tokenizer_path):
    with open(expected_tokenizer_path, "rb") as f:
        tokenizer = pickle.load(f)
else:
    print("Saved tokenizer not found, creating tokenizer with common parameters")
    tokenizer = GeneTokenizer(get_inference_config(  # change these parameters
        bin_edges=[0.1], 
        pretrained_model_path=model_checkpt_path,  # needs to be a path that exists
        max_length=130,        
        num_top_genes=128
    ))

tokenizer.config.vocab_path = os.path.join("/home/yufan/", tokenizer.config.vocab_path)  # rel path -> abs path
config = tokenizer.config

data_collator = collate_fn_wrapper(tokenizer)

In [3]:
validation_data = read_h5ad_file("Tabula_Sapiens_ranked_47.h5ad", config.num_top_genes)

In [4]:
cell = validation_data[1000]
emb = get_gene_embedding(cell, model, tokenizer, data_collator,"MT-TP")
emb.shape

torch.Size([1, 1024])

In [5]:
cell.var["feature_name"]

ensemblid
ENSG00000223972        DDX11L1
ENSG00000227232         WASH7P
ENSG00000278267      MIR6859-1
ENSG00000243485    MIR1302-2HG
ENSG00000284332      MIR1302-2
                      ...     
ENSG00000198695         MT-ND6
ENSG00000210194          MT-TE
ENSG00000198727         MT-CYB
ENSG00000210195          MT-TT
ENSG00000210196          MT-TP
Name: feature_name, Length: 58604, dtype: category
Categories (58604, object): ['5S_rRNA_ENSG00000276861', '5S_rRNA_ENSG00000277411', '5S_rRNA_ENSG00000277488', '5S_rRNA_ENSG00000285609', ..., 'hsa-mir-1253', 'hsa-mir-423', 'snoZ196_ENSG00000281780', 'yR211F11.2']