In [1]:
import os
os.chdir('..')

import torch
import numpy as np
import anndata as ad
from dataclasses import dataclass
from utils.datasetBence import get_loader, get_positional_encoding_vector
import pandas as pd
import scanpy as sc


In [2]:
dataRoot = "data/vcc_data"
tr_adata_path = f"{dataRoot}/adata_Training.h5ad"
adata = ad.read_h5ad(tr_adata_path)
sampleBatch= "Flex_1_01"

In [3]:
gene_names = pd.read_csv("data/vcc_data/gene_names.csv", header = None).to_numpy().flatten()
pert_counts = pd.read_csv('data/vcc_data/pert_counts_Validation.csv')

In [4]:
@dataclass
class Config:
    batch_size = 64 # Genes processed at once
    num_workers = 15
    num_samples = 100
    target_gene_dim = 128

@dataclass
class ModelConfig:
    embed_dim = 128
    num_heads = 4
    mlp_hidden_dims = [256, 128]
    
cfg = Config()
model_cfg = ModelConfig()

In [5]:
adata_raw = adata.copy()

In [6]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

In [7]:
sc.pp.highly_variable_genes(adata)

In [8]:
tmp = adata[adata.obs.batch == sampleBatch].X.toarray()[:100]

array([[0.        , 0.        , 0.69651216, ..., 0.        , 0.        ,
        0.69651216],
       [0.        , 1.1825377 , 0.        , ..., 0.        , 0.        ,
        1.1825377 ],
       [0.        , 0.5420716 , 0.        , ..., 0.        , 0.89164174,
        0.5420716 ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 1.4580032 ,
        0.        ],
       [0.        , 1.6402287 , 0.        , ..., 0.        , 1.3273135 ,
        0.        ],
       [0.        , 1.7353313 , 0.        , ..., 0.        , 0.        ,
        0.        ]], shape=(100, 18080), dtype=float32)

In [46]:
tmp = adata[:,maskidx].X.toarray()
tmp.shape

(221273, 1635)

In [47]:
tmp.shape[1]

1635

In [22]:
batch_info = adata_raw[adata_raw.obs.batch == sampleBatch].X.toarray()

In [26]:
stds = batch_info.std(axis = 0)

In [29]:
top_values,top_idx = torch.topk(torch.from_numpy(stds),1000)

In [40]:
positions = [get_positional_encoding_vector(idx,128) for idx in top_idx.detach().numpy()]

In [43]:
positions = np.array(positions)
np.permute_dims(positions,axes = (1,0)).shape

(128, 1000)

In [48]:
batch_info[:,top_idx].shape

(4339, 1000)

# Pseudo pipeline


### get top std indexes
### encode them into pos vector
### encode top highly var genes
### encode their std
### combine to state vector