In [1]:
import torch
from dataset import read_data
from genechat_model import GeneChatModel

device = "cuda" if torch.cuda.is_available() else "cpu"

train_data, test_data, _, _ = read_data()

model = GeneChatModel(
    gene_chunk_nt=512,
    gene_chunk_overlap=0,
    freeze_gene_encoder=True,
).to(device)

state = torch.load("genechat_checkpoints/model_best.pt", map_location=device)
model.load_state_dict(state)

example1 = test_data[10]
example2 = test_data[200]

for ex_i, ex in enumerate([example1, example2], start=1):
    dna = ex["dna"]
    target = ex["target"]
    print(f"\n=== Example {ex_i} Ground Truth ===")
    print(target)

    gen = model.generate(
        dna=dna,
        max_new_tokens=80,
        device=device,
        temperature=0.8,
        top_k=50,
    )
    print(f"\n=== Example {ex_i} Model Output ===")
    print(gen)


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
  state = torch.load("genechat_checkpoints/model_best.pt", map_location=device)



=== Example 1 Ground Truth ===
This genomic region was validated as an active enhancer by the ChIP-STARR-seq massively parallel reporter assay in naive human embryonic stem cells. This enhancer is marked by the H3K27ac and H3K4me1 histone modifications.

=== Example 1 Model Output ===
of insulin secretion; and negative regulation of signal transduction. Predicted to be located in cytoplasm. Predicted to be active in plasma membrane. Is expressed in intestinal bulb. Orthologous to human SLC17A1 (solute carrier family 17 member 1). Human ortholog(

=== Example 2 Ground Truth ===
This genomic sequence was predicted to be a transcriptional regulatory region based on chromatin state analysis from the ENCODE (ENCyclopedia Of DNA Elements) project. It was validated as a functional repressive element by the Sharpr-MPRA technique (Systematic high-resolution activation and repression profiling with reporter tiling using massively parallel reporter assays) in K562 erythroleukemia cells (group: K

In [5]:
# model.eval()
# !pip install sacrebleu rouge-score nltk matplotlib tqdm

In [6]:
import nltk
nltk.download('wordnet', download_dir='/home/jovyan/nltk_data')
nltk.data.path.append('/home/jovyan/nltk_data')

[nltk_data] Downloading package wordnet to /home/jovyan/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [32]:
NUM_SAMPLES = 10
random_indices = random.sample(range(len(test_data)), NUM_SAMPLES)
metrics_list = []

for idx in random_indices:
    ex = test_data[idx]
    dna = ex["dna"]
    ref = ex["target"]

    pred = model.generate(
        dna=dna,
        max_new_tokens=80,
        device=device,
        temperature=0.7,
        top_k=50,
    )

    m = compute_metrics(ref, pred)
    metrics_list.append(m)

    print(f"\n--- Sample idx {idx} ---")
    print("Ground Truth:", ref[:200].replace("\n"," ") + "...")
    print("Prediction:  ", pred[:200].replace("\n"," ") + "...")
    print("Metrics:", m)

# ---------- Aggregate ----------
avg_bleu   = np.mean([m["bleu"] for m in metrics_list])
avg_meteor = np.mean([m["meteor"] for m in metrics_list])
avg_rouge  = np.mean([m["rougeL"] for m in metrics_list])


--- Sample idx 1242 ---
Ground Truth: This gene encodes a homeobox-containing protein that belongs to the NK-2 homeobox family. This protein is a vertebrate homolog of Drosophila homeobox-containing protein called 'tinman', which has been...
Prediction:   be involved in protein methylation. Predicted to be located in cytoplasm and nucleus. Human ortholog(s) of this gene implicated in methylmalonic aciduria. Orthologous to human HMTM1 (methylmalonyl-CoA...
Metrics: {'bleu': 1.3772734577189594, 'meteor': 0.11492947386638977, 'rougeL': 0.13333333333333333}

--- Sample idx 197 ---
Ground Truth: This gene encodes a member of the DEAD box protein family. These proteins are characterized by the conserved motif Asp-Glu-Ala-Asp (DEAD) and are putative RNA helicases. They are implicated in a numbe...
Prediction:   in nucleus. Orthologous to human NFKBIE (NFKB inhibitor epsilon). Orthologous to human NFKBIE1 (NFKB inhibitor epsilon 1). Orthologous to human KNMT1 (NFKB inhibitor epsilon 1). Ortho

In [7]:
import random, numpy as np
from sacrebleu import BLEU
from rouge_score import rouge_scorer
from nltk.translate.meteor_score import meteor_score
import matplotlib.pyplot as plt
from tqdm import tqdm

# BLEU scorers for BLEU-1..4
bleu1 = BLEU(max_ngram_order=1, effective_order=True)
bleu2 = BLEU(max_ngram_order=2, effective_order=True)
bleu3 = BLEU(max_ngram_order=3, effective_order=True)
bleu4 = BLEU(max_ngram_order=4, effective_order=True)

# ROUGE scorer
rouge = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)

def compute_metrics(ref, pred):
    """Compute BLEU-n, METEOR, ROUGE-1, ROUGE-L."""

    # BLEU scores
    b1 = bleu1.sentence_score(pred, [ref]).score / 100
    b2 = bleu2.sentence_score(pred, [ref]).score / 100
    b3 = bleu3.sentence_score(pred, [ref]).score / 100
    b4 = bleu4.sentence_score(pred, [ref]).score / 100

    # METEOR
    ref_tok = ref.split()
    pred_tok = pred.split()
    meteor = meteor_score([ref_tok], pred_tok)

    # ROUGE
    r = rouge.score(ref, pred)
    rouge1 = r["rouge1"].fmeasure
    rougeL = r["rougeL"].fmeasure

    return {
        "bleu1": b1,
        "bleu2": b2,
        "bleu3": b3,
        "bleu4": b4,
        "meteor": meteor,
        "rouge1": rouge1,
        "rougeL": rougeL,
    }


# Storage lists
bleu1_scores, bleu2_scores, bleu3_scores, bleu4_scores = [], [], [], []
meteor_scores = []
rouge1_scores, rougeL_scores = [], []


# Loop through dataset
for ex in tqdm(test_data):
    dna = ex["dna"]
    ref = ex["target"]

    pred = model.generate(
        dna=dna,
        max_new_tokens=80,
        device=device,
        temperature=0.7,
        top_k=50,
    )

    m = compute_metrics(ref, pred)

    bleu1_scores.append(m["bleu1"])
    bleu2_scores.append(m["bleu2"])
    bleu3_scores.append(m["bleu3"])
    bleu4_scores.append(m["bleu4"])

    meteor_scores.append(m["meteor"])

    rouge1_scores.append(m["rouge1"])
    rougeL_scores.append(m["rougeL"])




100%|██████████| 4106/4106 [48:34<00:00,  1.41it/s] 


In [None]:
avg_bleu = np.mean(bleu_scores)
avg_chrf = np.mean(chrf_scores)
avg_rouge = np.mean(rouge_scores)

print("\n===== AVERAGE METRICS =====")
print(f"BLEU:    {avg_bleu:.4f}")
print(f"chrF:    {avg_chrf:.4f}")
print(f"ROUGE-L: {avg_rouge:.4f}")


plt.figure(figsize=(14,4))

plt.subplot(1,3,1)
plt.hist(bleu_scores, bins=30, color='skyblue')
plt.title("BLEU Score Distribution")

plt.subplot(1,3,2)
plt.hist(chrf_scores, bins=30, color='lightgreen')
plt.title("chrF Score Distribution")

plt.subplot(1,3,3)
plt.hist(rouge_scores, bins=30, color='salmon')
plt.title("ROUGE-L Score Distribution")

plt.tight_layout()
plt.show()