## Libraries and Imports

In [None]:
import editdistance
import pandas as pd

from llms.dna_translator.gpt import DnaTranslatorGPT
from schemas.train_params import TrainParams
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from utils.blast_analysis import blast_analysis

## Params and Files

In [None]:
seed = 42

pretrain_csv_name = "prot-1000.csv"
finetune_csv_name = "dna-1000.csv"
pretrained_model_name = "DNAPTGPTModel"
finetuned_model_name = "DNAFTGPTModel"

In [None]:
pretrain_csv_path = f"./storage/data/processed/{pretrain_csv_name}"
finetune_csv_path = f"./storage/data/processed/{finetune_csv_name}"
pretrained_output_path = f"./storage/models/tuned/{pretrained_model_name}"
finetuned_output_path = f"./storage/models/tuned/{finetuned_model_name}"

## Reading Datasets

In [None]:
pretrain_df = pd.read_csv(pretrain_csv_path, keep_default_na=False)
finetune_df = pd.read_csv(finetune_csv_path, keep_default_na=False)

## Loading the Model

In [None]:
llm = DnaTranslatorGPT(
  checkpoint="gpt2",
  seed=seed
)

## Data Processing

In [None]:
pretrain_data = pretrain_df.to_dict(orient="records")
finetune_data = finetune_df.to_dict(orient="records")

In [None]:
pretrain_dataset = []
for data in tqdm(pretrain_data):
	example = llm.build_input_for_pretrain(
		protein_sequence=data["sequence"]
	)
	pretrain_dataset.append(example)

finetune_all_dataset = []
for data in tqdm(finetune_data):
	example = llm.build_input_for_finetune(
		dna_sequence=data["sequence"],
		protein_sequence=data["target"],
		organism=data.get("organism")
	)
	finetune_all_dataset.append(example)

finetune_train_dataset, finetune_test_dataset = train_test_split(
    finetune_all_dataset,
    test_size=0.05,
    random_state=seed,
    shuffle=True
)

In [None]:
llm.pretrain(
  dataset=pretrain_dataset,
  params=TrainParams(
    epochs=1,
    batch_size=1,
    gradient_accumulation=1,
    lr=2e-5
	)
)

In [None]:
llm.save_pretrained(pretrained_output_path)

In [None]:
llm.finetune(
  dataset=finetune_train_dataset,
  params=TrainParams(
    epochs=1,
    batch_size=1,
    gradient_accumulation=1,
    lr=2e-5
	)
)

In [None]:
llm.save_pretrained(finetuned_output_path)

In [None]:
results = []

for data in tqdm(finetune_test_dataset):
	pred = llm.generate(data)
	target = data["protein_sequence"]

	dist = editdistance.eval(pred, target)
	similarity = 1 - dist / max(len(pred), len(target))

	blast_results = blast_analysis(
		pred=pred,
		target=target
	)

	results.append({
		"target": target,
		"pred": pred,
		"edit_dist": dist,
		"similarity": similarity,
		**blast_results
	})

In [None]:
df = pd.DataFrame(results)