In [None]:
import torch
import pandas as pd

from datasets import Dataset

from sklearn.metrics import (
    accuracy_score, 
    precision_recall_fscore_support)

from transformers import (
    DistilBertTokenizerFast,        
    DistilBertForSequenceClassification,  
    Trainer,                     
    TrainingArguments,
)


In [None]:
main_path= "" #your path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [ ]:
df_test = pd.read_csv(f'{main_path}/data/cleaned/test.csv')
val_dataset = Dataset.from_pandas(df_test)

In [ ]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

def tokenize_function(examples: dict)-> dict:
    return tokenizer(
        examples['post'], 
        truncation=True, 
        padding='max_length', 
        max_length=128,
        #clean_up_tokenization_spaces=True
    )

val_dataset = val_dataset.map(tokenize_function, batched=True)

In [ ]:
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

In [ ]:
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)
model.to(device)

for param in model.parameters():
    param.requires_grad = False

In [ ]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [ ]:
training_args = TrainingArguments(
    output_dir=f'{main_path}/models/baseline',
    per_device_eval_batch_size=64,
    logging_dir=f'{main_path}/logs/baseline',
    logging_steps=400,
    eval_strategy='epoch',
    save_strategy='epoch',
    fp16=False,
    seed=42,
)

trainer = Trainer(
    model=model,
    args=training_args,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

In [ ]:
%%time
metrics = trainer.evaluate()

In [ ]:
for key, value in metrics.items():
    print(f"{key}: {value}")