In [None]:
import pandas as pd

from llms.exin_classifier.gpt import ExInClassifierGPT
from llms.exin_classifier.bert import ExInClassifierBERT
from llms.exin_classifier.dnabert import ExInClassifierDNABERT
from schemas.train_params import TrainParams
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [None]:
seed = 42

In [None]:
df = pd.read_csv("ExIn-GPT.csv", keep_default_na=False)

In [None]:
df = pd.read_csv("data/ExIn-BERTs.csv", keep_default_na=False)

In [None]:
output = "pretrained_model"

In [None]:
llm = ExInClassifierGPT(
  checkpoint="gpt2",
  seed=seed
)

In [None]:
llm = ExInClassifierBERT(
  checkpoint="bert-base-uncased",
  seed=seed
)

In [None]:
llm = ExInClassifierDNABERT(
  seed=seed
).load_checkpoint()

In [None]:
data = df.to_dict(orient="records")

In [None]:
all_dataset = []
for record in tqdm(data):
	example = llm.build_input(
		sequence=record["sequence"],
		target=record.get("target"),
		#organism=record.get("organism"),
		#gene=record.get("gene"),
		#before=record.get("flankBefore"),
		#after=record.get("flankAfter")
	)
	all_dataset.append(example)

train_dataset, test_dataset = train_test_split(
    all_dataset,
    test_size=0.05,
    random_state=seed,
    shuffle=True
)

In [None]:
llm.train(
  dataset=train_dataset,
  params=TrainParams(
    epochs=1,
    batch_size=1,
    gradient_accumulation=1,
    lr=2e-5
	)
)

In [None]:
llm.save_pretrained(output)

In [None]:
y_true = []
y_pred = []

for data in tqdm(test_dataset):
	answer = llm.generate(data)
	y_pred.append(answer)
	y_true.append(data["target"])


In [None]:
print("INTRON class:")
print("  Precision :", precision_score(y_true, y_pred, pos_label="INTRON"))
print("  Recall    :", recall_score(y_true, y_pred, pos_label="INTRON"))
print("  F1        :", f1_score(y_true, y_pred, pos_label="INTRON"))
print()
print("EXON class:")
print("  Precision :", precision_score(y_true, y_pred, pos_label="EXON"))
print("  Recall    :", recall_score(y_true, y_pred, pos_label="EXON"))
print("  F1        :", f1_score(y_true, y_pred, pos_label="EXON"))
print()
print("  Accuracy  :", accuracy_score(y_true, y_pred))