# Finetune 2kb Borzoi CNN model into CREsted-style scalar model

For a better tutorial, look at the CREsted documentation's Borzoi finetuning guide.

In [1]:
import os
from pprint import pprint
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import anndata as ad
import tensorflow as tf
import keras
import wandb
import crested

2025-03-20 14:04:15.631637: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-20 14:04:15.669256: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


Download the data for the notebooks from the dedicated Zenodo link of the CREsted paper. Then use it below.

In [3]:
DATA_DIR = "../../../crested_data/Figure_5/"

In [4]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mkemp[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
# Data paths
adata_file = f"{DATA_DIR}data/mouse_biccn_data_full.h5ad" # see data/README.md, data from crested.get_dataset("mouse_cortex_bigwig_cut_sites")
adata_filtered_file = f"{DATA_DIR}data/mouse_biccn_data_specific.h5ad"
folds_file = f"{DATA_DIR}data/consensus_peaks_biccn_borzoifolds.tsv"
# Genome paths
resources_dir = "../../../mouse/biccn/" # CHANGE TO OWN
genome_file = os.path.join(resources_dir, "mm10.fa")
chromsizes_file = os.path.join(resources_dir, "mm10.chrom.sizes")


In [6]:
genome = crested.Genome(genome_file, chromsizes_file)
crested.register_genome(genome)

2025-03-20T14:04:30.761938+0100 INFO Genome mm10 registered.


## Read in ATAC data

In [8]:
if os.path.exists(adata_file):
    adata = ad.read_h5ad(adata_file)
else:
    atac_dir, regions_file = crested.get_dataset("mouse_cortex_bigwig_cut_sites")
    adata = crested.import_bigwigs(
        bigwigs_folder=atac_dir,
        regions_file=regions_file,
        target_region_width=1000,  # optionally, use a different width than the consensus regions file (500bp) for the .X values calculation
        target="count",  # or "max", "count", "logcount" --> what we will be predicting
    )
    crested.pp.train_val_test_split(
        adata, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"]
    )
    crested.pp.change_regions_width(
        adata,
        2048,
        chromsizes_file=chromsizes_file,
    )
    crested.pp.normalize_peaks(
        adata, top_k_percent=0.03
    )  # The top_k_percent parameters can be tuned based on potential bias towards cell types. If some weights are overcompensating too much, consider increasing the top_k_percent. Default is 0.01


## Model definition 

In [9]:
# Load in default Borzoi architecture, with shrunk input size
base_model_architecture = crested.tl.zoo.borzoi(seq_len=2048, target_length = 64, num_classes = (7611, 2608))
# Load in original Borzoi weights
model_file, _ = crested.get_model("Borzoi_mouse_rep0")
# Put weights into base architecture
with zipfile.ZipFile(model_file) as model_archive, tempfile.TemporaryDirectory() as tmpdir:
    model_weights_path = model_archive.extract('model.weights.h5', tmpdir)
    base_model_architecture.load_weights(model_weights_path)

# Replace track head by flatten+dense to predict single vector of scalars per region
## Get last layer at end of conv tower
current = base_model_architecture.get_layer("tower_conv_6_pool").output
## Flatten and add new layer
current = keras.layers.Flatten()(current)
current = keras.layers.Dense(
    adata.n_obs, activation='softplus', name="dense_out"
)(current)
## Turn into model
model_architecture = keras.Model(inputs = base_model_architecture.inputs, outputs = current, name = 'Borzoi_scalar')
print(model_architecture.summary())



2025-03-05 14:40:00.245082: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78783 MB memory:  -> device: 0, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:26:00.0, compute capability: 9.0


None


## First-round finetuning
### Data

In [10]:
datamodule = crested.tl.data.AnnDataModule(
    adata,
    batch_size=32,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  # optional augmentation
    always_reverse_complement=True,  # default True. Will double the effective size of the training dataset.
)

### TaskConfig

The TaskConfig object specifies the optimizer, loss function, and metrics to use in training (we call this our 'task').  
Some default configurations are available for some common tasks such as 'topic_classification' and 'peak_regression',
which you can load using the {func}`crested.tl.default_configs` function.  

In [11]:
optimizer = keras.optimizers.Adam(learning_rate=5e-5) # 5e-5 for first round finetuning, 1e-5 for second round
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.MeanSquaredError(),
    keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    crested.tl.metrics.PearsonCorrelationLog(),
    crested.tl.metrics.ZeroPenaltyMetric(),
]

config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(config)

TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x153fb466c690>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x153fb4610990>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])


### Training


In [12]:
# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=config,
    project_name="biccn_borzoi_atac",
    run_name="borzoi_cnn_ft_consensus", 
    logger="wandb",
)

In [13]:
# train the model
trainer.fit(epochs=20)

None
2025-03-05T14:40:26.194510+0100 INFO Loading sequences into memory...


100%|██████████| 440993/440993 [00:07<00:00, 62394.34it/s]


2025-03-05T14:40:33.995615+0100 INFO Loading sequences into memory...


100%|██████████| 56064/56064 [00:00<00:00, 87146.49it/s]


Epoch 1/20


I0000 00:00:1741182038.056037 1960684 service.cc:145] XLA service 0x153f54002d00 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1741182038.079674 1960684 service.cc:153]   StreamExecutor device (0): NVIDIA H100 80GB HBM3, Compute Capability 9.0
2025-03-05 14:40:41.922943: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-03-05 14:40:47.005423: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8902
I0000 00:00:1741182060.148324 1960684 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
[34m[1mwandb[0m: [32m[41mERROR[0m Unable to log learning rate.


[1m27563/27563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m440s[0m 15ms/step - concordance_correlation_coefficient: 0.7034 - cosine_similarity: 0.8531 - loss: -0.5235 - mean_absolute_error: 2.7558 - mean_squared_error: 32.7529 - pearson_correlation: 0.7943 - pearson_correlation_log: 0.5943 - zero_penalty_metric: 134.6950 - val_concordance_correlation_coefficient: 0.8338 - val_cosine_similarity: 0.8803 - val_loss: -0.6188 - val_mean_absolute_error: 2.4226 - val_mean_squared_error: 23.7686 - val_pearson_correlation: 0.8651 - val_pearson_correlation_log: 0.6572 - val_zero_penalty_metric: 136.5396 - learning_rate: 5.0000e-05
Epoch 2/20
[1m27563/27563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m400s[0m 15ms/step - concordance_correlation_coefficient: 0.8265 - cosine_similarity: 0.8860 - loss: -0.6401 - mean_absolute_error: 2.3355 - mean_squared_error: 22.1289 - pearson_correlation: 0.8629 - pearson_correlation_log: 0.6594 - zero_penalty_metric: 134.2990 - val_concordance_correlati

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



## Second-round further finetuning
### Load in model

In [7]:
model_architecture = keras.models.load_model(
    f"{DATA_DIR}models/borzoi_cnn_ft_consensus.keras", compile=False 
)

2025-03-20 14:04:42.522546: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78790 MB memory:  -> device: 0, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:e4:00.0, compute capability: 9.0


### Subset regions to keep specific ones

In [15]:
if os.path.exists(adata_filtered_file):
    adata_ct = ad.read_h5ad(adata_filtered_file)
    crested.pp.train_val_test_split(adata_ct, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"])
    crested.pp.change_regions_width(adata_ct, 2048)
else:
    adata_ct = adata.copy()
    crested.pp.filter_regions_on_specificity(
        adata_ct, gini_std_threshold=1.0
    )

### Data

In [16]:
datamodule_ct = crested.tl.data.AnnDataModule(
    adata_ct,
    genome=genome,
    batch_size=32,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  # optional augmentation
    always_reverse_complement=True,  # default True. Will double the effective size of the training dataset.
)

### TaskConfig

The TaskConfig object specifies the optimizer, loss function, and metrics to use in training (we call this our 'task').  
Some default configurations are available for some common tasks such as 'topic_classification' and 'peak_regression',
which you can load using the {func}`crested.tl.default_configs` function.  

In [17]:
optimizer_ct = keras.optimizers.Adam(learning_rate=1e-5) # 5e-5 for first round finetuning, 1e-5 for second round
loss_ct = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics_ct= [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.MeanSquaredError(),
    keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    crested.tl.metrics.PearsonCorrelationLog(),
    crested.tl.metrics.ZeroPenaltyMetric(),
]

config_ct = crested.tl.TaskConfig(optimizer_ct, loss_ct, metrics_ct)
print(config_ct)

TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x152a1370eb50>, loss=<crested.tl.losses._cosinemse_log.CosineMSELogLoss object at 0x152a136252d0>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])


### Training


In [18]:
# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule_ct,
    model=model_architecture,
    config=config_ct,
    project_name="biccn_borzoi_atac",
    run_name="borzoi_cnn_ft_consensus_ft_specific", 
    logger="wandb", 
)

In [20]:
# train the model
trainer.fit(epochs=5)

None
2025-03-05T16:21:27.884205+0100 INFO Loading sequences into memory...


100%|██████████| 73326/73326 [00:20<00:00, 3636.61it/s] 

2025-03-05T16:21:48.321690+0100 INFO Loading sequences into memory...



100%|██████████| 9951/9951 [00:01<00:00, 5460.12it/s]


Epoch 1/5
[1m4583/4583[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m131s[0m 20ms/step - concordance_correlation_coefficient: 0.6344 - cosine_similarity: 0.8775 - loss: -0.6540 - mean_absolute_error: 1.6552 - mean_squared_error: 14.8305 - pearson_correlation: 0.7789 - pearson_correlation_log: 0.6289 - zero_penalty_metric: 319.7846 - val_concordance_correlation_coefficient: 0.6192 - val_cosine_similarity: 0.8614 - val_loss: -0.6110 - val_mean_absolute_error: 1.7039 - val_mean_squared_error: 14.8640 - val_pearson_correlation: 0.7414 - val_pearson_correlation_log: 0.6063 - val_zero_penalty_metric: 323.4451 - learning_rate: 1.0000e-05
Epoch 2/5
[1m4583/4583[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 15ms/step - concordance_correlation_coefficient: 0.6756 - cosine_similarity: 0.8886 - loss: -0.6822 - mean_absolute_error: 1.5894 - mean_squared_error: 13.5486 - pearson_correlation: 0.7978 - pearson_correlation_log: 0.6380 - zero_penalty_metric: 317.6757 - val_concordance_corre

0,1
batch/batch_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
batch/concordance_correlation_coefficient,▁▂▂▃▃▃▃▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇████████
batch/cosine_similarity,▁▁▂▂▂▂▂▃▄▄▄▄▄▄▄▄▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇████████
batch/loss,██▇▇▇▇▆▆▄▅▅▅▅▅▅▅▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
batch/mean_absolute_error,█▇▇▇▆▆▆▆▅▅▅▅▅▅▅▅▃▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁
batch/mean_squared_error,██▇▇▇▆▆▆▅▅▅▄▄▄▄▄▂▃▄▃▃▃▃▃▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁
batch/pearson_correlation,▁▁▁▂▂▂▂▂▄▄▄▄▄▄▄▄▆▆▅▅▆▆▆▆▆▇▇▇▇▇▇▇████████
batch/pearson_correlation_log,▁▂▃▃▃▃▃▃▆▅▅▅▅▅▅▅▆▇▆▆▆▆▆▆▇█▇▇▇▇▇▇████████
batch/zero_penalty_metric,█▆▆▆▆▆▆▆▃▅▅▆▅▅▅▅▆▃▃▄▄▄▄▄▂▂▃▃▃▃▃▃▁▂▁▁▁▁▁▁
epoch/concordance_correlation_coefficient,▁▁▄▄▅▅▇▇██

0,1
batch/batch_step,22940.0
batch/concordance_correlation_coefficient,0.72043
batch/cosine_similarity,0.9053
batch/loss,-0.7224
batch/mean_absolute_error,1.49273
batch/mean_squared_error,12.01441
batch/pearson_correlation,0.82386
batch/pearson_correlation_log,0.64823
batch/zero_penalty_metric,309.91177
epoch/concordance_correlation_coefficient,0.72039


wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
dropped chunk 404 Client Error: Not Found for url: https://api.wandb.ai/files/cas-blaauw/biccn_borzoi_atac/2eppc4j1/file_stream
NoneType: None
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
dropped chunk 404 Client Error: Not Found for url: https://api.wandb.ai/files/cas-blaauw/biccn_borzoi_atac/2eppc4j1/file_stream
NoneType: None
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
dropped chunk 404 Client Error: Not Found for url: https://api.wandb.ai/files/cas-blaauw/biccn_borzoi_atac/2eppc4j1/file_stream
NoneType: None
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
dropped chunk 404 Client Error: Not Found for url: https://api.wandb.ai/files/cas-blaauw/biccn_borzoi_atac/2eppc4j1/file_stream
NoneType: None
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
dropped chunk 404 Client Error: Not Found for url: https://api.