# HyenaDNA Binary Classifier for DNA Breakpoint Detection

This notebook implements a complete pipeline to fine-tune HyenaDNA for classifying DNA sequences as positive or negative breakpoints, similar to FusionAI.

## Overview

**What this notebook does:**
1. **Prepares data** from FusionAI-format CSV files (positive/negative breakpoints)
2. **Merges sequences** (5' + 3' = 20kb input)
3. **Fine-tunes HyenaDNA** with a binary classification head
4. **Makes predictions** on test data

**Key Features:**
- Handles sequences up to 1M bp (we'll use 20kb)
- Uses Hugging Face Trainer for optimization
- Automatic mixed precision (FP16) for faster training
- Comprehensive metrics (Accuracy, Precision, Recall, F1, AUC)

---

## üì¶ Installation & Setup

First, let's install all required dependencies and check our compute environment.

In [None]:
# Install required packages
# !pip install -q torch transformers accelerate pandas scikit-learn

print("‚úÖ Installation complete!")

In [None]:
# Check GPU availability and setup
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è WARNING: No GPU detected! Training will be very slow.")
    print("   Consider using a GPU runtime (Runtime ‚Üí Change runtime type ‚Üí GPU)")

## üìö Import Libraries

Import all necessary libraries for data processing, model training, and evaluation.

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer
from safetensors.torch import load_model, save_model
from typing import Dict, List, Tuple
import os
from tqdm.auto import tqdm
import warnings

warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
print("‚úÖ Libraries imported successfully!")

## ‚öôÔ∏è Configuration

Set all hyperparameters and file paths here. **Update the CSV paths to match your data!**

### Model Selection Guide:
- `hyenadna-tiny-1k-seqlen` - Max 1kb (‚ùå Too short for 20kb!)
- `hyenadna-small-32k-seqlen` - Max 32kb (‚úÖ **Recommended for 20kb**)
- `hyenadna-medium-160k-seqlen` - Max 160kb (Better performance, more memory)
- `hyenadna-large-1m-seqlen` - Max 1M (Best performance, requires A100 80GB)

In [None]:
# ============ FILE PATHS - UPDATE THESE! ============
POSITIVE_CSV = "datasets/fusion_gene_positive_bp_information_with_class_for_modeling.txt"  # Path to positive breakpoint CSV
NEGATIVE_CSV = "datasets/fusion_gene_negative_bp_information_with_class_for_modeling.txt"  # Path to negative breakpoint CSV
TEST_CSV = "datasets/fusion_gene_positive_bp_information_with_class_for_testing.txt"          # Path to test data CSV
COLUMNS = ["Hgene","Hchr","Hbp","Hstrand","Tgene","Tchr","Tbp","Tstrand","5'-gene sequence (10Kb)","3'-gene sequence (10Kb)"] # CSV Columns names
OUTPUT_DIR = "./hyenadna_breakpoint_model" # Where to save the model

# ============ MODEL CONFIGURATION ============
MODEL_NAME = "LongSafari/hyenadna-small-32k-seqlen-hf"  # Recommended for 20kb sequences
MAX_LENGTH = 20480  # Maximum sequence length (20kb)

# ============ TRAINING HYPERPARAMETERS ============
NUM_EPOCHS = 4
BATCH_SIZE = 8          # Adjust based on GPU memory (2 for T4, 4 for V100, 8 for A100)
LEARNING_RATE = 1e-5    # Learning rate
WARMUP_STEPS = 1000      # Learning rate warmup
WEIGHT_DECAY = 0.01     # L2 regularization
VAL_SPLIT = 0.2         # Validation split (20%)

# ============ DEVICE SETUP ============
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Model: {MODEL_NAME}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Training epochs: {NUM_EPOCHS}")

## üìä Step 1: Data Preparation Class

This class handles loading and preparing the FusionAI CSV files.

**Input CSV Format (tab-separated):**
- `Hgene`, `Hchr`, `Hbp`, `Hstrand` - Head gene information
- `Tgene`, `Tchr`, `Tbp`, `Tstrand` - Tail gene information
- `5'-gene sequence (10Kb)` - 5' sequence (10,000 bp)
- `3'-gene sequence (10Kb)` - 3' sequence (10,000 bp)

**What it does:**
1. Loads positive and negative CSV files
2. Merges 5' + 3' sequences ‚Üí 20kb sequences
3. Adds labels (1=positive, 0=negative)
4. Combines and shuffles data
5. Validates sequences

In [None]:
class DataPreparator:
    """
    Prepares FusionAI CSV data for HyenaDNA training with Data Augmentation
    """
    COLUMNS = ["Hgene","Hchr","Hbp","Hstrand","Tgene","Tchr","Tbp","Tstrand","5'-gene sequence (10Kb)","3'-gene sequence (10Kb)"]

    def __init__(self, positive_csv_path: str, negative_csv_path: str, extra_fasta_path: str = "datasets/chimeras_43466.fa"):
        self.positive_csv_path = positive_csv_path
        self.negative_csv_path = negative_csv_path
        self.extra_fasta_path = extra_fasta_path
        # Translation table for efficient reverse complement calculation
        self.trans_table = str.maketrans("ATCGN", "TAGCN")

    def _get_reverse_complement(self, sequence: str) -> str:
        """
        Returns the reverse complement of a DNA sequence.
        Fast implementation using string translation.
        """
        # 1. Translate (A->T, C->G, etc.)
        # 2. Reverse string ([::-1])
        return sequence.upper().translate(self.trans_table)[::-1]

    def load_and_prepare_data(self, augment: bool = False) -> pd.DataFrame:
        print("Loading positive breakpoint data...")
        positive_df = pd.read_csv(self.positive_csv_path, header=None, names=self.COLUMNS, sep='\t')

        print("Loading negative breakpoint data...")
        negative_df = pd.read_csv(self.negative_csv_path, header=None, names=self.COLUMNS, sep='\t')

        # Merge 5' and 3' sequences
        print("Merging 5' and 3' gene sequences...")
        positive_df['sequence'] = positive_df["5'-gene sequence (10Kb)"] + positive_df["3'-gene sequence (10Kb)"]
        positive_df['label'] = 1

        negative_df['sequence'] = negative_df["5'-gene sequence (10Kb)"] + negative_df["3'-gene sequence (10Kb)"]
        negative_df['label'] = 0

        # Reduce to sequence + label
        positive_prepared = positive_df[['sequence', 'label']]
        negative_prepared = negative_df[['sequence', 'label']]

        # Load additional positive sequences from FASTA
        print(f"Loading extra positive samples from: {self.extra_fasta_path}")
        if os.path.exists(self.extra_fasta_path):
            extra_sequences = self._load_fasta_multiline(self.extra_fasta_path)
            print(f"  Loaded {len(extra_sequences)} FASTA sequences")
            extra_positive_df = pd.DataFrame({
                "sequence": extra_sequences,
                "label": 1
            })
        else:
            print("  ‚ö†Ô∏è Warning: Extra FASTA file not found. Skipping.")
            extra_positive_df = pd.DataFrame(columns=['sequence', 'label'])

        # Merge all datasets
        print("Merging datasets (original positives + FASTA positives + negatives)...")
        combined_df = pd.concat([positive_prepared, extra_positive_df, negative_prepared], ignore_index=True)

        # --- AUGMENTATION LOGIC START ---
        if augment:
            print(f"\nApplying Reverse Complement Augmentation...")
            print(f"  Original count: {len(combined_df)}")
            
            # Create augmented copy
            augmented_df = combined_df.copy()
            # Apply reverse complement to the sequence column
            augmented_df['sequence'] = augmented_df['sequence'].apply(self._get_reverse_complement)
            
            # Concatenate original + augmented
            combined_df = pd.concat([combined_df, augmented_df], ignore_index=True)
            print(f"  Augmented count: {len(combined_df)} (Doubled dataset)")
        # --- AUGMENTATION LOGIC END ---

        # Shuffle
        print("Shuffling dataset...")
        combined_df = combined_df.sample(frac=1, random_state=42).reset_index(drop=True)

        # Validate
        print("Validating sequences...")
        self._validate_sequences(combined_df)

        print(f"\nDataset prepared:")
        print(f"  Total samples: {len(combined_df)}")
        print(f"  Positive samples: {(combined_df['label'] == 1).sum()}")
        print(f"  Negative samples: {(combined_df['label'] == 0).sum()}")

        return combined_df

    def _load_fasta_multiline(self, fasta_path: str):
        """Loads sequences from multi-line FASTA."""
        sequences = []
        current_seq = []
        with open(fasta_path, 'r') as f:
            for line in f:
                line = line.strip()
                if not line: continue
                if line.startswith(">"):
                    if current_seq:
                        sequences.append("".join(current_seq))
                        current_seq = []
                else:
                    current_seq.append(line)
        if current_seq:
            sequences.append("".join(current_seq))
        return sequences

    def _validate_sequences(self, df: pd.DataFrame):
        valid_bases = set('ATCGN')
        # Quick check on a few samples
        for idx, seq in enumerate(df['sequence'].head(20)):
            if not set(seq.upper()).issubset(valid_bases):
                print(f"Warning: Invalid characters in sequence {idx}")
                break

print("‚úÖ DataPreparator class updated with Reverse Complement augmentation")

## üîÑ Step 2: PyTorch Dataset Class

Custom PyTorch Dataset for DNA sequences. This handles:
- Tokenizing DNA sequences (character-level: A, T, C, G)
- Padding/truncating to max length
- Converting to PyTorch tensors

In [None]:
class DNABreakpointDataset(Dataset):
    """PyTorch Dataset for DNA sequences"""
    
    def __init__(self, sequences: List[str], labels: List[int], tokenizer, max_length: int = 20480):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx].upper()
        label = self.labels[idx]
        
        # Tokenize the sequence
        encoding = self.tokenizer(
            sequence,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].squeeze(0)
        # HyenaDNATokenizer does not return attention_mask, so we create it:
        if 'attention_mask' in encoding:
            attention_mask = encoding['attention_mask'].squeeze(0)
        else:
            pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 4
            attention_mask = (input_ids != pad_token_id).long()
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': torch.tensor(label, dtype=torch.long)
        }

print("‚úÖ DNABreakpointDataset class defined (attention_mask patch applied)")

## üß¨ Step 3: HyenaDNA Binary Classifier Model

This defines the model architecture:

**Architecture:**
1. **HyenaDNA Backbone** (pretrained) - Extracts features from DNA sequences
2. **Global Average Pooling** - Aggregates sequence representation
3. **Classification Head:**
   - Linear layer (hidden_size ‚Üí 512) + ReLU + Dropout
   - Linear layer (512 ‚Üí 128) + ReLU + Dropout
   - Linear layer (128 ‚Üí 2) for binary classification

**Output:** Logits for 2 classes (negative=0, positive=1)

In [None]:
class HyenaDNAClassifier(nn.Module):
    """
    HyenaDNA with Enhanced Pooling (Mean + Max) for Breakpoint Detection
    """
    
    def __init__(self, model_name: str = "LongSafari/hyenadna-small-32k-seqlen-hf", num_labels: int = 2):
        super(HyenaDNAClassifier, self).__init__()
        
        print(f"Loading HyenaDNA model: {model_name}")
        
        # Load the pretrained HyenaDNA model
        self.hyenadna = AutoModel.from_pretrained(
            model_name,
            trust_remote_code=True,
            torch_dtype=torch.float32
        )
        
        # Get the hidden size from the model config
        self.hidden_size = self.hyenadna.config.d_model
        
        # --- IMPROVEMENT START ---
        # We are now using Mean + Max pooling, so the input dimension doubles
        classifier_input_dim = self.hidden_size * 2
        
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 256), 
            nn.LayerNorm(256),              
            nn.ReLU(),
            nn.Dropout(0.2),                   
            nn.Linear(256, num_labels)         
        )
        # --- IMPROVEMENT END ---
        
        print(f"Model loaded. Hidden size: {self.hidden_size}")
        print(f"Classifier input dimension: {classifier_input_dim} (Mean + Max Pooling)")
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        # Get HyenaDNA embeddings
        outputs = self.hyenadna(input_ids)
        
        # Extract last hidden state
        if hasattr(outputs, 'last_hidden_state'):
            sequence_output = outputs.last_hidden_state
        else:
            sequence_output = outputs[0]
            
        # --- IMPROVEMENT START: ENHANCED POOLING ---
        # 1. Mean Pooling (General context)
        # Shape: [batch_size, hidden_size]
        mean_pool = torch.mean(sequence_output, dim=1)
        
        # 2. Max Pooling (Specific feature detection - crucial for breakpoints)
        # torch.max returns (values, indices), we only need values
        # Shape: [batch_size, hidden_size]
        max_pool, _ = torch.max(sequence_output, dim=1)
        
        # 3. Concatenate
        # Shape: [batch_size, hidden_size * 2]
        pooled_output = torch.cat((mean_pool, max_pool), dim=1)
        # --- IMPROVEMENT END ---
        
        # Classification
        logits = self.classifier(pooled_output)
        logits = torch.clamp(logits, min=-10, max=10)
        
        loss = None
        if labels is not None:
            # Weighted loss can be added here if handling class imbalance
            loss_fct = nn.CrossEntropyLoss(label_smoothing=0.1)
            loss = loss_fct(logits, labels)
        
        return {
            'loss': loss,
            'logits': logits
        }

print("‚úÖ HyenaDNAClassifier updated with Mean+Max pooling.")

## üìà Step 4: Metrics Function

Computes evaluation metrics:
- **Accuracy**: Overall correctness
- **Precision**: Of predicted positives, how many are truly positive?
- **Recall**: Of actual positives, how many did we find?
- **F1 Score**: Harmonic mean of precision and recall
- **AUC**: Area under ROC curve (discrimination ability)

In [None]:
def compute_metrics(eval_pred):
    """Compute classification metrics"""
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    
    # For AUC, we need probabilities
    probs = torch.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1)[:, 1].numpy()
    probs = np.nan_to_num(probs, nan=0.5)

    if len(np.unique(labels)) > 1:
        auc = roc_auc_score(labels, probs)
    else:
        auc = 0.0
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc
    }

print("‚úÖ Metrics function defined")

## üéØ Step 5: Training Pipeline Class

Complete training pipeline that handles:
- Dataset preparation and splitting
- Model initialization
- Training with Hugging Face Trainer
- Model saving and checkpointing
- Predictions on test data

In [None]:
class HyenaDNATrainingPipeline:
    """Complete training pipeline for HyenaDNA breakpoint classifier"""
    
    def __init__(
        self,
        model_name: str,
        output_dir: str,
        max_length: int
    ):
        self.model_name = model_name
        self.output_dir = output_dir
        self.max_length = max_length
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        print(f"Using device: {self.device}")
        
        # Load tokenizer
        print("Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True
        )
        
        # Initialize model
        self.model = None
    
    def prepare_datasets(
        self,
        train_df: pd.DataFrame,
        val_split: float = 0.2
    ) -> Tuple[DNABreakpointDataset, DNABreakpointDataset]:
        """Prepare train and validation datasets"""
        
        # Split data
        train_sequences, val_sequences, train_labels, val_labels = train_test_split(
            train_df['sequence'].tolist(),
            train_df['label'].tolist(),
            test_size=val_split,
            random_state=42,
            stratify=train_df['label']
        )
        
        print(f"\nDataset split:")
        print(f"  Training samples: {len(train_sequences)}")
        print(f"  Validation samples: {len(val_sequences)}")
        
        # Create datasets
        train_dataset = DNABreakpointDataset(
            train_sequences, train_labels, self.tokenizer, self.max_length
        )
        val_dataset = DNABreakpointDataset(
            val_sequences, val_labels, self.tokenizer, self.max_length
        )
        
        return train_dataset, val_dataset
    
    def train(
        self,
        train_dataset: DNABreakpointDataset,
        val_dataset: DNABreakpointDataset,
        num_epochs: int,
        batch_size: int,
        learning_rate: float,
        warmup_steps: int,
        weight_decay: float
    ):
        """Train the model"""
        
        # Initialize model
        self.model = HyenaDNAClassifier(self.model_name, num_labels=2)
        self.model.to(self.device)

        use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    
        # Training arguments
        training_args = TrainingArguments(
            output_dir=self.output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            learning_rate=learning_rate,
            warmup_steps=warmup_steps,
            weight_decay=weight_decay,        
            max_grad_norm=1.0,             
            lr_scheduler_type="cosine",
            bf16=use_bf16,
            fp16=False, 
            logging_dir=f'{self.output_dir}/logs',
            logging_steps=50,
            eval_strategy="steps",
            eval_steps=500,
            save_strategy="steps",
            save_steps=500,
            save_total_limit=3,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            dataloader_num_workers=4,
            remove_unused_columns=False,
            report_to="none",
            save_safetensors=False,
        )
        
        # Create trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=compute_metrics
        )
        
        # Train
        print("\n" + "="*50)
        print("Starting training...")
        print("="*50 + "\n")
        
        trainer.train()
        
        # Save the final model
        print("\nSaving final model...")
        save_model(self.model, f"{self.output_dir}/final_model/model.safetensors")
        self.tokenizer.save_pretrained(f"{self.output_dir}/final_model")
        
        print(f"Model saved to {self.output_dir}/final_model")
        
        return trainer
    
    def predict(
        self,
        test_csv_path: str,
        model_path: str = None,
        output_path: str = "predictions.csv",
        batch_size: int = 8
    ):
        """Make predictions on test data"""
        
        if model_path is None:
            model_path = f"{self.output_dir}/final_model"
        
        
        if self.model == None:
            print(f"\nLoading model from {model_path}...")
            self.model = HyenaDNAClassifier(self.model_name, num_labels=2)
            load_model(self.model, f"{model_path}/model.safetensors")
            self.model.to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(f"{OUTPUT_DIR}/final_model",
            trust_remote_code=True)
        self.model.eval()
        
        # Load test data
        print(f"Loading test data from {test_csv_path}...")
        test_df = pd.read_csv(test_csv_path, header=None, names=COLUMNS, sep='\t')
        
        # Prepare sequences
        test_df['sequence'] = test_df["5'-gene sequence (10Kb)"] + test_df["3'-gene sequence (10Kb)"]
        
        # Create dataset (dummy labels)
        test_dataset = DNABreakpointDataset(
            test_df['sequence'].tolist(),
            [0] * len(test_df),  # Dummy labels
            self.tokenizer,
            self.max_length
        )
        
        # Create dataloader
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=4
        )
        
        # Make predictions
        print("Making predictions...")
        all_predictions = []
        all_probabilities = []
        
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Predicting"):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                
                outputs = self.model(input_ids)
                logits = outputs['logits']
                
                # Get probabilities
                probs = torch.softmax(logits, dim=-1)
                predictions = torch.argmax(probs, dim=-1)
                
                all_predictions.extend(predictions.cpu().numpy())
                all_probabilities.extend(probs[:, 1].cpu().numpy())  # Probability of positive class
        
        # Add predictions to dataframe
        test_df['predicted_label'] = all_predictions
        test_df['breakpoint_probability'] = all_probabilities
        test_df['prediction'] = test_df['predicted_label'].map({0: 'Negative', 1: 'Positive'})
        
        # Save results
        print(f"\nSaving predictions to {output_path}...")
        test_df.to_csv(output_path, index=False)
        
        print("\nPrediction Summary:")
        print(f"  Total samples: {len(test_df)}")
        print(f"  Predicted positive: {(test_df['predicted_label'] == 1).sum()}")
        print(f"  Predicted negative: {(test_df['predicted_label'] == 0).sum()}")
        print(f"  Mean positive probability: {test_df['breakpoint_probability'].mean():.4f}")
        
        return test_df

print("‚úÖ HyenaDNATrainingPipeline class defined")

---
# üöÄ EXECUTION STARTS HERE

Now we'll run the complete pipeline step by step.

---

## üìÇ Step 6: Load and Prepare Data

Load the positive and negative CSV files, merge sequences, and prepare the dataset.

In [None]:
print("="*70)
print("STEP 1: DATA PREPARATION")
print("="*70)

# Initialize data preparator
data_prep = DataPreparator(POSITIVE_CSV, NEGATIVE_CSV)

# Load and prepare data
train_df = data_prep.load_and_prepare_data()

# Display first few rows
print("\nFirst 3 samples:")
display(train_df.head(3))

print("\n‚úÖ Data preparation complete!")

## üèóÔ∏è Step 7: Initialize Training Pipeline

Create the training pipeline with the specified model and configuration.

In [None]:
print("="*70)
print("STEP 2: INITIALIZE TRAINING PIPELINE")
print("="*70)

pipeline = HyenaDNATrainingPipeline(
    model_name=MODEL_NAME,
    output_dir=OUTPUT_DIR,
    max_length=MAX_LENGTH
)

print("\n‚úÖ Pipeline initialized!")

## üìä Step 8: Prepare Training and Validation Datasets

Split the data into training and validation sets (80/20 split by default).

In [None]:
print("="*70)
print("STEP 3: PREPARE TRAINING AND VALIDATION DATASETS")
print("="*70)

train_dataset, val_dataset = pipeline.prepare_datasets(train_df, val_split=VAL_SPLIT)

print(f"\nTraining samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print("\n‚úÖ Datasets prepared!")

## üèãÔ∏è Step 9: Train the Model

This is the main training step. It will:
- Load the pretrained HyenaDNA model
- Add the classification head
- Fine-tune on your data
- Evaluate on validation set periodically
- Save the best model based on F1 score

**Expected time:** 30 minutes to several hours depending on:
- Dataset size
- GPU type (T4 vs A100)
- Number of epochs
- Model size

**Progress indicators:**
- Training loss per batch
- Validation metrics every 200 steps
- Best model automatically saved

In [None]:
print("="*70)
print("STEP 4: TRAIN THE MODEL")
print("="*70)
print(f"\nTraining configuration:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Warmup steps: {WARMUP_STEPS}")
print(f"  Weight decay: {WEIGHT_DECAY}")
print(f"\nThis may take a while... ‚òï")
print()

trainer = pipeline.train(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    num_epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    weight_decay=WEIGHT_DECAY
)

print("\n‚úÖ Training complete!")

## üìä Step 10: View Training Results

Display the final training metrics and model information.

In [None]:
# Get training history
history = trainer.state.log_history

# Extract validation metrics
val_metrics = [entry for entry in history if 'eval_accuracy' in entry]

if val_metrics:
    best_metrics = val_metrics[-1]
    print("\n" + "="*70)
    print("FINAL VALIDATION METRICS")
    print("="*70)
    print(f"Accuracy:  {best_metrics.get('eval_accuracy', 0):.4f}")
    print(f"Precision: {best_metrics.get('eval_precision', 0):.4f}")
    print(f"Recall:    {best_metrics.get('eval_recall', 0):.4f}")
    print(f"F1 Score:  {best_metrics.get('eval_f1', 0):.4f}")
    print(f"AUC:       {best_metrics.get('eval_auc', 0):.4f}")
    print("="*70)

print(f"\n‚úÖ Model saved to: {OUTPUT_DIR}/final_model/")

## üîÆ Step 11: Make Predictions on Test Data

Load the test CSV file and make predictions using the trained model.

**Output:** CSV file with:
- All original columns
- `predicted_label` (0 or 1)
- `breakpoint_probability` (0.0 to 1.0)
- `prediction` ("Negative" or "Positive")

In [None]:
print("="*70)
print("STEP 5: MAKE PREDICTIONS ON TEST DATA")
print("="*70)

predictions_df = pipeline.predict(
    test_csv_path=TEST_CSV,
    output_path="breakpoint_predictions.csv",
    batch_size=BATCH_SIZE
)

print("\n‚úÖ Predictions complete!")

## üìã Step 12: View Prediction Results

Display sample predictions and summary statistics.

In [None]:
print("\n" + "="*70)
print("PREDICTION SUMMARY")
print("="*70)

# Summary statistics
print(f"\nTotal test samples: {len(predictions_df)}")
print(f"Predicted Positive: {(predictions_df['predicted_label'] == 1).sum()} ({(predictions_df['predicted_label'] == 1).sum() / len(predictions_df) * 100:.1f}%)")
print(f"Predicted Negative: {(predictions_df['predicted_label'] == 0).sum()} ({(predictions_df['predicted_label'] == 0).sum() / len(predictions_df) * 100:.1f}%)")
print(f"\nAverage breakpoint probability: {predictions_df['breakpoint_probability'].mean():.4f}")
print(f"Min probability: {predictions_df['breakpoint_probability'].min():.4f}")
print(f"Max probability: {predictions_df['breakpoint_probability'].max():.4f}")

# Show sample predictions
print("\n" + "="*70)
print("SAMPLE PREDICTIONS (first 5 rows)")
print("="*70)
display(predictions_df[['predicted_label', 'breakpoint_probability', 'prediction']].head())

# Probability distribution
print("\n" + "="*70)
print("PROBABILITY DISTRIBUTION")
print("="*70)
print(predictions_df['breakpoint_probability'].describe())

## üìä (Optional) Step 13: Visualize Results

Create visualizations of the prediction results.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 4)

# Create subplots
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 1. Prediction distribution
predictions_df['prediction'].value_counts().plot(kind='bar', ax=axes[0], color=['#FF6B6B', '#4ECDC4'])
axes[0].set_title('Prediction Distribution', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Prediction')
axes[0].set_ylabel('Count')
axes[0].tick_params(axis='x', rotation=0)

# 2. Probability distribution
axes[1].hist(predictions_df['breakpoint_probability'], bins=30, edgecolor='black', alpha=0.7, color='#95E1D3')
axes[1].set_title('Breakpoint Probability Distribution', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Probability')
axes[1].set_ylabel('Frequency')
axes[1].axvline(x=0.5, color='red', linestyle='--', linewidth=2, label='Threshold (0.5)')
axes[1].legend()

# 3. Box plot by prediction
predictions_df.boxplot(column='breakpoint_probability', by='prediction', ax=axes[2])
axes[2].set_title('Probability by Prediction', fontsize=14, fontweight='bold')
axes[2].set_xlabel('Prediction')
axes[2].set_ylabel('Breakpoint Probability')
plt.suptitle('')  # Remove default title

plt.tight_layout()
plt.savefig('prediction_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úÖ Visualization saved as 'prediction_analysis.png'")

## üíæ Step 14: Save/Download Results

Download the results if you're using Google Colab.

In [None]:
# For Google Colab: Uncomment these lines to download files
# from google.colab import files
# 
# # Download predictions
# files.download('breakpoint_predictions.csv')
# 
# # Download visualization (if created)
# files.download('prediction_analysis.png')
# 
# # Optionally download the model
# !zip -r hyenadna_model.zip hyenadna_breakpoint_model/final_model/
# files.download('hyenadna_model.zip')

print("\nFiles ready for download:")
print("  - breakpoint_predictions.csv")
print("  - prediction_analysis.png (if visualization was run)")
print(f"  - {OUTPUT_DIR}/final_model/ (trained model)")

---
# ‚úÖ PIPELINE COMPLETE!

## What You've Accomplished:

1. ‚úÖ Loaded and prepared FusionAI format data
2. ‚úÖ Merged 5' and 3' sequences into 20kb inputs
3. ‚úÖ Fine-tuned HyenaDNA for binary classification
4. ‚úÖ Evaluated model performance on validation data
5. ‚úÖ Made predictions on test data
6. ‚úÖ Saved results and visualizations

## Output Files:

- **`breakpoint_predictions.csv`** - Test predictions with probabilities
- **`hyenadna_breakpoint_model/final_model/`** - Trained model weights
- **`prediction_analysis.png`** - Visualization (if created)

## Next Steps:

1. **Analyze results** - Check prediction quality and metrics
2. **Adjust hyperparameters** - Try different learning rates, batch sizes, or epochs
3. **Use the model** - Apply to new data using the saved model
4. **Experiment** - Try different HyenaDNA model sizes for better performance

---

## üîÑ To Use the Trained Model Later:

```python
# Load the pipeline
pipeline = HyenaDNATrainingPipeline(
    model_name=MODEL_NAME,
    output_dir=OUTPUT_DIR,
    max_length=MAX_LENGTH
)

# Make predictions on new data
new_predictions = pipeline.predict(
    test_csv_path="new_test_data.csv",
    model_path="./hyenadna_breakpoint_model/final_model",
    output_path="new_predictions.csv"
)
```

---

**Questions or issues?** Check the comments in each cell or refer to the README documentation.

**Happy predicting! üß¨üöÄ**