In [None]:
import torch
from transformers import AlbertTokenizer, AlbertForSequenceClassification, Trainer, TrainingArguments
import sys
import os
import evaluate
from torch.utils.data import random_split

sys.path.append(os.path.join(os.getcwd(),'dataset'))
from TwitterTextDataset import TwitterTextDataset


tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
dataset = TwitterTextDataset('data', tokenizer)  

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


model = AlbertForSequenceClassification.from_pretrained("albert-base-v2", num_labels=2)

def compute_metrics(eval_pred):
    accuracy = evaluate.load("accuracy")
    precision = evaluate.load("precision")
    recall = evaluate.load("recall")
    f1 = evaluate.load("f1")

    logits, labels = eval_pred
    predictions = torch.argmax(torch.tensor(logits), dim=-1)

    return {
        "accuracy": accuracy.compute(predictions=predictions, references=labels)["accuracy"],
        "precision": precision.compute(predictions=predictions, references=labels, average="binary")["precision"],
        "recall": recall.compute(predictions=predictions, references=labels, average="binary")["recall"],
        "f1": f1.compute(predictions=predictions, references=labels, average="binary")["f1"]
    }

training_args = TrainingArguments(
    output_dir="./results-abert-base-v2",
    eval_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

eval_results = trainer.evaluate()
print("Evaluation Results:", eval_results)

model.save_pretrained("albert_base_v2_model")
tokenizer.save_pretrained("albert_base_v2_model")

print("Training complete. Model saved.")
