## Libraries and Imports

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [None]:
import random

import pandas as pd
from sklearn.metrics import (accuracy_score, f1_score, precision_score,
                             recall_score)
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from llms.exin_classifier.dnabert import ExInClassifierDNABERT
from schemas.train_params import TrainParams

## Params and Files

In [None]:
seed = 42

csv_path = "exin-256.csv"
pretrained_model_name = "ExInDNABERT2Model"

In [None]:
csv_path = f"./storage/data/processed/{csv_path}"
output_path = f"./storage/models/tuned/{pretrained_model_name}"
checkpoint = f"./storage/models/base/dnabert2"

## Reading Dataset

In [None]:
df = pd.read_csv(csv_path, keep_default_na=False)

## Loading the Model

In [None]:
llm = ExInClassifierDNABERT(
 checkpoint=checkpoint, 
 seed=seed
)

## Data Processing

In [None]:
data = df.to_dict(orient="records")

In [None]:
all_dataset = []
for record in tqdm(data):
	example = llm.build_input(
		sequence=record["sequence"],
		target=record.get("target")
	)
	all_dataset.append(example)

train_dataset, test_dataset = train_test_split(
	all_dataset,
	test_size=0.05,
	random_state=seed,
	shuffle=True
)

exons = [ex for ex in train_dataset if ex.get("target") == "EXON"]
introns = [ex for ex in train_dataset if ex.get("target") == "INTRON"]

min_len = min(len(exons), len(introns))

random.seed(seed)
exons_sample = random.sample(exons, min_len)
introns_sample = random.sample(introns, min_len)

train_dataset = exons_sample + introns_sample
random.shuffle(train_dataset)

In [None]:
print("Train Dataset Len:", len(train_dataset))
print("Test Dataset Len:", len(test_dataset))

In [None]:
llm.train(
  dataset=train_dataset,
  params=TrainParams(
    epochs=1,
    batch_size=1,
    gradient_accumulation=1,
    lr=2e-5,
    logging_steps=25000
	)
)

In [None]:
llm.save_pretrained(output_path)

In [None]:
y_true = []
y_pred = []

for data in tqdm(test_dataset):
	answer = llm.generate(data)
	y_pred.append(answer)
	y_true.append(data["target"])

In [None]:
print("INTRON class:")
print("  Precision :", precision_score(y_true, y_pred, pos_label="INTRON"))
print("  Recall    :", recall_score(y_true, y_pred, pos_label="INTRON"))
print("  F1        :", f1_score(y_true, y_pred, pos_label="INTRON"))
print()
print("EXON class:")
print("  Precision :", precision_score(y_true, y_pred, pos_label="EXON"))
print("  Recall    :", recall_score(y_true, y_pred, pos_label="EXON"))
print("  F1        :", f1_score(y_true, y_pred, pos_label="EXON"))
print()
print("  Accuracy  :", accuracy_score(y_true, y_pred))