# Fine-tuning Borzoi to create a Decima model

In [1]:
import scanpy as sc
import pandas as pd
import bioframe as bf
import os

In [2]:
outdir = "."
ad_file_path = os.path.join(outdir, "data.h5ad")
h5_file_path = os.path.join(outdir, "data.h5")

## 1. Load input anndata file

The input anndata file needs to be in the format (pseudobulks x genes).

In [3]:
ad = sc.read("data/test_data.h5ad")
ad

AnnData object with n_obs × n_vars = 50 × 1000
    obs: 'cell_type', 'tissue', 'disease', 'study'
    var: 'chrom', 'start', 'end', 'strand'

`.obs` should be a dataframe with a unique index per pseudobulk. You can also include other columns with metadata about the pseudobulks, e.g. cell type, tissue, disease, study, number of cells, total counts.

In [4]:
ad.obs.head()

Unnamed: 0,cell_type,tissue,disease,study
pseudobulk_0,ct_0,t_0,d_0,st_0
pseudobulk_1,ct_0,t_0,d_1,st_0
pseudobulk_2,ct_0,t_0,d_2,st_1
pseudobulk_3,ct_0,t_0,d_0,st_1
pseudobulk_4,ct_0,t_0,d_1,st_2


`.var` should be a dataframe with a unique index per gene. The index can be the gene name or Ensembl ID, as long as it is unique. Other essential columns are: chrom, start, end and strand (the gene coordinates).

You can also include other columns with metadata about the genes, e.g. Ensembl ID, type of gene.

In [5]:
ad.var.head()

Unnamed: 0,chrom,start,end,strand
gene_0,chr1,28648600,28648730,+
gene_1,chr19,39341773,39341945,-
gene_2,chr1,78004346,78004554,-
gene_3,chr8,143290399,143290621,-
gene_4,chr16,1971655,1971896,-


`.X` should contain the total counts per gene and pseudobulk. These should be non-negative integers.

In [6]:
ad.X[:5, :5]

array([[ 0, 36, 82,  0, 53],
       [29, 84,  0, 33, 27],
       [12, 33, 24, 60, 57],
       [32,  0, 51, 77, 42],
       [37,  2,  0,  0, 80]])

## 2. Normalize and log transform data

We first transform the counts to log(CPM+1) values. CPM = Counts Per Million.

In [7]:
sc.pp.normalize_total(ad, target_sum=1e6)
sc.pp.log1p(ad) 

In [8]:
ad.X[:5, :5]

array([[0.       , 6.921574 , 7.7442207, 0.       , 7.3080306],
       [6.6934667, 7.756176 , 0.       , 6.822528 , 6.6220994],
       [5.8283887, 6.838115 , 6.5200634, 7.4354696, 7.3842077],
       [6.832712 , 0.       , 7.2984004, 7.7101517, 7.104389 ],
       [6.996557 , 4.0946727, 0.       , 0.       , 7.767174 ]],
      dtype=float32)

## 3. Create intervals surrounding genes

Decima is trained on 524,288 bp sequence surrounding the genes. Therefore, we have to take the given gene coordinates and extend them to create intervals of this length.

In [9]:
from decima.data.preprocess import var_to_intervals

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
ad.var.head()

Unnamed: 0,chrom,start,end,strand
gene_0,chr1,28648600,28648730,+
gene_1,chr19,39341773,39341945,-
gene_2,chr1,78004346,78004554,-
gene_3,chr8,143290399,143290621,-
gene_4,chr16,1971655,1971896,-


First, we copy the start and end columns to `gene_start` and `gene_end`. We also create a new column `gene_length`. 

In [11]:
ad.var['gene_start'] = ad.var.start.tolist()
ad.var['gene_end'] = ad.var.end.tolist()
ad.var['gene_length'] = ad.var['gene_end'] - ad.var['gene_start']

In [12]:
ad.var.head()

Unnamed: 0,chrom,start,end,strand,gene_start,gene_end,gene_length
gene_0,chr1,28648600,28648730,+,28648600,28648730,130
gene_1,chr19,39341773,39341945,-,39341773,39341945,172
gene_2,chr1,78004346,78004554,-,78004346,78004554,208
gene_3,chr8,143290399,143290621,-,143290399,143290621,222
gene_4,chr16,1971655,1971896,-,1971655,1971896,241


Now, we extend the gene coordinates to create enclosing intervals:

In [13]:
ad = var_to_intervals(ad, chr_end_pad = 10000, genome="hg38") 
# Replace genome name if necessary

The interval size is 524288 bases. Of these, 163840 will be upstream of the gene start and 360448 will be downstream of the gene start.
2 intervals extended beyond the chromosome start and have been shifted
0 intervals extended beyond the chromosome end and have been shifted
0 intervals did not extend far enough upstream of the TSS and have been dropped


In [14]:
ad.var.head()

Unnamed: 0,chrom,start,end,strand,gene_start,gene_end,gene_length,gene_mask_start,gene_mask_end
gene_0,chr1,28484760,29009048,+,28648600,28648730,130,163840,163970
gene_1,chr19,38981497,39505785,-,39341773,39341945,172,163840,164012
gene_2,chr1,77644106,78168394,-,78004346,78004554,208,163840,164048
gene_3,chr8,142930173,143454461,-,143290399,143290621,222,163840,164062
gene_4,chr16,1611448,2135736,-,1971655,1971896,241,163840,164081


You see that the columns `start` and `end` now contain the start and end coordinates for the 524,288 bp intervals.

## 3. Split genes into training, validation and test sets

We load the coordinates of the genomic regions used to train Borzoi:

In [15]:
splits_file = 'https://raw.githubusercontent.com/calico/borzoi/main/data/sequences_human.bed.gz' 
# replace human with mouse for mm10 splits
splits = pd.read_table(splits_file, header=None, names=['chrom', 'start', 'end', 'fold'])
splits.head()

Unnamed: 0,chrom,start,end,fold
0,chr4,82524421,82721029,fold0
1,chr13,18604798,18801406,fold0
2,chr2,189923408,190120016,fold0
3,chr10,59875743,60072351,fold0
4,chr1,117109467,117306075,fold0


Now, we overlap our gene intervals with these regions:

In [16]:
overlaps = bf.overlap(ad.var.reset_index(names="gene"), splits, how='left')
overlaps = overlaps[['gene', 'fold_']].drop_duplicates().astype(str)
overlaps.head()

Unnamed: 0,gene,fold_
0,gene_0,fold5
15,gene_1,fold0
30,gene_2,fold0
45,gene_3,fold4
60,gene_4,fold0


Based on the overlap, we divide our gene intervals into training, validation and test sets.

In [17]:
test_genes = overlaps.gene[overlaps.fold_=='fold3'].tolist()
val_genes = overlaps.gene[overlaps.fold_=='fold4'].tolist()
train_genes = set(overlaps.gene).difference(set(test_genes).union(val_genes))

And add this information back to `ad.var`.

In [18]:
ad.var["dataset"] = "test"   
ad.var.loc[ad.var.index.isin(val_genes), "dataset"] = "val"
ad.var.loc[ad.var.index.isin(train_genes), "dataset"] = "train"



In [19]:
ad.var.head()

Unnamed: 0,chrom,start,end,strand,gene_start,gene_end,gene_length,gene_mask_start,gene_mask_end,dataset
gene_0,chr1,28484760,29009048,+,28648600,28648730,130,163840,163970,train
gene_1,chr19,38981497,39505785,-,39341773,39341945,172,163840,164012,train
gene_2,chr1,77644106,78168394,-,78004346,78004554,208,163840,164048,train
gene_3,chr8,142930173,143454461,-,143290399,143290621,222,163840,164062,val
gene_4,chr16,1611448,2135736,-,1971655,1971896,241,163840,164081,train


In [20]:
ad.var.dataset.value_counts()

dataset
train    824
test      98
val       78
Name: count, dtype: int64

We have now divided the 1000 genes in our dataset into separate sets to be used for training, validation and testing.

## 4. Save processed anndata

We will save the processed anndata file containing these intervals and data splits.

In [21]:
ad.write_h5ad(ad_file_path)

## 5. Create an hdf5 file

To train Decima, we need to extract the genomic sequences for all the intervals and convert them to one-hot encoded format. We save these one-hot encoded inputs to an hdf5 file.

In [22]:
from decima.data.write_hdf5 import write_hdf5

In [23]:
write_hdf5(file=h5_file_path, ad=ad, pad=5000, genome="hg38") 
# Change genome name if necessary

Writing metadata
Writing task indices
Writing genes array of shape: (1000, 2)
Writing labels array of shape: (1000, 50, 1)
Making gene masks
Writing mask array of shape: (1000, 534288)
Encoding sequences
Writing sequence array of shape: (1000, 534288)
Done!


## 6. Set training parameters

In [24]:
# Learning rate default=0.001
lr = 5e-5
# Total weight parameter for the loss function
total_weight = 1e-4
# Gradient accumulation steps
grad = 5
# batch-size. default=4
bs = 4
# max-seq-shift. default=5000
shift = 5000
# Number of epochs. Default 1
epochs = 15

# logger
logger ="wandb" # Change to csv to save logs locally

# Number of workers default=16
workers = 16

## 7. Generate training commands

In [25]:
cmds = []

for model in range(4):
    name = f'finetune_test_{model}'
    device = model
    
    cmd = f"decima finetune --name {name} " +\
    f"--model {model} --device {device} " +\
    f"--matrix-file {ad_file_path} --h5-file {h5_file_path} " + \
    f"--outdir {outdir} --learning-rate {lr} " +\
    f"--loss-total-weight {total_weight} --gradient-accumulation {grad} "+\
    f"--batch-size {bs} --max-seq-shift {shift} " +\
    f"--epochs {epochs} --logger {logger} --num-workers {workers}"
    cmds.append(cmd)

In [26]:
for cmd in cmds:
    print(cmd)

decima finetune --name finetune_test_0 --model 0 --device 0 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
decima finetune --name finetune_test_1 --model 1 --device 1 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
decima finetune --name finetune_test_2 --model 2 --device 2 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
decima finetune --name finetune_test_3 --model 3 --device 3 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumul

## Test

In [28]:
import wandb
wandb.login(host="https://genentech.wandb.io", anonymous="never")

[34m[1mwandb[0m: Currently logged in as: [33mlal-avantika[0m ([33mgrelu[0m) to [32mhttps://genentech.wandb.io[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
!decima finetune --name finetune_test_3 --model 3 --device 3 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16

decima - INFO - Data paths: matrix_file=./data.h5ad, h5_file=./data.h5
decima - INFO - Reading anndata
decima - INFO - Making dataset objects
decima - INFO - train_params: {'name': 'finetune_test_3', 'batch_size': 4, 'num_workers': 16, 'devices': 3, 'logger': 'wandb', 'save_dir': '.', 'max_epochs': 15, 'lr': 5e-05, 'total_weight': 0.0001, 'accumulate_grad_batches': 5, 'loss': 'poisson_multinomial', 'clip': 0.0, 'save_top_k': 1, 'pin_memory': True}
decima - INFO - model_params: {'n_tasks': 50, 'init_borzoi': True, 'replicate': '3'}
decima - INFO - Initializing model
decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 3
[34m[1mwandb[0m: Currently logged in as: [33manony-mouse-891169334544049289[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Downloading large artifact human_state_dict_fold3:latest, 709.30MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.2 (583.7MB/s