# Fine-tuning Borzoi to create a Decima model

In [1]:
import glob
import anndata
import scanpy as sc
import pandas as pd
import bioframe as bf
import os

In [2]:
inputdir = "./data"
outdir = "./example"
ad_file_path = os.path.join(inputdir, "data.h5ad")
h5_file_path = os.path.join(outdir, "data.h5")

## 1. Load input anndata file

The input anndata file needs to be in the format (pseudobulks x genes).

In [3]:
ad = sc.read(ad_file_path)
ad

AnnData object with n_obs × n_vars = 50 × 931
    obs: 'cell_type', 'tissue', 'disease', 'study'
    var: 'chrom', 'start', 'end', 'strand', 'gene_start', 'gene_end', 'gene_length', 'gene_mask_start', 'gene_mask_end', 'dataset'
    uns: 'log1p'

`.obs` should be a dataframe with a unique index per pseudobulk. You can also include other columns with metadata about the pseudobulks, e.g. cell type, tissue, disease, study, number of cells, total counts. 

Note that the original Decima model does NOT separate pseudobulks by sample, i.e. different samples from the same cell type, tissue, disease and study were merged. We also recommend filtering out pseudobulks with few cells or low read count. 

In [4]:
ad.obs.head()

Unnamed: 0,cell_type,tissue,disease,study
pseudobulk_0,ct_0,t_0,d_0,st_0
pseudobulk_1,ct_0,t_0,d_1,st_0
pseudobulk_2,ct_0,t_0,d_2,st_1
pseudobulk_3,ct_0,t_0,d_0,st_1
pseudobulk_4,ct_0,t_0,d_1,st_2


`.var` should be a dataframe with a unique index per gene. The index can be the gene name or Ensembl ID, as long as it is unique. Other essential columns are: chrom, start, end and strand (the gene coordinates).

You can also include other columns with metadata about the genes, e.g. Ensembl ID, type of gene.

In [5]:
ad.var.head()

Unnamed: 0,chrom,start,end,strand,gene_start,gene_end,gene_length,gene_mask_start,gene_mask_end,dataset
gene_0,chr1,26846360,27370648,+,27010200,27534488,524288,163840,524288,train
gene_1,chr19,40619897,41144185,-,40456057,40980345,524288,163840,524288,train
gene_2,chr1,79282506,79806794,-,79118666,79642954,524288,163840,524288,train
gene_3,chr8,144568573,145092861,-,144404733,144929021,524288,163840,524288,val
gene_4,chr16,3249848,3774136,-,3086008,3610296,524288,163840,524288,train


`.X` should contain the total counts per gene and pseudobulk. These should be non-negative integers.

In [6]:
ad.X[:5, :5]

array([[0.       , 7.2824097, 7.2824097, 0.       , 7.2824097],
       [7.3014727, 7.3014727, 0.       , 7.3014727, 7.3014727],
       [7.2867765, 7.2867765, 7.2867765, 7.2867765, 7.2867765],
       [7.283863 , 0.       , 7.283863 , 7.283863 , 7.283863 ],
       [7.3239307, 7.3239307, 0.       , 0.       , 7.3239307]],
      dtype=float32)

## 2. Normalize and log transform data

We first transform the counts to log(CPM+1) values. CPM = Counts Per Million.

In [7]:
sc.pp.normalize_total(ad, target_sum=1e6)
sc.pp.log1p(ad)



In [8]:
ad.X[:5, :5]

array([[0.       , 7.2867765, 7.2867765, 0.       , 7.2867765],
       [7.305924 , 7.305924 , 0.       , 7.305924 , 7.305924 ],
       [7.2911625, 7.2911625, 7.2911625, 7.2911625, 7.2911625],
       [7.2896986, 0.       , 7.2896986, 7.2896986, 7.2896986],
       [7.3284836, 7.3284836, 0.       , 0.       , 7.3284836]],
      dtype=float32)

## 3. Create intervals surrounding genes

Decima is trained on 524,288 bp sequence surrounding the genes. Therefore, we have to take the given gene coordinates and extend them to create intervals of this length.

In [9]:
from decima.data.preprocess import var_to_intervals



In [10]:
ad.var.head()

Unnamed: 0,chrom,start,end,strand,gene_start,gene_end,gene_length,gene_mask_start,gene_mask_end,dataset
gene_0,chr1,26846360,27370648,+,27010200,27534488,524288,163840,524288,train
gene_1,chr19,40619897,41144185,-,40456057,40980345,524288,163840,524288,train
gene_2,chr1,79282506,79806794,-,79118666,79642954,524288,163840,524288,train
gene_3,chr8,144568573,145092861,-,144404733,144929021,524288,163840,524288,val
gene_4,chr16,3249848,3774136,-,3086008,3610296,524288,163840,524288,train


First, we copy the start and end columns to `gene_start` and `gene_end`. We also create a new column `gene_length`. 

In [11]:
ad.var["gene_start"] = ad.var.start.tolist()
ad.var["gene_end"] = ad.var.end.tolist()
ad.var["gene_length"] = ad.var["gene_end"] - ad.var["gene_start"]

In [12]:
ad.var.head()

Unnamed: 0,chrom,start,end,strand,gene_start,gene_end,gene_length,gene_mask_start,gene_mask_end,dataset
gene_0,chr1,26846360,27370648,+,26846360,27370648,524288,163840,524288,train
gene_1,chr19,40619897,41144185,-,40619897,41144185,524288,163840,524288,train
gene_2,chr1,79282506,79806794,-,79282506,79806794,524288,163840,524288,train
gene_3,chr8,144568573,145092861,-,144568573,145092861,524288,163840,524288,val
gene_4,chr16,3249848,3774136,-,3249848,3774136,524288,163840,524288,train


Now, we extend the gene coordinates to create enclosing intervals:

In [13]:
ad = var_to_intervals(ad, chr_end_pad=10000, genome="hg38")
# Replace genome name if necessary

The interval size is 524288 bases. Of these, 163840 will be upstream of the gene start and 360448 will be downstream of the gene start.
3 intervals extended beyond the chromosome start and have been shifted
2 intervals extended beyond the chromosome end and have been shifted
5 intervals did not extend far enough upstream of the TSS and have been dropped


In [14]:
ad.var.head()

Unnamed: 0,chrom,start,end,strand,gene_start,gene_end,gene_length,gene_mask_start,gene_mask_end,dataset
gene_0,chr1,26682520,27206808,+,26846360,27370648,524288,163840,524288,train
gene_1,chr19,40783737,41308025,-,40619897,41144185,524288,163840,524288,train
gene_2,chr1,79446346,79970634,-,79282506,79806794,524288,163840,524288,train
gene_4,chr16,3413688,3937976,-,3249848,3774136,524288,163840,524288,train
gene_5,chr10,22987161,23511449,+,23151001,23675289,524288,163840,524288,train


You see that the columns `start` and `end` now contain the start and end coordinates for the 524,288 bp intervals.

## 3. Split genes into training, validation and test sets

We load the coordinates of the genomic regions used to train Borzoi:

In [15]:
splits_file = "https://raw.githubusercontent.com/calico/borzoi/main/data/sequences_human.bed.gz"
# replace human with mouse for mm10 splits
splits = pd.read_table(splits_file, header=None, names=["chrom", "start", "end", "fold"])
splits.head()

Unnamed: 0,chrom,start,end,fold
0,chr4,82524421,82721029,fold0
1,chr13,18604798,18801406,fold0
2,chr2,189923408,190120016,fold0
3,chr10,59875743,60072351,fold0
4,chr1,117109467,117306075,fold0


Now, we overlap our gene intervals with these regions:

In [16]:
overlaps = bf.overlap(ad.var.reset_index(names="gene"), splits, how="left")
overlaps = overlaps[["gene", "fold_"]].drop_duplicates().astype(str)
overlaps.head()

Unnamed: 0,gene,fold_
0,gene_0,fold5
15,gene_1,fold0
30,gene_2,fold0
44,gene_4,fold0
45,gene_4,fold2


Based on the overlap, we divide our gene intervals into training, validation and test sets.

In [17]:
test_genes = overlaps.gene[overlaps.fold_ == "fold3"].tolist()
val_genes = overlaps.gene[overlaps.fold_ == "fold4"].tolist()
train_genes = set(overlaps.gene).difference(set(test_genes).union(val_genes))

And add this information back to `ad.var`.

In [18]:
ad.var["dataset"] = "test"
ad.var.loc[ad.var.index.isin(val_genes), "dataset"] = "val"
ad.var.loc[ad.var.index.isin(train_genes), "dataset"] = "train"



In [19]:
ad.var.head()

Unnamed: 0,chrom,start,end,strand,gene_start,gene_end,gene_length,gene_mask_start,gene_mask_end,dataset
gene_0,chr1,26682520,27206808,+,26846360,27370648,524288,163840,524288,train
gene_1,chr19,40783737,41308025,-,40619897,41144185,524288,163840,524288,train
gene_2,chr1,79446346,79970634,-,79282506,79806794,524288,163840,524288,train
gene_4,chr16,3413688,3937976,-,3249848,3774136,524288,163840,524288,train
gene_5,chr10,22987161,23511449,+,23151001,23675289,524288,163840,524288,train


In [20]:
ad.var.dataset.value_counts()

dataset
train    769
test      86
val       71
Name: count, dtype: int64

We have now divided the 1000 genes in our dataset into separate sets to be used for training, validation and testing.

## 4. Save processed anndata

We will save the processed anndata file containing these intervals and data splits.

In [21]:
ad.write_h5ad(ad_file_path)

## 5. Create an hdf5 file

To train Decima, we need to extract the genomic sequences for all the intervals and convert them to one-hot encoded format. We save these one-hot encoded inputs to an hdf5 file.

In [22]:
from decima.data.write_hdf5 import write_hdf5

In [23]:
write_hdf5(file=h5_file_path, ad=ad, pad=5000, genome="hg38")
# Change genome name if necessary

Writing metadata
Writing task indices
Writing genes array of shape: (926, 2)
Writing labels array of shape: (926, 50, 1)
Making gene masks
Writing mask array of shape: (926, 534288)
Encoding sequences
Writing sequence array of shape: (926, 534288)
Done!


## 6. Set training parameters

In [24]:
# Learning rate default=0.001
lr = 5e-5
# Total weight parameter for the loss function
total_weight = 1e-4
# Gradient accumulation steps
grad = 5
# batch-size. default=4
bs = 4
# max-seq-shift. default=5000
shift = 5000
# Number of epochs. Default 1
epochs = 15

# logger
logger = "wandb"  # Change to csv to save logs locally

# Number of workers default=16
workers = 16

## 7. Generate training commands

In [25]:
cmds = []

for model in range(4):
    name = f"finetune_test_{model}"
    device = model

    cmd = (
        f"decima finetune --name {name} "
        + f"--model {model} --device {device} "
        + f"--matrix-file {ad_file_path} --h5-file {h5_file_path} "
        + f"--outdir {outdir} --learning-rate {lr} "
        + f"--loss-total-weight {total_weight} --gradient-accumulation {grad} "
        + f"--batch-size {bs} --max-seq-shift {shift} "
        + f"--epochs {epochs} --logger {logger} --num-workers {workers}"
    )
    cmds.append(cmd)

In [26]:
for cmd in cmds:
    print(cmd)

decima finetune --name finetune_test_0 --model 0 --device 0 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
decima finetune --name finetune_test_1 --model 1 --device 1 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
decima finetune --name finetune_test_2 --model 2 --device 2 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
decima finetune --name finetune_test_3 --model 3 --device 3 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --

Here, we train the model for 1 epoch for quick progressing in tutorial. Run the training for more epochs in your training.

In [27]:
! CUDA_VISIBLE_DEVICES=0 decima finetune \
--name finetune_test_0 \
--model 0 \
--device 0 \
--matrix-file {ad_file_path} \
--h5-file {h5_file_path} \
--outdir {outdir} \
--learning-rate {lr} \
--loss-total-weight {total_weight} \
--gradient-accumulation {grad} \
--batch-size 1 \
--max-seq-shift {shift} \
--epochs 1 \
--logger {logger} \
--num-workers {workers}

decima - INFO - Data paths: matrix_file=./data/data.h5ad, h5_file=./example/data.h5
decima - INFO - Reading anndata
decima - INFO - Making dataset objects
decima - INFO - train_params: {'batch_size': 1, 'num_workers': 16, 'devices': 0, 'logger': 'wandb', 'save_dir': './example', 'max_epochs': 1, 'lr': 5e-05, 'total_weight': 0.0001, 'accumulate_grad_batches': 5, 'loss': 'poisson_multinomial', 'clip': 0.0, 'save_top_k': 1, 'pin_memory': True}
decima - INFO - model_params: {'n_tasks': 50, 'init_borzoi': True, 'replicate': '0'}
decima - INFO - Initializing model
decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 0
[34m[1mwandb[0m: Currently logged in as: [33manony-mouse-591272909468377997[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Downloading large artifact 'human_state_dict_fold0:latest', 709.30MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:01.6 (439.0MB/s)
d

In [28]:
# Uncomment if necessary
# import wandb
# wandb.login(host="https://genentech.wandb.io", anonymous="never", relogin=True)

## 8. Make and evaluate predictions using trained models

Using the training commands above, we trained two model replicates. Now, we can use these models to predict gene expression:

In [29]:
checkpoint = glob.glob(os.path.join(outdir, "lightning_logs/*/checkpoints/*.ckpt"))[0]
print(checkpoint)

./example/lightning_logs/bj42z19b/checkpoints/epoch=0-step=154.ckpt


In [30]:
# comma-separated list of model checkpoints
checkpoint_list = ",".join([checkpoint, checkpoint])
checkpoint_list

'./example/lightning_logs/bj42z19b/checkpoints/epoch=0-step=154.ckpt,./example/lightning_logs/bj42z19b/checkpoints/epoch=0-step=154.ckpt'

In [31]:
! CUDA_VISIBLE_DEVICES=0 decima predict-genes \
--output example/test_preds.h5ad \
--model {checkpoint_list} \
--metadata {ad_file_path} \
--device 0 \
--batch-size 8 \
--num-workers 32 \
--max_seq_shift 0 \
--genome hg38 \
--save-replicates

decima - INFO - Using device: cuda:0 and genome: hg38 for prediction.
decima - INFO - Making predictions
decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 0
[34m[1mwandb[0m: Currently logged in as: [33manony-mouse-591272909468377997[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Downloading large artifact 'human_state_dict_fold0:latest', 709.30MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:00.5 (1298.4MB/s)
decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 0
[34m[1mwandb[0m: Downloading large artifact 'human_state_dict_fold0:latest', 709.30MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:00.5 (1349.1MB/s)
decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 0
[34m[1mwandb[0m: Downloading large artifact 'human_state_dict_fold0:latest', 709.30MB. 1 files...
[34m[1mwa

We can open the output h5ad file to see the individual predictions and metrics.

In [32]:
ad_out = anndata.read_h5ad("example/test_preds.h5ad")

In [33]:
ad_out

AnnData object with n_obs × n_vars = 50 × 926
    obs: 'cell_type', 'tissue', 'disease', 'study', 'size_factor', 'train_pearson', 'val_pearson', 'test_pearson'
    var: 'chrom', 'start', 'end', 'strand', 'gene_start', 'gene_end', 'gene_length', 'gene_mask_start', 'gene_mask_end', 'dataset', 'pearson', 'size_factor_pearson'
    layers: 'preds', 'preds_finetune_test_0'

`.layers['preds_0']` and `.layers['preds_1']` contain the predictions made by the individual models whereas `.layers['preds_0']` contains the average predictions. You will see that performance metrics have been added to both `.obs` and `.var`.

In [34]:
ad_out.obs.head()

Unnamed: 0,cell_type,tissue,disease,study,size_factor,train_pearson,val_pearson,test_pearson
pseudobulk_0,ct_0,t_0,d_0,st_0,4976.871582,-0.01608,0.038915,0.077958
pseudobulk_1,ct_0,t_0,d_1,st_0,4887.680664,-0.026668,0.149208,-0.025352
pseudobulk_2,ct_0,t_0,d_2,st_1,4950.70459,-0.044364,0.008472,0.115734
pseudobulk_3,ct_0,t_0,d_0,st_1,4949.709961,0.021054,-0.091807,0.048898
pseudobulk_4,ct_0,t_0,d_1,st_2,4792.810547,-0.046708,-0.071214,-0.104898


In [35]:
ad_out.var.head()

Unnamed: 0,chrom,start,end,strand,gene_start,gene_end,gene_length,gene_mask_start,gene_mask_end,dataset,pearson,size_factor_pearson
gene_0,chr1,26682520,27206808,+,26846360,27370648,524288,163840,524288,train,0.30628,-0.059291
gene_1,chr19,40783737,41308025,-,40619897,41144185,524288,163840,524288,train,0.014492,-0.035897
gene_2,chr1,79446346,79970634,-,79282506,79806794,524288,163840,524288,train,0.182172,0.226918
gene_4,chr16,3413688,3937976,-,3249848,3774136,524288,163840,524288,train,0.098095,-0.032441
gene_5,chr10,22987161,23511449,+,23151001,23675289,524288,163840,524288,train,0.016748,-0.059998
