### Config file 

In [None]:
import os

# Data paths
DATA_DIR = "data" 
TRAIN_PARQUET = os.path.join(DATA_DIR, "train/task3_variant_prediction/data/train.parquet")
VAL_PARQUET = os.path.join(DATA_DIR, "train/task3_variant_prediction/data/val.parquet")
TEST_PARQUET = os.path.join(DATA_DIR, "train/task3_variant_prediction/data/test.parquet")

# Embeddings directory
EMB_DIR = "benchmark/task3_variant_prediction/HyenaDNA/embeddings_hyenadna"
TRAIN_EMB = os.path.join(EMB_DIR, "train_embeddings.pt")
VAL_EMB = os.path.join(EMB_DIR, "val_embeddings.pt")
TEST_EMB = os.path.join(EMB_DIR, "test_embeddings.pt")

# HyenaDNA model
HYENADNA_MODEL = "LongSafari/hyenadna-large-1m-seqlen-hf"  # Hoặc medium-450k-seqlen
DNA_SEQ_LEN = 601  # Max length cho DNA sequences

# Embedding batch sizes
DNA_BATCH = 128

# Model hyperparameters
PROJ_DIM = 512
FUSION_HIDDEN = [512, 256]
DROPOUT = 0.3

# Training hyperparameters
LR = 1e-4
EPOCHS = 30
PATIENCE = 5
BATCH_SIZE = 512
WEIGHT_DECAY = 1e-4
SEED = 42

  TRAIN_PARQUET = os.path.join(DATA_DIR, "train\task3_variant_prediction\data\train.parquet")
  VAL_PARQUET = os.path.join(DATA_DIR, "train\task3_variant_prediction\data\val.parquet")
  TEST_PARQUET = os.path.join(DATA_DIR, "train\task3_variant_prediction\data\test.parquet")
  EMB_DIR = "benchmark\task3_variant_prediction\HyenaDNA\embeddings_hyenadna"


### Dataset 

In [7]:
import torch
from torch.utils.data import Dataset


class HyenaDNADataset(Dataset):
    """Dataset cho HyenaDNA embeddings"""
    def __init__(self, pt_file):
        data = torch.load(pt_file)
        self.dna_ref = data["dna_ref"]
        self.dna_alt = data["dna_alt"]
        self.labels = data["label"]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return (
            self.dna_ref[idx],
            self.dna_alt[idx],
            self.labels[idx],
        )

### Model 

In [8]:
import torch
import torch.nn as nn


class ModalityProjector(nn.Module):
    """Projector cho DNA embeddings: [ref, alt, diff] -> proj_dim"""
    def __init__(self, emb_dim, proj_dim, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_dim * 3, proj_dim),
            nn.LayerNorm(proj_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )

    def forward(self, ref, alt):
        diff = alt - ref
        x = torch.cat([ref, alt, diff], dim=-1)
        return self.net(x)


class DNAClassifier(nn.Module):
    """MLP Classifier chỉ dùng DNA embeddings từ HyenaDNA"""
    def __init__(self, dna_dim, proj_dim, hidden_dims, dropout):
        super().__init__()
        self.dna_proj = ModalityProjector(dna_dim, proj_dim, dropout)
        
        # Classifier layers
        layers = []
        in_dim = proj_dim
        for h in hidden_dims:
            layers.extend([nn.Linear(in_dim, h), nn.ReLU(), nn.Dropout(dropout)])
            in_dim = h
        
        layers.append(nn.Linear(in_dim, 1))
        self.classifier = nn.Sequential(*layers)

    def forward(self, dna_ref, dna_alt):
        dna_z = self.dna_proj(dna_ref, dna_alt)
        return self.classifier(dna_z).squeeze(dim=-1)

### Precompute Embeddings

In [9]:
import os
import argparse
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def embed_dna_hyenadna(seqs, tokenizer, model, batch_size, max_length=601):
    """Embed DNA sequences using HyenaDNA - lấy center token"""
    all_embs = []
    model.eval()
    with torch.no_grad():
        for i in tqdm(range(0, len(seqs), batch_size), desc="Embedding DNA"):
            batch = seqs[i : i + batch_size]
            inputs = tokenizer(
                batch,
                return_tensors="pt",
                max_length=max_length,
                padding="max_length",
                truncation=True,
            )
            input_ids = inputs["input_ids"].to(DEVICE)
            
            outputs = model(input_ids, output_hidden_states=True)
            # Lấy hidden state lớp cuối cùng, token ở giữa
            last_hidden = outputs.hidden_states[-1]  # [B, L, H]
            seq_len = last_hidden.size(1)
            center_idx = seq_len // 2
            batch_embs = last_hidden[:, center_idx, :].float().cpu()
            all_embs.append(batch_embs)
    
    return torch.cat(all_embs, dim=0)


def process_split(parquet_path, out_path, tokenizer, model, batch_size):
    """Process một split (train/val/test)"""
    df = pd.read_parquet(parquet_path)
    print(f"\nProcessing {parquet_path} ({len(df)} rows)")
    
    # Kiểm tra columns
    if "ref_seq" not in df.columns or "alt_seq" not in df.columns:
        raise ValueError(f"Missing required columns in {parquet_path}")
    
    dna_ref = df["ref_seq"].astype(str).tolist()
    dna_alt = df["alt_seq"].astype(str).tolist()
    
    # Kiểm tra label column
    if "label" in df.columns:
        labels = torch.tensor(df["label"].values, dtype=torch.long)
    elif "ClinicalSignificance" in df.columns:
        keep = ["Pathogenic", "Benign"]
        df = df[df["ClinicalSignificance"].isin(keep)].copy()
        label_map = {"Pathogenic": 1, "Benign": 0}
        df["label"] = df["ClinicalSignificance"].map(label_map)
        labels = torch.tensor(df["label"].values, dtype=torch.long)
    else:
        raise ValueError(f"No label column found in {parquet_path}")
    
    # Embed DNA sequences
    print("Embedding ref sequences...")
    dna_ref_emb = embed_dna_hyenadna(dna_ref, tokenizer, model, batch_size, DNA_SEQ_LEN)
    
    print("Embedding alt sequences...")
    dna_alt_emb = embed_dna_hyenadna(dna_alt, tokenizer, model, batch_size, DNA_SEQ_LEN)
    
    # Lưu embeddings
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    torch.save(
        {
            "dna_ref": dna_ref_emb,
            "dna_alt": dna_alt_emb,
            "label": labels,
        },
        out_path,
    )
    
    print(f"Saved to {out_path}")
    print(f"  DNA embedding dim: {dna_ref_emb.shape[1]}")
    print(f"  Number of samples: {len(labels)}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--split", type=str, choices=["train", "val", "test", "all"], default="all",
                       help="Which split to process")
    parser.add_argument("--dna_batch", type=int, default=DNA_BATCH,
                       help="Batch size for DNA embedding")
    args = parser.parse_args([])
    
    print(f"Device: {DEVICE}")
    print(f"Loading HyenaDNA model: {HYENADNA_MODEL}")
    
    tokenizer = AutoTokenizer.from_pretrained(HYENADNA_MODEL, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        HYENADNA_MODEL,
        torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.bfloat16,
        trust_remote_code=True
    ).to(DEVICE)
    
    splits_to_process = []
    if args.split == "all":
        splits_to_process = [
            (TRAIN_PARQUET, TRAIN_EMB),
            (VAL_PARQUET, VAL_EMB),
            (TEST_PARQUET, TEST_EMB),
        ]
    elif args.split == "train":
        splits_to_process = [(TRAIN_PARQUET, TRAIN_EMB)]
    elif args.split == "val":
        splits_to_process = [(VAL_PARQUET, VAL_EMB)]
    elif args.split == "test":
        splits_to_process = [(TEST_PARQUET, TEST_EMB)]
    
    for parquet_path, emb_path in splits_to_process:
        if not os.path.exists(parquet_path):
            print(f"Warning: {parquet_path} not found, skipping...")
            continue
        process_split(parquet_path, emb_path, tokenizer, model, args.dna_batch)
    
    del model
    torch.cuda.empty_cache()
    print("\nDone!")


if __name__ == "__main__":
    main()

Device: cuda
Loading HyenaDNA model: LongSafari/hyenadna-large-1m-seqlen-hf

Done!


### Training Script

In [10]:
import os
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import json
import shutil
from datetime import datetime
from torchinfo import summary
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import torchmetrics
from torchmetrics.classification import (
    BinaryAUROC, BinaryAccuracy, BinaryMatthewsCorrCoef, MulticlassF1Score, MulticlassAccuracy,
    BinaryConfusionMatrix, BinaryPrecision, BinaryRecall, BinarySpecificity
)


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def plot_confusion_matrix(cm_tensor, epoch, stage):
    """Plot confusion matrix"""
    cm = cm_tensor.cpu().numpy()
    fig, ax = plt.subplots(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='g', cmap='Blues', ax=ax,
                xticklabels=['Benign', 'Pathogenic'], 
                yticklabels=['Benign', 'Pathogenic'])
    ax.set_xlabel('Predicted labels')
    ax.set_ylabel('True labels')
    ax.set_title(f'Confusion Matrix - {stage} - Epoch {epoch}')
    plt.tight_layout()
    return fig


def run_epoch(model, loader, criterion, device, metrics_collection, cm_metric, optimizer=None, writer=None, epoch=0, stage="train"):
    is_train = optimizer is not None
    model.train() if is_train else model.eval()

    total_loss = 0.0
    metrics_collection.reset()
    cm_metric.reset()

    pbar = tqdm(loader, desc=f"{stage.capitalize()} Epoch {epoch}", leave=False)

    for dna_ref, dna_alt, label in pbar:
        dna_ref = dna_ref.to(device)
        dna_alt = dna_alt.to(device)
        label = label.to(device).float()

        with torch.set_grad_enabled(is_train):
            logits = model(dna_ref, dna_alt)
            loss = criterion(logits, label)
            if is_train:
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

        total_loss += loss.item() * len(label)
        preds = (torch.sigmoid(logits) > 0.5).int()
        metrics_collection.update(preds, label.int())
        cm_metric.update(preds, label.int())
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    avg_loss = total_loss / len(loader.dataset)
    results = {k: v.item() for k, v in metrics_collection.compute().items()}
    results['loss'] = avg_loss

    if writer:
        for name, value in results.items():
            writer.add_scalar(f"{stage}/{name}", value, epoch)
        
        cm_tensor = cm_metric.compute()
        fig = plot_confusion_matrix(cm_tensor, epoch, stage)
        writer.add_figure(f"ConfusionMatrix/{stage}", fig, epoch)
        plt.close(fig)
            
    return results


def train(args):
    seed_everything(args.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Tạo experiment directory
    if args.exp_name is None:
        args.exp_name = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    
    if args.log_dir is None:
        args.log_dir = os.path.join("runs_hyenadna", args.exp_name)
    
    exp_dir = os.path.join("experiments_hyenadna", args.exp_name)
    os.makedirs(exp_dir, exist_ok=True)
    
    # In configuration
    print("=" * 70)
    print("TRAINING CONFIGURATION:")
    print("=" * 70)
    print(f"  Experiment Name: {args.exp_name}")
    print(f"  Device: {device}")
    print(f"  Learning Rate: {args.lr}")
    print(f"  Epochs: {args.epochs}")
    print(f"  Batch Size: {args.batch_size}")
    print(f"  Patience: {args.patience}")
    print(f"  Dropout: {args.dropout}")
    print(f"  Seed: {args.seed}")
    print(f"  Proj Dim: {args.proj_dim}")
    print(f"  Fusion Hidden: {args.fusion_hidden}")
    print("=" * 70)
    
    # Lưu config
    config_snapshot = {
        "exp_name": args.exp_name,
        "timestamp": datetime.now().isoformat(),
        "lr": args.lr,
        "epochs": args.epochs,
        "batch_size": args.batch_size,
        "patience": args.patience,
        "dropout": args.dropout,
        "seed": args.seed,
        "proj_dim": args.proj_dim,
        "fusion_hidden": args.fusion_hidden,
    }
    with open(os.path.join(exp_dir, "config.json"), "w") as f:
        json.dump(config_snapshot, f, indent=2)
    
    writer = SummaryWriter(log_dir=args.log_dir)

    # Load datasets
    train_ds = HyenaDNADataset(TRAIN_EMB)
    val_ds = HyenaDNADataset(VAL_EMB)
    test_ds = HyenaDNADataset(TEST_EMB)

    loader_args = {'batch_size': args.batch_size, 'num_workers': 8, 'pin_memory': True}
    train_loader = DataLoader(train_ds, shuffle=True, **loader_args)
    val_loader = DataLoader(val_ds, shuffle=False, **loader_args)
    test_loader = DataLoader(test_ds, shuffle=False, **loader_args)

    # Tạo model
    dna_dim = train_ds.dna_ref.shape[1]
    model = DNAClassifier(
        dna_dim=dna_dim,
        proj_dim=args.proj_dim,
        hidden_dims=args.fusion_hidden,
        dropout=args.dropout
    ).to(device)

    print("\n" + "="*30 + " MODEL SUMMARY " + "="*30)
    input_data_shapes = [
        (args.batch_size, dna_dim),  # dna_ref
        (args.batch_size, dna_dim),  # dna_alt
    ]

    model_stats = summary(
        model, 
        input_size=input_data_shapes,
        col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"],
        device=device,
        verbose=0
    )
    
    print(model_stats)
    
    summary_path = os.path.join(exp_dir, "model_summary.txt")
    with open(summary_path, "w", encoding="utf-8") as f:
        f.write(str(model_stats))
    print(f"--> Model summary saved to {summary_path}")
    print("="*75 + "\n")

    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

    metrics = torchmetrics.MetricCollection({
        'auc': BinaryAUROC(),
        'acc': BinaryAccuracy(),
        'mcc': BinaryMatthewsCorrCoef(),
        'balanced_acc': MulticlassAccuracy(num_classes=2, average='macro'),
        'f1_macro': MulticlassF1Score(num_classes=2, average='macro'),
        'precision': BinaryPrecision(),
        'recall': BinaryRecall(),
        'specificity': BinarySpecificity()
    }).to(device)

    cm_metric = BinaryConfusionMatrix().to(device)

    best_val_loss = float("inf")
    patience_counter = 0
    save_path = os.path.join(exp_dir, "best_model.pt")

    for epoch in range(1, args.epochs + 1):
        train_res = run_epoch(model, train_loader, criterion, device, metrics, cm_metric, optimizer, writer, epoch, "train")
        val_res = run_epoch(model, val_loader, criterion, device, metrics, cm_metric, None, writer, epoch, "val")

        print(f"[{epoch}] Train Loss: {train_res['loss']:.4f} | Val Loss: {val_res['loss']:.4f} | Train Acc: {train_res['acc']:.4f} | Val Acc: {val_res['acc']:.4f}")

        scheduler.step(val_res['loss'])

        if val_res['loss'] < best_val_loss:
            best_val_loss = val_res['loss']
            patience_counter = 0
            torch.save(model.state_dict(), save_path)
            print(f"--> Saved best model checkpoint to {save_path}")
        else:
            patience_counter += 1
            if patience_counter >= args.patience:
                print("Early stopping triggered.")
                break

    print("\n--- Testing with Best Model ---")
    model.load_state_dict(torch.load(save_path))
    test_res = run_epoch(model, test_loader, criterion, device, metrics, cm_metric, None, writer, args.epochs, "test")
    print(f"[TEST] Loss: {test_res['loss']:.4f} | AUC: {test_res['auc']:.4f} | MCC: {test_res['mcc']:.4f} | Acc: {test_res['acc']:.4f} | Spec: {test_res['specificity']:.4f}")
    print(f"[TEST] Balanced Acc: {test_res['balanced_acc']:.4f} | F1_macro: {test_res['f1_macro']:.4f} | Precision: {test_res['precision']:.4f} | Recall: {test_res['recall']:.4f}")
    
    # Lưu hparams vào TensorBoard
    hparams = {
        "lr": args.lr,
        "dropout": args.dropout,
        "batch_size": args.batch_size,
        "proj_dim": args.proj_dim,
        "fusion_hidden": str(args.fusion_hidden),
        "patience": args.patience,
    }
    metrics_dict = {
        "test_auc": test_res['auc'],
        "test_acc": test_res['acc'],
        "test_mcc": test_res['mcc'],
        "test_balanced_acc": test_res['balanced_acc'],
        "test_f1_macro": test_res['f1_macro'],
        "test_precision": test_res['precision'],
        "test_recall": test_res['recall'],
        "test_specificity": test_res['specificity'],
        "test_loss": test_res['loss'],
        "best_val_loss": best_val_loss,
    }
    writer.add_hparams(hparams, metrics_dict)
    
    # Lưu kết quả
    final_results = {
        "exp_name": args.exp_name,
        "timestamp": datetime.now().isoformat(),
        "best_val_loss": float(best_val_loss),
        "test_results": {k: float(v) for k, v in test_res.items()},
        "epochs_trained": epoch,
        "hparams": hparams,
    }
    with open(os.path.join(exp_dir, "results.json"), "w") as f:
        json.dump(final_results, f, indent=2)
    
    writer.close()
    
    print(f"\n✓ Experiment saved to: {exp_dir}")
    print(f"  - Config: {os.path.join(exp_dir, 'config.json')}")
    print(f"  - Results: {os.path.join(exp_dir, 'results.json')}")
    print(f"  - Model: {save_path}")
    
    return test_res


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp_name", type=str, default="experiment_1")
    parser.add_argument("--lr", type=float, default=LR)
    parser.add_argument("--epochs", type=int, default=EPOCHS)
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE)
    parser.add_argument("--patience", type=int, default=PATIENCE)
    parser.add_argument("--dropout", type=float, default=DROPOUT)
    parser.add_argument("--seed", type=int, default=SEED)
    parser.add_argument("--proj_dim", type=int, default=PROJ_DIM)
    parser.add_argument("--fusion_hidden", type=int, nargs='+', default=FUSION_HIDDEN)
    parser.add_argument("--weight_decay", type=float, default=WEIGHT_DECAY)
    parser.add_argument("--log_dir", type=str, default=None)
    args = parser.parse_args([])
    
    train(args)

TRAINING CONFIGURATION:
  Experiment Name: experiment_1
  Device: cuda
  Learning Rate: 0.0001
  Epochs: 30
  Batch Size: 512
  Patience: 5
  Dropout: 0.3
  Seed: 42
  Proj Dim: 512
  Fusion Hidden: [512, 256]


OSError: [Errno 22] Invalid argument: 'benchmark\task3_variant_prediction\\HyenaDNA\\embeddings_hyenadna\\train_embeddings.pt'