In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"

# Libraries and Imports

In [2]:
import editdistance
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from llms.dna_translator.gpt import DNATranslatorGPT
from schemas.train_params import TrainParams

# Params and Files

In [3]:
seed = 12

csv_path = "tran-1000.csv"
pretrained_model_name = "DNATranGPT"

In [4]:
csv_path = f"./storage/data/processed/{csv_path}"
output_path = f"./storage/models/tuned/{pretrained_model_name}"
checkpoint = "./storage/models/base/gpt2"

# Reading Dataset

In [5]:
df = pd.read_csv(csv_path, keep_default_na=False)

# Loading Model

In [6]:
llm = DNATranslatorGPT(
  checkpoint=checkpoint,
  seed=seed
)

# Data Processing

In [7]:
data = df.to_dict(orient="records")

In [8]:
all_dataset = []
for record in tqdm(data):
  example = llm.build_input(
    dna_sequence=record["sequence"],
    organism=record["organism"],
    protein_sequence=record["target"]
	)
  all_dataset.append(example)

train_dataset, test_dataset = train_test_split(
  all_dataset,
  test_size=0.05,
  random_state=seed,
  shuffle=True
)

100%|██████████| 244090/244090 [00:00<00:00, 1914913.92it/s]


# Data Analysis

In [9]:
print("Train Dataset Len:", len(train_dataset))
print("Test Dataset Len:", len(test_dataset))

Train Dataset Len: 231885
Test Dataset Len: 12205


In [None]:
train_lengths = [len(example["dna_sequence"]) for example in train_dataset]
test_lengths = [len(example["dna_sequence"]) for example in test_dataset]

KeyError: 'dna'

In [None]:
print("Train Dataset Len:", len(train_dataset))
print("Test Dataset Len:", len(test_dataset))

sns.set_theme(style="whitegrid", palette="muted", font_scale=1.2)

plt.figure(figsize=(10, 6))
sns.histplot(train_lengths, kde=True, bins=40, color="skyblue", label="Train")
sns.histplot(test_lengths, kde=True, bins=40, color="salmon", label="Test")

plt.title("Sequence Length Distribution", fontsize=16, weight="bold")
plt.xlabel("Sequence Length")
plt.ylabel("Frequency")
plt.legend()
plt.tight_layout()
plt.show()

# Training

In [None]:
llm.train(
  dataset=train_dataset,
  params=TrainParams(
    epochs=3,
    batch_size=16,
    gradient_accumulation=1,
    lr=3e-6,
    logging_steps=1000
	)
)

# Saving The Model

In [None]:
llm.save_pretrained(output_path)

# Test Without Blast

In [None]:
results = []

for data in tqdm(test_dataset):
	pred = llm.generate(data)
	target = data["protein_sequence"]

	dist = editdistance.eval(pred, target)
	similarity = 1 - dist / max(len(pred), len(target))

	#blast_results = blast_analysis(
	#	pred=pred,
	#	target=target
	#)

	results.append({
		"target": target,
		"pred": pred,
		"edit_dist": dist,
		"similarity": similarity,
		#**blast_results
	})

In [None]:
similarities = [r["similarity"] for r in results]
mean_similarity = np.mean(similarities)
std_similarity = np.std(similarities)

print(f"Mean similarity: {mean_similarity:.4f} ± {std_similarity:.4f}")

df = pd.DataFrame(results)
df["length"] = df["target"].apply(len)

print(df.groupby(pd.cut(df["length"], bins=[0,50,100,200,400]), observed=False)["similarity"].mean())