In [90]:
import numpy as np, pandas as pd

from datasets import  load_from_disk
import anndata 

import scanpy as sc
import os

import pickle

from geneformer import TranscriptomeTokenizer

In [83]:
GeneFormer_PATH = '/home/syyang/Projects/geneformer/Geneformer/geneformer'
GeneFormer_Results_PATH = '/home/syyang/GitRepo/_former/geneformer_results'

In [44]:
## Load gene name and gene id used in geneformer 
vocabulary_file = os.path.join(GeneFormer_PATH, 'token_dictionary.pkl')
with open(vocabulary_file, 'rb') as f:
    geneformer_vocabulary = pickle.load(f)

geneformer_vocabulary = pd.DataFrame(geneformer_vocabulary.items(), columns=['token', 'index'])


gene_name_file = os.path.join(GeneFormer_PATH, 'gene_name_id_dict.pkl')
with open(gene_name_file, 'rb') as f:
    gene_name_id_dict = pickle.load(f)

gene_name_id_dict = pd.DataFrame(gene_name_id_dict.items(), columns=['gene_name', 'gene_id'])

In [45]:
geneformer_vocabulary.head() , geneformer_vocabulary.shape

(             token  index
 0            <pad>      0
 1           <mask>      1
 2  ENSG00000000003      2
 3  ENSG00000000005      3
 4  ENSG00000000419      4,
 (25426, 2))

In [46]:
gene_name_id_dict.head(), gene_name_id_dict.shape

(  gene_name          gene_id
 0     MT-TF  ENSG00000210049
 1   MT-RNR1  ENSG00000211459
 2     MT-TV  ENSG00000210077
 3   MT-RNR2  ENSG00000210082
 4    MT-TL1  ENSG00000209082,
 (40248, 2))

## Load adipose data

ref: https://huggingface.co/ctheodoris/Geneformer/blob/main/examples/tokenizing_scRNAseq_data.ipynb

In [81]:
# UCE_filtered_DIR = '/home/syyang/GitRepo/cs294/finalproject/data/UCEgenes_anndata'

# adipose_sn_file = 'adata_with_infer_multiresleiden_toR_Xastypeinteger_proc.h5ad'
# adipose_sn_ad = sc.read_h5ad(os.path.join(UCE_filtered_DIR,  adipose_sn_file) )

# # only keep the genes that are in geneformer vocabulary
# adipose_sn_ad2 = adipose_sn_ad[:, adipose_sn_ad.var.index.isin(gene_name_id_dict.gene_name)]

# adipose_sn_ad2.var = adipose_sn_ad2.var.reset_index().merge(gene_name_id_dict, left_on='index', right_on='gene_name', how='left').set_index('index')
# adipose_sn_ad2.var['ensembl_id'] = adipose_sn_ad2.var['gene_id'].copy()  # prepare gene column 'ensembl_id' for geneformer
# adipose_sn_ad2.obs['n_counts'] = adipose_sn_ad2.X.sum(axis=1)  # prepare cell column 'n_counts' for geneformer

# adipose_sn_ad2.write_h5ad('/home/syyang/GitRepo/_former/data/adata_with_infer_multiresleiden_toR2.h5ad')



### `ensembl_id` in adata.var 
### `n_counts` in adata.obs 
### are the only required columns for geneformer


adipose_data_PATH = '/home/syyang/GitRepo/_former/data/'
adipose_data_file = 'adata_with_infer_multiresleiden_toR2.h5ad'

adipose_sn_ad2 = sc.read_h5ad(os.path.join(adipose_data_PATH,  adipose_data_file) )



In [82]:
adipose_sn_ad2

AnnData object with n_obs × n_vars = 71200 × 15008
    obs: 'sample', '_scvi_batch', '_scvi_labels', 'leiden_scVI', '_scvi_raw_norm_scaling', 'leiden_scVI_res0.6', 'leiden_scVI_res0.7', 'leiden_scVI_res0.8', 'leiden_scVI_res0.9', 'seurat_clusters', 'n_genes', 'n_counts'
    var: 'n_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches', 'n_cells', 'gene_name', 'gene_id', 'ensembl_id'
    uns: '_scvi_manager_uuid', '_scvi_uuid', 'hvg', 'leiden', 'log1p', 'neighbors', 'seurat_clusters_colors', 'umap'
    obsm: 'X_scVI', 'X_umap'
    layers: 'counts', 'log_10k_norm'
    obsp: 'connectivities', 'distances'

### Use geneformer to Tokenize cells [Data Preparation]

In [87]:
tk = TranscriptomeTokenizer({"seurat_clusters": "cell_type"}, nproc=16)

In [89]:
tk.tokenize_data(data_directory = adipose_data_PATH,
                 output_directory = os.path.join(GeneFormer_Results_PATH,  'adipose'),
                 output_prefix = 'output_prefix', 
                 file_format="h5ad")  # result into a 'output_prefix.dataset' file 

Tokenizing /home/syyang/GitRepo/_former/data/adata_with_infer_multiresleiden_toR2.h5ad


  for i in adata.var["ensembl_id"][coding_miRNA_loc]
  coding_miRNA_ids = adata.var["ensembl_id"][coding_miRNA_loc]


/home/syyang/GitRepo/_former/data/adata_with_infer_multiresleiden_toR2.h5ad has no column attribute 'filter_pass'; tokenizing all cells.
Creating dataset.


Map (num_proc=16): 100%|██████████| 71200/71200 [00:40<00:00, 1745.45 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 71200/71200 [00:00<00:00, 342946.57 examples/s]


In [92]:
data = load_from_disk('/home/syyang/GitRepo/_former/geneformer_results/adipose/output_prefix.dataset')

In [95]:
data.features

{'input_ids': Sequence(feature=Value(dtype='int16', id=None), length=-1, id=None),
 'cell_type': Value(dtype='int64', id=None),
 'length': Value(dtype='int64', id=None)}

In [110]:
print(data.data)
# data.data['input_ids'] is emsembl_id ordered by their expression rank in each cell. The rank was calculated by normalizing against the median value stored in Geneformer's  `gene_median_dictionary.pkl` file 
#                  i.e.: 1). total length os 'input_ids' is #cells 
#                        2). Each element in 'input_ids' is a list of gene ids, ordered by their expression rank in the cell.
#                        3). Max length of one element in 'input_ids' is the 2048 (max_seq_length)

len(data.data['input_ids'][1]), data.data['input_ids'][0], data.data['input_ids'][1]

MemoryMappedTable
input_ids: list<item: int16>
  child 0, item: int16
cell_type: int64
length: int64
----
input_ids: [[[8617,11077,6773,703,18001,...,10536,7790,2414,1988,9655],[16111,5234,703,6991,3430,...,16227,2514,3411,6124,6319],...,[1351,8583,12653,9468,6872,...,5588,2849,750,19421,2727],[9468,4969,7476,17098,6872,...,7490,6484,1598,2336,494]],[[8457,17098,7476,6991,4285,...,9993,997,3593,10598,13915],[8579,12848,9118,13170,17817,...,1415,14412,8172,4212,12769],...,[7455,10062,14751,15313,13210,...,3721,12179,24839,9251,17929],[8579,17817,16632,4099,3953,...,6401,18049,9750,16610,8841]],...,[[11699,4279,2858,14893,2906,...,16585,1480,7073,114,12417],[9009,9468,16220,1351,5396,...,8778,8785,4015,11010,11431],...,[15279,7296,1630,17846,7410,...,538,17749,12728,4011,8917],[7296,15279,11646,596,4278,...,5551,3991,16659,5016,8917]],[[15279,10352,7296,24537,6621,...,17528,5553,3991,11354,6895],[4457,13210,885,1307,16941,...,8783,347,9792,5531,2072],...,[6814,11025,10630,7296,11162,...,

(2048,
 <pyarrow.ListScalar: [8617, 11077, 6773, 703, 18001, 16111, 13418, 3779, 1584, 1057, 7658, 448, 6509, 12508, 17717, 4406, 13019, 6991, 10265, 13350, 555, 7197, 618, 11414, 14739, 7086, 9881, 5126, 13284, 2092, 9468, 2893, 3325, 14703, 6073, 14454, 14669, 772, 12632, 6899, 18508, 2028, 2746, 2867, 4291, 12516, 40, 5053, 7863, 9600, 3216, 6872, 1188, 236, 4588, 15935, 4395, 672, 7967, 2859, 10460, 848, 8601, 1247, 3348, 1600, 11731, 15412, 1266, 7139, 7241, 879, 12381, 1207, 1075, 16853, 6739, 8956, 13447, 16825, 10115, 11691, 12792, 10010, 1964, 9337, 13468, 15376, 10337, 7378, 14479, 12446, 14072, 7180, 1483, 8689, 1358, 13210, 8424, 10089, 384, 5297, 3321, 75, 9015, 10886, 17114, 6981, 9365, 559, 8847, 8173, 7740, 12774, 6668, 403, 13226, 6701, 4279, 8526, 3651, 7555, 10097, 7475, 9916, 12431, 20549, 5411, 7088, 1325, 6949, 2549, 7887, 4138, 11434, 489, 17072, 785, 1958, 6179, 6400, 11767, 7477, 3439, 24133, 865, 3411, 3819, 3459, 640, 3968, 826, 6043, 6071, 9743, 1694, 6971, 

In [99]:
data.num_rows

71200