In [None]:
import json
import math
import os
import random
from collections import defaultdict
from itertools import combinations
from typing import Any

import editdistance
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from datasets import Dataset, load_dataset
from llms.dna_translator.gpt import DNATranslatorGPT
from pandas import DataFrame
from schemas.train_params import TrainParams
from tqdm import tqdm

# Configurations & Limits

In [None]:
# Seed For Reproducibility
SEED = 42

# Output Folder For Metrics And Checkpoints
CHECKPOINT_NAME = "GeneFormer"

# Training Configuration
MAX_EPOCHS = 30
PATIENCE = 3
MIN_DELTA = 1e-4

# Model Context Limit 
DATA_MAX_LENGTH = 1000

# Trivial Thresholds
TRIVIAL_THRESHOLD = 0.8
AMBIGUOUS_THRESHOLD = 0.65

# Similarity Thresholds
MAX_PAIRS_PER_SPECIES = 500
MAX_INTERSPECIES_PAIRS = 50000

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

random.seed(SEED)
np.random.seed(SEED)


# Dataset

In [None]:
dataset = load_dataset("GustavoHCruz/DNA_Coding_Regions", split="train")

assert isinstance(dataset, Dataset)

dataset = dataset.to_pandas()

In [None]:
def join_proteins(protein_list) -> str:
	seqs = []
	for p in protein_list:
		seq = p.get("sequence", "")
		if isinstance(seq, str):
			seqs.append(seq)
	return "".join(seqs)

In [None]:
assert isinstance(dataset, DataFrame)

dataset["target"] = dataset["proteins"].apply(join_proteins).astype(str)
mask = (dataset["sequence"].str.len() + dataset["target"].str.len()) < DATA_MAX_LENGTH
original_df = dataset[mask]
original_df = original_df[["sequence", "target", "organism"]]

original_df

In [None]:
print("Dataset Length:", len(original_df))

# Removing Invalid

In [None]:
original_df = original_df[original_df["sequence"].str.len().ge(2)]

In [None]:
print("Dataset Length:", len(original_df))

# How Many Trivial?

In [None]:
CODON_TABLE = {
	'TTT':'F','TTC':'F','TTA':'L','TTG':'L',
	'CTT':'L','CTC':'L','CTA':'L','CTG':'L',
	'ATT':'I','ATC':'I','ATA':'I','ATG':'M',
	'GTT':'V','GTC':'V','GTA':'V','GTG':'V',
	'TCT':'S','TCC':'S','TCA':'S','TCG':'S',
	'CCT':'P','CCC':'P','CCA':'P','CCG':'P',
	'ACT':'T','ACC':'T','ACA':'T','ACG':'T',
	'GCT':'A','GCC':'A','GCA':'A','GCG':'A',
	'TAT':'Y','TAC':'Y','TAA':'*','TAG':'*',
	'CAT':'H','CAC':'H','CAA':'Q','CAG':'Q',
	'AAT':'N','AAC':'N','AAA':'K','AAG':'K',
	'GAT':'D','GAC':'D','GAA':'E','GAG':'E',
	'TGT':'C','TGC':'C','TGA':'*','TGG':'W',
	'CGT':'R','CGC':'R','CGA':'R','CGG':'R',
	'AGT':'S','AGC':'S','AGA':'R','AGG':'R',
	'GGT':'G','GGC':'G','GGA':'G','GGG':'G'
}

def translate_frame(seq, frame):
	protein = []
	for i in range(frame, len(seq) - 2, 3):
		codon = seq[i:i+3]
		protein.append(CODON_TABLE.get(codon, 'X'))
	return ''.join(protein)

def protein_segments(protein):
	return [seg for seg in protein.split('*') if len(seg) > 0]

def similarity(a, b):
	if len(a) == 0 or len(b) == 0:
		return 0.0
	dist = editdistance.eval(a, b)
	return 1.0 - dist / max(len(a), len(b))

In [None]:
max_similarities = []

for _, row in tqdm(original_df.iterrows(), total=len(original_df)):
	dna = row["sequence"].upper()
	target = row["target"]

	best = 0.0

	for frame in (0, 1, 2):
		protein = translate_frame(dna, frame)
		segments = protein_segments(protein)

		for seg in segments:
			sim = similarity(seg, target)
			if sim > best:
				best = sim

	max_similarities.append(best)

original_df["max_frame_similarity"] = max_similarities

In [None]:
def classify(sim):
	if sim >= TRIVIAL_THRESHOLD:
		return "trivial"
	elif sim >= AMBIGUOUS_THRESHOLD:
		return "ambiguous"
	else:
		return "genomic"

original_df["translation_class"] = original_df["max_frame_similarity"].apply(classify)

In [None]:
print("Distribution:")
print(original_df["translation_class"].value_counts())

In [None]:
print("Percentils:")
print(original_df["translation_class"].value_counts(normalize=True) * 100)

In [None]:
plt.figure()
plt.hist(original_df["max_frame_similarity"], bins=50)
plt.axvline(TRIVIAL_THRESHOLD, color="yellow")
plt.axvline(AMBIGUOUS_THRESHOLD, color="red")
plt.xlabel("Max frame similarity")
plt.ylabel("Count")
plt.title("Frame-based translation similarity")
plt.show()

In [None]:
genomic = original_df[original_df["translation_class"] == "genomic"]
ambiguous = original_df[original_df["translation_class"] == "ambiguous"]
trivial = original_df[original_df["translation_class"] == "trivial"]

In [None]:
print(f"Genomics: {len(genomic)} ({len(genomic)/len(original_df)*100:.2f}%)")
print(f"Other (trivial + ambiguous): {len(ambiguous) + len(trivial)} ({(len(ambiguous) + len(trivial))/len(original_df)*100:.2f}%)")

# Data Structure & Analysis

In [None]:
stats = genomic.copy()

stats["organism_norm"] = (
	stats["organism"]
	.astype(str)
	.str.strip()
	.str.lower()
)

print(f"Total of Sequences: {len(stats)}")
print(f"Total of Species: {stats['organism_norm'].nunique()}")

In [None]:
species_stats = (
	stats
	.groupby("organism_norm")
	.agg(
		num_sequences=("sequence", "count"),
		avg_seq_length=("sequence", lambda x: x.str.len().mean())
	)
	.reset_index()
)

In [None]:
print("Number of sequences per species:")
print(species_stats["num_sequences"].describe())

In [None]:
print("Average sequence size by species:")
print(species_stats["avg_seq_length"].describe())

In [None]:
print("Species with MORE sequences:")
print(species_stats.loc[species_stats["num_sequences"].idxmax()])

# Similarity Check

In [None]:
def dna_similarity(a, b):
	if len(a) == 0 or len(b) == 0:
		return 0.0
	dist = editdistance.eval(a, b)
	return 1.0 - dist / max(len(a), len(b))

In [None]:
intra_results = []

for organism, group in tqdm(genomic.groupby("organism"), desc="Intra-species"):
	sequences = group["sequence"].tolist()

	if len(sequences) < 2:
		continue

	pairs = list(combinations(sequences, 2))
	random.shuffle(pairs)
	pairs = pairs[:MAX_PAIRS_PER_SPECIES]

	for a, b in pairs:
		sim = dna_similarity(a, b)
		intra_results.append({
			"organism": organism,
			"similarity": sim
		})

In [None]:
intra_df = pd.DataFrame(intra_results)

print("INTRA-SPECIES")
print(intra_df["similarity"].describe())

In [None]:
organisms = genomic["organism"].unique().tolist()
inter_results = []

pairs_done = 0

for org_a, org_b in tqdm(list(combinations(organisms, 2)), desc="Inter-species"):
	if pairs_done >= MAX_INTERSPECIES_PAIRS:
		break

	seqs_a = genomic[genomic["organism"] == org_a]["sequence"].tolist()
	seqs_b = genomic[genomic["organism"] == org_b]["sequence"].tolist()

	if not seqs_a or not seqs_b:
		continue

	a = random.choice(seqs_a)
	b = random.choice(seqs_b)

	sim = dna_similarity(a, b)
	inter_results.append({
		"organism_a": org_a,
		"organism_b": org_b,
		"similarity": sim
	})

	pairs_done += 1

In [None]:
inter_df = pd.DataFrame(inter_results)

print("INTER-SPECIES")
print(inter_df["similarity"].describe())

In [None]:
plt.figure()
plt.hist(
	intra_df["similarity"],
	bins=50,
	alpha=0.6,
	label="Intra-species",
	density=True
)
plt.hist(
	inter_df["similarity"],
	bins=50,
	alpha=0.6,
	label="Inter-species",
	density=True
)

plt.xlabel("Genomic similarity")
plt.ylabel("Density")
plt.legend()
plt.title("Intra vs Inter species similarity distribution")
plt.show()


# Splitting Things

In [None]:
def split_by_organism(
	df,
	train_ratio=0.85,
	test_ratio=0.10,
	val_ratio=0.05
):
	assert abs(train_ratio + test_ratio + val_ratio - 1.0) < 1e-6

	organisms = df["organism"].unique().tolist()
	random.shuffle(organisms)

	n_orgs = len(organisms)

	n_train = int(n_orgs * train_ratio)
	n_test = int(n_orgs * test_ratio)

	train_orgs = set(organisms[:n_train])
	test_orgs = set(organisms[n_train:n_train + n_test])
	val_orgs = set(organisms[n_train + n_test:])

	train_df = df[df["organism"].isin(train_orgs)].reset_index(drop=True)
	test_df = df[df["organism"].isin(test_orgs)].reset_index(drop=True)
	val_df = df[df["organism"].isin(val_orgs)].reset_index(drop=True)

	print("=" * 60)
	print(f"Total organisms: {n_orgs}")
	print()
	print("Organism split:")
	print(f"  Train: {len(train_orgs)} organisms")
	print(f"  Test : {len(test_orgs)} organisms")
	print(f"  Val  : {len(val_orgs)} organisms")
	print()
	print("Sequence split:")
	print(f"  Train: {len(train_df)} sequences")
	print(f"  Test : {len(test_df)} sequences")
	print(f"  Val  : {len(val_df)} sequences")
	print()
	print("Approximate ratios (by sequence):")
	total_seq = len(df)
	print(f"  Train: {len(train_df)/total_seq:.3f}")
	print(f"  Test : {len(test_df)/total_seq:.3f}")
	print(f"  Val  : {len(val_df)/total_seq:.3f}")
	print("=" * 60)

	return train_df, test_df, val_df

In [None]:
train_df, test_df, val_df = split_by_organism(
	genomic
)

# Training To Generate

In [None]:
llm = DNATranslatorGPT(
  checkpoint="./storage/models/base/gpt2",
  seed=SEED
)

In [None]:
train_data = train_df.to_dict(orient="records")
test_data = test_df.to_dict(orient="records")
val_data = val_df.to_dict(orient="records")

In [None]:
train_dataset = []
for record in tqdm(train_data, desc="Train Data"):
	example = llm.build_input(
		dna_sequence=record["sequence"],
		organism=record["organism"],
		protein_sequence=record["target"]
	)
	train_dataset.append(example)

test_dataset = []
for record in tqdm(test_data, desc="Test Data"):
	example = llm.build_input(
		dna_sequence=record["sequence"],
		organism=record["organism"],
		protein_sequence=record["target"]
	)
	test_dataset.append(example)

eval_dataset = []
for record in tqdm(val_data, desc="Eval Data"):
	example = llm.build_input(
		dna_sequence=record["sequence"],
		organism=record["organism"],
		protein_sequence=record["target"]
	)
	eval_dataset.append(example)

In [None]:
train_lengths = [len(example["dna_sequence"]) for example in train_dataset]
eval_lengths = [len(example["dna_sequence"]) for example in eval_dataset]
test_lengths = [len(example["dna_sequence"]) for example in test_dataset]

In [None]:
sns.set_theme(style="whitegrid", palette="muted", font_scale=1.2)

plt.figure(figsize=(10, 6))
sns.histplot(train_lengths, kde=True, bins=40, color="skyblue", label="Train")
sns.histplot(eval_lengths, kde=True, bins=40, color="green", label="Eval")
sns.histplot(test_lengths, kde=True, bins=40, color="salmon", label="Test")

plt.title("Sequence Length Distribution", fontsize=16, weight="bold")
plt.xlabel("Sequence Length")
plt.ylabel("Frequency")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
def safe_similarity(
	pred: str,
	target: str
) -> tuple[float, float]:
	den = max(len(pred), len(target))
	if den <= 0:
		return 1.0, 1.0
	dist = editdistance.eval(pred, target)
	sim = 1.0 - (dist / den)

	if sim < 0.0:
		return 0.0, 0.0
	if sim > 1.0:
		return 1.0, 1.0
	return dist, sim

In [None]:
def evaluate_on_test(
	llm: DNATranslatorGPT,
	test_dataset: list,
	show_tqdm: bool = True
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
	results = []
	sims = []
	it = tqdm(test_dataset, desc="Eval (test)", leave=False) if show_tqdm else test_dataset

	for data in it:
		pred = llm.generate(data)
		target = data["protein_sequence"]
		organism = data["organism"]

		exact_match = int(pred == target)

		dist, sim = safe_similarity(pred, target)
		results.append({
			"target": target,
			"prediction": pred,
			"edit_distance": dist,
			"similarity": sim,
			"organism": organism,
			"exact_match": exact_match
		})
		sims.append(sim)

	mean = float(np.mean(sims)) if results else 0.0
	std = float(np.std(sims)) if results else 0.0

	return results, {
		"mean": mean,
		"std": std,
	}

In [None]:
def save_metrics_json(
	path: str,
	payload: dict[str, Any]
) -> None:
	os.makedirs(path, exist_ok=True)
	with open(os.path.join(path, "metrics.json"), "w", encoding="utf-8") as f:
		json.dump(payload, f, ensure_ascii=False, indent=2)

In [None]:
# Preparing Datasets

train_dataset = llm.prepare_dataset(train_dataset)
eval_dataset = llm.prepare_dataset(eval_dataset)

In [None]:
best_mean = -math.inf
best_epoch = -1
epochs_no_improve = 0

history: list[dict[str, Any]] = []

save_last_dir = f"output/{CHECKPOINT_NAME}_last"
save_best_dir = f"output/{CHECKPOINT_NAME}_best"

pbar = tqdm(total=MAX_EPOCHS, desc="Global epochs", dynamic_ncols=True)

best_results = []
for epoch_idx in range(MAX_EPOCHS):
	global_epoch = epoch_idx + 1

	print(f"\n[Train] Global epoch {global_epoch}/{MAX_EPOCHS} starting...")

	llm.train(
		train_dataset=train_dataset,
		params=TrainParams(
			epochs=1,
			batch_size=4,
			gradient_accumulation=4,
			lr=3e-5
		),
		eval_dataset=eval_dataset
	)

	print(f"[Eval] Global epoch {global_epoch}: evaluating on test set...")
	results, metrics = evaluate_on_test(llm, test_dataset, show_tqdm=True)

	mean_similarity = metrics["mean"]
	std_similarity = metrics["std"]
	print(f"[Eval] Global epoch {global_epoch}: mean similarity = {mean_similarity:.4f} Â± {std_similarity:.4f}")

	print(f"[Save] Saving LAST -> {save_last_dir}")
	llm.save_pretrained(save_last_dir)
	pd.DataFrame(results).to_csv(f"output/results_last.csv")

	row = {
		"global_epoch": global_epoch,
		"mean_similarity": mean_similarity,
		"std_similarity": std_similarity,
		"best_mean_so_far": best_mean,
		"epochs_no_improve": epochs_no_improve,
	}
	history.append(row)
	save_metrics_json(save_last_dir, {
		"last_epoch": global_epoch,
		"metrics": metrics,
		"history_tail": history[-10:],
	})

	if mean_similarity > best_mean + MIN_DELTA:
		best_mean = mean_similarity
		best_epoch = global_epoch
		epochs_no_improve = 0

		print(f"[Best] New BEST at epoch {global_epoch}: {best_mean:.4f}. Saving BEST -> {save_best_dir}")
		llm.save_pretrained(save_best_dir)
		pd.DataFrame(results).to_csv("output/results_best.csv")
		best_results = results

		save_metrics_json(save_best_dir, {
			"best_epoch": best_epoch,
			"best_mean": best_mean,
			"metrics": metrics,
		})

	else:
		epochs_no_improve += 1
		print(f"[Best] No improvement (best={best_mean:.4f} @ epoch {best_epoch}). Patience: {epochs_no_improve}/{PATIENCE}")
	
	pbar.set_postfix({
		"mean": f"{mean_similarity:.4f}",
		"best": f"{best_mean:.4f}",
		"no_impr": f"{epochs_no_improve}/{PATIENCE}",
	})
	pbar.update(1)

	if epochs_no_improve >= PATIENCE:
		print(f"[Stop] Early stopping triggered: {PATIENCE} epochs without improvement.")
		break

pbar.close()

print("\n[Done]")
print(f"- BEST: epoch {best_epoch}, mean similarity {best_mean:.4f} -> {save_best_dir}")
print(f"- LAST: epoch {history[-1]['global_epoch'] if history else 0} -> {save_last_dir}")

# Evaluation

In [None]:
def compute_metrics_from_similarities(
	similarities: list[float],
	exact_matches: list[int]
) -> dict[str, float]:
	sims = np.array(similarities, dtype=float)
	exacts = np.array(exact_matches, dtype=float)

	return {
		"n": int(len(sims)),
		"mean": float(np.mean(sims)) if len(sims) else 0.0,
		"std": float(np.std(sims)) if len(sims) else 0.0,
		"median": float(np.median(sims)) if len(sims) else 0.0,
		"p75": float(np.percentile(sims, 75)) if len(sims) else 0.0,
		"p90": float(np.percentile(sims, 90)) if len(sims) else 0.0,
		"exact_match_rate": float(np.mean(exacts)) if len(exacts) else 0.0,
	}

all_metrics = compute_metrics_from_similarities(
	[r["similarity"] for r in best_results],
	[r["exact_match"] for r in best_results]
)

best_by_target: dict[str, dict] = {}

for r in best_results:
	t = r["target"]
	if t not in best_by_target or r["similarity"] > best_by_target[t]["similarity"]:
		best_by_target[t] = r

unique_target_results = list(best_by_target.values())

unique_target_metrics = compute_metrics_from_similarities(
	[r["similarity"] for r in unique_target_results],
	[r["exact_match"] for r in unique_target_results]
)

print(f"Eval Dataset Length:", len(eval_dataset))

print("All (sequence-weighted):", all_metrics)
print("Unique target:", unique_target_metrics)

target_counts = defaultdict(int)
for r in best_results:
	target_counts[r["target"]] += 1

counts = np.array(list(target_counts.values()), dtype=int)
print("\nTarget repetition diagnostics:")
print("unique targets:", len(target_counts))
print("total sequences:", len(best_results))
print("median occurrences per target:", int(np.median(counts)))
print("p90 occurrences per target:", int(np.percentile(counts, 90)))
print("max occurrences for a target:", int(np.max(counts)))


In [None]:
org_counts = defaultdict(int)
for r in best_results:
	org_counts[r["organism"]] += 1

vals = np.array(list(org_counts.values()), dtype=int)

print("Organisms in test:", len(vals))
print("Seqs per organism - median:", int(np.median(vals)))
print("Seqs per organism - p75:", int(np.percentile(vals, 75)))
print("Seqs per organism - p90:", int(np.percentile(vals, 90)))
print("Seqs per organism - max:", int(np.max(vals)))

bins = {
	"1": int(np.sum(vals == 1)),
	"2-5": int(np.sum((vals >= 2) & (vals <= 5))),
	"6-20": int(np.sum((vals >= 6) & (vals <= 20))),
	">20": int(np.sum(vals > 20)),
}

print("\nBins (#seqs per organism):")
for k, v in bins.items():
	print(f"{k}: {v}")