### I run the code on High-RAM (80GB) A100 GPU on google colab

In [None]:
%pip -q uninstall -y transformers tokenizers huggingface-hub peft trl optimum sentence-transformers || true

%pip -q install "transformers==4.45.2" "accelerate>=0.29,<1" "safetensors"

%pip -q install -U "huggingface_hub>=0.23"

%pip -q install "torch==2.5.1" "torchvision==0.20.1" "torchaudio==2.5.1" --index-url https://download.pytorch.org/whl/cu124
%pip -q install "flash-attn==2.7.4.post1" --no-build-isolation

%pip -q install -U evo-model

!python -m pip check

In [2]:
import torch, gc, difflib
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from evo.scoring import prepare_batch
from evo import Evo
import ast  # For parsing label strings
from datetime import datetime
from tqdm import tqdm

from google.colab import drive
import sys, os
import importlib
import pandas as pd

def free_gpu(*objs):
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    print("Freed Python GC and CUDA caches.")

Mount Data

In [None]:
drive.mount('/content/drive/')

BASE = "/content/drive/MyDrive/EvoTagger-files"
CSV_PATH_TRAIN = f"{BASE}/ecoli_train_12k.csv"
CSV_PATH_VAL = f"{BASE}/ecoli_val_750.csv"
CSV_PATH_TEST = f"{BASE}/ecoli_test_5k.csv"
CHECKPOINT_PATH = f"{BASE}/checkpoints"

if BASE not in sys.path:
    sys.path.append(BASE)

Data Prepration

In [None]:
NUM_OF_TRAIN_SEQ = 12000
NUM_OF_VAL_SEQ = 750
NUM_OF_TEST_SEQ = 5000

print("Loading data from data/data_with_labels.csv...")
train_df = pd.read_csv(CSV_PATH_TRAIN)
val_df = pd.read_csv(CSV_PATH_VAL)
train_df = train_df[:NUM_OF_TRAIN_SEQ]
val_df = val_df[:NUM_OF_VAL_SEQ]
print(f"✓ Loaded {len(train_df)}, {len(val_df)} sequences")


# ============================================================================
# Dataset and Collate Function
# ============================================================================

class DNADataset(Dataset):
    """
    Dataset for DNA error correction.
    Uses pre-computed labels from CSV (no need to regenerate).
    """
    def __init__(self, noisy_list, clean_list, fine_labels_list, coarse_labels_list):
        """
        Args:
            noisy_list: List of noisy DNA sequences
            clean_list: List of clean DNA sequences
            fine_labels_list: List of pre-computed fine-grained labels
            coarse_labels_list: List of pre-computed coarse-grained labels
        """
        self.noisy = noisy_list
        self.clean = clean_list
        self.fine_labels = fine_labels_list
        self.coarse_labels = coarse_labels_list

    def __len__(self):
        return len(self.noisy)

    def __getitem__(self, i):
        """
        Returns:
            noisy_seq: DNA sequence string
            fine_labels: Tensor of fine-grained labels
            coarse_labels: Tensor of coarse-grained labels
        """
        return (
            self.noisy[i],
            torch.tensor(self.fine_labels[i], dtype=torch.long),
            torch.tensor(self.coarse_labels[i], dtype=torch.long)
        )


def create_collate_fn(tokenizer, device):
    """
    Factory function to create collate_fn with tokenizer and device bound.
    CharLevelTokenizer uses .tokenize() method, not __call__.
    """
    def collate_fn(batch):
        # batch: list of (noisy_str, fine_labels, coarse_labels)
        seqs = [b[0] for b in batch]
        fine = [b[1] for b in batch]
        coarse = [b[2] for b in batch]

        # Tokenize sequences
        # CharLevelTokenizer.tokenize(seq) returns list of token IDs
        tokenized = [tokenizer.tokenize(seq) for seq in seqs]

        # Pad to max length in batch
        max_len = max(len(t) for t in tokenized)
        pad_id = tokenizer.pad_id

        input_ids = torch.full((len(batch), max_len), pad_id, dtype=torch.long, device=device)
        seq_lengths = []

        for i, tokens in enumerate(tokenized):
            seq_len = len(tokens)
            input_ids[i, :seq_len] = torch.tensor(tokens, dtype=torch.long, device=device)
            seq_lengths.append(seq_len)

        # Pad labels
        pad_val = -100  # ignore_index for CE
        fine_pad = torch.full((len(batch), max_len), pad_val, dtype=torch.long, device=device)
        coarse_pad = torch.full((len(batch), max_len), pad_val, dtype=torch.long, device=device)

        for i, (f, c, L) in enumerate(zip(fine, coarse, seq_lengths)):
            fine_pad[i, :L] = f.to(device)
            coarse_pad[i, :L] = c.to(device)

        return input_ids, fine_pad, coarse_pad

    return collate_fn

Model

In [7]:
FINE = {
    "KEEP": 0,
    "REPLACE_A": 1, "REPLACE_C": 2, "REPLACE_G": 3, "REPLACE_T": 4,
}
COARSE_MAP = {0: 0, 1: 1, 2: 1, 3: 1, 4: 1}  # KEEP=0, REPLACE=1

class CustomEmbedding(nn.Module):
    """Passthrough layer - returns input unchanged."""
    def unembed(self, u):
        return u


class EvoFeaturizer(nn.Module):
    """
    Wrap Evo and expose (B, L, D) features.

    Extracts rich DNA sequence embeddings from the frozen Evo model.
    """
    def __init__(self, evo_backbone):
        super().__init__()
        self.evo = evo_backbone

        # This makes the model return hidden states instead of logits
        self.evo.unembed = CustomEmbedding()

    def forward(self, input_ids):
        # Forward through Evo with modified unembed (returns hidden states)
        with torch.no_grad():
            hidden_states, _ = self.evo(input_ids)  # Fallback for backward compatibility

            # Convert from bfloat16 to float32 for compatibility with LSTM
            hidden_states = hidden_states.float()  # (B, L, 4096) in float32

        return hidden_states  # Output: (B, L, D=4096)

class DNATagger(nn.Module):
    """
    Hierarchical character-level tagger for DNA error correction - SIMPLIFIED VERSION.

    Predicts edit operations at each position:
    - Fine-grained: 5 classes (KEEP, REPLACE_A, REPLACE_C, REPLACE_G, REPLACE_T)
    - Coarse-grained: 2 categories (KEEP, REPLACE)
    """
    def __init__(self, D, num_fine=5):
        super().__init__()

        # Layer normalization for input features (helps convergence)
        self.input_norm = nn.LayerNorm(D)

        # BiLSTM for context modeling
        # Input: D, Hidden: D//2, Bidirectional → Output: 2*(D//2) = D
        self.ctx = nn.LSTM(D, D//2, num_layers=1, bidirectional=True, batch_first=True)

        # Layer normalization after LSTM (helps convergence)
        self.ctx_norm = nn.LayerNorm(D)

        # Fine-grained prediction head (5 classes)
        self.head_fine = nn.Linear(D, num_fine)

        # Hierarchical grouping: fine labels → coarse categories
        # SIMPLIFIED: Only 2 coarse categories
        # Group 0: KEEP (label 0)
        # Group 1: REPLACE (labels 1, 2, 3, 4 = REPLACE_A, C, G, T)
        self.groups = {
            0: [0],           # KEEP
            1: [1, 2, 3, 4]   # REPLACE_{A,C,G,T}
        }

    def forward(self, H):
        # Step 1: Normalize input features (stabilizes training)
        H = self.input_norm(H)                           # (B, L, D)

        # Step 2: BiLSTM for bidirectional context
        H2, _ = self.ctx(H)                              # (B, L, D) - context features

        # Step 3: Normalize after LSTM (helps gradient flow)
        H2 = self.ctx_norm(H2)                           # (B, L, D)

        # Step 4: Predict fine-grained labels
        fine_logits = self.head_fine(H2)                 # (B, L, 5)

        # Step 5: Aggregate to coarse labels via log-sum-exp (parameter-free!)
        coarse_logits = []

        for gid in [0, 1]:
            # Get indices for this group
            # gid=0: [0] (KEEP)
            # gid=1: [1,2,3,4] (REPLACE_A, REPLACE_C, REPLACE_G, REPLACE_T)
            idx = torch.tensor(self.groups[gid], device=H.device)

            # Select logits for this group: (B, L, 5) → (B, L, group_size)
            # Then logsumexp over last dim: (B, L, group_size) → (B, L, 1)
            group_logits = fine_logits.index_select(-1, idx)  # (B, L, group_size)
            aggregated = torch.logsumexp(group_logits, dim=-1, keepdim=True)  # (B, L, 1)
            coarse_logits.append(aggregated)

        # Concatenate 2 groups: [(B,L,1), (B,L,1)] → (B,L,2)
        coarse_logits = torch.cat(coarse_logits, dim=-1)  # (B, L, 2)

        return fine_logits, coarse_logits  # (B, L, 5), (B, L, 2)

Training

In [None]:
# ============================================================================
# Training Script
# ============================================================================
free_gpu()


def main():
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}\n")

    # Load Evo model
    print("Loading Evo model...")
    evo_model = Evo('evo-1.5-8k-base')

    # Evo wrapper handles device placement internally
    tokenizer = evo_model.tokenizer
    model_backbone = evo_model.model.to(device)
    model_backbone.eval()

    # Create models
    # NOTE: evo-1-8k-base has hidden dimension 4096
    print("Creating models...")
    featurizer = EvoFeaturizer(model_backbone).to(device)
    tagger = DNATagger(D=4096, num_fine=5).to(device)  # 5 classes: KEEP + 4 REPLACE types
    featurizer.eval() #freeze evo model

    # Show mutation statistics
    train_mutations = train_df['num_mutations'].sum()
    val_mutations = val_df['num_mutations'].sum()
    print(f"\nMutation statistics:")
    print(f"  Train mutations: {train_mutations} ({train_df['num_mutations'].mean():.1f} per seq)")
    print(f"  Val mutations: {val_mutations} ({val_df['num_mutations'].mean():.1f} per seq)")
    print(f"  💡 Training data has {train_df['num_mutations'].mean():.0f}x more mutations per sequence!")

    # Parse pre-computed labels from CSV (they're stored as strings)
    print("\nParsing labels...")
    train_df['fine_labels_parsed'] = train_df['fine_labels'].apply(ast.literal_eval)
    train_df['coarse_labels_parsed'] = train_df['coarse_labels'].apply(ast.literal_eval)
    val_df['fine_labels_parsed'] = val_df['fine_labels'].apply(ast.literal_eval)
    val_df['coarse_labels_parsed'] = val_df['coarse_labels'].apply(ast.literal_eval)
    print("✓ Labels parsed")

    # Calculate class weights based on TRAINING data only
    from collections import Counter
    print("\nAnalyzing training class distribution...")
    all_fine_labels = []
    all_coarse_labels = []
    for fine_labels, coarse_labels in zip(train_df['fine_labels_parsed'], train_df['coarse_labels_parsed']):
        all_fine_labels.extend(fine_labels)
        all_coarse_labels.extend(coarse_labels)

    fine_label_counts = Counter(all_fine_labels)
    coarse_label_counts = Counter(all_coarse_labels)
    total_labels = len(all_fine_labels)

    # Calculate inverse frequency weights for COARSE-grained labels (2 classes)
    coarse_class_weights = torch.zeros(2)
    for label in range(2):
        count = coarse_label_counts.get(label, 1)
        coarse_class_weights[label] = total_labels / (2 * count)

    # Apply sqrt dampening for stable gradients with extreme imbalance
    coarse_class_weights = torch.sqrt(coarse_class_weights)
    coarse_class_weights = coarse_class_weights.to(device)

    # No split needed - we loaded separate train/val datasets!
    print(f"✓ Train={len(train_df)}, Val={len(val_df)}\n")

    train_noisy = train_df['noisy'].tolist()
    train_clean = train_df['clean'].tolist()
    train_fine = train_df['fine_labels_parsed'].tolist()
    train_coarse = train_df['coarse_labels_parsed'].tolist()

    val_noisy = val_df['noisy'].tolist()
    val_clean = val_df['clean'].tolist()
    val_fine = val_df['fine_labels_parsed'].tolist()
    val_coarse = val_df['coarse_labels_parsed'].tolist()

    print(f"✓ Split: Train={len(train_noisy)}, Val={len(val_noisy)}\n")

    # Create datasets (using pre-computed labels!)
    train_dataset = DNADataset(train_noisy, train_clean, train_fine, train_coarse)
    val_dataset = DNADataset(val_noisy, val_clean, val_fine, val_coarse)

    batch_size = 16
    num_epochs = 15
    collate_fn = create_collate_fn(tokenizer, device)
    train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size, shuffle=False, collate_fn=collate_fn)

    optimizer = torch.optim.AdamW(tagger.parameters(), lr=1e-4, weight_decay=0.01)

    print('Save to Drive')
    # Save training parameters to txt file
    with open(f'{CHECKPOINT_PATH}/training_params.txt', 'w') as f:
        f.write("="*80 + "\n")
        f.write("TRAINING PARAMETERS\n")
        f.write("="*80 + "\n\n")
        f.write(f"Model: DNATagger with EvoFeaturizer\n")
        f.write(f"Evo Model: evo-1.5-8k-base\n")
        f.write(f"Device: {device}\n\n")
        f.write(f"Dataset:\n")
        f.write(f"  Train sequences: {len(train_df)}\n")
        f.write(f"  Val sequences: {len(val_df)}\n")
        f.write(f"  Train mutations: {train_df['num_mutations'].sum()}\n")
        f.write(f"  Val mutations: {val_df['num_mutations'].sum()}\n")
        f.write(f"  Avg mutations/seq (train): {train_df['num_mutations'].mean():.1f}\n\n")
        f.write(f"Training:\n")
        f.write(f"  Epochs: {num_epochs}\n")
        f.write(f"  Batch size: {batch_size}\n")
        f.write(f"  Learning rate: {optimizer.param_groups[0]['lr']:.6f}\n")
        f.write(f"  Optimizer: AdamW (weight_decay=0.01)\n")
        f.write(f"Class Weights:\n")
        f.write(f"  Coarse (2 classes): {coarse_class_weights.cpu().tolist()}\n")
        f.write("="*80 + "\n")

    print("="*80)
    print("STARTING TRAINING")
    print("="*80)
    print(f"Epochs: {num_epochs}")
    print(f"Batch size: {batch_size}")
    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")
    print(f"Initial LR: {optimizer.param_groups[0]['lr']:.6f}")
    print(f"✓ Saving to: checkpoints/")
    print("="*80 + "\n")

    best_val_acc = 0.0

    for epoch in range(num_epochs):
        print(f"\n{'='*80}")
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"{'='*80}")

        # Train
        tagger.train()
        train_loss = 0.0
        train_mutation_loss = 0.0
        train_coarse_loss = 0.0

        # Training progress bar
        train_pbar = tqdm(train_loader, desc='Training', leave=False)
        for input_ids, fine_y, coarse_y in train_pbar:
            # Get features from frozen Evo
            with torch.no_grad():
                H = featurizer(input_ids)

            # Get predictions
            fine_logits, coarse_logits = tagger(H)

            # Calculate losses
            B, L, Kf = fine_logits.shape

            mutation_mask = (fine_y != 0) & (fine_y != -100)
            if mutation_mask.sum() > 0:
                mutation_loss = F.cross_entropy(
                    fine_logits[mutation_mask],
                    fine_y[mutation_mask]
                )
            else:
                mutation_loss = 0

            coarse_loss = F.cross_entropy(
                coarse_logits.view(B*L, 2),
                coarse_y.view(B*L),
                weight=coarse_class_weights,  # Sqrt dampening
                ignore_index=-100
            )

            # Combined loss
            loss = 0.5 * mutation_loss + 0.5 * coarse_loss

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Track losses
            train_loss += loss.item()
            train_mutation_loss += mutation_loss.item()
            train_coarse_loss += coarse_loss.item()

            # Update progress bar
            train_pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        train_loss /= len(train_loader)
        train_mutation_loss /= len(train_loader)
        train_coarse_loss /= len(train_loader)

        # Validate
        tagger.eval()
        val_loss = 0.0
        val_mutation_loss = 0.0
        val_coarse_loss = 0.0
        val_correct = 0
        val_total = 0

        # Track mutation-specific metrics
        mutation_correct = 0  # Correct predictions for non-KEEP labels
        mutation_total = 0    # Total non-KEEP labels
        per_class_correct = [0] * 5  # Correct per class (5 classes)
        per_class_total = [0] * 5    # Total per class (5 classes)

        # Validation progress bar
        val_pbar = tqdm(val_loader, desc='Validation', leave=False)
        with torch.no_grad():
            for input_ids, fine_y, coarse_y in val_pbar:
                H = featurizer(input_ids)
                fine_logits, coarse_logits = tagger(H)

                B, L, Kf = fine_logits.shape


                mutation_mask = (fine_y != 0) & (fine_y != -100)
                if mutation_mask.sum() > 0:
                    mutation_loss = F.cross_entropy(
                        fine_logits[mutation_mask],
                        fine_y[mutation_mask]
                    )
                else:
                    mutation_loss = 0

                coarse_loss = F.cross_entropy(
                    coarse_logits.view(B*L, 2),
                    coarse_y.view(B*L),
                    weight=coarse_class_weights,
                    ignore_index=-100
                )

                # Combined loss
                loss = 0.5 * mutation_loss + 0.5 * coarse_loss
                val_loss += loss.item()
                val_mutation_loss += mutation_loss.item()
                val_coarse_loss += coarse_loss.item()

                # Calculate overall accuracy
                fine_pred = fine_logits.argmax(dim=-1)
                mask = (fine_y != -100)
                val_correct += ((fine_pred == fine_y) & mask).sum().item()
                val_total += mask.sum().item()

                # Calculate mutation-specific accuracy (non-KEEP only)
                # KEEP = class 0, mutations = classes 1-4
                mutation_mask = mask & (fine_y != 0)  # Valid positions that are NOT KEEP
                if mutation_mask.sum() > 0:
                    mutation_correct += ((fine_pred == fine_y) & mutation_mask).sum().item()
                    mutation_total += mutation_mask.sum().item()

                # Per-class accuracy
                for class_idx in range(5):
                    class_mask = mask & (fine_y == class_idx)
                    if class_mask.sum() > 0:
                        per_class_correct[class_idx] += ((fine_pred == fine_y) & class_mask).sum().item()
                        per_class_total[class_idx] += class_mask.sum().item()

                # Update progress bar
                val_pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        val_loss /= len(val_loader)
        val_mutation_loss /= len(val_loader)
        val_coarse_loss /= len(val_loader)
        val_acc = val_correct / val_total if val_total > 0 else 0

        # Calculate mutation-specific accuracy
        mutation_acc = mutation_correct / mutation_total if mutation_total > 0 else 0

        # Calculate per-class accuracies
        per_class_acc = []
        for i in range(5):
            if per_class_total[i] > 0:
                per_class_acc.append(per_class_correct[i] / per_class_total[i])
            else:
                per_class_acc.append(0.0)

        # Track best model (based on mutation accuracy, not overall!)
        is_best = mutation_acc > best_val_acc
        if is_best:
            best_val_acc = mutation_acc

        # Print detailed epoch summary
        print(f"\nTRAIN:")
        print(f"  Total Loss:  {train_loss:.4f}")
        print(f"  Mutation Loss:   {train_mutation_loss:.4f}")
        print(f"  Coarse Loss: {train_coarse_loss:.4f}")
        print(f"\nVALIDATION:")
        print(f"  Total Loss:  {val_loss:.4f}")
        print(f"  Mutation Loss:   {val_mutation_loss:.4f}")
        print(f"  Coarse Loss: {val_coarse_loss:.4f}")
        print(f"  Overall Accuracy:  {val_acc:.4f} ({val_correct}/{val_total})")
        print(f"  ⭐ MUTATION Accuracy: {mutation_acc:.4f} ({mutation_correct}/{mutation_total}) {'🌟 NEW BEST!' if is_best else ''}")
        print(f"  Best Mutation Acc:  {best_val_acc:.4f}")

        # Show per-class breakdown (classes with >0 samples)
        print(f"\n  Per-Class Accuracy:")
        class_names = {
            0: 'KEEP',
            1: 'REPLACE_A', 2: 'REPLACE_C', 3: 'REPLACE_G', 4: 'REPLACE_T'
        }
        for i in range(5):
            if per_class_total[i] > 0:
                name = class_names.get(i, f'Class_{i}')
                acc = per_class_acc[i]
                count = per_class_total[i]
                print(f"    {name:<12}: {acc:.3f} ({per_class_correct[i]}/{count})")

        # Save model checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': tagger.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'mutation_acc': mutation_acc,
            'val_acc': val_acc,
            'best_val_acc': best_val_acc,
            'per_class_acc': per_class_acc,
        }
        checkpoint_path = f'{CHECKPOINT_PATH}/model_epoch.pt'
        torch.save(checkpoint, checkpoint_path)

        # Also save best model separately
        if is_best:
            torch.save(checkpoint, f'{CHECKPOINT_PATH}/model_best.pt')
            print(f"\n  💾 Saved best model (mutation_acc={mutation_acc:.4f})")

    # Training complete - show summary
    print("\n" + "="*80)
    print("TRAINING COMPLETE!")
    print("="*80)
    print(f"Best Mutation Accuracy: {best_val_acc:.4f}")
    print(f"Final Overall Accuracy: {val_acc:.4f}")
    print(f"Final Mutation Accuracy: {mutation_acc:.4f}")
    print(f"Final Train Loss: {train_loss:.4f}")
    print(f"Final Val Loss: {val_loss:.4f}")
    print("="*80 + "\n")

    # Save model
    print("Saving model...")
    torch.save(tagger.state_dict(), 'simple_model.pt')
    print("✓ Saved to simple_model.pt (local)")

    print("\n" + "="*80)
    print("Next step: Run simple_inference.py to test the model!")
    print("="*80)

main()

Test

In [None]:
def load_model(checkpoint_path, device):
    """Load trained model from checkpoint."""
    print(f"Loading model from {checkpoint_path}...")

    # Load Evo
    print("Loading Evo model...")
    evo_model = Evo('evo-1.5-8k-base')
    tokenizer = evo_model.tokenizer
    model_backbone = evo_model.model.to(device)
    model_backbone.eval()

    # Create models
    featurizer = EvoFeaturizer(model_backbone).to(device)
    tagger = DNATagger(D=4096, num_fine=5).to(device)
    featurizer.eval()

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    tagger.load_state_dict(checkpoint['model_state_dict'])
    tagger.eval()

    print(f"✓ Model loaded from epoch {checkpoint['epoch']}")
    print(f"  Training mutation acc: {checkpoint['mutation_acc']:.4f}")
    print(f"  Training overall acc: {checkpoint['val_acc']:.4f}\n")

    return featurizer, tagger, tokenizer

def predict_sequence(sequence, featurizer, tagger, tokenizer, device):
    """
    Predict corrections for a single sequence.

    Returns:
        fine_predictions: List of predicted labels (0-4)
        fine_probs: List of confidence scores
    """
    # Tokenize
    tokens = tokenizer.tokenize(sequence)
    input_ids = torch.tensor([tokens], dtype=torch.long, device=device)

    # Get features and predictions
    with torch.no_grad():
        H = featurizer(input_ids)
        fine_logits, coarse_logits = tagger(H)

    # Get predictions and probabilities
    fine_probs = F.softmax(fine_logits[0], dim=-1)  # (L, 5)
    fine_preds = fine_logits[0].argmax(dim=-1)  # (L,)

    return fine_preds.cpu().tolist(), fine_probs.cpu()

def test_from_csv(csv_path, model_path, device, num_samples=10):
    """
    Test model on sequences from CSV file.

    Args:
        csv_path: Path to CSV with test data
        model_path: Path to model checkpoint
        device: CUDA or CPU
        num_samples: Number of sequences to test
    """
    print("="*80)
    print("DNA ERROR CORRECTION - TEST SCRIPT")
    print("="*80 + "\n")

    # Load model
    featurizer, tagger, tokenizer = load_model(model_path, device)

    # Load test data
    print(f"Loading test data from {csv_path}...")
    df = pd.read_csv(csv_path)
    print(f"✓ Loaded {len(df)} sequences\n")

    # Test on random samples
    import random
    test_indices = random.sample(range(len(df)), min(num_samples, len(df)))

    label_names = {
        0: 'KEEP',
        1: 'REPLACE_A',
        2: 'REPLACE_C',
        3: 'REPLACE_G',
        4: 'REPLACE_T'
    }

    total_correct = 0
    total_positions = 0
    mutation_correct = 0
    mutation_total = 0
    per_class_correct = [0] * 5  # Per-class accuracy tracking
    per_class_total = [0] * 5

    print("="*80)
    print(f"TESTING {num_samples} SEQUENCES")
    print("="*80 + "\n")

    for idx in test_indices:
        row = df.iloc[idx]
        noisy_seq = row['noisy']
        clean_seq = row['clean']
        true_fine = ast.literal_eval(row['fine_labels'])

        print(f"Sequence {idx}:")
        print(f"  Length: {len(noisy_seq)} bases")
        print(f"  Mutations: {row['num_mutations']}")

        # Get predictions
        pred_fine, pred_probs = predict_sequence(noisy_seq, featurizer, tagger, tokenizer, device)

        # Find mutation positions
        mutation_positions = [i for i, label in enumerate(true_fine) if label != 0]

        if mutation_positions:
            print(f"\n  Mutation Details:")
            for pos in mutation_positions[:5]:  # Show first 5
                if pos < len(pred_fine):
                    true_label = true_fine[pos]
                    pred_label = pred_fine[pos]
                    confidence = pred_probs[pos, pred_label].item()

                    correct = "✓" if pred_label == true_label else "✗"
                    print(f"    Position {pos}: {noisy_seq[pos]} → {clean_seq[pos]}")
                    print(f"      True:  {label_names[true_label]}")
                    print(f"      Pred:  {label_names[pred_label]} ({confidence:.2%} conf) {correct}")

            if len(mutation_positions) > 5:
                print(f"    ... and {len(mutation_positions) - 5} more mutations")

        # Calculate accuracy
        seq_len = min(len(true_fine), len(pred_fine))
        correct = sum(1 for i in range(seq_len) if pred_fine[i] == true_fine[i])
        total_correct += correct
        total_positions += seq_len

        # Per-class and mutation-specific accuracy
        for i in range(seq_len):
            true_label = true_fine[i]
            pred_label = pred_fine[i]

            # Track per-class stats
            per_class_total[true_label] += 1
            if pred_label == true_label:
                per_class_correct[true_label] += 1

            # Track mutation stats
            if true_label != 0:  # Non-KEEP (mutation)
                mutation_total += 1
                if pred_label == true_label:
                    mutation_correct += 1

        print(f"\n  Accuracy: {correct}/{seq_len} = {100*correct/seq_len:.1f}%")
        print("  " + "-"*76 + "\n")

    # Overall statistics
    print("="*80)
    print("TEST SUMMARY")
    print("="*80)
    print(f"  Overall Accuracy:  {100*total_correct/total_positions:.2f}% ({total_correct}/{total_positions})")
    if mutation_total > 0:
        print(f"  Mutation Accuracy: {100*mutation_correct/mutation_total:.2f}% ({mutation_correct}/{mutation_total})")
    else:
        print(f"  Mutation Accuracy: N/A (no mutations in test set)")

    # Per-class accuracy breakdown
    print(f"\n  Per-Class Accuracy:")
    for class_idx in range(5):
        if per_class_total[class_idx] > 0:
            name = label_names[class_idx]
            acc = per_class_correct[class_idx] / per_class_total[class_idx]
            count = per_class_total[class_idx]
            print(f"    {name:<12}: {acc:.3f} ({per_class_correct[class_idx]}/{count})")

    print("="*80 + "\n")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

model_path_to_test = f'{BASE}/model_epoch.pt'

test_from_csv(CSV_PATH_TEST, model_path_to_test, device, NUM_OF_TEST_SEQ)