# Finetune 2kb full Borzoi model into CREsted-style scalar model

For a better tutorial, look at the CREsted documentation's Borzoi finetuning guide.

In [1]:
import os
import zipfile
import tempfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import anndata as ad
import tensorflow as tf
import keras
import wandb
import crested

2025-03-20 14:06:28.735970: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-20 14:06:28.773300: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


Download the data for the notebooks from the dedicated Zenodo link of the CREsted paper. Then use it below.

In [3]:
DATA_DIR = "../../../crested_data/Figure_5/"

In [3]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcas-blaauw[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
# Data paths
adata_file = f"{DATA_DIR}data/mouse_biccn_data_full.h5ad" # see data/README.md, data from crested.get_dataset("mouse_cortex_bigwig_cut_sites")
adata_filtered_file = f"{DATA_DIR}data/mouse_biccn_data_specific.h5ad"
folds_file = f"{DATA_DIR}data/consensus_peaks_biccn_borzoifolds.tsv"
# Genome paths
resources_dir = "../../../mouse/biccn/" # CHANGE TO OWN
genome_file = os.path.join(resources_dir, "mm10.fa")
chromsizes_file = os.path.join(resources_dir, "mm10.chrom.sizes")

In [6]:
genome = crested.Genome(genome_file, chromsizes_file)
crested.register_genome(genome)

2025-03-20T14:07:50.479641+0100 INFO Genome mm10 registered.


In [7]:
frozen = False
borsplit = False

## First round training

### Read in ATAC data

In [6]:
if os.path.exists(adata_file):
    adata = ad.read_h5ad(adata_file)
else:
    atac_dir, regions_file = crested.get_dataset("mouse_cortex_bigwig_cut_sites")
    adata = crested.import_bigwigs(
        bigwigs_folder=atac_dir,
        regions_file=regions_file,
        target_region_width=1000, 
        target="count",
    )
    crested.pp.train_val_test_split(
        adata, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"]
    )
    crested.pp.change_regions_width(
        adata,
        2048,
        chromsizes_file=chromsizes_file,
    )
    crested.pp.normalize_peaks(
        adata, top_k_percent=0.03
    )  # The top_k_percent parameters can be tuned based on potential bias towards cell types. If some weights are overcompensating too much, consider increasing the top_k_percent. Default is 0.01


2025-03-06T10:33:47.807612+0100 INFO Extracting values from 19 bigWig files...


AnnData object with n_obs × n_vars = 19 × 546993
    obs: 'file_path'
    var: 'chr', 'start', 'end'

### Training split

#### Use standard split or use Borzoi folds

In [None]:
if not borsplit:
    crested.pp.train_val_test_split(
        adata, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"]
    )
    print(adata.var['split'].value_counts(dropna=False))
else:
    val_train_folds = {'fold4': 'val', 'fold3': 'test'}

    if 'fold' in adata.var.columns:
        adata.var = adata.var.drop(['fold'], axis = 1)
    if 'split' in adata.var.columns:
        print("Warning: dropping previous split information and adding Borzoi split info")
        adata.var = adata.var.drop(['split'], axis = 1)
    
    # Read in and resize folds file to match 2048bp regions
    folds = pd.read_csv(folds_file, sep = '\t', names = ['chr', 'start', 'end', 'region', 'fold'], usecols = ['region',  'fold'])
    folds_split = folds['region'].str.extract(r"(?P<chr>chr.+):(?P<start>\d+)-(?P<end>\d+)")
    folds_split['center'] = (folds_split['start'].astype(int) + folds_split['end'].astype(int)) // 2
    folds_split['new_start'] = folds_split['center'] - 2048//2
    folds_split['new_end'] = folds_split['center'] + 2048//2
    folds['region'] =  folds_split['chr']+':'+folds_split['new_start'].astype(str)+'-'+folds_split['new_end'].astype(str)
    folds = folds.set_index('region')
    
    # Drop duplicates (on edge of folds)
    folds = folds[~folds.index.duplicated(False)]
    
    # Drop regions not in any folds
    print(f"Dropping {(~adata.var_names.isin(folds.index)).sum()} regions because they are not in any fold.")
    adata = adata[:, adata.var_names.isin(folds.index)].copy()
    
    # Add fold data to var
    adata.var = adata.var.join(folds)
    
    # Turn fold data into split data
    fold_mapping = {fold: 'train' for fold in folds['fold'].unique()}
    fold_mapping.update(val_train_folds)
    adata.var['split'] = adata.var['fold'].map(fold_mapping)
    
    # Check result
    print(adata.var['split'].value_counts(dropna=False))

### Model definition 

In [9]:
# Load in default Borzoi architecture, with shrunk input size
base_model_architecture = crested.tl.zoo.borzoi(seq_len=2048, target_length = 64, num_classes = (7611, 2608))
# Load in original Borzoi weights
model_file, _ = crested.get_model("Borzoi_mouse_rep0")
# Put weights into base architecture
with zipfile.ZipFile(model_file) as model_archive, tempfile.TemporaryDirectory() as tmpdir:
    model_weights_path = model_archive.extract('model.weights.h5', tmpdir)
    base_model_architecture.load_weights(model_weights_path)

# To train Frozen Borzoi
if frozen:
    for layer in base_model_architecture.layers:
        layer.trainable = False

# Replace track head by flatten+dense to predict single vector of scalars per region
## Get last layer before head
current = base_model_architecture.get_layer("final_conv_activation").output
## Flatten and add new layer
current = keras.layers.Flatten()(current)
current = keras.layers.Dense(
    adata.n_obs, activation='softplus', name="dense_out"
)(current)
## Turn into model
model_architecture = keras.Model(inputs = base_model_architecture.inputs, outputs = current, name = 'Borzoi_scalar')
print(model_architecture.summary())

2025-03-06 13:40:47.383422: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78783 MB memory:  -> device: 0, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:c6:00.0, compute capability: 9.0


None


### TaskConfig

In [None]:
optimizer = keras.optimizers.Adam(learning_rate=5e-5) # 5e-5 for first round finetuning, 1e-5 for second round or frozen
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.MeanSquaredError(),
    keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    crested.tl.metrics.PearsonCorrelationLog(),
    crested.tl.metrics.ZeroPenaltyMetric(),
]

config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(config)

### Data module

In [None]:
datamodule = crested.tl.data.AnnDataModule(
    adata,
    genome=genome,
    batch_size=32,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  
    always_reverse_complement=True,  
)

### Training

In [None]:
# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=config,
    project_name="biccn_borzoi_atac",
    run_name="borzoi_borsplit_ft_consensus", # Remember to change!
    logger="wandb", 
)

In [None]:
# train the model
trainer.fit(epochs=10)

## Second round training
### Read in specific ATAC data

In [None]:
# This next part is same as running filtering on the previous adata, except for small seed/time of day/idk differences:
# crested.pp.filter_regions_on_specificity(
#     adata, gini_std_threshold=1.0
# )

In [6]:
# load in pre-saved filtered file
adata = ad.read_h5ad(adata_filtered_file)

In [7]:
# Re-process data
crested.pp.change_regions_width(
    adata,
    2048,
)

In [8]:
# Optional: replace fold info by Borzoi fold info if training those models

val_train_folds = {'fold4': 'val', 'fold3': 'test'}

adata.var = adata.var.drop(['split'], axis = 1)

# Read in and resize folds file to match 2048bp regions
folds = pd.read_csv(folds_file, sep = '\t', names = ['chr', 'start', 'end', 'region', 'fold'], usecols = ['region',  'fold'])
folds_split = folds['region'].str.extract(r"(?P<chr>chr.+):(?P<start>\d+)-(?P<end>\d+)")
folds_split['center'] = (folds_split['start'].astype(int) + folds_split['end'].astype(int)) // 2
folds_split['new_start'] = folds_split['center'] - 2048//2
folds_split['new_end'] = folds_split['center'] + 2048//2
folds['region'] =  folds_split['chr']+':'+folds_split['new_start'].astype(str)+'-'+folds_split['new_end'].astype(str)
folds = folds.set_index('region')
print(f"adata length: {adata.n_vars}")
print(f"folds length: {len(folds)}")
print(f"adata in folds length: {adata.var_names.isin(folds.index).sum()}")
print(f"folds in adata length: {folds.index.isin(adata.var_names).sum()}")

# Drop duplicates (on edge of folds)
folds = folds[~folds.index.duplicated(False)]

# Drop regions not in any folds
print(f"Dropping {(~adata.var_names.isin(folds.index)).sum()} regions because they are not in any fold.")
adata = adata[:, adata.var_names.isin(folds.index)].copy()

# Add fold data to var
adata.var = adata.var.join(folds)

# Turn fold data into split data
fold_mapping = {fold: 'train' for fold in folds['fold'].unique()}
fold_mapping.update(val_train_folds)
adata.var['split'] = adata.var['fold'].map(fold_mapping)

# Check result
adata.var['split'].value_counts(dropna=False)

adata length: 91475
folds length: 543877
adata in folds length: 91013
folds in adata length: 91013
Dropping 462 regions because they are not in any fold.


split
train    67153
val      12974
test     10886
Name: count, dtype: int64

### Load in model

In [8]:
model_architecture = keras.models.load_model(
    f"{DATA_DIR}models/borzoi_borsplit_ft_consensus.keras", compile=False 
)

2025-03-20 14:08:14.674561: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78790 MB memory:  -> device: 0, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:e4:00.0, compute capability: 9.0


### TaskConfig

In [10]:
optimizer = keras.optimizers.Adam(learning_rate=1e-5) # 5e-5 for first round finetuning, 1e-5 for second round or frozen
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.MeanSquaredError(),
    keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    crested.tl.metrics.PearsonCorrelationLog(),
    crested.tl.metrics.ZeroPenaltyMetric(),
]

config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(config)

TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x1509c1b3ca50>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x1509cb9cee10>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])


### Data module

In [11]:
datamodule = crested.tl.data.AnnDataModule(
    adata,
    genome=genome,
    batch_size=32,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  # optional augmentation
    always_reverse_complement=True,  # default True. Will double the effective size of the training dataset.
)

### Training

In [13]:
# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=config,
    project_name="biccn_borzoi_atac",
    run_name="borzoi_borsplit_ft_consensus_ft_specific", # Remember to change!
    logger="wandb", 
)

In [14]:
# train the model
trainer.fit(epochs=30)

None
2025-03-06T13:42:21.893388+0100 INFO Loading sequences into memory...


100%|██████████| 67153/67153 [00:54<00:00, 1232.42it/s]


2025-03-06T13:43:16.751253+0100 INFO Loading sequences into memory...


100%|██████████| 12974/12974 [00:04<00:00, 3190.76it/s]


Epoch 1/30


I0000 00:00:1741265009.359234  378410 service.cc:145] XLA service 0x15093c010a70 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1741265009.379141  378410 service.cc:153]   StreamExecutor device (0): NVIDIA H100 80GB HBM3, Compute Capability 9.0
2025-03-06 13:43:35.604812: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-03-06 13:43:44.691680: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8902

I0000 00:00:1741265075.481889  378410 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
[34m[1mwandb[0m: [32m[41mERROR[0m Unable to log learning rate.


[1m4198/4198[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m169s[0m 23ms/step - concordance_correlation_coefficient: 0.2283 - cosine_similarity: 0.7186 - loss: -0.2981 - mean_absolute_error: 2.2493 - mean_squared_error: 26.7955 - pearson_correlation: 0.5089 - pearson_correlation_log: 0.5275 - zero_penalty_metric: 339.9114 - val_concordance_correlation_coefficient: 0.3568 - val_cosine_similarity: 0.7695 - val_loss: -0.4136 - val_mean_absolute_error: 2.1054 - val_mean_squared_error: 23.1069 - val_pearson_correlation: 0.6280 - val_pearson_correlation_log: 0.5342 - val_zero_penalty_metric: 311.6286 - learning_rate: 1.0000e-05
Epoch 2/30
[1m4198/4198[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 12ms/step - concordance_correlation_coefficient: 0.3683 - cosine_similarity: 0.7732 - loss: -0.4295 - mean_absolute_error: 2.0794 - mean_squared_error: 23.4079 - pearson_correlation: 0.6356 - pearson_correlation_log: 0.5384 - zero_penalty_metric: 321.4921 - val_concordance_correlation_co

0,1
batch/batch_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
batch/concordance_correlation_coefficient,▁▃▅▆▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████████████
batch/cosine_similarity,▁▃▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████████
batch/loss,█▆▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch/mean_absolute_error,█▆▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁
batch/mean_squared_error,█▆▄▃▂▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch/pearson_correlation,▁▄▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇█████████████████
batch/pearson_correlation_log,▂▃▃▃▄▃▄▄▅▁▅▅▅▆▆▆▆▆▆▆▆▇▆▇▆▇▇▇▇█▇█████████
batch/zero_penalty_metric,█▆▄▃▃▃▃▃▃▁▃▃▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch/concordance_correlation_coefficient,▁▁▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█████████████████████

0,1
batch/batch_step,125990.0
batch/concordance_correlation_coefficient,0.46259
batch/cosine_similarity,0.82001
batch/loss,-0.52494
batch/mean_absolute_error,1.93689
batch/mean_squared_error,20.3576
batch/pearson_correlation,0.69555
batch/pearson_correlation_log,0.56386
batch/zero_penalty_metric,314.23633
epoch/concordance_correlation_coefficient,0.46262
