In [None]:
import torch
print(torch.backends.mps.is_available())
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [None]:
import sqlite3
import pandas as pd

conn  = sqlite3.connect('../../../giicg.db')
all_prompts = pd.read_sql("Select * from expanded_roberta_prompts", conn)
conn.close()

In [None]:
import json
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

with open("finetune/label2id.json", "r") as f:
    label2id = json.load(f)

model_name = "Mayaryin/gender-prompt-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
num_labels = len(label2id)

def model_init():
    return AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=num_labels
    )

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def tokenize_function(examples):
    return tokenizer(
        examples["conversational"],
        truncation=True,
        padding=False # padding is handled in the data collator
    )

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='weighted')
    precision = precision_score(labels, preds, average='weighted')
    recall = recall_score(labels, preds, average='weighted')
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }