In [None]:
%cd ..

In [None]:
from sklearn.metrics import cohen_kappa_score
import numpy as np
import pandas as pd
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from tqdm.auto import tqdm
from sklearn.metrics import recall_score, precision_score, f1_score
import nltk
from src.dr_hatespeech.attack import load_attack
nltk.download('punkt')
pd.set_option('max_colwidth', None)
pd.set_option('max_rows', 200)

## Load data and models

In [None]:
val_df = pd.read_parquet("data/final/val-off.parquet")
val_df.head()

In [None]:
danlp_tok = AutoTokenizer.from_pretrained('DaNLP/da-electra-hatespeech-detection')
danlp_model = AutoModelForSequenceClassification.from_pretrained('DaNLP/da-electra-hatespeech-detection')

In [None]:
attack_tok, attack_model = load_attack()

## Run models on dataset

In [None]:
def get_logits(text: str, tok, model) -> torch.Tensor:
    if tok.model_max_length > 100_000:
        tok.model_max_length = 512
    toks = tok(text, return_tensors='pt', truncation=True, padding=True)
    with torch.no_grad():
        logits = model(input_ids=toks["input_ids"], attention_mask=toks["attention_mask"])[0]
    if len(logits.shape) == 2:
        logits = logits[0]
    return logits[-1]

In [None]:
danlp_preds = torch.stack([
    get_logits(doc, danlp_tok, danlp_model) for doc in tqdm(val_df.text, leave=False)
]) > 0
val_df["danlp_preds"] = danlp_preds
val_df.head()

In [None]:
attack_preds = torch.stack([
    get_logits(doc, attack_tok, attack_model) for doc in tqdm(val_df.text, leave=False)
]) > 0
val_df["attack_preds"] = attack_preds
val_df.head()

## Compare models

In [None]:
disagreement_df = val_df.copy()
disagreement_df.label = disagreement_df.label.map({"Not offensive": False, "Offensive": True})
disagreement_df = disagreement_df.query("danlp_preds != attack_preds")
disagreement_df.head()

In [None]:
disagreement_df.query("danlp_preds != label").label.value_counts()

In [None]:
disagreement_df.query("danlp_preds != label")

In [None]:
disagreement_df.query("attack_preds != label").label.value_counts()

In [None]:
disagreement_df.query("attack_preds != label")