# Predicting Gene Expression with Decima

Decima allows prediction of gene expression at the cell type level, and this tutorial demonstrates how to leverage the prediction API for both genes in the training data and custom genes.

### Precomputed Predictions

Scores for all genes in the training data are precomputed and saved to metadata h5ad object for each model replicate and are available under the `DecimaResult` class. `predicted_expression_matrix` class returns predicted average gene expression across the replicates.

In [1]:
from decima import DecimaResult

result = DecimaResult.load()
result.predicted_expression_matrix()

[34m[1mwandb[0m: Currently logged in as: [33mmhcelik[0m ([33mmhcw[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Downloading large artifact metadata:latest, 3122.32MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.6 (1995.7MB/s)


Unnamed: 0,STRADA,ETV4,USP25,ZSWIM5,C21orf58,MIR497HG,CFAP74,GSE1,LPP,CLK1,...,STRIP2,TNFRSF1A,RBM14-RBM4,C1orf21,LINC00639,NPDC1,ZNF425,COL5A1,BRD3,EVI5L
agg_0,2.973438,1.845565,4.592531,5.099802,1.774879,0.356812,2.590836,4.629774,4.897171,3.326940,...,2.836060,0.297015,1.883849,4.293593,1.463565,3.183534,2.340202,2.374942,2.911916,3.230072
agg_1,2.954213,1.896726,4.688557,5.510440,1.666929,0.352725,2.292625,4.459535,4.915286,3.192858,...,3.125704,0.242543,1.908177,4.439424,1.236739,3.494824,2.425672,2.054568,2.713408,3.491463
agg_2,2.938851,2.197247,4.861410,5.617520,1.773381,0.380867,2.394917,4.415038,4.836399,3.390717,...,3.082098,0.263285,2.006456,4.383455,1.208590,4.013819,2.408381,2.297343,2.892222,3.695785
agg_3,3.045972,2.138573,4.863791,5.273604,1.760097,0.463555,2.391702,3.940975,4.857763,3.410926,...,2.882890,0.290327,1.922963,4.550189,1.430520,3.693118,2.297103,2.121887,2.626117,3.223912
agg_4,3.025518,2.019096,4.602948,5.257001,1.755338,0.382190,2.432810,4.392480,4.959488,3.250500,...,3.082296,0.258540,2.038277,4.464807,1.249043,3.665800,2.400820,2.255862,2.925619,3.471005
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
agg_9533,2.333562,0.633322,4.675825,2.793023,0.752030,0.692083,0.503531,4.327948,6.903193,3.695593,...,0.549795,2.270181,1.563218,4.395422,0.550088,1.330252,1.044471,3.759369,2.491346,1.872717
agg_9535,0.835037,0.358773,1.964896,0.307449,0.337240,0.834196,0.093885,1.853794,3.700790,4.467302,...,0.176885,1.370898,1.022708,3.400267,0.052162,1.908870,0.253417,1.448111,1.622033,1.064292
agg_9536,3.008039,1.209324,4.798392,3.931870,1.401328,1.638555,0.969720,4.779201,6.631931,4.127797,...,1.174298,1.870530,2.506874,5.151776,0.967644,1.809947,2.205356,4.244005,2.974467,2.659873
agg_9537,1.241936,0.455059,2.919995,0.571672,0.486448,1.175586,0.145397,2.412148,4.759118,4.913945,...,0.371035,1.361073,1.668085,4.005738,0.078611,1.571750,0.508187,2.067150,2.323764,1.429850


To access the predicted expression matrix for a specific model, you can use the `model_name` parameter. In this example, we obtain the predicted gene expression for first model replicate.

In [2]:
result.predicted_expression_matrix(model_name="v1_rep0")

Unnamed: 0,STRADA,ETV4,USP25,ZSWIM5,C21orf58,MIR497HG,CFAP74,GSE1,LPP,CLK1,...,STRIP2,TNFRSF1A,RBM14-RBM4,C1orf21,LINC00639,NPDC1,ZNF425,COL5A1,BRD3,EVI5L
agg_0,2.932758,2.020476,4.795636,5.142194,1.721651,0.403188,3.207011,4.647735,4.823714,3.241457,...,1.804592,0.392356,1.923160,4.256888,1.837033,3.382362,2.358378,2.471895,2.808940,3.232095
agg_1,2.816624,1.938934,4.917684,5.402302,1.319163,0.393221,2.809423,4.842717,4.939364,2.980938,...,1.753353,0.307677,2.096857,4.221365,1.679307,3.917950,2.459049,2.356942,2.696189,3.447138
agg_2,2.742950,2.394729,4.805934,5.102471,1.731927,0.452663,2.778222,4.538105,4.443571,3.041557,...,1.804225,0.364422,2.145009,4.004273,1.471316,4.249442,2.204638,2.687080,2.793549,3.760900
agg_3,2.804515,2.358393,4.676318,5.462350,1.543962,0.483929,2.695394,4.411092,4.355472,3.220853,...,1.663766,0.429275,2.065255,4.045637,1.837513,3.879115,2.242636,2.191440,2.614539,3.287658
agg_4,2.815917,2.098659,4.594968,4.648893,1.517963,0.426446,3.173781,4.829787,4.389949,3.011941,...,2.119307,0.320393,2.064416,3.868720,1.499770,4.144956,2.074824,2.375941,2.745383,3.339186
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
agg_9533,2.510139,0.811121,5.367220,2.944120,0.668071,0.898573,0.737285,4.750776,7.324188,3.585325,...,0.258857,1.437364,1.913985,4.471622,0.823436,1.486110,1.288418,3.211422,2.712792,1.665431
agg_9535,1.245692,0.397959,2.112123,0.537773,0.231708,0.898190,0.114970,2.234345,4.391264,4.418188,...,0.132106,0.980370,1.672495,3.558780,0.072055,2.029817,0.418454,1.192156,1.853873,1.099853
agg_9536,2.808824,1.503026,4.660236,3.781434,1.292014,1.926411,1.174456,4.833678,6.598332,4.183246,...,0.833774,1.562745,2.589504,5.134478,1.171923,1.833690,2.468382,4.154689,2.939265,2.304627
agg_9537,1.579961,0.518564,2.777915,0.865460,0.373625,1.167617,0.173764,2.826529,4.751182,4.572947,...,0.219681,0.967018,2.106584,4.237208,0.087220,1.733176,0.613084,1.758543,2.333462,1.247311


and for the second model replicate.

In [3]:
result.predicted_expression_matrix(model_name="v1_rep1")

Unnamed: 0,STRADA,ETV4,USP25,ZSWIM5,C21orf58,MIR497HG,CFAP74,GSE1,LPP,CLK1,...,STRIP2,TNFRSF1A,RBM14-RBM4,C1orf21,LINC00639,NPDC1,ZNF425,COL5A1,BRD3,EVI5L
agg_0,2.892922,1.268846,4.257362,5.062403,1.708644,0.509606,2.129140,5.102046,3.987716,3.229576,...,3.279328,0.496723,2.626515,3.294918,1.754605,3.018026,2.396543,1.611294,2.942932,3.334204
agg_1,2.812950,1.338056,4.490562,5.533390,1.759490,0.507474,1.839018,4.959751,4.505932,3.428190,...,3.428484,0.463658,2.532722,4.212840,1.994184,3.312843,2.755963,1.515888,2.539843,3.585655
agg_2,2.971495,1.536687,4.427551,5.517713,1.655871,0.504365,1.900368,4.990572,4.717015,3.675014,...,3.258069,0.498031,2.830813,3.899091,1.731513,3.000879,2.582952,1.570875,2.792023,3.834818
agg_3,3.346772,1.558132,4.612154,4.838746,1.690119,0.695964,2.075206,3.900921,4.603170,3.397104,...,3.024333,0.464141,2.634533,4.260577,1.873491,3.353543,2.285160,1.407212,2.382554,3.494318
agg_4,3.088684,1.522742,4.351163,4.806933,1.668947,0.553315,1.892262,4.495190,4.887375,3.282181,...,3.286668,0.475796,2.699739,4.285741,1.986780,3.199118,2.432456,1.463589,2.867734,3.607585
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
agg_9533,2.760382,0.603564,4.133419,2.195248,1.040137,0.813743,0.439841,4.018094,6.189613,3.845854,...,0.460953,2.616969,2.051131,3.720593,0.373556,1.720908,0.799354,2.626200,2.094780,2.103047
agg_9535,1.192427,0.355300,1.690948,0.294813,0.529883,1.071292,0.109688,1.716427,2.887877,4.564609,...,0.117188,1.749936,1.504746,2.753326,0.032030,2.539737,0.283777,1.469010,1.537342,1.323113
agg_9536,3.369931,0.922407,4.246971,3.432078,1.621070,1.825625,0.817042,4.820827,6.642282,4.045234,...,0.992599,1.805059,3.023350,4.650090,0.711800,2.415006,1.855187,3.172574,2.799354,2.962450
agg_9537,1.533311,0.472664,2.386347,0.445858,0.655848,1.248309,0.157859,1.858983,3.740174,5.334047,...,0.221859,1.986270,2.305407,3.003855,0.045920,1.986692,0.455313,1.756422,2.054446,1.544962


In [4]:
result.anndata.layers

Layers with keys: preds, v1_rep0, v1_rep1, v1_rep2, v1_rep3

### CLI API

If you want to perform gene expression prediction again, rather than using the precomputed scores, you can use the Decima command-line interface (CLI) to generate new predictions for any set of genes you specify. For example, you can run the `decima predict-genes` command with the `--genes` argument to provide a comma-separated list of gene names (such as "STRADA,ETV4,USP25") if no gene provided it will perform expression predictions for all genes, select the prediction model with the `--model` option (for instance, "ensemble" or a specific replicate like "0"), and use `--save-replicates` to save predictions for each replicate. The `-o` flag lets you specify the output file path for the predictions in `.h5ad` format. 

In [5]:
! decima predict-genes --genes "STRADA,ETV4,USP25" --model ensemble --save-replicates -o example/predict_genes/predictions.h5ad 

decima - INFO - Using device: cuda and genome: hg38 for prediction.
[34m[1mwandb[0m: Currently logged in as: [33mmhcelik[0m ([33mmhcw[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Downloading large artifact rep0:latest, 720.03MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5 (1433.8MB/s)
[34m[1mwandb[0m: Downloading large artifact rep1:latest, 720.03MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5 (1374.8MB/s)
[34m[1mwandb[0m: Downloading large artifact rep2:latest, 720.03MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.7 (1016.5MB/s)
[34m[1mwandb[0m: Downloading large artifact rep3:latest, 720.03MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5 (1371.2MB/s)
[34m[1mwandb[0m: Downloading large artifact metadata:latest, 3122.32MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  


After running this command, you can load the resulting predictions in Python using the `DecimaResult` class and access the predicted expression matrix as:

In [6]:
result = DecimaResult.load("example/predict_genes/predictions.h5ad")
result.predicted_expression_matrix()

Unnamed: 0,STRADA,ETV4,USP25
agg_0,2.973423,1.845874,4.592660
agg_1,2.954493,1.897248,4.688962
agg_2,2.938927,2.197576,4.861925
agg_3,3.046052,2.138910,4.864000
agg_4,3.025393,2.019534,4.602815
...,...,...,...
agg_9533,2.334128,0.633340,4.675878
agg_9535,0.835045,0.358672,1.964816
agg_9536,3.008091,1.209339,4.798233
agg_9537,1.241775,0.455031,2.919700


or for a specific replicate:

In [7]:
result.predicted_expression_matrix(model_name="preds_v1_rep0")

Unnamed: 0,STRADA,ETV4,USP25
agg_0,2.932055,2.022040,4.795146
agg_1,2.816373,1.940856,4.917033
agg_2,2.742780,2.396595,4.806089
agg_3,2.804382,2.360139,4.676453
agg_4,2.815311,2.100366,4.594069
...,...,...,...
agg_9533,2.510625,0.811365,5.366706
agg_9535,1.245751,0.397869,2.111997
agg_9536,2.808623,1.503265,4.659208
agg_9537,1.580037,0.518710,2.777559


### Python API

The same functionality is available through the Python API, allowing you to perform gene expression prediction programmatically. You can specify the genes, model, and other options directly in your Python code using the provided classes and functions.

In [8]:
from decima.tools.inference import predict_gene_expression

ad = predict_gene_expression(
    genes=["STRADA", "ETV4", "USP25"],
    model="ensemble",
)

[34m[1mwandb[0m: Downloading large artifact rep0:latest, 720.03MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5 (1404.0MB/s)
[34m[1mwandb[0m: Downloading large artifact rep1:latest, 720.03MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.6 (1277.2MB/s)
[34m[1mwandb[0m: Downloading large artifact rep2:latest, 720.03MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.6 (1157.4MB/s)
[34m[1mwandb[0m: Downloading large artifact rep3:latest, 720.03MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5 (1321.9MB/s)
[34m[1mwandb[0m: Downloading large artifact metadata:latest, 3122.32MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.6 (1932.7MB/s)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU avai

Predicting: |          | 0/? [00:00<?, ?it/s]



### Developer API

Under the hood, the Decima prediction API uses the `GeneDataset` and `SeqDataset` pytorch Dataset classes to prepare the data for prediction. These classes provide a flexible way to handle different types of input data, including custom genes and DNA sequences. Internally, these datasets represent sequences using one-hot encoding and apply a gene mask to indicate which positions correspond to gene regions.

For example, you can create a `GeneDataset` object to predict expression for the genes in your metadata. The `predict_on_dataset` method returns a dictionary containing the predicted expression values and if there is any warnings.

In [9]:
from pprint import pprint
from decima.data.dataset import GeneDataset
from decima.hub import load_decima_model

model = load_decima_model("rep0", device=0)
ds = GeneDataset(genes=["STRADA", "ETV4", "USP25"])

preds = model.predict_on_dataset(ds, devices=0)
pprint(preds)

[34m[1mwandb[0m: Downloading large artifact rep0:latest, 720.03MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5 (1333.5MB/s)
[34m[1mwandb[0m: Downloading large artifact metadata:latest, 3122.32MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.6 (2010.7MB/s)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

{'expression': array([[2.9320552 , 2.8163733 , 2.7427804 , ..., 2.808623  , 1.5800365 ,
        2.0138068 ],
       [2.0220397 , 1.9408565 , 2.3965948 , ..., 1.5032653 , 0.51870984,
        0.80854166],
       [4.7951455 , 4.9170327 , 4.806089  , ..., 4.659208  , 2.7775593 ,
        3.1479335 ]], dtype=float32),
                      'allele_mismatch_with_reference_genome': tensor(0)})}




## Custom Expression for custom genes

If you have custom genes, you can create a `SeqDataset` object to predict expression for those genes. 

To do this, prepare a FASTA file where:
  - Each sequence is exactly the Decima context size (524,288 bases).
  - The FASTA header for each sequence must include the gene name and the gene mask coordinates, using the format:
        `>gene_name|gene_mask_start=X|gene_mask_end=Y`
    where `X` and `Y` specify the start and end positions (0-based, inclusive) of the gene region within the sequence. The gene mask indicates which region of the sequence corresponds to the gene for which expression will be predicted.

For example, seqs.fasta contains these information:

In [10]:
! cat ../tests/data/seqs.fasta | cut -c 1-200

>CD68|gene_mask_start=163840|gene_mask_end=166460
CTCTGCAGAGAGCGAGGACGGTGTGTCTGCCAGCGCCTTTGACTTCACTGTCTCCAACTTTGTGGACAACCTGTATGGCTACCCGGAAGGCAAGGATGTGCTTCGGGAGACCATCAAGTTTATGTACACAGACTGGGCCGACCGGGACAATGGCGAAATGCGCCGCAAAACCCTGCTGGCGCTCTTTACTGACCACCAAT
>SPI1|gene_mask_start=163840|gene_mask_end=187556
TGCCACTTTTAGATATGTTCATGGGTGCAGATACGGCTTTATTTATTTGAGACAGAGTTTCACTCTTGTTGCCCAGGGTGGAGTGCAGTGGTGCGATCTCAGCTCACTGCAGCCTTCGCCTCCCGGGTTGAAGCGATTCTTCTGCCTCAACCTCGAGTAGCTGGGATTATAGGCACCTGCCAGCATGCCTGGCTAATTTT


In [11]:
from decima.data.dataset import SeqDataset


ds = SeqDataset.from_fasta("../tests/data/seqs.fasta")

preds = model.predict_on_dataset(ds, devices=0)
pprint(preds["expression"])

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

array([[0.22032928, 0.21946219, 0.230885  , ..., 1.0364848 , 1.2134479 ,
        1.4428278 ],
       [0.2652336 , 0.13350041, 0.14896178, ..., 0.44437772, 0.26560807,
        0.30437455]], dtype=float32)


See the documentation of SeqDataset for more details. SeqDataset can be created from a pandas DataFrame with the following columns: `seq`, `gene_mask_start`, `gene_mask_end`, and `gene_name` with `SeqDataset.from_dataframe` or from a one-hot encoded tensor with `SeqDataset.from_one_hot`.