# HyenaDNA Binary Classifier for DNA Breakpoint Detection

This notebook implements a complete pipeline to fine-tune HyenaDNA for classifying DNA sequences as positive or negative breakpoints, similar to FusionAI.

## Overview

**What this notebook does:**
1. **Prepares data** from FusionAI-format CSV files (positive/negative breakpoints)
2. **Merges sequences** (5' + 3' = 20kb input)
3. **Fine-tunes HyenaDNA** with a binary classification head
4. **Makes predictions** on test data

**Key Features:**
- Handles sequences up to 1M bp (we'll use 20kb)
- Uses Hugging Face Trainer for optimization
- Automatic mixed precision (FP16) for faster training
- Comprehensive metrics (Accuracy, Precision, Recall, F1, AUC)

---

## üì¶ Installation & Setup

First, let's install all required dependencies and check our compute environment.

In [2]:
# Install required packages
# !pip install -q torch transformers accelerate pandas scikit-learn

print("‚úÖ Installation complete!")

‚úÖ Installation complete!


In [1]:
# Check GPU availability and setup
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è WARNING: No GPU detected! Training will be very slow.")
    print("   Consider using a GPU runtime (Runtime ‚Üí Change runtime type ‚Üí GPU)")

PyTorch version: 2.9.1+cu128
CUDA available: True
GPU Device: NVIDIA A100 80GB PCIe
GPU Memory: 84.93 GB




## üìö Import Libraries

Import all necessary libraries for data processing, model training, and evaluation.

In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer
from safetensors.torch import load_model, save_model
from typing import Dict, List, Tuple
import os
from tqdm.auto import tqdm
import warnings

warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
print("‚úÖ Libraries imported successfully!")

  from .autonotebook import tqdm as notebook_tqdm


‚úÖ Libraries imported successfully!


## ‚öôÔ∏è Configuration

Set all hyperparameters and file paths here. **Update the CSV paths to match your data!**

### Model Selection Guide:
- `hyenadna-tiny-1k-seqlen` - Max 1kb (‚ùå Too short for 20kb!)
- `hyenadna-small-32k-seqlen` - Max 32kb (‚úÖ **Recommended for 20kb**)
- `hyenadna-medium-160k-seqlen` - Max 160kb (Better performance, more memory)
- `hyenadna-large-1m-seqlen` - Max 1M (Best performance, requires A100 80GB)

In [3]:
# ============ FILE PATHS - UPDATE THESE! ============
POSITIVE_CSV = "datasets/fusion_gene_positive_bp_information_with_class_for_modeling.txt"  # Path to positive breakpoint CSV
NEGATIVE_CSV = "datasets/fusion_gene_negative_bp_information_with_class_for_modeling.txt"  # Path to negative breakpoint CSV
TEST_CSV = "datasets/fusion_gene_positive_bp_information_with_class_for_testing.txt"          # Path to test data CSV
COLUMNS = ["Hgene","Hchr","Hbp","Hstrand","Tgene","Tchr","Tbp","Tstrand","5'-gene sequence (10Kb)","3'-gene sequence (10Kb)"] # CSV Columns names
OUTPUT_DIR = "./hyenadna_breakpoint_model_160K" # Where to save the model

# ============ MODEL CONFIGURATION ============
MODEL_NAME = "LongSafari/hyenadna-medium-160k-seqlen-hf"  # Recommended for 20kb sequences
MAX_LENGTH = 20480  # Maximum sequence length (20kb)

# ============ TRAINING HYPERPARAMETERS ============
NUM_EPOCHS = 3
BATCH_SIZE = 8          # Adjust based on GPU memory (2 for T4, 4 for V100, 8 for A100)
LEARNING_RATE = 2e-5    # Learning rate
WARMUP_STEPS = 1000      # Learning rate warmup
WEIGHT_DECAY = 0.01     # L2 regularization
VAL_SPLIT = 0.2         # Validation split (20%)

# ============ DEVICE SETUP ============
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Model: {MODEL_NAME}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Training epochs: {NUM_EPOCHS}")

Using device: cuda
Model: LongSafari/hyenadna-medium-160k-seqlen-hf
Batch size: 8
Training epochs: 3


## üìä Step 1: Data Preparation Class

This class handles loading and preparing the FusionAI CSV files.

**Input CSV Format (tab-separated):**
- `Hgene`, `Hchr`, `Hbp`, `Hstrand` - Head gene information
- `Tgene`, `Tchr`, `Tbp`, `Tstrand` - Tail gene information
- `5'-gene sequence (10Kb)` - 5' sequence (10,000 bp)
- `3'-gene sequence (10Kb)` - 3' sequence (10,000 bp)

**What it does:**
1. Loads positive and negative CSV files
2. Merges 5' + 3' sequences ‚Üí 20kb sequences
3. Adds labels (1=positive, 0=negative)
4. Combines and shuffles data
5. Validates sequences

In [4]:
class DataPreparator:
    """Prepares FusionAI CSV data for HyenaDNA training"""
    COLUMNS = ["Hgene","Hchr","Hbp","Hstrand","Tgene","Tchr","Tbp","Tstrand","5'-gene sequence (10Kb)","3'-gene sequence (10Kb)"]
    def __init__(self, positive_csv_path: str, negative_csv_path: str):
        self.positive_csv_path = positive_csv_path
        self.negative_csv_path = negative_csv_path

    def load_and_prepare_data(self) -> pd.DataFrame:
        """
        Load positive and negative CSV files, merge sequences, and shuffle
        
        Returns:
            DataFrame with columns: sequence (20kb), label (0/1)
        """
        print("Loading positive breakpoint data...")
        positive_df = pd.read_csv(self.positive_csv_path, header=None, names=COLUMNS, sep='\t')
        
        print("Loading negative breakpoint data...")
        negative_df = pd.read_csv(self.negative_csv_path, header=None, names=COLUMNS, sep='\t')
        
        # Merge 5' and 3' sequences to create 20kb sequences
        print("Merging 5' and 3' gene sequences...")
        positive_df['sequence'] = positive_df["5'-gene sequence (10Kb)"] + positive_df["3'-gene sequence (10Kb)"]
        positive_df['label'] = 1  # Positive breakpoint
        
        negative_df['sequence'] = negative_df["5'-gene sequence (10Kb)"] + negative_df["3'-gene sequence (10Kb)"]
        negative_df['label'] = 0  # Negative breakpoint
        
        # Keep only sequence and label columns
        positive_prepared = positive_df[['sequence', 'label']]
        negative_prepared = negative_df[['sequence', 'label']]
        
        # Merge datasets
        print("Merging positive and negative datasets...")
        combined_df = pd.concat([positive_prepared, negative_prepared], ignore_index=True)
        
        # Shuffle the dataset
        print("Shuffling dataset...")
        combined_df = combined_df.sample(frac=1, random_state=42).reset_index(drop=True)
        
        # Validate sequences
        print("Validating sequences...")
        self._validate_sequences(combined_df)
        
        print(f"\nDataset prepared:")
        print(f"  Total samples: {len(combined_df)}")
        print(f"  Positive samples: {(combined_df['label'] == 1).sum()}")
        print(f"  Negative samples: {(combined_df['label'] == 0).sum()}")
        print(f"  Average sequence length: {combined_df['sequence'].str.len().mean():.0f} bp")
        
        return combined_df
    
    def _validate_sequences(self, df: pd.DataFrame):
        """Validate DNA sequences"""
        valid_bases = set('ATCGN')
        
        for idx, seq in enumerate(df['sequence'].head(100)):  # Check first 100
            if not set(seq.upper()).issubset(valid_bases):
                invalid_chars = set(seq.upper()) - valid_bases
                print(f"Warning: Invalid characters found in sequence {idx}: {invalid_chars}")
        
        # Check for empty sequences
        empty_sequences = df['sequence'].str.len() == 0
        if empty_sequences.any():
            print(f"Warning: {empty_sequences.sum()} empty sequences found")
        
        # Check sequence length distribution
        seq_lengths = df['sequence'].str.len()
        print(f"  Sequence length range: {seq_lengths.min()} - {seq_lengths.max()} bp")

print("‚úÖ DataPreparator class defined")

‚úÖ DataPreparator class defined


## üîÑ Step 2: PyTorch Dataset Class

Custom PyTorch Dataset for DNA sequences. This handles:
- Tokenizing DNA sequences (character-level: A, T, C, G)
- Padding/truncating to max length
- Converting to PyTorch tensors

In [5]:
class DNABreakpointDataset(Dataset):
    """PyTorch Dataset for DNA sequences"""
    
    def __init__(self, sequences: List[str], labels: List[int], tokenizer, max_length: int = 20480):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx].upper()
        label = self.labels[idx]
        
        # Tokenize the sequence
        # HyenaDNA uses character-level tokenization (A, T, C, G)
        encoding = self.tokenizer(
            sequence,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            # 'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }

print("‚úÖ DNABreakpointDataset class defined")

‚úÖ DNABreakpointDataset class defined


## üß¨ Step 3: HyenaDNA Binary Classifier Model

This defines the model architecture:

**Architecture:**
1. **HyenaDNA Backbone** (pretrained) - Extracts features from DNA sequences
2. **Global Average Pooling** - Aggregates sequence representation
3. **Classification Head:**
   - Linear layer (hidden_size ‚Üí 512) + ReLU + Dropout
   - Linear layer (512 ‚Üí 128) + ReLU + Dropout
   - Linear layer (128 ‚Üí 2) for binary classification

**Output:** Logits for 2 classes (negative=0, positive=1)

In [6]:
class HyenaDNAClassifier(nn.Module):
    """HyenaDNA with binary classification head"""
    
    def __init__(self, model_name: str = "LongSafari/hyenadna-tiny-1k-seqlen", num_labels: int = 2):
        super(HyenaDNAClassifier, self).__init__()
        
        print(f"Loading HyenaDNA model: {model_name}")
        
        # Load the pretrained HyenaDNA model
        self.hyenadna = AutoModel.from_pretrained(
            model_name,
            trust_remote_code=True,
            torch_dtype=torch.float32
        )
        
        # Get the hidden size from the model config
        self.hidden_size = self.hyenadna.config.d_model
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(self.hidden_size, 256),  
            nn.LayerNorm(256),              
            nn.ReLU(),
            nn.Dropout(0.2),                   
            nn.Linear(256, num_labels)         
        )
        
        print(f"Model loaded. Hidden size: {self.hidden_size}")
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        # Get HyenaDNA embeddings
        outputs = self.hyenadna(input_ids)
        
        # Use the last hidden state
        # Take the mean over the sequence length dimension
        if hasattr(outputs, 'last_hidden_state'):
            sequence_output = outputs.last_hidden_state
        else:
            sequence_output = outputs[0]
        
        # Global average pooling
        pooled_output = torch.mean(sequence_output, dim=1)
        
        # Classification
        logits = self.classifier(pooled_output)
        logits = torch.clamp(logits, min=-10, max=10)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(label_smoothing=0.1)
            loss = loss_fct(logits, labels)
        
        return {
            'loss': loss,
            'logits': logits
        }

print("‚úÖ HyenaDNAClassifier class defined")

‚úÖ HyenaDNAClassifier class defined


## üìà Step 4: Metrics Function

Computes evaluation metrics:
- **Accuracy**: Overall correctness
- **Precision**: Of predicted positives, how many are truly positive?
- **Recall**: Of actual positives, how many did we find?
- **F1 Score**: Harmonic mean of precision and recall
- **AUC**: Area under ROC curve (discrimination ability)

In [7]:
def compute_metrics(eval_pred):
    """Compute classification metrics"""
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    
    # For AUC, we need probabilities
    probs = torch.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1)[:, 1].numpy()
    probs = np.nan_to_num(probs, nan=0.5)

    if len(np.unique(labels)) > 1:
        auc = roc_auc_score(labels, probs)
    else:
        auc = 0.0
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc
    }

print("‚úÖ Metrics function defined")

‚úÖ Metrics function defined


## üéØ Step 5: Training Pipeline Class

Complete training pipeline that handles:
- Dataset preparation and splitting
- Model initialization
- Training with Hugging Face Trainer
- Model saving and checkpointing
- Predictions on test data

In [None]:
class HyenaDNATrainingPipeline:
    """Complete training pipeline for HyenaDNA breakpoint classifier"""
    
    def __init__(
        self,
        model_name: str,
        output_dir: str,
        max_length: int
    ):
        self.model_name = model_name
        self.output_dir = output_dir
        self.max_length = max_length
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        print(f"Using device: {self.device}")
        
        # Load tokenizer
        print("Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True
        )
        
        # Initialize model
        self.model = None
    
    def prepare_datasets(
        self,
        train_df: pd.DataFrame,
        val_split: float = 0.2
    ) -> Tuple[DNABreakpointDataset, DNABreakpointDataset]:
        """Prepare train and validation datasets"""
        
        # Split data
        train_sequences, val_sequences, train_labels, val_labels = train_test_split(
            train_df['sequence'].tolist(),
            train_df['label'].tolist(),
            test_size=val_split,
            random_state=42,
            stratify=train_df['label']
        )
        
        print(f"\nDataset split:")
        print(f"  Training samples: {len(train_sequences)}")
        print(f"  Validation samples: {len(val_sequences)}")
        
        # Create datasets
        train_dataset = DNABreakpointDataset(
            train_sequences, train_labels, self.tokenizer, self.max_length
        )
        val_dataset = DNABreakpointDataset(
            val_sequences, val_labels, self.tokenizer, self.max_length
        )
        
        return train_dataset, val_dataset
    
    def train(
        self,
        train_dataset: DNABreakpointDataset,
        val_dataset: DNABreakpointDataset,
        num_epochs: int,
        batch_size: int,
        learning_rate: float,
        warmup_steps: int,
        weight_decay: float
    ):
        """Train the model"""
        
        # Initialize model
        self.model = HyenaDNAClassifier(self.model_name, num_labels=2)
        self.model.to(self.device)

        use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    
        # Training arguments
        training_args = TrainingArguments(
            output_dir=self.output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            learning_rate=learning_rate,
            warmup_steps=warmup_steps,
            weight_decay=weight_decay,        
            max_grad_norm=1.0,             
            lr_scheduler_type="cosine",
            bf16=use_bf16,
            fp16=False, 
            logging_dir=f'{self.output_dir}/logs',
            logging_steps=50,
            eval_strategy="steps",
            eval_steps=500,
            save_strategy="steps",
            save_steps=500,
            save_total_limit=3,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            dataloader_num_workers=4,
            remove_unused_columns=False,
            report_to="none",
            save_safetensors=False,
        )
        
        # Create trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=compute_metrics
        )
        
        # Train
        print("\n" + "="*50)
        print("Starting training...")
        print("="*50 + "\n")
        
        trainer.train()
        
        # Save the final model
        print("\nSaving final model...")
        save_model(self.model, f"{self.output_dir}/final_model/model.safetensors")
        self.tokenizer.save_pretrained(f"{self.output_dir}/final_model")
        
        print(f"Model saved to {self.output_dir}/final_model")
        
        return trainer
    
    def predict(
        self,
        test_csv_path: str,
        model_path: str = None,
        output_path: str = "predictions.csv",
        batch_size: int = 8
    ):
        """Make predictions on test data"""
        
        if model_path is None:
            model_path = f"{self.output_dir}/final_model"
        
        
        if self.model == None:
            print(f"\nLoading model from {model_path}...")
            self.model = HyenaDNAClassifier(self.model_name, num_labels=2)
            load_model(self.model, f"{model_path}/model.safetensors")
            self.model.to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(f"{OUTPUT_DIR}/final_model",
            trust_remote_code=True)
        self.model.eval()
        
        # Load test data
        print(f"Loading test data from {test_csv_path}...")
        test_df = pd.read_csv(test_csv_path, header=None, names=COLUMNS, sep='\t')
        
        # Prepare sequences
        test_df['sequence'] = test_df["5'-gene sequence (10Kb)"] + test_df["3'-gene sequence (10Kb)"]
        
        # Create dataset (dummy labels)
        test_dataset = DNABreakpointDataset(
            test_df['sequence'].tolist(),
            [0] * len(test_df),  # Dummy labels
            self.tokenizer,
            self.max_length
        )
        
        # Create dataloader
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=4
        )
        
        # Make predictions
        print("Making predictions...")
        all_predictions = []
        all_probabilities = []
        
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Predicting"):
                input_ids = batch['input_ids'].to(self.device)
                # attention_mask = batch['attention_mask'].to(self.device)
                
                outputs = self.model(input_ids)
                logits = outputs['logits']
                
                # Get probabilities
                probs = torch.softmax(logits, dim=-1)
                predictions = torch.argmax(probs, dim=-1)
                
                all_predictions.extend(predictions.cpu().numpy())
                all_probabilities.extend(probs[:, 1].cpu().numpy())  # Probability of positive class
        
        # Add predictions to dataframe
        test_df['predicted_label'] = all_predictions
        test_df['breakpoint_probability'] = all_probabilities
        test_df['prediction'] = test_df['predicted_label'].map({0: 'Negative', 1: 'Positive'})
        
        # Save results
        print(f"\nSaving predictions to {output_path}...")
        test_df.to_csv(output_path, index=False)
        
        print("\nPrediction Summary:")
        print(f"  Total samples: {len(test_df)}")
        print(f"  Predicted positive: {(test_df['predicted_label'] == 1).sum()}")
        print(f"  Predicted negative: {(test_df['predicted_label'] == 0).sum()}")
        print(f"  Mean positive probability: {test_df['breakpoint_probability'].mean():.4f}")
        
        return test_df

print("‚úÖ HyenaDNATrainingPipeline class defined")

‚úÖ HyenaDNATrainingPipeline class defined


---
# üöÄ EXECUTION STARTS HERE

Now we'll run the complete pipeline step by step.

---

## üìÇ Step 6: Load and Prepare Data

Load the positive and negative CSV files, merge sequences, and prepare the dataset.

In [11]:
print("="*70)
print("STEP 1: DATA PREPARATION")
print("="*70)

# Initialize data preparator
data_prep = DataPreparator(POSITIVE_CSV, NEGATIVE_CSV)

# Load and prepare data
train_df = data_prep.load_and_prepare_data()

# Display first few rows
print("\nFirst 3 samples:")
display(train_df.head(3))

print("\n‚úÖ Data preparation complete!")

STEP 1: DATA PREPARATION
Loading positive breakpoint data...


Loading negative breakpoint data...
Merging 5' and 3' gene sequences...
Merging positive and negative datasets...
Shuffling dataset...
Validating sequences...
  Sequence length range: 20000 - 20000 bp

Dataset prepared:
  Total samples: 50745
  Positive samples: 30745
  Negative samples: 20000
  Average sequence length: 20000 bp

First 3 samples:


Unnamed: 0,sequence,label
0,GCTGGGATTACAGGTGCCCACCACCATGCCTGGCTAATTTTTGTAT...,1
1,CCTCAGCCCTCCCATACAATTCTCCCAATGATAAGTGTGAGAACAC...,0
2,CCTGCACTCAAGCTATCCCCCCACCTCAGCCTCCCAAAGAGCTGGG...,1



‚úÖ Data preparation complete!


## üèóÔ∏è Step 7: Initialize Training Pipeline

Create the training pipeline with the specified model and configuration.

In [12]:
print("="*70)
print("STEP 2: INITIALIZE TRAINING PIPELINE")
print("="*70)

pipeline = HyenaDNATrainingPipeline(
    model_name=MODEL_NAME,
    output_dir=OUTPUT_DIR,
    max_length=MAX_LENGTH
)

print("\n‚úÖ Pipeline initialized!")

STEP 2: INITIALIZE TRAINING PIPELINE
Using device: cuda
Loading tokenizer...


A new version of the following files was downloaded from https://huggingface.co/LongSafari/hyenadna-medium-160k-seqlen-hf:
- tokenization_hyena.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.



‚úÖ Pipeline initialized!


## üìä Step 8: Prepare Training and Validation Datasets

Split the data into training and validation sets (80/20 split by default).

In [13]:
print("="*70)
print("STEP 3: PREPARE TRAINING AND VALIDATION DATASETS")
print("="*70)

train_dataset, val_dataset = pipeline.prepare_datasets(train_df, val_split=VAL_SPLIT)

print(f"\nTraining samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print("\n‚úÖ Datasets prepared!")

STEP 3: PREPARE TRAINING AND VALIDATION DATASETS

Dataset split:
  Training samples: 40596
  Validation samples: 10149

Training samples: 40596
Validation samples: 10149

‚úÖ Datasets prepared!


## üèãÔ∏è Step 9: Train the Model

This is the main training step. It will:
- Load the pretrained HyenaDNA model
- Add the classification head
- Fine-tune on your data
- Evaluate on validation set periodically
- Save the best model based on F1 score

**Expected time:** 30 minutes to several hours depending on:
- Dataset size
- GPU type (T4 vs A100)
- Number of epochs
- Model size

**Progress indicators:**
- Training loss per batch
- Validation metrics every 200 steps
- Best model automatically saved

In [14]:
print("="*70)
print("STEP 4: TRAIN THE MODEL")
print("="*70)
print(f"\nTraining configuration:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Warmup steps: {WARMUP_STEPS}")
print(f"  Weight decay: {WEIGHT_DECAY}")
print(f"\nThis may take a while... ‚òï")
print()

trainer = pipeline.train(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    num_epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    weight_decay=WEIGHT_DECAY
)

print("\n‚úÖ Training complete!")

STEP 4: TRAIN THE MODEL

Training configuration:
  Epochs: 3
  Batch size: 8
  Learning rate: 2e-05
  Warmup steps: 1000
  Weight decay: 0.01

This may take a while... ‚òï

Loading HyenaDNA model: LongSafari/hyenadna-medium-160k-seqlen-hf


A new version of the following files was downloaded from https://huggingface.co/LongSafari/hyenadna-medium-160k-seqlen-hf:
- configuration_hyena.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
`torch_dtype` is deprecated! Use `dtype` instead!
A new version of the following files was downloaded from https://huggingface.co/LongSafari/hyenadna-medium-160k-seqlen-hf:
- modeling_hyena.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Model loaded. Hidden size: 256

Starting training...



Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Auc
500,0.5432,0.536436,0.766282,0.76372,0.889413,0.821788,0.830335
1000,0.5336,0.534729,0.772096,0.772753,0.883721,0.82452,0.850227
1500,0.5479,0.499599,0.799685,0.804889,0.883558,0.842391,0.868315
2000,0.512,0.498878,0.800966,0.782383,0.930233,0.849926,0.878564
2500,0.4955,0.50819,0.804316,0.813055,0.879167,0.84482,0.875451
3000,0.5486,0.52013,0.785003,0.757163,0.949748,0.842591,0.881688
3500,0.49,0.494452,0.798798,0.838359,0.827452,0.83287,0.873166
4000,0.5329,0.47629,0.815647,0.84951,0.845503,0.847502,0.887787
4500,0.49,0.497519,0.804611,0.790354,0.922101,0.85116,0.885358
5000,0.4572,0.481054,0.81742,0.813852,0.905838,0.857385,0.890923



Saving final model...


SafetensorError: Error while serializing: I/O error: No such file or directory (os error 2)

## üìä Step 10: View Training Results

Display the final training metrics and model information.

In [None]:
# Get training history
history = trainer.state.log_history

# Extract validation metrics
val_metrics = [entry for entry in history if 'eval_accuracy' in entry]

if val_metrics:
    best_metrics = val_metrics[-1]
    print("\n" + "="*70)
    print("FINAL VALIDATION METRICS")
    print("="*70)
    print(f"Accuracy:  {best_metrics.get('eval_accuracy', 0):.4f}")
    print(f"Precision: {best_metrics.get('eval_precision', 0):.4f}")
    print(f"Recall:    {best_metrics.get('eval_recall', 0):.4f}")
    print(f"F1 Score:  {best_metrics.get('eval_f1', 0):.4f}")
    print(f"AUC:       {best_metrics.get('eval_auc', 0):.4f}")
    print("="*70)

print(f"\n‚úÖ Model saved to: {OUTPUT_DIR}/final_model/")

NameError: name 'trainer' is not defined

## üîÆ Step 11: Make Predictions on Test Data

Load the test CSV file and make predictions using the trained model.

**Output:** CSV file with:
- All original columns
- `predicted_label` (0 or 1)
- `breakpoint_probability` (0.0 to 1.0)
- `prediction` ("Negative" or "Positive")

In [None]:
print("="*70)
print("STEP 5: MAKE PREDICTIONS ON TEST DATA")
print("="*70)

predictions_df = pipeline.predict(
    test_csv_path=TEST_CSV,
    output_path="breakpoint_predictions.csv",
    batch_size=BATCH_SIZE
)

print("\n‚úÖ Predictions complete!")

## üìã Step 12: View Prediction Results

Display sample predictions and summary statistics.

In [None]:
print("\n" + "="*70)
print("PREDICTION SUMMARY")
print("="*70)

# Summary statistics
print(f"\nTotal test samples: {len(predictions_df)}")
print(f"Predicted Positive: {(predictions_df['predicted_label'] == 1).sum()} ({(predictions_df['predicted_label'] == 1).sum() / len(predictions_df) * 100:.1f}%)")
print(f"Predicted Negative: {(predictions_df['predicted_label'] == 0).sum()} ({(predictions_df['predicted_label'] == 0).sum() / len(predictions_df) * 100:.1f}%)")
print(f"\nAverage breakpoint probability: {predictions_df['breakpoint_probability'].mean():.4f}")
print(f"Min probability: {predictions_df['breakpoint_probability'].min():.4f}")
print(f"Max probability: {predictions_df['breakpoint_probability'].max():.4f}")

# Show sample predictions
print("\n" + "="*70)
print("SAMPLE PREDICTIONS (first 5 rows)")
print("="*70)
display(predictions_df[['predicted_label', 'breakpoint_probability', 'prediction']].head())

# Probability distribution
print("\n" + "="*70)
print("PROBABILITY DISTRIBUTION")
print("="*70)
print(predictions_df['breakpoint_probability'].describe())

## üìä (Optional) Step 13: Visualize Results

Create visualizations of the prediction results.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 4)

# Create subplots
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 1. Prediction distribution
predictions_df['prediction'].value_counts().plot(kind='bar', ax=axes[0], color=['#FF6B6B', '#4ECDC4'])
axes[0].set_title('Prediction Distribution', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Prediction')
axes[0].set_ylabel('Count')
axes[0].tick_params(axis='x', rotation=0)

# 2. Probability distribution
axes[1].hist(predictions_df['breakpoint_probability'], bins=30, edgecolor='black', alpha=0.7, color='#95E1D3')
axes[1].set_title('Breakpoint Probability Distribution', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Probability')
axes[1].set_ylabel('Frequency')
axes[1].axvline(x=0.5, color='red', linestyle='--', linewidth=2, label='Threshold (0.5)')
axes[1].legend()

# 3. Box plot by prediction
predictions_df.boxplot(column='breakpoint_probability', by='prediction', ax=axes[2])
axes[2].set_title('Probability by Prediction', fontsize=14, fontweight='bold')
axes[2].set_xlabel('Prediction')
axes[2].set_ylabel('Breakpoint Probability')
plt.suptitle('')  # Remove default title

plt.tight_layout()
plt.savefig('prediction_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úÖ Visualization saved as 'prediction_analysis.png'")

## üíæ Step 14: Save/Download Results

Download the results if you're using Google Colab.

In [None]:
# For Google Colab: Uncomment these lines to download files
# from google.colab import files
# 
# # Download predictions
# files.download('breakpoint_predictions.csv')
# 
# # Download visualization (if created)
# files.download('prediction_analysis.png')
# 
# # Optionally download the model
# !zip -r hyenadna_model.zip hyenadna_breakpoint_model/final_model/
# files.download('hyenadna_model.zip')

print("\nFiles ready for download:")
print("  - breakpoint_predictions.csv")
print("  - prediction_analysis.png (if visualization was run)")
print(f"  - {OUTPUT_DIR}/final_model/ (trained model)")

---
# ‚úÖ PIPELINE COMPLETE!

## What You've Accomplished:

1. ‚úÖ Loaded and prepared FusionAI format data
2. ‚úÖ Merged 5' and 3' sequences into 20kb inputs
3. ‚úÖ Fine-tuned HyenaDNA for binary classification
4. ‚úÖ Evaluated model performance on validation data
5. ‚úÖ Made predictions on test data
6. ‚úÖ Saved results and visualizations

## Output Files:

- **`breakpoint_predictions.csv`** - Test predictions with probabilities
- **`hyenadna_breakpoint_model/final_model/`** - Trained model weights
- **`prediction_analysis.png`** - Visualization (if created)

## Next Steps:

1. **Analyze results** - Check prediction quality and metrics
2. **Adjust hyperparameters** - Try different learning rates, batch sizes, or epochs
3. **Use the model** - Apply to new data using the saved model
4. **Experiment** - Try different HyenaDNA model sizes for better performance

---

## üîÑ To Use the Trained Model Later:

```python
# Load the pipeline
pipeline = HyenaDNATrainingPipeline(
    model_name=MODEL_NAME,
    output_dir=OUTPUT_DIR,
    max_length=MAX_LENGTH
)

# Make predictions on new data
new_predictions = pipeline.predict(
    test_csv_path="new_test_data.csv",
    model_path="./hyenadna_breakpoint_model/final_model",
    output_path="new_predictions.csv"
)
```

---

**Questions or issues?** Check the comments in each cell or refer to the README documentation.

**Happy predicting! üß¨üöÄ**