In [1]:
import os
os.chdir('..')


In [2]:

import torch
import numpy as np
import anndata as ad
from dataclasses import dataclass
from utils.dataset_highvars import get_loader
import pandas as pd
import scanpy as sc


In [3]:

@dataclass
class Config:
    batch_size = 64 # Genes processed at once
    num_workers = 15
    num_samples = 700
    target_gene_dim = 128


@dataclass
class ModelConfig:
    embed_dim = 128
    num_heads = 4
    mlp_hidden_dims = [256, 128]
    
cfg = Config()
model_cfg = ModelConfig()

In [4]:
dataset,gene_dim,maskidx = get_loader(cfg.num_samples,cfg.target_gene_dim)
testLoader =  torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)

In [5]:
gene_names = pd.read_csv("data/vcc_data/gene_names.csv", header = None).to_numpy().flatten()
pert_counts = pd.read_csv('data/vcc_data/pert_counts_Validation.csv')

In [6]:
import scanpy as sc
matrix = np.load('v4_submissionmx.npy')


In [7]:
matrix.max(), matrix.min()

(np.float32(6.193662), np.float32(-0.10654889))

In [27]:
tmp = torch.relu(torch.from_numpy(matrix))

In [29]:
tmp = tmp.detach().numpy()

In [30]:
pert_names = pert_counts["target_gene"].to_numpy()
cell_counts = pert_counts['n_cells'].to_numpy()

adata = ad.AnnData(
        X=tmp,
        obs=pd.DataFrame(
            {
                "target_gene": np.repeat(pert_names, cell_counts),
            },
            index=np.arange(cell_counts.sum()).astype(str),
        ),
        var=pd.DataFrame(index=gene_names),
    )

In [34]:
tr_adata_path = "data/vcc_data/adata_Training.h5ad"
# Read in the anndata
tr_adata = ad.read_h5ad(tr_adata_path)

# Filter for non-targeting
ntc_adata = tr_adata[tr_adata.obs["target_gene"] == "non-targeting"]

sc.pp.normalize_total(ntc_adata)
sc.pp.log1p(ntc_adata)

# Append the non-targeting controls to the example anndata if they're missing
if "non-targeting" not in adata.obs["target_gene"].unique():
    assert np.all(adata.var_names.values == ntc_adata.var_names.values), (
        "Gene-Names are out of order or unequal"
    )
    adata = ad.concat(
        [
            adata,
            ntc_adata,
        ]
    )

  view_to_actual(adata)


In [35]:
adata.obs

Unnamed: 0,target_gene
0,SH3BP4
1,SH3BP4
2,SH3BP4
3,SH3BP4
4,SH3BP4
...,...
TTTCGCGCAACCTGTTATTCGGTT-Flex_3_16,non-targeting
TTTCGCGCATTCGGTTATTCGGTT-Flex_3_16,non-targeting
TTTGCTGAGTAACTTCATTCGGTT-Flex_3_16,non-targeting
TTTGGACGTGGTGCAGATTCGGTT-Flex_3_16,non-targeting


In [36]:
adata.X.toarray().max(), adata.X.toarray().min()

(np.float32(6.890225), np.float32(0.0))

In [37]:
adata[adata.obs.target_gene == "non-targeting"]

View of AnnData object with n_obs × n_vars = 38176 × 18080
    obs: 'target_gene'

In [38]:
adata.write_h5ad('data/vcc_data/submission.h5ad')

In [39]:
os.chdir('data/vcc_data')
!cell-eval prep -i submission.h5ad --genes gene_names.csv

INFO:cell_eval._cli._prep:Reading input anndata
INFO:cell_eval._cli._prep:Reading gene list
INFO:cell_eval._cli._prep:Preparing anndata
INFO:cell_eval._cli._prep:Using 32-bit float encoding
INFO:cell_eval._cli._prep:Setting data to sparse if not already
INFO:cell_eval._cli._prep:Simplifying obs dataframe
INFO:cell_eval._cli._prep:Simplifying var dataframe
INFO:cell_eval._cli._prep:Creating final minimal AnnData object
INFO:cell_eval._cli._prep:Applying normlog transformation if required
INFO:cell_eval.utils:Data appears to be log1p normalized (decimals detected, range [0.00, 6.89])
INFO:cell_eval._evaluator:Input is found to be log-normalized already - skipping transformation.
INFO:cell_eval._cli._prep:Writing h5ad output to /tmp/tmpwhrbpbtk/pred.h5ad
INFO:cell_eval._cli._prep:Zstd compressing /tmp/tmpwhrbpbtk/pred.h5ad
/tmp/tmpwhrbpbtk/pred.h5ad : 41.43%   (  9.90 GiB =>   4.10 GiB, /tmp/tmpwhrbpbtk/pred.h5ad.zst) 
INFO:cell_eval._cli._prep:Packing files into submission.prep.vcc
INFO: