# Train decima fine-tuning only the Borzoi head

In [None]:
import anndata
import os, sys
import argparse
import wandb

sys.path.append('/code/decima/src/decima/')
from read_hdf5 import HDF5Dataset
from lightning import LightningModel

wandb.login(host="https://genentech.wandb.io")

## Paths

In [None]:
save_dir="/gstore/data/resbioai/grelu/decima/20240823/"
matrix_file = os.path.join(save_dir, "aggregated.h5ad")
h5_file = os.path.join(save_dir, "data.h5")

## Load data

In [None]:
ad = anndata.read_h5ad(matrix_file)

## Make pytorch datasets

In [None]:
train_dataset = HDF5Dataset(h5_file=h5_file, ad=ad, key="train", max_seq_shift=5000, augment_mode="random", seed=0)
val_dataset = HDF5Dataset(h5_file=h5_file, ad=ad, key="val", max_seq_shift=0)

## Parameters

In [None]:
weight=1e-4
name=f'decima_v20240823_rep0_head_only' 

train_params = {
    "optimizer": "adam",
    "batch_size": 4,
    "num_workers": 16,
    "devices": 1,
    "logger": "wandb",
    "save_dir": save_dir,
    "max_epochs": 15,
    "lr":3e-5,
    "total_weight": weight,
    "accumulate_grad_batches": 5,
    "loss": 'poisson_multinomial',
    "pairs": ad.uns["disease_pairs"].values,
}
model_params = {
    "n_tasks":ad.shape[0],
    "replicate":0,
    "init_borzoi": True,
}

model = LightningModel(model_params=model_params, train_params=train_params)

## Freeze borzoi weights

In [None]:
for param in model.model.embedding.parameters():
    param.requires_grad = False

## Train

In [None]:
run = wandb.init(project="decima", dir=name, name=name)
model.train_on_dataset(train_dataset, val_dataset)

train_dataset.close()
val_dataset.close()
run.finish()