In [4]:
import torch
from dataset import read_data
from genechat_model import GeneChatModel, DNABERTBartDecoder, DNABERTT5Decoder, DNABERTGRUDecoder

device = "cuda" if torch.cuda.is_available() else "cpu"

train_data, test_data, _, _ = read_data()

# Load all 4 models
models = {}

# GenChat model
model_genechat = GeneChatModel(
    gene_chunk_nt=512,
    gene_chunk_overlap=0,
    freeze_gene_encoder=True,
).to(device)
state = torch.load("model_weights/genechat_best.pt", map_location=device, weights_only=False)
model_genechat.load_state_dict(state)
models['genechat'] = model_genechat

# BART model
model_bart = DNABERTBartDecoder(
    gene_chunk_nt=512,
    gene_chunk_overlap=0,
    freeze_gene_encoder=True,
).to(device)
state = torch.load("model_weights/bart_best.pt", map_location=device, weights_only=False)
model_bart.load_state_dict(state)
models['bart'] = model_bart

# T5 model
model_t5 = DNABERTT5Decoder(
    gene_chunk_nt=512,
    gene_chunk_overlap=0,
    freeze_gene_encoder=True,
).to(device)
state = torch.load("model_weights/t5_best.pt", map_location=device, weights_only=False)
model_t5.load_state_dict(state)
models['t5'] = model_t5

# GRU model
model_gru = DNABERTGRUDecoder(
    gene_chunk_nt=512,
    gene_chunk_overlap=0,
    freeze_gene_encoder=True,
).to(device)
state = torch.load("model_weights/gru_best.pt", map_location=device, weights_only=False)
model_gru.load_state_dict(state)
models['gru'] = model_gru

print(f"Loaded {len(models)} models: {list(models.keys())}")
print(f"Test dataset size: {len(test_data)}")

# Quick test with one example
example = test_data[10]
dna = example["dna"]
target = example["target"]
print(f"\n=== Ground Truth ===")
print(target[:200])

for model_name, model in models.items():
    gen = model.generate(
        dna=dna,
        max_new_tokens=80,
        device=device,
        temperature=0.8,
        top_k=50,
    )
    print(f"\n=== {model_name.upper()} Output ===")
    print(gen[:200])

  from .autonotebook import tqdm as notebook_tqdm
A new version of the following files was downloaded from https://huggingface.co/zhihan1996/DNA_bert_6:
- configuration_bert.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/zhihan1996/DNA_bert_6:
- dnabert_layer.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


KeyboardInterrupt: 

In [3]:
!pip install -r requirements.txt

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting transformers>=4.30.0 (from -r requirements.txt (line 2))
  Downloading transformers-4.57.3-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scikit-learn>=1.3.0 (from -r requirements.txt (line 3))
  Downloading scikit_learn-1.7.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers>=4.30.0->-r requirements.txt (line 2))
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers>=4.30.0->-r requirements.txt (line 2))
  Downloading tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting safetensors>=0.4.3 (from transformers>=4.30.0->-r requirements.txt (line 2))
  Downloading safetensors-0.

In [None]:
import nltk
nltk.download('wordnet', download_dir='/home/jovyan/nltk_data')
nltk.data.path.append('/home/jovyan/nltk_data')

[nltk_data] Downloading package wordnet to /home/jovyan/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [None]:
NUM_SAMPLES = 5
random_indices = random.sample(range(len(test_data)), NUM_SAMPLES)

print("="*80)
print("SAMPLE PREDICTIONS - ALL MODELS")
print("="*80)

for idx in random_indices:
    ex = test_data[idx]
    dna = ex["dna"]
    ref = ex["target"]

    print(f"\n{'='*80}")
    print(f"Sample Index: {idx}")
    print(f"{'='*80}")
    print(f"\nGround Truth:\n{ref[:300]}")
    print()
    
    for model_name, model in models.items():
        pred = model.generate(
            dna=dna,
            max_new_tokens=80,
            device=device,
            temperature=0.7,
            top_k=50,
        )

        m = compute_metrics(ref, pred)
        
        print(f"\n--- {model_name.upper()} ---")
        print(f"Prediction: {pred[:300]}")
        print(f"Metrics: BLEU-1={m['bleu1']:.3f}, BLEU-4={m['bleu4']:.3f}, METEOR={m['meteor']:.3f}, ROUGE-1={m['rouge1']:.3f}")

print(f"\n{'='*80}")
print("SAMPLE PREDICTIONS COMPLETE")
print(f"{'='*80}")

In [None]:
import random, numpy as np
from sacrebleu import BLEU
from rouge_score import rouge_scorer
from nltk.translate.meteor_score import meteor_score
import matplotlib.pyplot as plt
from tqdm import tqdm

# BLEU scorers for BLEU-1..4
bleu1 = BLEU(max_ngram_order=1, effective_order=True)
bleu2 = BLEU(max_ngram_order=2, effective_order=True)
bleu3 = BLEU(max_ngram_order=3, effective_order=True)
bleu4 = BLEU(max_ngram_order=4, effective_order=True)

# ROUGE scorer
rouge = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)

def compute_metrics(ref, pred):
    """Compute BLEU-n, METEOR, ROUGE-1, ROUGE-L."""

    # BLEU scores
    b1 = bleu1.sentence_score(pred, [ref]).score / 100
    b2 = bleu2.sentence_score(pred, [ref]).score / 100
    b3 = bleu3.sentence_score(pred, [ref]).score / 100
    b4 = bleu4.sentence_score(pred, [ref]).score / 100

    # METEOR
    ref_tok = ref.split()
    pred_tok = pred.split()
    meteor = meteor_score([ref_tok], pred_tok)

    # ROUGE
    r = rouge.score(ref, pred)
    rouge1 = r["rouge1"].fmeasure
    rougeL = r["rougeL"].fmeasure

    return {
        "bleu1": b1,
        "bleu2": b2,
        "bleu3": b3,
        "bleu4": b4,
        "meteor": meteor,
        "rouge1": rouge1,
        "rougeL": rougeL,
    }


# Storage for ALL models
all_model_metrics = {}

# Loop through each model
for model_name, model in models.items():
    print(f"\n{'='*80}")
    print(f"Evaluating {model_name.upper()} model...")
    print(f"{'='*80}")
    
    model.eval()
    
    # Storage lists for this model
    bleu1_scores, bleu2_scores, bleu3_scores, bleu4_scores = [], [], [], []
    meteor_scores = []
    rouge1_scores, rougeL_scores = [], []

    # Loop through dataset
    for ex in tqdm(test_data, desc=f"Evaluating {model_name}"):
        dna = ex["dna"]
        ref = ex["target"]

        pred = model.generate(
            dna=dna,
            max_new_tokens=80,
            device=device,
            temperature=0.7,
            top_k=50,
        )

        m = compute_metrics(ref, pred)

        bleu1_scores.append(m["bleu1"])
        bleu2_scores.append(m["bleu2"])
        bleu3_scores.append(m["bleu3"])
        bleu4_scores.append(m["bleu4"])

        meteor_scores.append(m["meteor"])

        rouge1_scores.append(m["rouge1"])
        rougeL_scores.append(m["rougeL"])
    
    # Store all metrics for this model
    all_model_metrics[model_name] = {
        'bleu1': bleu1_scores,
        'bleu2': bleu2_scores,
        'bleu3': bleu3_scores,
        'bleu4': bleu4_scores,
        'meteor': meteor_scores,
        'rouge1': rouge1_scores,
        'rougeL': rougeL_scores,
    }
    
    print(f"Completed evaluation for {model_name}!")

print(f"\n{'='*80}")
print("ALL MODELS EVALUATED!")
print(f"{'='*80}")

In [None]:
# ==== COMPUTE AVERAGES FOR ALL MODELS ====
print("\n" + "="*80)
print("AVERAGE METRICS COMPARISON - ALL MODELS")
print("="*80 + "\n")

results_summary = {}

for model_name, metrics in all_model_metrics.items():
    avg_bleu1   = np.mean(metrics['bleu1'])
    avg_bleu2   = np.mean(metrics['bleu2'])
    avg_bleu3   = np.mean(metrics['bleu3'])
    avg_bleu4   = np.mean(metrics['bleu4'])
    avg_meteor  = np.mean(metrics['meteor'])
    avg_rouge1  = np.mean(metrics['rouge1'])
    avg_rougeL  = np.mean(metrics['rougeL'])
    
    results_summary[model_name] = {
        'BLEU-1': avg_bleu1,
        'BLEU-2': avg_bleu2,
        'BLEU-3': avg_bleu3,
        'BLEU-4': avg_bleu4,
        'METEOR': avg_meteor,
        'ROUGE-1': avg_rouge1,
        'ROUGE-L': avg_rougeL,
    }
    
    print(f"=== {model_name.upper()} ===")
    print(f"BLEU-1:   {avg_bleu1:.4f}")
    print(f"BLEU-2:   {avg_bleu2:.4f}")
    print(f"BLEU-3:   {avg_bleu3:.4f}")
    print(f"BLEU-4:   {avg_bleu4:.4f}")
    print(f"METEOR:   {avg_meteor:.4f}")
    print(f"ROUGE-1:  {avg_rouge1:.4f}")
    print(f"ROUGE-L:  {avg_rougeL:.4f}")
    print()

# ==== CREATE COMPARISON TABLE ====
import pandas as pd
df = pd.DataFrame(results_summary).T
print("\n" + "="*80)
print("METRICS TABLE")
print("="*80)
print(df.to_string())
print()

# ==== BAR CHART COMPARISON ====
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
fig.suptitle('Metrics Comparison Across All Models', fontsize=16, fontweight='bold')

metrics_names = ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4', 'METEOR', 'ROUGE-1', 'ROUGE-L']
model_names = list(all_model_metrics.keys())

for idx, metric_name in enumerate(metrics_names):
    ax = axes[idx // 4, idx % 4]
    values = [results_summary[m][metric_name] for m in model_names]
    bars = ax.bar(model_names, values, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'][:len(model_names)])
    ax.set_title(metric_name, fontweight='bold')
    ax.set_ylabel('Score')
    ax.set_ylim(0, max(values) * 1.2)
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.3f}',
                ha='center', va='bottom', fontsize=9)

# Hide the last subplot (we have 7 metrics, 8 subplots)
axes[1, 3].axis('off')

plt.tight_layout()
plt.show()

# ==== DISTRIBUTION HISTOGRAMS FOR EACH MODEL ====
for model_name, metrics in all_model_metrics.items():
    fig, axes = plt.subplots(3, 3, figsize=(18, 12))
    fig.suptitle(f'{model_name.upper()} - Metric Distributions', fontsize=16, fontweight='bold')
    
    # Row 1 — BLEU scores
    axes[0, 0].hist(metrics['bleu1'], bins=30, color='steelblue', alpha=0.7)
    axes[0, 0].set_title("BLEU-1")
    axes[0, 0].axvline(np.mean(metrics['bleu1']), color='red', linestyle='--', label=f'Mean: {np.mean(metrics["bleu1"]):.3f}')
    axes[0, 0].legend()
    
    axes[0, 1].hist(metrics['bleu2'], bins=30, color='steelblue', alpha=0.7)
    axes[0, 1].set_title("BLEU-2")
    axes[0, 1].axvline(np.mean(metrics['bleu2']), color='red', linestyle='--', label=f'Mean: {np.mean(metrics["bleu2"]):.3f}')
    axes[0, 1].legend()
    
    axes[0, 2].hist(metrics['bleu3'], bins=30, color='steelblue', alpha=0.7)
    axes[0, 2].set_title("BLEU-3")
    axes[0, 2].axvline(np.mean(metrics['bleu3']), color='red', linestyle='--', label=f'Mean: {np.mean(metrics["bleu3"]):.3f}')
    axes[0, 2].legend()
    
    # Row 2 — BLEU-4, METEOR, ROUGE-1
    axes[1, 0].hist(metrics['bleu4'], bins=30, color='steelblue', alpha=0.7)
    axes[1, 0].set_title("BLEU-4")
    axes[1, 0].axvline(np.mean(metrics['bleu4']), color='red', linestyle='--', label=f'Mean: {np.mean(metrics["bleu4"]):.3f}')
    axes[1, 0].legend()
    
    axes[1, 1].hist(metrics['meteor'], bins=30, color='green', alpha=0.7)
    axes[1, 1].set_title("METEOR")
    axes[1, 1].axvline(np.mean(metrics['meteor']), color='red', linestyle='--', label=f'Mean: {np.mean(metrics["meteor"]):.3f}')
    axes[1, 1].legend()
    
    axes[1, 2].hist(metrics['rouge1'], bins=30, color='orange', alpha=0.7)
    axes[1, 2].set_title("ROUGE-1")
    axes[1, 2].axvline(np.mean(metrics['rouge1']), color='red', linestyle='--', label=f'Mean: {np.mean(metrics["rouge1"]):.3f}')
    axes[1, 2].legend()
    
    # Row 3 — ROUGE-L (centered)
    axes[2, 1].hist(metrics['rougeL'], bins=30, color='orange', alpha=0.7)
    axes[2, 1].set_title("ROUGE-L")
    axes[2, 1].axvline(np.mean(metrics['rougeL']), color='red', linestyle='--', label=f'Mean: {np.mean(metrics["rougeL"]):.3f}')
    axes[2, 1].legend()
    
    # Hide empty subplots
    axes[2, 0].axis('off')
    axes[2, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

print("\n" + "="*80)
print("EVALUATION COMPLETE!")
print("="*80)