In [None]:
import datasets 
from transformers import PreTrainedTokenizerFast
from transformers import Trainer, TrainingArguments
from transformers import RobertaForSequenceClassification
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import os 
import warnings
from collections import defaultdict

warnings.filterwarnings("ignore")

os.environ["WANDB_MODE"] = "offline"


MAX_LENGTH = 64

In [None]:
def handle_sample(sample):
    texts = sample['text']
    labels = sample['label']
    
    flattened = defaultdict(list)

    for text, label in zip(texts, labels):
        tokenized = tokenizer(
            text,
            padding='max_length',
            max_length=MAX_LENGTH,
            return_overflowing_tokens=True,
            truncation=True,
            return_special_tokens_mask=True,
        )

        for i in range(len(tokenized['input_ids'])):
            for k in tokenized:
                flattened[k].append(tokenized[k][i])
            flattened['label'].append(label)

    return dict(flattened)

tokenizer = PreTrainedTokenizerFast.from_pretrained("../MalBERTa")
dataset = datasets.load_from_disk("../data/raw")
processed_dataset = dataset.map(
    handle_sample,
    remove_columns=dataset['test'].column_names,
    batch_size=64,
    batched=True,
    num_proc=8,
)

processed_dataset

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    return {
        "accuracy": accuracy_score(labels, predictions),
        "precision": precision_score(labels, predictions, average="weighted", zero_division=0),
        "recall": recall_score(labels, predictions, average="weighted", zero_division=0),
        "f1": f1_score(labels, predictions, average="weighted", zero_division=0),
    }

model = RobertaForSequenceClassification.from_pretrained("./MalBERTa")

train_args = TrainingArguments(
    output_dir="./MalBERTa-classifier",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=256, 
    per_device_eval_batch_size=512, 
    save_strategy="no",
    eval_strategy="steps",
    logging_steps=100,
    eval_steps=(len(processed_dataset['train']) // 256) // 10,
    report_to="wandb",
)

trainer = Trainer(
    model=model,
    args=train_args, 
    processing_class=tokenizer,
    train_dataset=processed_dataset['train'],
    eval_dataset=processed_dataset['test'],
    compute_metrics=compute_metrics,
)

trainer.train()