In [1]:
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Cloud AI Alphagenome Fine tuning Notebook
<table align="left">
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Falphagenome%2Fcloudai_alphagenome_finetune.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
</table>

# AlphaGenome Finetuning Tutorial using Google Cloud Platform

This notebook demonstrates how to finetune an AlphaGenome model on custom genomic tracks. This notebook adds the Google Cloud Platform utilities to the Google Deep Mind's 
<a href="https://github.com/google-deepmind/alphagenome_research/blob/main/colabs/finetune.ipynb">AlphaGenome Finetuning Tutorial</a>.

**What you'll learn:**

-   How to define custom track metadata for finetuning
-   How to set up the data pipeline for training
-   How to initialize and configure the model with new output heads
-   How to run the training loop with JAX/Haiku
-   How to use the finetuned model for inference
-   How to save 'finetuned trained checkpoints' and 'prediction vs ground truth plot' in Google Cloud Storage

## Get started
Goal: To fine tune Google DeepMind's AlphaGenome model.

Process: The notebook uses four inputs and produces two outputs. The outputs are all stored in your Google Cloud Project.

Terms and conditions (T&Cs): Agree to T&Cs of the respective organization/websites to download the inputs detailed in the below section.

Disclaimer:
- This is an experimental release.
- Check frequently for updated content.

Inputs and outputs:

* Inputs:
    * bigwig files - Can be downloaded from ENCODE portal.
    * fasta files - Can be downloaded from https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_46/GRCh38.p14.genome.fa.gz
    * sequences_human.bed - Can be downloaded from https://github.com/calico/borzoi/raw/5c9358222b5026abb733ed5fb84f3f6c77239b37/data/sequences_human.bed.gz.
    * initial model weights - Can be downloaded from Huggingface
* Outputs:
    * The fine tuned weights .
    * The predicted vs ground truth plot (.png).

Steps:
* The ***extremely*** important step: Carefully set/modify the variables.
* Checkout the prerequisites.
* Prepare the inputs.
* Follow the notebook and run the cells.
* Utility functions - download the weights after the pipeline is compled


## 0: Prerequisites

- Install AlphaGenome Research and Google Cloud Platform packages.
- (*) Choose either H100 or A100 specific vm notebook runtime.
- (**) Save Huggingface credentials in Google Cloud Secret manager
- (***) Install 0.9.0 jax libraries

Notebook launch:
- Launch the Notebook in Google Cloud Enterprise Colab.
- Use a custom runtime(*) to run the notebook

PS:
- If the cell magic bash pip install errors in the notebook. Run them from command line.
- Restart the session after installing the packages (pay attention to the instruction after you run the pip install cell)

(*):
- Request quota for h100 or 100
- Create a reservation
- Create a template using the above reservation
- Creat runtime using the above template

(**):
You will be downloading weights from Huggingface.
Ensure that:
- You create a token that has 'Read access to contents of all public gated repos you can access' (under Finegrained control)
- You accept the T&C of the [model](https://huggingface.co/google/alphagenome-fold-0).


(***):
Upgrading to 0.9.0 will require multiple runtime restarts.

In [None]:
from IPython.display import clear_output

! PIP_NO_BINARY=pyBigWig pip install git+https://github.com/google-deepmind/alphagenome_research.git
clear_output()

In [None]:
import jax
# We need >0.9.0 jax libs.
# Check the jax version.
import jaxlib

print(f"{jax.__version__=}")
print(f"{jaxlib.__version__=}")

In [None]:
# Uninstall the previous version.
# Run only if version < 0.9.0.
# Restart runtime/kernel after uninstalling.
# Run from next cell after the kernel restart
if jax.__version__ != "0.9.0":
    print(f"Unistalling {jax.__version__}")
    ! pip uninstall -y jax jaxlib jax_cuda12_plugin

In [None]:
# Install specific vesion.
# Run only once after unistalling the jax packages.
# Restart the runtime/kernel.
# Run from next cell after the kernel restart.
!pip install --upgrade jax[cuda12_pip]==0.9.0 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
# Install specific vesion.
# Run only once after the upgrade.
# Restart the runtime/kernel.
# Run from next cell after the kernel restart.
!pip install jax_cuda12_plugin==0.9.0

## 1: Imports

Import the necessary libraries:

-   `alphagenome_research.finetuning` contains the finetuning utilities
-   `alphagenome_research.model` provides the model architecture and metadata
    handling
-   `alphagenome.data` provides genomic data utilities

In [None]:
import dataclasses
import os
import pprint
import subprocess
import time

import haiku as hk
import huggingface_hub
import numpy as np
import optax
import orbax.checkpoint as ocp
import pandas as pd
from alphagenome.data import fold_intervals, genome
from alphagenome.visualization import plot_components
from alphagenome_research.finetuning import dataset as dataset_lib
from alphagenome_research.finetuning import finetune
from alphagenome_research.model import dna_model
from alphagenome_research.model.metadata import metadata as metadata_lib
from etils import epath
from google.cloud import secretmanager
from huggingface_hub import login
from jax import errors
from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P

## 2: Environment Setup

First, we configure TensorFlow to avoid GPU conflicts since we only use it for
data loading (JAX handles the actual training).

In [None]:
import tensorflow as tf

# Hide local GPUs/TPUs. TensorFlow only used for data loading.
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

In [None]:
import jax
# we need >0.9.0 jax libs
# check the jax versions
import jaxlib

print(f"{jax.__version__=}")
print(f"{jaxlib.__version__=}")

In [None]:
# Ensure you are running an Enterprise Colab with a runtime that
# have GPUs


def gpu_info() -> None:
    """Prints the GPU information."""
    try:
        backend = jax.default_backend()
        if backend == "gpu":
            num_gpus = jax.local_device_count()
            print(f"JAX is using GPU backend with {num_gpus} GPU(s).")
        else:
            print(f"JAX default backend is {backend}, not GPU.")
    except errors.JaxRuntimeError as e:
        print(f"JAX runtime error occurred while detecting devices: {e}")


gpu_info()

## 3: Notebook variables setup

### Initialize Google Cloud Platform variables

- `PROJECT_ID`: Google Cloud Project - a string variable.
-  `BUCKET_NAME`: Google Cloud Storage bucket - 'gs://<your bucket name>'

### Intitialize local file directories (No need to change)

- `LOCAL_TAR_DIR`: Where the finetune weights are prepared before pushed to Google Cloud Storage for later reference
- `LOCAL_BIGWIG_DIR`: Where the bigwig files are downloaded to.
- `LOCAL_FASTA_DIR`: Where the fasta files are downloaded to.
- `LOCAL_HUMAN_SEQ_DIR`: Where the human sequence files are download to.
- `SAVE_CHECKPOINT_DIR`: Where the finetune checkpoint is stored

### Model Configuration
Define the key hyperparameters for finetuning:

-   `LEARNING_RATE`: Controls the step size during optimization
-   `MODEL_VERSION`: Which pretrained fold to use (FOLD_0 through FOLD_3)
-   `NUM_TRAIN_STEPS`: Number of training steps for which we optimize the model.
-   `SEQUENCE_LENGTH`: Length of input DNA sequences (1M bp = 2^20). Training
    requires at least 2**17.
-   `BATCH_SIZE`: Number of samples per device.
-   `ORGANISM`: Target organism for predictions. Harded-coded to human for now.

### Google Cloud Project Setup (outside this notebook's scope)
To get started using Google Cloud Platform, you must have an existing Google Cloud project. Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [None]:
# --- Google Cloud Platform variables ---
PROJECT_ID = "<your project>"  # @param {type:'string'}
BUCKET_NAME = "<your bucket>"  # @param {type: 'string'}

# ---- Local file directories ----
LOCAL_TAR_DIR = "/tmp/tar_outputs"
LOCAL_BIGWIG_DIR = "/tmp/bigwig"
LOCAL_FASTA_DIR = "/tmp/fasta"
LOCAL_HUMAN_SEQ_DIR = "/tmp/example_regions_path"
SAVE_CHECKPOINT_DIR = "/tmp/checkpoint"

# --- AlphaGenome finetuning Model params---
LEARNING_RATE = 5e-4
NUM_TRAIN_STEPS = 1000
MODEL_VERSION = dna_model.ModelVersion.FOLD_0
SEQUENCE_LENGTH = int(2**20)
BATCH_SIZE = 1  # Per device
ORGANISM = dna_model.Organism.HOMO_SAPIENS

# --- derived variables ---
from datetime import datetime

GCS_AGFT_NAME = f"ag-ft-{BATCH_SIZE}-{NUM_TRAIN_STEPS}"
GCS_AGFT_NAME = f"{GCS_AGFT_NAME}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
GCS_AGFT_NAME = f"{GCS_AGFT_NAME}.tar.gz"
GCS_AGFT_NAME = GCS_AGFT_NAME.replace("_", "-")
GCS_AGFT_PATH = f"{BUCKET_NAME}/finetune/{GCS_AGFT_NAME}"
GCS_AGFT_PLOT_NAME = GCS_AGFT_NAME.replace(".tar.gz", ".png")
GCS_AGFT_PLOT_PATH = f"{BUCKET_NAME}/finetune/{GCS_AGFT_PLOT_NAME}"

## 4: Inputs preparation

### Input - Fasta files
* Download from [ENCODE](https://www.gencodegenes.org/human/release_46.html) portal directly.
* unzip it
* Create an index file

In [None]:
# Check if the file does NOT exist
if not os.path.exists(LOCAL_FASTA_DIR):
    ! mkdir -p $LOCAL_FASTA_DIR
    ! echo wget -P "$LOCAL_FASTA_DIR" https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_46/GRCh38.p14.genome.fa.gz
    ! wget -P "$LOCAL_FASTA_DIR" https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_46/GRCh38.p14.genome.fa.gz

    # unzip, create an index file, .tar.gz, and upload to gcs
    ! echo gunzip $LOCAL_FASTA_DIR/GRCh38.p14.genome.fa.gz
    ! gunzip $LOCAL_FASTA_DIR/GRCh38.p14.genome.fa.gz

    # install samtools to create the inded file
    ! echo apt install samtools
    ! apt install samtools

    # create the index file
    # the tool creates GRCh38.p14.genome.fa.fai and stores in the same dir.
    ! echo samtools faidx $LOCAL_FASTA_DIR/GRCh38.p14.genome.fa
    ! samtools faidx $LOCAL_FASTA_DIR/GRCh38.p14.genome.fa
    print("All set to use fasta files in the training.")
else:
    print("Going to use already prepared fasta files.")

### Input - sequences_human.bed
* Download from [GitHub Calico](https://github.com/calico/borzoi/raw/5c9358222b5026abb733ed5fb84f3f6c77239b37/data/sequences_human.bed.gz) portal directly.
* unzip it


In [None]:
if not os.path.exists(LOCAL_HUMAN_SEQ_DIR):
    # create dir
    ! echo mkdir $LOCAL_HUMAN_SEQ_DIR

    # download the file
    ! echo wget -P "$LOCAL_HUMAN_SEQ_DIR" https://github.com/calico/borzoi/raw/5c9358222b5026abb733ed5fb84f3f6c77239b37/data/sequences_human.bed.gz
    ! wget -P "$LOCAL_HUMAN_SEQ_DIR" https://github.com/calico/borzoi/raw/5c9358222b5026abb733ed5fb84f3f6c77239b37/data/sequences_human.bed.gz

    # unzip the file
    ! echo gunzip $LOCAL_HUMAN_SEQ_DIR/sequences_human.bed.gz
    ! gunzip $LOCAL_HUMAN_SEQ_DIR/sequences_human.bed.gz
    print("All set to use human sequence file in the training.")
else:
    print("Going to use already prepared human sequence files.")

### Input - BigWig files
* Download from ENCODE portal directly.

In [None]:
if not os.path.exists(LOCAL_BIGWIG_DIR):
    # create the temp dire
    ! mkdir -p $LOCAL_BIGWIG_DIR

    # downloads the big wig files
    ! pushd $LOCAL_BIGWIG_DIR && curl \
        -C - \
        -Z -O https://storage.googleapis.com/alphagenome/reference/encode/hg38/ENCFF018EZY.bigWig \
        -O https://storage.googleapis.com/alphagenome/reference/encode/hg38/ENCFF904TSK.bigWig \
        -O https://storage.googleapis.com/alphagenome/reference/encode/hg38/ENCFF218CLQ.bigWig && popd
    print("All set to use the bigwig files in the training.")
else:
    print("Going to use already prepared bigwig files.")

## 5: Track Metadata

Define which genomic tracks to finetune on. Each track requires:

-   `name`: Human-readable name.
-   `output_type`: Output type of assay (e.g., `RNA_SEQ`, `DNASE`, `CHIP_TF`,
    `ATAC`). One of `dna_model.OutputType`.
-   `strand`: Strand orientation (`+`, `-`, or `.` for unstranded)
-   `nonzero_mean`: Optional mean of non-zero values (used for normalization).
-   `file_path`: Path to the BigWig file containing the track data.

In [None]:
TRACK_METADATA = pd.DataFrame(
    data=[
        [
            "RNA_SEQ",
            "UBERON:0000948 total RNA-seq",
            "+",
            f"{LOCAL_BIGWIG_DIR}/ENCFF018EZY.bigWig",
        ],
        [
            "RNA_SEQ",
            "UBERON:0000948 total RNA-seq",
            "-",
            f"{LOCAL_BIGWIG_DIR}/ENCFF904TSK.bigWig",
        ],
        [
            "DNASE",
            "EFO:0005337 DNase-seq",
            ".",
            f"{LOCAL_BIGWIG_DIR}/ENCFF218CLQ.bigWig",
        ],
    ],
    columns=["output_type", "name", "strand", "file_path"],
)
TRACK_METADATA

### Build Output Metadata

Convert the track DataFrame into an `AlphaGenomeOutputMetadata` object that
configures the model's output heads.

In [None]:
def build_output_metadata(
    track_metadata: pd.DataFrame,
) -> metadata_lib.AlphaGenomeOutputMetadata:
    """Builds AlphaGenomeOutputMetadata from the track metadata DataFrame.

    Args:
      track_metadata: A pandas DataFrame containing metadata for the tracks,
        including 'output_type', 'name', 'strand', and 'file_path'.

    Returns:
      A dict mapping organism to AlphaGenomeOutputMetadata.
    """
    required_cols = {"file_path", "name", "output_type", "strand"}
    if not required_cols.issubset(track_metadata.columns):
        raise ValueError(
            f"track_metadata must have columns {required_cols}. Missing: {required_cols - set(track_metadata.columns)}."
        )
    metadata = {}
    for output_type, df_group in track_metadata.groupby("output_type"):
        try:
            output_type = dna_model.OutputType[str(output_type)]
        except KeyError as e:
            raise ValueError(f"Unknown output_type: {output_type}") from e
        metadata[output_type.name.lower()] = df_group
    return metadata_lib.AlphaGenomeOutputMetadata(**metadata)


output_metadata = {
    dna_model.Organism.HOMO_SAPIENS: build_output_metadata(TRACK_METADATA)
}

## 6: Data Pipeline

Set up the training data iterator. This loads genomic intervals and
corresponding track values from the specified BigWig files.

In [None]:
ds_iter = finetune.get_dataset_iterator(
    batch_size=BATCH_SIZE * jax.local_device_count(),
    sequence_length=SEQUENCE_LENGTH,
    output_metadata=output_metadata[ORGANISM],
    organism=ORGANISM,
    model_version=MODEL_VERSION,
    subset=fold_intervals.Subset.TRAIN,
    fasta_path=f"{LOCAL_FASTA_DIR}/GRCh38.p14.genome.fa",
    example_regions_path=f"{LOCAL_HUMAN_SEQ_DIR}/sequences_human.bed",
)

In [None]:
# validate the shape
batch = next(ds_iter)
pprint.pprint(jax.tree.map(np.shape, batch))

## 7: Model Initialization

Load the pretrained AlphaGenome checkpoint and initialize new output heads for
the finetuning tracks.

You will be downloading weights from Huggingface.
Ensure that:
- You create a token that has 'Read access to contents of all public gated repos you can access' (under Finegrained control)
- You accept the T&C of the [model](https://huggingface.co/google/alphagenome-fold-0).

PS: A "401 Unauthorized" error in the cell below may occur if the Terms and Conditions (T&C) have not been accepted.

In [None]:
# setup Huggingface credential and download the base model weights


def setup_huggingface_auth(secret_id="HUGGINGFACE_API_TOKEN", version_id="latest"):
    """Fetches HF token from Secret Manager and configures auth."""
    try:
        project_id = os.environ.get("GOOGLE_CLOUD_PROJECT")
        print(f"{project_id=}")
        if not project_id:
            try:
                project_id = (
                    subprocess.check_output(
                        ["gcloud", "config", "get-value", "project"]
                    )
                    .decode("utf-8")
                    .strip()
                )
            except subprocess.CalledProcessError:
                print("Could not automatically determine GCP Project ID.")
                return None

        client = secretmanager.SecretManagerServiceClient()
        name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}"
        response = client.access_secret_version(name=name)
        hf_token = response.payload.data.decode("UTF-8").strip()

        if not hf_token:
            print(f"Secret {secret_id} is empty.")
            return None

        print("Hugging Face token retrieved from Secret Manager.")
        return hf_token

    except Exception as e:
        print(f"Error setting up Hugging Face auth: {e}")
        return None


hf_token = setup_huggingface_auth()

if hf_token:
    # Option 1: Log in using huggingface-cli
    # This makes the token available for CLI commands and many libraries.
    try:
        # Use subprocess to handle the interactive nature of login
        process = subprocess.Popen(
            ["huggingface-cli", "login", "--token", hf_token],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
        stdout, stderr = process.communicate()
        if process.returncode == 0:
            print("Hugging Face CLI login successful.")
        else:
            print(f"Hugging Face CLI login failed: {stderr.decode()}")
        print(stdout.decode())

    except FileNotFoundError:
        print("huggingface-cli not found. Make sure huggingface_hub is installed.")
    except Exception as e:
        print(f"Error during huggingface-cli login: {e}")

    # Option 2: Set as environment variable (useful for some tools)
    os.environ["HF_TOKEN"] = hf_token
    print("HF_TOKEN environment variable set.")

    # Option 3: Programmatic login with huggingface_hub

    try:
        login(token=hf_token)
        print("huggingface_hub programmatic login successful.")
    except Exception as e:
        print(f"huggingface_hub login error: {e}")

# Now try accessing the gated model again, for example:
# from huggingface_hub import hf_hub_download
repo = f"google/alphagenome-{MODEL_VERSION.name.lower().replace('_', '-')}"
checkpoint_path = huggingface_hub.snapshot_download(repo_id=repo)
checkpointer = ocp.StandardCheckpointer()
params_base, state_base = checkpointer.restore(checkpoint_path)

### Set Up Device Mesh for Data Parallelism

In [None]:
num_devices = jax.local_device_count()
devices = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices, axis_names=("data",))
data_sharding = P("data")
replicated_sharding = P()

### Initialize New Output Heads

Create the forward function configured for our finetuning tracks and initialize
the new head parameters.

In [None]:
forward_fn = finetune.get_forward_fn(output_metadata)
with jax.set_mesh(mesh):
    batch = jax.device_put(batch, data_sharding)
    params_ft, state_ft = jax.jit(
        forward_fn.init,
        in_shardings=(replicated_sharding, data_sharding),
        out_shardings=replicated_sharding,
    )(jax.random.PRNGKey(0), batch)

### Merge Pretrained Trunk with New Heads

Perform weight surgery: keep the pretrained trunk parameters and replace the
head parameters with the newly initialized ones.

In [None]:
params_ft_head = hk.data_structures.filter(
    lambda module_name, *_: "head" in module_name, params_ft
)
params_base_no_head = hk.data_structures.filter(
    lambda module_name, *_: "head" not in module_name, params_base
)
params = hk.data_structures.merge(params_base_no_head, params_ft_head)
state = state_base
optimizer = optax.chain(
    optax.clip_by_global_norm(0.5),
    optax.adam(LEARNING_RATE),
)
opt_state = optimizer.init(params)
train_step = jax.jit(
    finetune.get_train_step(forward_fn.apply, optimizer),
    in_shardings=(
        replicated_sharding,
        replicated_sharding,
        replicated_sharding,
        data_sharding,
    ),
    out_shardings=(
        replicated_sharding,
        replicated_sharding,
        replicated_sharding,
        replicated_sharding,
    ),
)

## 8: Training Loop

Set up checkpointing and run the finetuning training loop.

### Configure Checkpoint Directory

In [None]:
path_suffix = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = epath.Path(SAVE_CHECKPOINT_DIR) / path_suffix
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_dir = str(checkpoint_dir)
print(f"We will be saving the trained checkpoint at {checkpoint_dir}")

In [None]:
checkpointer = ocp.StandardCheckpointer()


def save(weights, idx):
    ckpt_path = os.path.join(checkpoint_dir, "checkpoint_{:05d}".format(idx))
    print(f"Saving checkpoint to {ckpt_path}")
    checkpointer.save(ckpt_path, weights)
    checkpointer.wait_until_finished()
    print(f"Saved checkpoint to {ckpt_path}")
    return ckpt_path

### Run Training and Save chkpt

In [None]:
loss, times = [], []
for step in range(NUM_TRAIN_STEPS):
    start_time = time.time()
    try:
        batch = next(ds_iter)
    except StopIteration:
        print("Dataset exhausted")
        break
    with jax.set_mesh(mesh):
        batch = jax.device_put(batch, data_sharding)
        params, state, opt_state, scalars = train_step(params, state, opt_state, batch)
    loss.append(scalars["loss"])
    times.append(time.time() - start_time)
    if step % 10 == 1:
        print("loss", step, loss[-1], f"SPS: {1./np.mean(times[1:]):.4f}")

print(f"Total Training time: {np.sum(times[1:]):.4f} seconds")
print(f"Average Training time per step: {np.mean(times[1:]):.4f} seconds")
ckpt_path = save((params, state), step + 1)

### Upload the chkpt to Google Cloud Storage

For future reference and usage, save a copy in the Google Cloud Storage

In [None]:
# Check if the file or directory does NOT exist
# gcs path to where the chkpt will be uploaded
os.makedirs(LOCAL_TAR_DIR, exist_ok=True)

# tar the files
!echo tar -czvf $LOCAL_TAR_DIR/$GCS_AGFT_NAME $ckpt_path
!tar -czvf $LOCAL_TAR_DIR/$GCS_AGFT_NAME $ckpt_path

# list the tar file
!echo ls -l $LOCAL_TAR_DIR/$GCS_AGFT_NAME
!ls -l $LOCAL_TAR_DIR/$GCS_AGFT_NAME

# upload the tar file to gcs bucket
!echo gsutil cp $LOCAL_TAR_DIR/$GCS_AGFT_NAME $GCS_AGFT_PATH
!gsutil cp $LOCAL_TAR_DIR/$GCS_AGFT_NAME $GCS_AGFT_PATH

# list the tar file gcs bucket
!echo gsutil ls $GCS_AGFT_PATH
!gsutil ls $GCS_AGFT_PATH

## 9: Inference with Finetuned Model

Load the finetuned checkpoint into a `DnaModel` for inference and compare
predictions against ground truth.

In [None]:
# Load default organism settings but overwrite with fine-tuned output metadata.
default_settings_human = dna_model.default_organism_settings()[
    dna_model.Organism.HOMO_SAPIENS
]
settings_human_finetune = dataclasses.replace(
    default_settings_human,
    metadata=output_metadata[dna_model.Organism.HOMO_SAPIENS],
)
model = dna_model.create(
    ckpt_path,
    organism_settings={dna_model.Organism.HOMO_SAPIENS: settings_human_finetune},
)

### Select Test Interval

In [None]:
interval = genome.Interval(chromosome="chr21", start=46125238, end=46126738).resize(
    SEQUENCE_LENGTH
)

In [None]:
preds = model.predict_interval(
    interval,
    requested_outputs=[dna_model.OutputType.RNA_SEQ],
    ontology_terms=None,
)
preds

### Load Ground Truth Tracks

In [None]:
true_tracks = dataset_lib.MultiTrackExtractor(
    output_metadata[ORGANISM], sequence_length=SEQUENCE_LENGTH
).extract(interval)

### Visualize Predictions vs Ground Truth

In [None]:
def compact_dict(**kwargs):
    return {k: v for k, v in kwargs.items() if v is not None}


def plot(*, interval, predictions, targets=None):
    if targets is None:
        colors = {"pred": "black"}
    else:
        colors = {"pred": "black", "true": "red"}
    fig = plot_components.plot(
        [
            plot_components.OverlaidTracks(
                tdata=compact_dict(
                    pred=predictions.rna_seq,
                    true=(
                        dataclasses.replace(
                            predictions.rna_seq,
                            values=targets["rna_seq"].astype(np.float32),
                        )
                        if targets is not None
                        else None
                    ),
                ),
                colors=colors,
                shared_y_scale=True,
            ),
        ],
        interval=interval.resize(int(2**11)),
    )
    return fig


pred_vs_groundtruth_fig = plot(
    predictions=preds, interval=interval, targets=true_tracks
)

### Save the plot and upload to Google Cloud Storage for latter reference

In [None]:
# Save the figure to a file
file_path = (
    f"{LOCAL_TAR_DIR}/{GCS_AGFT_PLOT_NAME}"  # Choose your desired filename and format
)
pred_vs_groundtruth_fig.savefig(file_path)
print(f"Plot saved to {file_path}")

# upload the image file to gcs bucket
!echo gsutil cp $file_path $GCS_AGFT_PLOT_PATH
!gsutil cp $file_path $GCS_AGFT_PLOT_PATH

# list the tar file gcs bucket
!echo gsutil ls $GCS_AGFT_PLOT_PATH
!gsutil ls $GCS_AGFT_PLOT_PATH