In [4]:
import os, sys
import scanpy as sc
import importlib
import numpy as np

sys.path.append('../..')
sys.path.append('..')

import warnings
#warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', category=UserWarning)



# Vocab

scGPT vocab

In [5]:
from scgenehyena.scgpt_gene_tokenizer import GeneVocab

In [7]:
h5ad_files = [f"../data/{f}" for f in os.listdir("../data/") if f.endswith("_binned.h5ad") and f.startswith("SRX")]
print(f"Number of h5ad files: {len(h5ad_files)}")
print(h5ad_files)

# use protein codeing genes only
adata = sc.read(h5ad_files[0])
adata


Number of h5ad files: 2
['../data/SRX9777399_binned.h5ad', '../data/SRX9856815_binned.h5ad']


AnnData object with n_obs Ã— n_vars = 2852 Ã— 36601
    obs: 'gene_count', 'umi_count', 'SRX_accession', 'cell_type'
    var: 'gene_ids', 'feature_types'
    uns: 'log1p'
    obsm: 'bin_edges'
    layers: 'X_binned', 'X_log1p', 'X_normed', 'ambiguous', 'spliced', 'spliced_binned', 'spliced_log1p', 'spliced_normed', 'unspliced', 'unspliced_binned', 'unspliced_log1p', 'unspliced_normed'

In [8]:
genes = adata.var_names

vocab = GeneVocab(
    gene_list_or_vocab=genes.tolist(),
    specials=['<pad>', '<cls>', '<eoc>'],
    special_first=True,
    default_token='<pad>',

)

In [9]:
print(f"If '<pad>' is in vocab: {'<pad>' in vocab}")
print(vocab.get_itos()[36602])

print(vocab['<pad>'])
print(vocab(['OR4F5', '<pad>']))
gene2id = vocab.get_stoi()
gene2id['ZYG11B']

If '<pad>' is in vocab: True
hsa-mir-1253
0
[27476, 0]


36599

## Cell Index

In [10]:
from scgenehyena.utils import build_cell_index
cell_type_map, cell_index = build_cell_index(h5ad_files, 'cell_type')

print(len(cell_type_map.keys()))
(cell_type_map[9995])
print(cell_index)
cell_index[0]


9996
[('../data/SRX9777399_binned.h5ad', 0), ('../data/SRX9777399_binned.h5ad', 1), ('../data/SRX9777399_binned.h5ad', 2), ('../data/SRX9777399_binned.h5ad', 3), ('../data/SRX9777399_binned.h5ad', 4), ('../data/SRX9777399_binned.h5ad', 5), ('../data/SRX9777399_binned.h5ad', 6), ('../data/SRX9777399_binned.h5ad', 7), ('../data/SRX9777399_binned.h5ad', 8), ('../data/SRX9777399_binned.h5ad', 9), ('../data/SRX9777399_binned.h5ad', 10), ('../data/SRX9777399_binned.h5ad', 11), ('../data/SRX9777399_binned.h5ad', 12), ('../data/SRX9777399_binned.h5ad', 13), ('../data/SRX9777399_binned.h5ad', 14), ('../data/SRX9777399_binned.h5ad', 15), ('../data/SRX9777399_binned.h5ad', 16), ('../data/SRX9777399_binned.h5ad', 17), ('../data/SRX9777399_binned.h5ad', 18), ('../data/SRX9777399_binned.h5ad', 19), ('../data/SRX9777399_binned.h5ad', 20), ('../data/SRX9777399_binned.h5ad', 21), ('../data/SRX9777399_binned.h5ad', 22), ('../data/SRX9777399_binned.h5ad', 23), ('../data/SRX9777399_binned.h5ad', 24), ('..

('../data/SRX9777399_binned.h5ad', 0)

## Data Loader

In [11]:
from scgenehyena.data_loader import VeloCellDataset
from scgenehyena.data_sampler import StratifiedCellSampler
from scgenehyena.tokenizer import VeloTokenizer
from scgenehyena.data_collator import VeloDataCollator, TokenizeAndCollate
from torch.utils.data import DataLoader

In [None]:
# dataset
dataset = VeloCellDataset(
    cell_index,
    cell_type_map,
    t_key='X_binned',
    s_key='spliced_binned',
    u_key='unspliced_binned',
    batch_key='SRX_accession',
)

# sampler
sampler = StratifiedCellSampler(
    cell_types=list(cell_type_map.values()),
    samples_per_epoch=50, # use a larger number in real run
)

# tokenizer
tokenizer = VeloTokenizer(
    genes,
    vocab,
    max_len=1000,
    pad_token="<pad>",
    pad_value=2,
    append_cls=True,
    include_zero_gene=False,
    cls_token="<cls>",
)

# data collator
collator = VeloDataCollator(
    pad_value=-2,
    mlm_probability=0.3,
    mask_value=-1,
    keep_first_n_tokens=1,
    keys2mask=("values_t", "values_s", "values_u"),
    use_attention_mask=True,
)

tokenize_and_collate = TokenizeAndCollate(
    tokenizer=tokenizer,
    collator=collator,
)

  adata.obs[self.batch_key][cell_idx]


{'t': tensor([0, 0, 0,  ..., 0, 0, 0]),
 's': tensor([0, 0, 0,  ..., 0, 0, 0]),
 'u': tensor([0, 0, 0,  ..., 0, 0, 0]),
 'cell_type': 'CD4-positive, alpha-beta T cell',
 'batch': 'SRX9777399'}

In [None]:
# dataloader
dataloader = DataLoader(
    dataset,
    sampler=sampler,
    batch_size=4,
    num_workers=0,
    pin_memory=False,
    persistent_workers=False,
    collate_fn=tokenize_and_collate,
    drop_last=True, 
)

for batch in dataloader:
    break

batch

{'gene_ids': tensor([[    1, 28298, 32389,  ...,     0,     0,     0],
         [    1, 32459, 30423,  ...,     0,     0,     0],
         [    1, 32459, 30423,  ...,     0,     0,     0],
         [    1, 16257, 30022,  ...,     0,     0,     0]]),
 'values_t': tensor([[0., 9., 2.,  ..., 2., 2., 2.],
         [0., 4., 9.,  ..., 2., 2., 2.],
         [0., 4., 9.,  ..., 2., 2., 2.],
         [0., 9., 1.,  ..., 2., 2., 2.]]),
 'values_s': tensor([[0., 3., 3.,  ..., 2., 2., 2.],
         [0., 3., 7.,  ..., 2., 2., 2.],
         [0., 3., 7.,  ..., 2., 2., 2.],
         [0., 5., 6.,  ..., 2., 2., 2.]]),
 'values_u': tensor([[0., 0., 0.,  ..., 2., 2., 2.],
         [0., 0., 0.,  ..., 2., 2., 2.],
         [0., 0., 0.,  ..., 2., 2., 2.],
         [0., 0., 0.,  ..., 2., 2., 2.]]),
 'expr_mask': tensor([[False,  True,  True,  ...,  True,  True,  True],
         [False,  True,  True,  ...,  True,  True,  True],
         [False,  True,  True,  ...,  True,  True,  True],
         [False,  True,  T

In [None]:
# parallel
dataloader = DataLoader(
    dataset,
    sampler=sampler,
    batch_size=4,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    collate_fn=tokenize_and_collate,
    drop_last=True, 
)

for batch in dataloader:
    break

batch

  adata.obs[self.batch_key][cell_idx]
  adata.obs[self.batch_key][cell_idx]
  adata.obs[self.batch_key][cell_idx]
  adata.obs[self.batch_key][cell_idx]
  adata.obs[self.batch_key][cell_idx]


{'gene_ids': tensor([[    1, 23064, 19668,  ...,     0,     0,     0],
         [    1, 32680, 33575,  ..., 25400, 30445, 28473],
         [    1, 19502, 35145,  ...,     0,     0,     0],
         [    1, 32459, 15364,  ...,     0,     0,     0]]),
 'values_t': tensor([[ 0.,  1., 10.,  ...,  2.,  2.,  2.],
         [ 0., 11.,  9.,  ...,  7., 13., 11.],
         [ 0.,  6.,  5.,  ...,  2.,  2.,  2.],
         [ 0.,  6.,  6.,  ...,  2.,  2.,  2.]]),
 'values_s': tensor([[ 0.,  8.,  7.,  ...,  2.,  2.,  2.],
         [ 0., 10.,  1.,  ...,  9., 13., 11.],
         [ 0., 10.,  9.,  ...,  2.,  2.,  2.],
         [ 0.,  8.,  2.,  ...,  2.,  2.,  2.]]),
 'values_u': tensor([[ 0.,  0.,  0.,  ...,  2.,  2.,  2.],
         [ 0.,  3., 12.,  ...,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  ...,  2.,  2.,  2.],
         [ 0.,  0.,  0.,  ...,  2.,  2.,  2.]]),
 'expr_mask': tensor([[False,  True,  True,  ...,  True,  True,  True],
         [False,  True,  True,  ...,  True,  True,  True],
         [Fal

In [25]:
print(batch.keys())
print(batch['gene_ids'].shape)
print(batch['values_s'].shape)
print(batch['values_u'].shape)
print(batch['expr_mask'])
print(len(batch['cell_type']))
print(batch['cell_type'])
print(len(batch['batch']))
print(batch['batch'])

print(batch['values_u'].sum(axis=1))
print(batch['cell_type'])


print(f'\nbatch data:\n{batch}')

    

dict_keys(['gene_ids', 'values_t', 'values_s', 'values_u', 'expr_mask', 'cell_type', 'batch', 'attention_mask', 'masked_t', 'masked_s', 'masked_u', 'mlm_mask'])
torch.Size([4, 1000])
torch.Size([4, 1000])
torch.Size([4, 1000])
tensor([[False,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True]])
4
['endothelial cell', 'CD4-positive, alpha-beta T cell', 'plasmacytoid dendritic cell', 'plasmacytoid dendritic cell']
4
['SRX9856815', 'SRX9856815', 'SRX9777399', 'SRX9777399']
tensor([1430., 1236., 1417., 1489.])
['endothelial cell', 'CD4-positive, alpha-beta T cell', 'plasmacytoid dendritic cell', 'plasmacytoid dendritic cell']

batch data:
{'gene_ids': tensor([[    1, 23064, 19668,  ...,     0,     0,     0],
        [    1, 32680, 33575,  ..., 25400, 30445, 28473],
        [    1, 19502, 35145,  ...,     0,     0,     0],
      

In [23]:
batch['masked_s']

tensor([[ 0., -1., -1.,  ...,  2.,  2.,  2.],
        [ 0., 10.,  1.,  ...,  9., 13., 11.],
        [ 0., 10.,  9.,  ...,  2., -1., -1.],
        [ 0.,  8.,  2.,  ...,  2.,  2., -1.]])

In [24]:
batch['masked_u']

tensor([[ 0., -1., -1.,  ...,  2.,  2.,  2.],
        [ 0.,  3., 12.,  ...,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  ...,  2., -1., -1.],
        [ 0.,  0.,  0.,  ...,  2.,  2., -1.]])