In [20]:
import pandas as pd
import anndata
import torch
from decima.core import DecimaResult

device = "cuda" if torch.cuda.is_available() else "cpu"


%matplotlib inline

In [2]:
variant_df = pd.read_table("variants.tsv")
variant_df

Unnamed: 0,chrom,pos,ref,alt,gene,rsid
0,chr1,1000018,G,A,ISG15,rs146254088
1,chr1,1002308,T,C,ISG15,rs2489000
2,chr1,109727471,A,C,GSTM3,rs11101994
3,chr1,109728286,T,G,GSTM3,rs4540683
4,chr1,109728807,T,G,GSTM3,rs4970775


In [3]:
result = DecimaResult.load()
ad = result.anndata

[34m[1mwandb[0m: Currently logged in as: [33mcelikm5[0m ([33mcelikm5-genentech[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Downloading large artifact decima_metadata:latest, 628.05MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5 (1247.3MB/s)


In [5]:
from decima.utils.variant import process_variants

In [6]:
variant_df = process_variants(variant_df, ad, min_from_end=5000)
variant_df

dropped 0 variants because the gene was not found in ad.var
dropped 0 variants because the variant did not fit in the interval


Unnamed: 0,chrom,pos,ref,alt,gene,rsid,start,end,strand,gene_mask_start,rel_pos,ref_tx,alt_tx,tss_dist
0,chr1,1000018,G,A,ISG15,rs146254088,837298,1361586,+,163840,162720,G,A,-1120
1,chr1,1002308,T,C,ISG15,rs2489000,837298,1361586,+,163840,165010,T,C,1170
2,chr1,109727471,A,C,GSTM3,rs11101994,109380590,109904878,-,163840,177407,T,G,13567
3,chr1,109728286,T,G,GSTM3,rs4540683,109380590,109904878,-,163840,176592,A,C,12752
4,chr1,109728807,T,G,GSTM3,rs4970775,109380590,109904878,-,163840,176071,A,C,12231


In [21]:
from decima.data.dataset import VariantDataset

dataset = VariantDataset(variant_df, ad=ad)

In [22]:
dataset[0]

tensor([[0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [26]:
relevant_tasks = result.query_cells(
    'tissue == "blood" and disease in ["healthy", "NA"] and cell_type == "CD8-positive, alpha-beta T cell"'
)
len(relevant_tasks)

38

In [27]:
from grelu.transforms.prediction_transforms import Aggregate

In [25]:
from decima.hub import load_decima_model

model = load_decima_model(device=device)

[34m[1mwandb[0m: Downloading large artifact decima_rep0:latest, 2155.88MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:4.8 (444.8MB/s)
[34m[1mwandb[0m: Downloading large artifact human_state_dict_fold0:latest, 709.30MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.8 (399.4MB/s)


In [38]:
agg_transform = Aggregate(tasks=relevant_tasks, model=model)
model.add_transform(agg_transform)

In [39]:
preds = model.predict_on_dataset(dataset, devices=device, batch_size=8, num_workers=16)
preds = anndata.AnnData(X=preds, obs=variant_df.set_index("rsid"), var=ad.obs.loc[relevant_tasks])
preds.shape

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

(5, 38)

In [40]:
preds

AnnData object with n_obs × n_vars = 5 × 38
    obs: 'chrom', 'pos', 'ref', 'alt', 'gene', 'start', 'end', 'strand', 'gene_mask_start', 'rel_pos', 'ref_tx', 'alt_tx', 'tss_dist'
    var: 'cell_type', 'tissue', 'organ', 'disease', 'study', 'dataset', 'region', 'subregion', 'celltype_coarse', 'n_cells', 'total_counts', 'n_genes', 'size_factor', 'train_pearson', 'val_pearson', 'test_pearson'

In [41]:
preds.obs

Unnamed: 0_level_0,chrom,pos,ref,alt,gene,start,end,strand,gene_mask_start,rel_pos,ref_tx,alt_tx,tss_dist
rsid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
rs146254088,chr1,1000018,G,A,ISG15,837298,1361586,+,163840,162720,G,A,-1120
rs2489000,chr1,1002308,T,C,ISG15,837298,1361586,+,163840,165010,T,C,1170
rs11101994,chr1,109727471,A,C,GSTM3,109380590,109904878,-,163840,177407,T,G,13567
rs4540683,chr1,109728286,T,G,GSTM3,109380590,109904878,-,163840,176592,A,C,12752
rs4970775,chr1,109728807,T,G,GSTM3,109380590,109904878,-,163840,176071,A,C,12231


In [42]:
preds.var.head()

Unnamed: 0,cell_type,tissue,organ,disease,study,dataset,region,subregion,celltype_coarse,n_cells,total_counts,n_genes,size_factor,train_pearson,val_pearson,test_pearson
agg_843,"CD8-positive, alpha-beta T cell",blood,blood,,GSE128243,scimilarity,,,,5597,10483704.0,13250,28325.40866,0.941655,0.808412,0.813604
agg_844,"CD8-positive, alpha-beta T cell",blood,blood,,GSE132950,scimilarity,,,,1334,7219685.0,13394,32566.675556,0.964465,0.847792,0.84342
agg_845,"CD8-positive, alpha-beta T cell",blood,blood,,GSE135325,scimilarity,,,,261,612912.0,10905,29777.260442,0.936802,0.808541,0.801103
agg_847,"CD8-positive, alpha-beta T cell",blood,blood,,GSE149356,scimilarity,,,,2054,7534250.0,13387,33923.162424,0.951058,0.818623,0.830627
agg_849,"CD8-positive, alpha-beta T cell",blood,blood,,GSE151310,scimilarity,,,,12443,29209159.0,13809,29807.25923,0.965129,0.829639,0.820742


In [43]:
preds.X.shape

(5, 38)