In [2]:
#CUDA and GPU Availability Check
import torch

print("=" * 70)
print("CUDA/GPU CONFIGURATION")
print("=" * 70)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")

if torch.cuda.is_available():
    print(f"\n✅ GPU ENABLED!")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Count: {torch.cuda.device_count()}")
    device = torch.device("cuda")
    print(f"Device: {device}")
else:
    print("\n❌ CUDA not available")
    device = torch.device("cpu")
    print(f"Device: {device}")

print("=" * 70 + "\n")

        
     



CUDA/GPU CONFIGURATION
PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA version: 12.1

✅ GPU ENABLED!
GPU Name: NVIDIA GeForce RTX 4090
GPU Count: 1
Device: cuda



In [3]:
import os
import torch
import numpy as np
import sklearn.metrics
import random
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, set_seed
from torch.utils.data import Dataset
from genomic_benchmarks.data_check import list_datasets
from genomic_benchmarks.loc2seq import download_dataset
from pathlib import Path

class GenomicDataset(Dataset):
    def __init__(self, sequences, labels, tokenizer, max_length=512):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            sequence,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }

def load_single_dataset(dataset_name, split):
    dataset_path = download_dataset(dataset_name)
    sequences = []
    labels = []
    split_path = Path(dataset_path) / split
    class_dirs = sorted([d for d in split_path.iterdir() if d.is_dir()])
    for label_idx, class_dir in enumerate(class_dirs):
        for seq_file in class_dir.glob('*.txt'):
            with open(seq_file, 'r') as f:
                sequences.append(f.read().strip())
                labels.append(label_idx)
    return sequences, labels

def load_merged_genomic_data(split='train'):
    seq_p, lab_p = load_single_dataset("human_nontata_promoters", split)
    seq_e, lab_e = load_single_dataset("human_enhancers_cohn", split)
    
    num_p = len(seq_p) // 2
    num_e = len(seq_e) // 2
    
    combined_seq = seq_p[:num_p] + seq_e[:num_e]
    combined_lab = lab_p[:num_p] + lab_e[:num_e]
    
    combined = list(zip(combined_seq, combined_lab))
    random.shuffle(combined)
    sequences, labels = zip(*combined)
    
    return list(sequences), list(labels)

def calculate_metrics(predictions, labels):
    valid_mask = labels != -100
    valid_predictions = predictions[valid_mask]
    valid_labels = labels[valid_mask]
    return {
        "accuracy": sklearn.metrics.accuracy_score(valid_labels, valid_predictions),
        "f1": sklearn.metrics.f1_score(valid_labels, valid_predictions, average="macro", zero_division=0),
        "matthews_correlation": sklearn.metrics.matthews_corrcoef(valid_labels, valid_predictions),
        "precision": sklearn.metrics.precision_score(valid_labels, valid_predictions, average="macro", zero_division=0),
        "recall": sklearn.metrics.recall_score(valid_labels, valid_predictions, average="macro", zero_division=0),
    }

def preprocess_logits(logits, _):
    if isinstance(logits, tuple):
        logits = logits[0]
    if logits.ndim == 3:
        logits = logits.reshape(-1, logits.shape[-1])
    return torch.argmax(logits, dim=-1)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    return calculate_metrics(predictions, labels)

def train_merged_omni_dna(
    model_name="zehui127/Omni-DNA-116M",
    output_dir="./omni_dna_merged_classifier",
    seed=42,
    learning_rate=5e-6,
    batch_size=5,
    num_epochs=3,
    max_length=512
):
    set_seed(seed)
    
    train_sequences, train_labels = load_merged_genomic_data(split='train')
    test_sequences, test_labels = load_merged_genomic_data(split='test')
    
    num_classes = len(set(train_labels))
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.model_max_length = max_length
    
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=num_classes,
        trust_remote_code=True
    ).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

    model._tied_weights_keys = ["word_embeddings.weight", "model.transformer.wte.weight"]
    
    train_dataset = GenomicDataset(train_sequences, train_labels, tokenizer, max_length)
    test_dataset = GenomicDataset(test_sequences, test_labels, tokenizer, max_length)
    
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size * 2,
        num_train_epochs=num_epochs,
        eval_strategy="epoch",
        save_strategy="epoch",
        max_grad_norm=1.0,
        metric_for_best_model="matthews_correlation",
        greater_is_better=True,
        save_total_limit=2,
        load_best_model_at_end=True,
        save_safetensors=True,
        logging_steps=10,
        report_to="none"
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        compute_metrics=compute_metrics,
        preprocess_logits_for_metrics=preprocess_logits,
    )
    
    trainer.train()
    trainer.save_model(output_dir) 
    tokenizer.save_pretrained(output_dir)
    
    test_metrics = trainer.evaluate(eval_dataset=test_dataset)
    
    print("\n" + "="*50)
    print("FINAL MERGED TEST RESULTS")
    print("="*50)
    for metric, value in test_metrics.items():
        print(f"{metric}: {value:.4f}")
    print("="*50)
    
    return trainer, model, test_metrics

if __name__ == "__main__":
    train_merged_omni_dna()

Some weights of OLMoForSequenceCLS were not initialized from the model checkpoint at zehui127/Omni-DNA-116M and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


num_labels: 2


Epoch,Training Loss,Validation Loss,Accuracy,F1,Matthews Correlation,Precision,Recall
1,0.3479,0.082577,0.981855,0.88657,0.789708,0.976259,0.827363
2,0.0001,0.110632,0.984357,0.904407,0.821377,0.981752,0.850107
3,0.0,0.153865,0.985233,0.912331,0.832605,0.971927,0.867234


There were missing keys in the checkpoint model loaded: ['model.transformer.wte.weight'].



FINAL MERGED TEST RESULTS
eval_loss: 0.1539
eval_accuracy: 0.9852
eval_f1: 0.9123
eval_matthews_correlation: 0.8326
eval_precision: 0.9719
eval_recall: 0.8672
eval_runtime: 48.2255
eval_samples_per_second: 165.7010
eval_steps_per_second: 16.5890
epoch: 3.0000
