In [None]:
import editdistance
import pandas as pd

from llms.protein_translator.gpt import DnaTranslatorGPT
from schemas.train_params import TrainParams
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [None]:
seed = 42

In [None]:
df = pd.read_csv("data/DNATranslator-IntronsExons.csv", keep_default_na=False)

In [None]:
llm = DnaTranslatorGPT(
  checkpoint="gpt2",
  seed=seed
)

In [None]:
all_dataset = []
for data in tqdm(df.itertuples()):
	if len(data.sequence) > 715:
		continue
	example = llm.build_input(
		sequence=data.sequence,
		target=data.target,
		organism=data.organism
	)
	all_dataset.append(example)

train_dataset, test_dataset = train_test_split(
    all_dataset,
    test_size=0.05,
    random_state=seed,
    shuffle=True
)


In [None]:
llm.train(
  dataset=train_dataset,
  params=TrainParams(
    epochs=2,
    batch_size=1,
    gradient_accumulation=4,
    lr=2e-5
	)
)

In [None]:
llm.save_pretrained("GPT2-DNATranslator")

In [None]:
results = []

for data in tqdm(test_dataset[:30]):
	pred = llm.generate(data)
	target = data["target"]

	dist = editdistance.eval(pred, target)
	similarity = 1 - dist / max(len(pred), len(target))

	results.append({
		"target": target,
		"pred": pred,
		"edit_dist": dist,
		"similarity": similarity
	})

In [None]:
df = pd.DataFrame(results)
df.to_csv("DNATranslatorResults.csv", index=False)