In [1]:
import glob, os, subprocess
from IPython.display import IFrame

In [5]:
# get working directory
print(f"Current working directory: {os.getcwd()}")
# change to parent directory
os.chdir("..")
print(f"Changed working directory to: {os.getcwd()}")

Current working directory: /Users/library/Documents/ML_project/basenji_genomics_ai/experiments/reproduction/notebooks
Changed working directory to: /Users/library/Documents/ML_project/basenji_genomics_ai/experiments/reproduction


## Precursors

To train a model, you first need to convert your sequences and targets into the input HDF5 format. Check out my tutorials for how to do that; they're linked from the [main page](../README.md).

For this tutorial, grab a small example HDF5 that I constructed here with 10% of the training sequences and only GM12878 targets for various DNase-seq, ChIP-seq, and CAGE experiments.

In [None]:
if len(glob.glob('data/heart_l131k/tfrecords/*.tfr')) == 0:
    subprocess.call('curl -o data/heart_l131k.tgz https://storage.googleapis.com/basenji_tutorial_data/heart_l131k.tgz', shell=True)
    subprocess.call('tar -xzvf data/heart_l131k.tgz', shell=True)

## Train

Next, you need to decide what sort of architecture to use. This grammar probably needs work; my goal was to enable hyperparameter searches to write the parameters to file so that I could run parallel training jobs to explore the hyperparameter space. I included an example set of parameters that will work well with this data in models/params_small.txt.

Then, run [basenji_train.py](https://github.com/calico/basenji/blob/master/bin/basenji_train.py) to train a model. The program will offer training feedback via stdout and write the model output files to the prefix given by the *-s* parameter.

The most relevant options here are:

| Option/Argument | Value | Note |
|:---|:---|:---|
| -o | models/heart | Directory to save training logs and model checkpoints. |
| params_file | models/params_small.json | JSON specified parameters to setup the model architecture and optimization. |
| data_dir | data/heart_l131k | Data directory containing the test input and output datasets as generated by [basenji_data.py](https://github.com/calico/basenji/blob/master/bin/basenji_data.py) |

In [None]:
# Don't run training
# it takes too long
# best model was downloaded below
! python basenji_train.py -o models/heart models/params_small.json data/heart_l131k

# Find what's using port 8080
lsof -ti:8080

# Kill the process
kill -9 $(lsof -ti:8080)

# Then try starting MLflow again
mlflow server --host 127.0.0.1 --port 8080

If you want to train, uncomment the following line and run it. Depending on your hardware, it may require several hours.

## Test

Alternatively, you can just download a trained model.

In [7]:
if not os.path.isdir('models/heart'):
    os.makedirs('models/heart')
if not os.path.isfile('models/heart/model_best.h5'):
    subprocess.call('curl -o models/heart/model_best.h5 https://storage.googleapis.com/basenji_tutorial_data/model_best.h5', shell=True)

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1157k  100 1157k    0     0   442k      0  0:00:02  0:00:02 --:--:--  444k


models/heart/model_best.tf will now specify the name of your saved model to be provided to other programs.

To further benchmark the accuracy (e.g. computing significant "peak" accuracy), use [basenji_test.py](https://github.com/calico/basenji/blob/master/bin/basenji_test.py).

The most relevant options here are:

| Option/Argument | Value | Note |
|:---|:---|:---|
| --ai | 0,1,2 | Make accuracy scatter plots for targets 0, 1, and 2. |
| -o | output/heart_test | Output directory. |
| --rc | | Average the forward and reverse complement to form an ensemble predictor. |
| --shifts | | Average sequence shifts to form an ensemble predictor. |
| params_file | models/params_small.json | JSON specified parameters to setup the model architecture and optimization. |
| model_file | models/heart/model_best.h5 | Trained saved model parameters. |
| data_dir | data/heart_l131k | Data directory containing the test input and output datasets as generated by [basenji_data.py](https://github.com/calico/basenji/blob/master/bin/basenji_data.py) |

In [None]:
! python basenji_test.py --ai 0,1,2 -o output/heart_test --rc --shifts "1,0,-1" models/params_small.json models/heart/model_best.h5 data/heart_l131k

2025-09-29 06:53:47.104635: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 sequence (InputLayer)       [(None, 131072, 4)]          0         []                            
                                                                                                  
 stochastic_reverse_complem  ((None, 131072, 4),          0         ['sequence[0][0]']            
 ent (StochasticReverseComp   ())                                                                 
 lement)                                                                     

In [17]:
# test with mlflow tracking
! python mlflow_test_tracking.py

2025/09/29 11:28:47 INFO mlflow.tracking.fluent: Experiment with name 'basenji_genomics' does not exist. Creating a new experiment.
Running basenji_test.py...
Logged metrics from output/heart_test/acc.txt
Testing completed! Run ID: 9cbf68286aa54dedb173d006c1cadc86
Test results logged to MLflow
Artifacts: output/heart_test
2025/09/29 11:29:03 INFO mlflow.tracking._tracking_service.client: 🏃 View run basenji_heart_testing at: http://127.0.0.1:8080/#/experiments/567337574845174630/runs/9cbf68286aa54dedb173d006c1cadc86.
2025/09/29 11:29:03 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:8080/#/experiments/567337574845174630.


# view mlflow results
http://127.0.0.1:8080

*data/heart_test/acc.txt* is a table specifiying the Pearson correlation and R2 for each dataset. 

In [13]:
! cat output/heart_test/acc.txt

index	pearsonr	r2	identifier	description
0	0.51173	0.19405	CNhs11760	aorta
1	0.64497	0.39054	CNhs12843	artery
2	0.50629	0.20107	CNhs12856	pulmonic_valve


The directories *pr*, *roc*, *violin*, and *scatter* in *data/heart_test* contain plots for the targets indexed by 0, 1, and 2 as specified by the --ai option above.

E.g.

In [None]:
IFrame('output/heart_test/pr/t0.pdf', width=600, height=500)