# Finetune and Run Zero-Shot Prediction with Evo 2

This tutorial shows shows how to finetune Evo2 using BioNeMo Framework to be robust for both fp8 and bf16 datatypes. It will then demonstrate using the finetuned model for zero shot prediction of gene variant effects. After completing the finetuning, the model should have good luck with downstream sequence scoring tasks. 

### Motivation
Evo2 is a foundation AI model trained on 9.3 trillion DNA base pairs, predicting variant effects without prior task-specific training. 

The [public HuggingFace Evo2 1b model checkpoint](https://huggingface.co/arcinstitute/evo2_1b_base) is sensitive to the `--fp8` datatype in training. This can cause zero shot inference to produce near random AUCs if you do not use `--fp8`. 
If you want to infer or score new data, you need fp8 enabled since the fp8 datatype was used to train the original model. If you do not use the fp8 datatype, the output that you get from scoring sequences with
sensitive checkpoints may not be biologically meaningful. 

Note: This issue does not occur with the larger 7b and 40b parameter model. Those versions of the model can be used directly as the checkpoint for the zero-shot prediction section of this example. The finetuning section of this tutorial can be skipped if on FP8 compatible hardware. 

### Requirements
* GPU: An NVIDIA GPU with approximately 45GB of RAM
* Container: `nvcr.io/nvidia/clara/bionemo-framework:2.7`
* Storage: 50 GB of Persistant Storage for the Checkpoints and Datasets

In [None]:
import os
from bionemo.core.utils.subprocess_utils import run_subprocess_safely
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import tensorboard.backend.event_processing.event_accumulator as event_accumulator

In [None]:
DATA_DIR = os.environ.get('DATA_DIR', os.path.abspath("datasets") ) ### change this line if you want to use a different DATA_DIR path
RESULTS_DIR = os.environ.get('RESULTS_DIR', os.path.abspath("results") ) ### change this line if you want to use a different RESULTS_DIR path

### Finetuning EVO2

#### Preprocess and Configure Training Data
Evo2 uses megatron style datasets behind the scenes with advanced support for randomly indexing into documents, and
packing documents together into batches at scale. The file-formats backing these datasets is not a standard biological format like fasta for representing genomes. 

To start, we will need to preprocess the fasta files into the required data format for downstream handling.

This should be done by:
1. Acquiring fasta files
2. Writing a config script defining how you want the processed files to be generated from the fasta files. This is where
  you specify top level train/validation/test splitting decisions.
3. Calling the actual `preprocess_evo2` script to generate the results.

The next 4 cells go through this process on a set of 3 smaller human chromosomes. At least 3 fasta records need to be present,
one for the train, validation, and test split. To use your own dataset, set the `full_fasta_path` variable to the path of your dataset.

In [None]:
FINETUNE_DATA_DIR=os.path.join(DATA_DIR, "finetune_evo2")
FASTA_PATH = os.path.join(FINETUNE_DATA_DIR,"chr20_21_22.fa")

if not os.path.exists(FASTA_PATH):
    !wget -P {FINETUNE_DATA_DIR} https://hgdownload.soe.ucsc.edu/goldenpath/hg38/chromosomes/chr20.fa.gz
    !wget -P {FINETUNE_DATA_DIR} https://hgdownload.soe.ucsc.edu/goldenpath/hg38/chromosomes/chr21.fa.gz
    !wget -P {FINETUNE_DATA_DIR} https://hgdownload.soe.ucsc.edu/goldenpath/hg38/chromosomes/chr22.fa.gz
    !zcat {FINETUNE_DATA_DIR}/chr20.fa.gz > {FINETUNE_DATA_DIR}/chr20.fa
    !zcat {FINETUNE_DATA_DIR}/chr21.fa.gz > {FINETUNE_DATA_DIR}/chr21.fa
    !zcat {FINETUNE_DATA_DIR}/chr22.fa.gz > {FINETUNE_DATA_DIR}/chr22.fa
    !cat {FINETUNE_DATA_DIR}/chr20.fa {FINETUNE_DATA_DIR}/chr21.fa {FINETUNE_DATA_DIR}/chr22.fa > {FASTA_PATH}

In [None]:
FASTA_DATA_PATH = FASTA_PATH ### change this line if you are using your own dataset
PREPROCESSED_DATA_DIR = os.path.join(FINETUNE_DATA_DIR, "preprocessed_data") ### folder path to store the preprocessed data

output_yaml = f"""
- datapaths: ["{FASTA_DATA_PATH}"]
  output_dir: "{PREPROCESSED_DATA_DIR}"
  output_prefix: chr20_21_22_uint8_distinct
  train_split: 0.9
  valid_split: 0.05
  test_split: 0.05
  overwrite: True
  embed_reverse_complement: true
  random_reverse_complement: 0.0
  random_lineage_dropout: 0.0
  include_sequence_id: false
  transcribe: "back_transcribe"
  force_uppercase: false
  indexed_dataset_dtype: "uint8"
  tokenizer_type: "Byte-Level"
  vocab_file: null
  vocab_size: null
  merges_file: null
  pretrained_tokenizer_model: null
  special_tokens: null
  fast_hf_tokenizer: true
  append_eod: true
  enforce_sample_length: null
  ftfy: false
  workers: 1
  preproc_concurrency: 100000
  chunksize: 25
  drop_empty_sequences: true
  nnn_filter: false  # If you split your fasta on NNN (in human these are contigs), then you should set this to true.
  seed: 12342  # Not relevant because we are not using random reverse complement or lineage dropout.
"""
with open("preprocess_config.yaml", "w") as f:
    print(output_yaml, file=f)

In [None]:
%%bash
preprocess_evo2 --config preprocess_config.yaml

In [None]:
!ls -lh {PREPROCESSED_DATA_DIR}

After preprocessing, there should be a collection of bin/idx files created in the preprocessed_data directory.

##### Configure the Training Dataset
To configure your training dataset, we create a yaml file specifying the paths for the training data we downloaded and preprocessed above.

In [None]:
output_pfx = os.path.join(PREPROCESSED_DATA_DIR,"chr20_21_22_uint8_distinct_byte-level")
output_yaml = f"""
- dataset_prefix: {output_pfx}_train
  dataset_split: train
  dataset_weight: 1.0
- dataset_prefix: {output_pfx}_val
  dataset_split: validation
  dataset_weight: 1.0
- dataset_prefix: {output_pfx}_test
  dataset_split: test
  dataset_weight: 1.0
"""
with open("training_data_config.yaml", "w") as f:
    print(output_yaml, file=f)

#### Specify or convert initial checkpoint
The main difference between pre-training and fine-tuning is whether or not you decide to start training the model with
weights from a prior training run. For this tutorial we want to finetune a `1b` checkpoint from HuggingFace that is known
(at the time of this writing) to be sensitive to GPU architecture so that it will work with your architecture. 

The following step will use a BioNeMo Framework
script to download and convert a savanna format evo2 checkpoint from HuggingFace, and output that into a NeMo2
format checkpoint directory that can be used as the starting point for a fine-tuning run.

This conversion script can also be used for the 7b and 40b models by changing `MODEL_SIZE`

In [None]:
MODEL_SIZE = "1b"
CHECKPOINT_PATH = Path(f"nemo2_evo2_{MODEL_SIZE}_8k")

if not CHECKPOINT_PATH.exists() or not any(CHECKPOINT_PATH.iterdir()):
    !evo2_convert_to_nemo2 \
      --model-path hf://arcinstitute/savanna_evo2_{MODEL_SIZE}_base \
      --model-size {MODEL_SIZE} --output-dir nemo2_evo2_{MODEL_SIZE}_8k
else:
    print("Checkpoint directory is not empty. Skipping command.")

CHECKPOINT_PATH=str(CHECKPOINT_PATH)

#### Setting up WANDB for experiment tracking
Now we are almost ready to start training our model. This example can use [Weights & Biases](https://wandb.ai/site) to manage our experiment tracking. The following steps set up your WANDB integration. This is not necessary to run the finetuning example and can be skipped. 

To use WANDB, set your WANDB_API_KEY in your environment variables or in the code below.

In [None]:
WANDB_API_KEY = os.environ.get('WANDB_API_KEY', '')
WANDB_PROJECT = os.environ.get('WANDB_PROJECT', 'evo2-finetuning')
WANDB_ENTITY = os.environ.get('WANDB_ENTITY', '')
try:
    import wandb
    if WANDB_API_KEY:
        wandb.login(key=WANDB_API_KEY)
        print("W&B login successful.")
        
        WANDB_ARGS = f" --wandb-project {WANDB_PROJECT}"
        WANDB_ARGS += f" --wandb-entity {WANDB_ENTITY}" if WANDB_ENTITY else ""
    else:
        print("WANDB_API_KEY not set; skipping W&B login.")
except Exception as e:
    print("W&B login skipped:", e)


#### Run Finetuning
Evo2 training and fine-tuning follow the same set of steps, so we use the same train_evo2 command.

The main difference is the --ckpt-dir argument will point to a pre-existing checkpoint from some other training run.

Note: If you have multiple GPUs with less
memory, or you are having trouble with CUDA OOM at the training step below, try reducing the `--micro-batch-size` and/or
increasing the number of `--devices [int]` to match your setup and also setting `--tensor-parallel-size [int]` to
the number of devices. This should split up most of the model evenly between your devices, which will require much less memory. 

When we train the 1b model, we typically have the micro batch size set to 8, and run without model parallelism on available devices to achieve the largest possible global batch size.

The training will need to be run for more than 100 steps on 8 GPUs to get training loss on the 1b checkpoint to the 1.08 range.

Modify the next cell to set the `MAX_STEPS` for training.

In [None]:
MAX_STEPS: int = 100
FINETUNE_RESULTS_DIR = os.path.join(RESULTS_DIR, "finetuning_demo_results")

VAL_CHECK_INTERVAL = min(int(MAX_STEPS // 2), 50)
WARMUP_STEPS = min(MAX_STEPS, 100)

In [None]:
train_cmd = f"""train_evo2 \
    -d training_data_config.yaml \
    --dataset-dir {PREPROCESSED_DATA_DIR} \
    --result-dir {FINETUNE_RESULTS_DIR} \
    --experiment-name evo2 \
    --model-size 1b \
    --devices 1 \
    --num-nodes 1 \
    --seq-length 8192 \
    --micro-batch-size 2 \
    --lr 0.000015 \
    --min-lr 0.0000149 \
    --warmup-steps {WARMUP_STEPS} \
    --grad-acc-batches 4 \
    --max-steps {MAX_STEPS} \
    --ckpt-dir {CHECKPOINT_PATH} \
    --clip-grad 250 \
    --wd 0.001 \
    --attention-dropout 0.01 \
    --hidden-dropout 0.01 \
    --val-check-interval {VAL_CHECK_INTERVAL} \
    --activation-checkpoint-recompute-num-layers 5 \
    --create-tensorboard-logger \
    --ckpt-async-save"""

try:
    train_cmd += WANDB_ARGS
except:
    print("Training without WANDB enabled")

print(f"Running command: {train_cmd}")

result = run_subprocess_safely(train_cmd)

#### Visualize EVO2 Training using Tensorboard
These visualizations can also be shown using WANDB and can be skipped if you are using WANDB for experiment tracking. The following function plots the tensorboard dataframe that was created during finetuning.

The generated figures will show various training metrics per step:
* `reduced_train_loss` captures the training loss. On larger runs you want to see the loss drop to about 1.08 consistently
  for the 1b checkpoint.
* `lr` shows the learning rate schedule for training. Typically we do a linear warmup schedule followed by a cosine decay.
  This small notebook tutorial just goes through the initial warmup period.
* `grad_norm` shows the gradient norm of the full model. As the model fits the data better you should see this value drop
  down below 1.0 consistently. 
* `val_loss` shows the same kind of loss shown in `reduced_train_loss` but for a held-out set of validation samples. If you
  ever train the model a very long time and see this start to go up while the training loss continues to drop that's a sign
  of over-fitting. We have not yet seen this happen. Small fluctuations up and down are expected during training.

In [None]:
# Function to extract data from TensorBoard event files and convert to DataFrame
def tensorboard_to_dataframe(event_file):
    """Given a TensorBoard event file, return a pandas DataFrame with the training metrics."""
    # Load the event file
    ea = event_accumulator.EventAccumulator(
        event_file,
        size_guidance={
            event_accumulator.SCALARS: 0,  # 0 means load all
        },
    )
    ea.Reload()

    # Get list of all available tags
    tags = ea.Tags()["scalars"]

    # First, find the union of all steps
    all_steps = set()
    for tag in tags:
        events = ea.Scalars(tag)
        steps = [event.step for event in events]
        all_steps.update(steps)

    # Sort steps for proper ordering
    all_steps = sorted(all_steps)

    # Initialize the dataframe with steps
    df = pd.DataFrame({"step": all_steps})

    # Add each metric as a column
    for tag in tags:
        events = ea.Scalars(tag)
        # Create a dictionary mapping steps to values
        step_to_value = {event.step: event.value for event in events}
        # Add the values to the dataframe, using NaN for missing steps
        df[tag] = df["step"].map(step_to_value)

    return df


# Example of creating a multi-metric plot with seaborn
def plot_multiple_training_metrics(df, metrics_to_plot, figsize=(15, 10)):
    """Given a pandas DataFrame with the training metrics, plot the metrics."""
    n = len(metrics_to_plot)
    fig, axes = plt.subplots(n, 1, figsize=figsize, sharex=True)

    if n == 1:  # Handle the case of a single plot
        axes = [axes]

    sns.set_style("whitegrid")

    for i, metric in enumerate(metrics_to_plot):
        if metric in df.columns:
            sns.lineplot(x="step", y=metric, data=df, ax=axes[i], linewidth=2.5, errorbar="sd")
            axes[i].set_title(metric, fontsize=14)
            axes[i].set_ylabel("Value", fontsize=12)
    axes[-1].set_xlabel("Steps", fontsize=14)
    plt.tight_layout()
    plt.show()

In [None]:
# Get the TensorBoard event file for the training run
log_dirs = !find {FINETUNE_RESULTS_DIR}/evo2/dev -name "events.out.tfevents*"
tf_event_file = log_dirs[0]

# Extract data from your event file
df = tensorboard_to_dataframe(tf_event_file)
# You can uncomment and modify this to plot multiple metrics once you see what's available
plot_multiple_training_metrics(df, ["reduced_train_loss", "lr", "grad_norm", "val_loss"])

#### Full Training Example
On a small number of devices, or with the small demo fasta we provided in this tutorial, it's possible you are not at the needed
1.08 loss level to get good downstream accuracy out of this checkpoint. You can try increasing the `MAX_STEPS` parameter in the training cell,
or running a larger cluster with more GPUs.

The following bash command can be used in the Lepton Batch Jobs feature to run a multinode distributed training to get the needed 1.08 `reduced_training_loss`. This may mean you need to finetune on the entire small demo fasta that is provided.

Run the following bash script with the built-in PyTorch template, the `nvcr.io/nvidia/clara/bionemo-framework:2.7` image, your desired number of nodes with your storage mounted for the results directory.

Make sure to set environment variables to the appropriate paths for your datasets and results so they are persistently stored. 

In [None]:
# Printing the necessary environment variables
print("DATASET_CONFIG_PATH=training_data_config.yaml")
print("PREPROCESSED_DATA_DIR={PREPROCESSED_DATA_DIR}")

In [None]:
%%bash
### Copy this into a batch job to run multinode distributed training

torchrun
/workspace/bionemo2/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py \
-d {DATASET_CONFIG_PATH} \
--dataset-dir {PREPROCESSED_DATA_DIR} \
--result-dir {FINETUNE_RESULTS_DIR} \
--experiment-name evo2 \
--model-size 1b \
--devices {PET_NPROC_PER_NODE} \
--num-nodes {PET_NNODES} \
--seq-length 8192 \
--micro-batch-size 2 \
--lr 0.000015 \
--min-lr 0.0000149 \
--warmup-steps 100 \
--grad-acc-batches 1 \
--max-steps 15000 \
--ckpt-dir nemo2_evo2_1b_8k \
--clip-grad 250 \
--wd 0.001 \
--attention-dropout 0.01 \
--hidden-dropout 0.01 \
--val-check-interval 100 \
--create-tensorboard-logger \
--ckpt-async-save \
--tensor-parallel-size=1 \
--context-parallel-size=1 \
--pipeline-model-parallel-size=1 \
--wandb-project={WANDB_PROJECT} \
--wandb-entity={WANDB_ENTITY}


Once this is done, set the `final_ckpt_path` variable to the path of your selected checkpoint for use in prediction. This can be done using the cell below. 

In [None]:
final_ckpt_paths = !ls -d {FINETUNE_RESULTS_DIR}/evo2/checkpoints/*-last
final_ckpt_path = final_ckpt_paths[-1]
final_ckpt_path

## Zero-shot prediction of BRCA1 variant effects with fine-tuned Evo 2

Now that we have fine-tuned our Evo2 model, let's demonstrate its capabilities by performing zero-shot prediction of BRCA1 variant effects. This section reproduces the analysis from The Arc Institute's BRCA1 tutorial, but using our fine-tuned checkpoint.

*Note - this section is based on The Arc Institute's notebook [here](https://github.com/ArcInstitute/evo2/blob/main/notebooks/brca1/brca1_zero_shot_vep.ipynb), adapted to use our fine-tuned BioNeMo 2 implementation of Evo2.*

Evo2 can predict variant effects without prior task-specific training. The human *BRCA1* gene encodes for a protein that repairs damaged DNA ([Moynahan et al., 1999](https://www.cell.com/molecular-cell/fulltext/S1097-2765%2800%2980202-6)). Certain variants of this gene have been associated with an increased risk of breast and ovarian cancers ([Miki et al., 1994](https://www.science.org/doi/10.1126/science.7545954?url_ver=Z39.88-2003&rfr_id=ori:rid:crossref.org&rfr_dat=cr_pub%20%200pubmed)). 

Using our fine-tuned Evo 2, we can predict whether a particular single nucleotide variant (SNV) of the *BRCA1* gene is likely to be harmful to the protein's function, and thus potentially increase the risk of cancer for the patient with the genetic variant.


In [None]:
%pip install biopython openpyxl

import glob
import gzip
import json
import math

import torch
from Bio import SeqIO
from sklearn.metrics import auc, roc_auc_score, roc_curve


#### Load BRCA1 Dataset and Reference Genome

We start by loading a dataset from [Findlay et al. (2018)](https://www.nature.com/articles/s41586-018-0461-z), which contains experimentally measured function scores of 3,893 *BRCA1* SNVs. These function scores reflect the extent by which the genetic variant has disrupted the protein's function, with lower scores indicating greater disruption. In this dataset, the SNVs are classified into three categories based on their function scores: `LOF` (loss-of-function), `INT` (intermediate), and `FUNC` (functional).

The functions download the excel and fasta files from the ARC Institute EVO2 repository and load them into Pandas Dataframes.

In [None]:
def download_data(data_dir="brca1", commit_hash="3819474bee6c24938016614411f1fa025e542bbe"):
    """Download required data files if they don't exist locally.

    Parameters:
    -----------
    data_dir : str
        Directory to store downloaded files
    commit_hash : str
        GitHub commit hash for data version
    """
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    excel_path = os.path.join(data_dir, "41586_2018_461_MOESM3_ESM.xlsx")
    genome_path = os.path.join(data_dir, "GRCh37.p13_chr17.fna.gz")

    if not os.path.exists(excel_path):
        os.system(
            f"wget https://github.com/ArcInstitute/evo2/raw/{commit_hash}/notebooks/brca1/41586_2018_461_MOESM3_ESM.xlsx -O {excel_path}"
        )

    if not os.path.exists(genome_path):
        os.system(
            f"wget https://github.com/ArcInstitute/evo2/raw/{commit_hash}/notebooks/brca1/GRCh37.p13_chr17.fna.gz -O {genome_path}"
        )

    return excel_path, genome_path


def load_genome_sequence(genome_path):
    """Load genome sequence from FASTA file.

    Parameters:
    -----------
    genome_path : str
        Path to the genome FASTA file

    Returns:
    --------
    str
        Genome sequence string
    """
    with gzip.open(genome_path, "rt") as handle:
        for record in SeqIO.parse(handle, "fasta"):
            return str(record.seq)

    raise ValueError("Failed to parse genome sequence")


def load_brca1_data(excel_path):
    """Load and preprocess BRCA1 data from Excel file.

    Parameters:
    -----------
    excel_path : str
        Path to the Excel file

    Returns:
    --------
    pandas.DataFrame
        Processed BRCA1 dataframe
    """
    # Load the dataframe
    brca1_df = pd.read_excel(excel_path, header=2)

    # Select and rename columns
    brca1_df = brca1_df[
        [
            "chromosome",
            "position (hg19)",
            "reference",
            "alt",
            "function.score.mean",
            "func.class",
        ]
    ]

    brca1_df.rename(
        columns={
            "chromosome": "chrom",
            "position (hg19)": "pos",
            "reference": "ref",
            "alt": "alt",
            "function.score.mean": "score",
            "func.class": "class",
        },
        inplace=True,
    )

    # Convert to two-class system
    brca1_df["class"] = brca1_df["class"].replace(["FUNC", "INT"], "FUNC/INT")

    return brca1_df

To make things run faster, we'll just look at a balanced sample of our data. If you want to run on the full dataset, set `disable: True` in the `SAMPLE_CONFIG`

In [None]:
# Configuration parameters
BRACA1_DATA_DIR = os.path.join(DATA_DIR, "brca1")
SAMPLE_CONFIG = {"sample_frac": 0.05, "balanced": True, "disable": False, "random_state": 42}

# 1. Download the necessary data files if not present
excel_path, genome_path = download_data(BRACA1_DATA_DIR)
seq_chr17 = load_genome_sequence(genome_path)

# 2. Load and preprocess BRCA1 data
brca1_df = load_brca1_data(excel_path)


### Data Sampling and Sequence Preparation

We group the `FUNC` and `INT` classes of SNVs together into a single category (`FUNC/INT`). To make things run faster, we'll look at a balanced sample of our data. We then build functions to parse the reference and variant sequences of an 8,192-bp window around the genomic position of each SNV, using the reference sequence of human chromosome 17 where *BRCA1* is located.

In [None]:
def sample_data(df, sample_frac=1.0, balanced=True, disable=False, random_state=42):
    """Sample dataframe, optionally with balanced classes.

    Parameters:
    -----------
    df : pandas.DataFrame
        Input dataframe
    sample_frac : float
        Fraction of data to sample
    balanced : bool
        Whether to balance classes
    disable : bool
        Whether to disable sampling
    random_state : int
        Random seed for reproducibility

    Returns:
    --------
    pandas.DataFrame
        Sampled dataframe
    """
    if disable:
        return df

    if balanced:
        # Get the number of rows in the dataframe
        num_rows_minor_class = math.ceil(len(df[df["class"] == "LOF"]) * sample_frac)
        return (
            pd.concat(
                [
                    df[df["class"] == "LOF"].sample(n=num_rows_minor_class, random_state=random_state),
                    df[df["class"] == "FUNC/INT"].sample(n=num_rows_minor_class, random_state=random_state),
                ]
            )
            .sample(frac=1.0, random_state=random_state)
            .reset_index(drop=True)
        )
    else:
        # Calculate the number of rows to sample
        return df.sample(frac=sample_frac, random_state=random_state).reset_index(drop=True)


In [None]:
FASTA_OUTPUT_DIR = os.path.join(RESULTS_DIR, "brca1_fasta_files")

brca1_df = sample_data(
    brca1_df,
    sample_frac=SAMPLE_CONFIG["sample_frac"],
    balanced=SAMPLE_CONFIG["balanced"],
    disable=SAMPLE_CONFIG["disable"],
    random_state=SAMPLE_CONFIG["random_state"],
)

print(brca1_df.head(5))

Next, we'll write these to local .fasta files so we can use them for prediction below.

In [None]:
def parse_sequences(pos, ref, alt, seq_chr17, window_size=8192):
    """Parse reference and variant sequences from the reference genome sequence.

    Parameters:
    -----------
    pos : int
        Position (1-indexed)
    ref : str
        Reference base
    alt : str
        Alternate base
    seq_chr17 : str
        Full chromosome 17 sequence
    window_size : int
        Size of the sequence window to extract

    Returns:
    --------
    tuple
        (reference_sequence, variant_sequence)
    """
    p = pos - 1  # Convert to 0-indexed position
    full_seq = seq_chr17

    ref_seq_start = max(0, p - window_size // 2)
    ref_seq_end = min(len(full_seq), p + window_size // 2)
    ref_seq = seq_chr17[ref_seq_start:ref_seq_end]
    snv_pos_in_ref = min(window_size // 2, p)
    var_seq = ref_seq[:snv_pos_in_ref] + alt + ref_seq[snv_pos_in_ref + 1 :]

    # Sanity checks
    assert len(var_seq) == len(ref_seq)
    assert ref_seq[snv_pos_in_ref] == ref
    assert var_seq[snv_pos_in_ref] == alt

    return ref_seq, var_seq


def generate_fasta_files(df, seq_chr17, output_dir="brca1_fasta_files", window_size=8192):
    """Generate FASTA files for reference and variant sequences.

    Parameters:
    -----------
    df : pandas.DataFrame
        Dataframe with variant information
    seq_chr17 : str
        Chromosome 17 sequence
    output_dir : str
        Output directory for FASTA files
    window_size : int
        Size of sequence window

    Returns:
    --------
    pandas.DataFrame
        Dataframe with added columns for FASTA names
    """
    # Create output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Paths for output files
    ref_fasta_path = output_dir / "brca1_reference_sequences.fasta"
    var_fasta_path = output_dir / "brca1_variant_sequences.fasta"

    # Track unique sequences
    ref_sequences = set()
    var_sequences = set()
    ref_seq_to_name = {}

    # Store unique sequences with metadata for writing
    ref_entries = []
    var_entries = []
    ref_names = []
    var_names = []

    # Collect unique reference and variant sequences
    for idx, row in df.iterrows():
        ref_seq, var_seq = parse_sequences(row["pos"], row["ref"], row["alt"], seq_chr17, window_size)

        # Add to sets to ensure uniqueness
        if ref_seq not in ref_sequences:
            ref_sequences.add(ref_seq)
            ref_name = f"BRCA1_ref_pos_{row['pos']}_{row['ref']}_class_{row['class']}"

            ref_entries.append(f">{ref_name}\n{ref_seq}\n")
            ref_names.append(ref_name)
            ref_seq_to_name[ref_seq] = ref_name
        else:
            ref_name = ref_seq_to_name[ref_seq]
            ref_names.append(ref_name)

        if var_seq not in var_sequences:
            var_sequences.add(var_seq)
            var_name = f"BRCA1_var_pos_{row['pos']}_{row['ref']}to{row['alt']}_class_{row['class']}"

            var_entries.append(f">{var_name}\n{var_seq}\n")
            var_names.append(var_name)
        else:
            assert False, "Duplicate variant sequence"

    # Write unique sequences to FASTA files
    with open(ref_fasta_path, "w") as f:
        f.writelines(ref_entries)

    with open(var_fasta_path, "w") as f:
        f.writelines(var_entries)

    # Add FASTA names to dataframe
    df_with_names = df.copy()
    df_with_names["ref_fasta_name"] = ref_names
    df_with_names["var_fasta_name"] = var_names

    print(f"Total unique reference sequences: {len(ref_sequences)}")
    print(f"Total unique variant sequences: {len(var_sequences)}")

    return df_with_names


In [None]:
# Generate FASTA files for reference and variant sequences
brca1_df = generate_fasta_files(brca1_df, seq_chr17, output_dir=FASTA_OUTPUT_DIR)

#### Specifying Checkpoint Path
If the desired 1.08 training loss has not been reached, this example can also be ran using the existing Evo2 weights. It can also be ran using the 7b model which will have better performance and works for GPUs that do not support FP8. 

This following section specifies the checkpoint path we will use and downloads the 7b model if it does not already exist. Modify the following cell as desired to specify the correct checkpoint path. 


In [None]:
USE_FINETUNED_CHECKPOINT = True
MODEL_SIZE = "1b"

In [None]:
# Use our fine-tuned checkpoint instead of the original
if USE_FINETUNED_CHECKPOINT:
    checkpoint_path = final_ckpt_path
else:
    checkpoint_path = Path(f"nemo2_evo2_{MODEL_SIZE}_8k")
    if not checkpoint_path.exists() or not any(checkpoint_path.iterdir()):
        !evo2_convert_to_nemo2 --model-path hf://arcinstitute/savanna_evo2_{MODEL_SIZE}_base --model-size {MODEL_SIZE} --output-dir nemo2_evo2_{MODEL_SIZE}_8k
    else:
        print(f"Using existing checkpoint at {checkpoint_path}")

    checkpoint_path=str(checkpoint_path)

#### Score Sequences with Fine-tuned Evo 2

Now we'll score the likelihoods of the reference and variant sequences of each SNV using our fine-tuned Evo 2 model. This demonstrates the practical application of our fine-tuned checkpoint for downstream biological tasks.

In [None]:
USE_FP8=False  ### Set to True if using a GPU that supports FP8

In [None]:
# Define output directories for prediction results
output_dir = Path(FASTA_OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)

# Save reference and variant sequences to FASTA
ref_fasta_path = output_dir / "brca1_reference_sequences.fasta"
var_fasta_path = output_dir / "brca1_variant_sequences.fasta"

predict_ref_dir = output_dir / "reference_predictions"
predict_var_dir = output_dir / "variant_predictions"
predict_ref_dir.mkdir(parents=True, exist_ok=True)
predict_var_dir.mkdir(parents=True, exist_ok=True)

# Update predict commands to use our fine-tuned checkpoint
predict_ref_command = (
    f"predict_evo2 --fasta {ref_fasta_path} --ckpt-dir {checkpoint_path} "
    f"--output-dir {predict_ref_dir} --model-size 1b --tensor-parallel-size 1 "
    f"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs "
)

predict_var_command = (
    f"predict_evo2 --fasta {var_fasta_path} --ckpt-dir {checkpoint_path} "
    f"--output-dir {predict_var_dir} --model-size 1b --tensor-parallel-size 1 "
    f"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs "
)
if USE_FP8:
    predict_ref_command += "--fp8"
    predict_var_command += "--fp8"

print(f"Using fine-tuned checkpoint: {checkpoint_path}")
print(f"Reference prediction command: {predict_ref_command}")
print(f"Variant prediction command: {predict_var_command}")


Score reference sequences:


In [None]:
%%capture
print(f"Running command: {predict_ref_command}")
result = run_subprocess_safely(predict_ref_command)

In [None]:
assert result["returncode"] == 0, result

Score variant sequences:


In [None]:
%%capture

print(f"Running command: {predict_var_command}")
result = run_subprocess_safely(predict_var_command)

In [None]:
assert result["returncode"] == 0, result

### Calculate Delta Scores and Evaluate Performance

We calculate the change in likelihoods for each variant relative to the likelihood of their respective wild-type sequence. This delta likelihood should be predictive of how disruptive the SNV is to the protein's function: the lower the delta, the more likely that the SNV is disruptive.


In [None]:
# Find and load prediction files
ref_pred_files = glob.glob(os.path.join(predict_ref_dir, "predictions__rank_*.pt"))
var_pred_files = glob.glob(os.path.join(predict_var_dir, "predictions__rank_*.pt"))

# Load sequence ID maps (maps sequence ID -> prediction index)
with open(os.path.join(predict_ref_dir, "seq_idx_map.json"), "r") as f:
    ref_seq_idx_map = json.load(f)
with open(os.path.join(predict_var_dir, "seq_idx_map.json"), "r") as f:
    var_seq_idx_map = json.load(f)

# Load predictions
ref_preds = torch.load(ref_pred_files[0])
var_preds = torch.load(var_pred_files[0])


Calculated the delta score:

In [None]:
# Calculate change in likelihoods
ref_log_probs = []
var_log_probs = []
for _, row in brca1_df.iterrows():
    ref_name = row["ref_fasta_name"]
    var_name = row["var_fasta_name"]
    ref_log_probs.append(ref_preds["log_probs_seqs"][ref_seq_idx_map[ref_name]].item())
    var_log_probs.append(var_preds["log_probs_seqs"][var_seq_idx_map[var_name]].item())

brca1_df["ref_log_probs"] = ref_log_probs
brca1_df["var_log_probs"] = var_log_probs
# Ideally probability of a broken variant is lower than a good one. So a bad var - good ref is negative.
brca1_df["evo2_delta_score"] = brca1_df["var_log_probs"] - brca1_df["ref_log_probs"]
brca1_df.head()


In [None]:
def plot_strip_with_means(df, x_col="evo2_delta_score", class_col="class"):
    """Creates a strip plot with jittered points and median indicators for each class using Seaborn.

    Parameters:
    - df (pd.DataFrame): The input DataFrame containing data.
    - x_col (str): The column name representing the x-axis values (e.g., evo2_delta_score).
    - class_col (str): The column name representing the class labels.

    Returns:
    - matplotlib Figure: Strip plot with median indicators.
    """
    # NVIDIA theme colors
    NVIDIA_GREEN = "#76B900"
    BACKGROUND_COLOR = "#F8F8F8"
    GRID_COLOR = "#DDDDDD"
    FONT_COLOR = "#333333"

    # Determine order of classes (if not already specified)
    unique_classes = sorted(df[class_col].unique())

    # Set up the plot with NVIDIA theme
    plt.figure(figsize=(9, 5), facecolor=BACKGROUND_COLOR)
    plt.style.use("default")  # Reset to default to avoid any pre-existing style

    # Create strip plot
    p = sns.stripplot(
        data=df,
        x=x_col,
        y=class_col,
        hue=class_col,
        order=unique_classes,
        palette=[NVIDIA_GREEN, "red"],
        size=6,
        jitter=0.3,
        alpha=0.6,
    )

    # Add median indicators using boxplot
    sns.boxplot(
        showmeans=True,
        meanline=True,
        meanprops={"visible": False},
        medianprops={"color": "black", "ls": "-", "lw": 2},
        whiskerprops={"visible": False},
        zorder=10,
        x=x_col,
        y=class_col,
        data=df,
        order=unique_classes,
        showfliers=False,
        showbox=False,
        showcaps=False,
        ax=p,
    )

    # Customize plot appearance
    plt.title(
        "Distribution of Delta Likelihoods Scores\nComparing Fine-tuned Evo 2 likelihood scores for different BRCA1 SNV classes",
        color=FONT_COLOR,
        fontsize=12,
        loc="left",
    )
    plt.xlabel("Delta Likelihood Score, Fine-tuned Evo 2", color=FONT_COLOR)
    plt.ylabel("BRCA1 SNV Class", color=FONT_COLOR)

    # Customize grid and tick colors
    plt.grid(color=GRID_COLOR, axis="x", linestyle="--", linewidth=0.5)
    plt.tick_params(colors=FONT_COLOR)

    # Set background color
    plt.gca().set_facecolor(BACKGROUND_COLOR)
    plt.gcf().set_facecolor(BACKGROUND_COLOR)

    plt.tight_layout()

plot_strip_with_means(brca1_df, x_col="evo2_delta_score", class_col="class")


We can also calculate the area under the receiver operating characteristic curve (AUROC) of this zero-shot prediction method. Note that the results are nearly random unless you are on one of the following configurations:

* --fp8 on an fp8 enabled GPU with either the 1b or 7b models. The 40b likely works as well.
* the 7b model uniquely seems to work well without --fp8 so if you are on an older device, the 7b model should produce robust results. Change the MODEL_SIZE earlier in this tutorial and rerun for good results in that case.

In [None]:
# Calculate AUROC of zero-shot predictions
# class 1 is LOF which is the bad thing. That means we expect this to be more negative.
y_true = brca1_df["class"] == "LOF"
auroc = roc_auc_score(y_true, -brca1_df["evo2_delta_score"])
print(f"Zero-shot prediction AUROC with fine-tuned model: {auroc:.3f}")

In [None]:
def plot_roc_curve(df):
    """Plots an ROC curve using Seaborn with a light NVIDIA-themed design.

    The function assumes:
    - `class` column as the true labels (binary, 'LOF' = 1, else 0).
    - `evo2_delta_score` as the prediction score.

    Parameters:
    - df (pd.DataFrame): DataFrame containing `class` and `evo2_delta_score`.

    Returns:
    - matplotlib Figure: ROC Curve Visualization.
    """
    # NVIDIA theme colors
    NVIDIA_GREEN = "#76B900"
    BACKGROUND_COLOR = "#F8F8F8"
    GRID_COLOR = "#DDDDDD"
    FONT_COLOR = "#333333"

    # Validate required columns
    if "class" not in df.columns or "evo2_delta_score" not in df.columns:
        raise ValueError("DataFrame must contain 'class' and 'evo2_delta_score' columns.")

    # Convert 'class' to binary labels: Assume 'LOF' = 1, anything else = 0
    y_true = (df["class"] == "LOF").astype(int)

    # Compute ROC curve
    fpr, tpr, _ = roc_curve(y_true, -df["evo2_delta_score"])  # Negative to align with previous logic
    roc_auc = auc(fpr, tpr)

    # Set up the plot with NVIDIA theme
    plt.figure(figsize=(9, 5), facecolor=BACKGROUND_COLOR)
    plt.style.use("default")  # Reset to default to avoid any pre-existing style

    # Plot ROC curve
    plt.plot(fpr, tpr, color=NVIDIA_GREEN, lw=3, label=f"ROC curve (AUROC = {roc_auc:.3f})")

    # Plot diagonal reference line for random guessing
    plt.plot([0, 1], [0, 1], color="gray", lw=2, linestyle="--")

    # Customize plot appearance
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate", color=FONT_COLOR, fontsize=12)
    plt.ylabel("True Positive Rate", color=FONT_COLOR, fontsize=12)
    plt.title(
        "Fine-tuned Model ROC Curve\nEvaluating the discriminative performance of fine-tuned Evo 2 predictions",
        color=FONT_COLOR,
        fontsize=16,
        loc="left",
    )

    # Customize grid and tick colors
    plt.grid(color=GRID_COLOR, linestyle="--", linewidth=0.5)
    plt.tick_params(colors=FONT_COLOR)

    # Set background color
    plt.gca().set_facecolor(BACKGROUND_COLOR)

    # Add legend
    plt.legend(loc="lower right", frameon=True, facecolor=BACKGROUND_COLOR, edgecolor=GRID_COLOR)

plot_roc_curve(brca1_df)

### Summary of Performance with Full Sample

The above analysis may have been performed on a subset of the available data.

For comparison, the table below presents the AUROC scores for different model sizes trained on the full dataset (100% sample fraction). Your finetuned BF16 based model should reach a similar AUROC when reaching a training loss of 1.08. 

| Model Type | Dataset Sample | AUROC |
|------------|---------------|--------|
| Original Evo2 1B | 100% | 0.74 |
| Original Evo2 7B | 100% | 0.87 |
