In [2]:
import os, sys
import scanpy as sc
import importlib
import numpy as np

import scgenehyena.data_loader

# Data Preprocessing

Offline, preprocess individual h5ad files separately.

# Load data

In [2]:
h5ad_files = [f"./data/{f}" for f in os.listdir("./data/") if f.endswith("_ct.h5ad")]
print(f"Number of h5ad files: {len(h5ad_files)}")
print(h5ad_files)

# use protein codeing genes only
adata = sc.read(h5ad_files[0])
adata


Number of h5ad files: 2
['./data/SRX9856815_ct.h5ad', './data/SRX9777399_ct.h5ad']


AnnData object with n_obs × n_vars = 7144 × 36601
    obs: 'gene_count', 'umi_count', 'SRX_accession', 'cell_type'
    var: 'gene_ids', 'feature_types'
    layers: 'ambiguous', 'spliced', 'unspliced'

# Vocab

In [7]:
from scgpt.tokenizer.gene_tokenizer import GeneVocab

In [None]:
genes = adata.var_names

vocab = GeneVocab(
    gene_list_or_vocab=genes.tolist(),
    specials=['<pad>', '<cls>', '<eoc>'],
    special_first=True,
    default_token='<pad>',

)

In [7]:
'<pad>' in vocab

True

In [25]:
vocab.get_itos()[36602]

'hsa-mir-1253'

In [None]:
print(vocab['<pad>'])
print(vocab(['OR4F5', '<pad>']))
gene2id = vocab.get_stoi()

# Dataloader (h5ad)

In [None]:
import os
import torch
from torch.utils.data import DataLoader
from scgenehyena.utils import get_toy_data
import scgenehyena.data_loader
import scanpy as sc
import importlib


In [None]:
# create dataset

importlib.reload(scgenehyena.data_loader)

dataset = scgenehyena.data_loader.StratifiedVeloDataset(
    h5ad_files=['./data/SRX9777399_ct.h5ad', './data/SRX9856815_ct.h5ad'], 
    genes=genes,
    cell_type_key="cell_type", 
    batch_key='SRX_accession',
    samples_per_epoch=40,
    vocab=vocab,
    pad_token="<pad>",
    pad_value="-2",
    append_cls=True,
    include_zero_gene=True,
    cls_token="<cls>",
    return_pt=True,
)


In [None]:
# create dataloader

dataloader = DataLoader(
    dataset,
    batch_size=10,
    num_workers=0,           # Reduce workers for CPU-only
    persistent_workers=False,  # Disable for CPU mode
    pin_memory=False          # No GPU
)

In [None]:

for batch in dataloader:
    break
    
    
batch

{'gene_ids': tensor([[    1, 20848,  5528,  ..., 20951, 20470,  8277],
         [    1, 16283,  4687,  ..., 24136, 17050,  9770],
         [    1,   394,  2750,  ..., 29010,    37, 30178],
         ...,
         [    1,  3947,  6289,  ..., 34876,  4856, 26558],
         [    1,  1777, 14098,  ..., 18436, 20250, 27048],
         [    1,  4645, 16748,  ..., 33100,  4231, 13187]]),
 'values_s': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 'values_u': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 'expr_mask': tensor([[False, False, False,  ..., False, False, False]

In [8]:
batch['gene_ids'].shape

torch.Size([10, 2000])

In [13]:
vocab.get_itos()[5280]

'AC090227.3'

# Data Collate

# ScGeneHyena model

In [None]:

from scgenehyena.model import ScGeneHyena

from tests.test_scgenehyena import TestScGeneHyena

In [None]:
import torch
print(torch.__version__)

from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau 

2.3.0


## Create model

In [None]:
model = ScGeneHyena(
    ntoken=3000,
    d_model=64,
    l_max=32768,
    nlayers=2,
    d_hid=32,
    vocab=vocab,
    domain_spec_batchnorm='dsbn',
)
model

## Example input data: 4 cells, 20k genes with 15% masking

In [None]:
torch.manual_seed(3)
masked_input, mask = get_toy_data(num_genes=20000, dim=4, mask_ratio=0.15, seed=3)

print(masked_input.shape)
masked_input

torch.Size([4, 20000, 1])


tensor([[[-0.0000],
         [ 0.3599],
         [-0.7820],
         ...,
         [-0.0405],
         [ 0.0000],
         [-0.0401]],

        [[ 1.2336],
         [ 1.2809],
         [ 0.0000],
         ...,
         [ 0.0000],
         [ 2.3811],
         [ 0.1270]],

        [[ 0.4681],
         [-0.9387],
         [ 0.9764],
         ...,
         [-0.5254],
         [ 0.0000],
         [-0.1891]],

        [[-0.6914],
         [-2.0339],
         [-0.0146],
         ...,
         [ 0.0000],
         [-0.0000],
         [-0.0000]]])

## Forward pass

In [None]:
reconstruction, cell_state = model(masked_input)

In [None]:
reconstruction

tensor([[[0.4858],
         [0.6848],
         [0.3534],
         ...,
         [0.3795],
         [0.3743],
         [0.5372]],

        [[0.6517],
         [0.7773],
         [0.4156],
         ...,
         [0.3889],
         [0.5282],
         [0.6416]],

        [[0.5674],
         [0.5576],
         [0.5840],
         ...,
         [0.3375],
         [0.4753],
         [0.5265]],

        [[0.4289],
         [0.5664],
         [0.4138],
         ...,
         [0.4240],
         [0.4464],
         [0.5575]]], grad_fn=<SigmoidBackward0>)

In [None]:
cell_state

tensor([[ 7.6225, -1.2154,  1.2303,  1.6074],
        [ 7.6227, -1.2150,  1.2443,  1.6110],
        [ 7.6069, -1.2048,  1.2510,  1.5950],
        [ 7.6321, -1.2098,  1.2517,  1.6118]], grad_fn=<AddmmBackward0>)

# to remove

In [None]:
labels = ['asdf', 'ewqr', 'yrj', 'usf', 'zsf', 'hadf']

In [None]:
uniq = sorted({str(x) for x in labels})
l2i = {l:i for i, l in enumerate(uniq)}
i2l = {i:l for l, i in l2i.items()}
ids = np.asarray([l2i[str(x)] for x in labels], dtype=np.int64)

uniq

['asdf', 'ewqr', 'hadf', 'usf', 'yrj', 'zsf']

In [None]:
l2i

{'asdf': 0, 'ewqr': 1, 'hadf': 2, 'usf': 3, 'yrj': 4, 'zsf': 5}

In [None]:
i2l

{0: 'asdf', 1: 'ewqr', 2: 'hadf', 3: 'usf', 4: 'yrj', 5: 'zsf'}

In [None]:
ids

array([0, 1, 4, 3, 5, 2])

In [None]:
labels

['asdf', 'ewqr', 'yrj', 'usf', 'zsf', 'hadf']