## Libraries and Imports

In [None]:
import pandas as pd

from llms.nucl_classifier.bert import NuclBERT
from schemas.train_params import TrainParams
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.model_selection import train_test_split
from tqdm import tqdm

## Params and Files

In [None]:
seed = 42

csv_path = "nucl-500.csv"
pretrained_model_name = "NuclBERTModel"

In [None]:
csv_path = f"./storage/data/processed/{csv_path}"
output_path = f"./storage/models/tuned/{pretrained_model_name}"
checkpoint = "storage/models/base/bert"

## Reading Dataset

In [None]:
df = pd.read_csv(csv_path, keep_default_na=False)

## Loading the Model

In [None]:
llm = NuclBERT(
  checkpoint=checkpoint,
  seed=seed
)

## Data Processing

In [None]:
data = df.to_dict(orient="records")

In [None]:
all_dataset = []
for record in tqdm(data):
  example = llm.build_input(
    sequence=record["sequence"],
    target=record["target"],
    organism=record.get("organism")
	)
  all_dataset.append(example)

train_dataset, test_dataset = train_test_split(
  all_dataset,
  test_size=0.05,
  random_state=seed,
  shuffle=True
)

In [None]:
llm.train(
  dataset=train_dataset,
  params=TrainParams(
    epochs=1,
    batch_size=1,
    gradient_accumulation=1,
    lr=2e-5
	)
)

In [None]:
llm.save_pretrained(output_path)

In [None]:
refs = []
preds = []

for data in tqdm(test_dataset):
  answer = llm.generate(data)
  preds.append(answer)
  refs.append(data["target"])

In [None]:
all_refs = []
all_preds = []
for ref, pred in zip(refs, preds):
  min_len = min(len(ref), len(pred))
  all_refs.extend(list(ref[:min_len]))
  all_preds.extend(list(pred[:min_len]))

acc = accuracy_score(all_refs, all_preds)

print("Accuracy: ", acc)

labels = ["I", "E", "U"]
precision, recall, f1, support = precision_recall_fscore_support(
	all_refs, all_preds, labels=labels, average=None, zero_division=0
)

for i, label in enumerate(labels):
    print(f"Class: '{label}'")
    print(f"  - Precision: {precision[i]:.4f}")
    print(f"  - Recall:   {recall[i]:.4f}")
    print(f"  - F1-Score: {f1[i]:.4f}\n")