# Predictions on the HateCheck dataset with a SepHeads model fine-tuned on the MHS data

__Objective:__ load a SepHeads model fine-tuned on the MHS dataset and evaluate it on the HateCheck dataset.

**Notes:**
- For each sample, we want to produce a prediction with each annotator head, in order to simulate each annotator's behaviour via the model.

In [26]:
import os
import sys
import pandas as pd
import numpy as np
import pickle
from tqdm.notebook import tqdm
import datasets
import torch

sys.path.append('../../modules/')

from model_utils import get_deberta_model
from models import DebertaWithAnnotatorHeadsPretrained
from pytorch_utils import send_batch_to_device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
DATASET_DIR = '/data1/shared_datasets/'

ANNOTATORS_DATA_PATH = '/data1/moscato/personalised-hate-boundaries-data/data/measuring_hate_speech_data_clean/annotators_data_mhs.csv'

DEBERTA_MODEL_DIR = '/data1/shared_models/'
CHECKPOINT_PATH = '/data1/moscato/personalised-hate-boundaries-data/models/mhs/sepheads_model_training_mhs_test_2/checkpoint-1611/'

## Load data

Load HateCheck data.

Notes:
- The `case_id` field is NOT a progressive integer, but it identifies texts uniquely nevertheless.

In [3]:
hatecheck_data = datasets.load_dataset(
    path="Paul/hatecheck",
    cache_dir=DATASET_DIR
)['test'].to_pandas()

hatecheck_data

Unnamed: 0,functionality,case_id,test_case,label_gold,target_ident,direction,focus_words,focus_lemma,ref_case_id,ref_templ_id,templ_id,case_templ
0,derog_neg_emote_h,1,I hate women.,hateful,women,general,hate,hate,,,1,I hate [IDENTITY_P].
1,derog_neg_emote_h,2,I hate trans people.,hateful,trans people,general,hate,hate,,,1,I hate [IDENTITY_P].
2,derog_neg_emote_h,3,I hate gay people.,hateful,gay people,general,hate,hate,,,1,I hate [IDENTITY_P].
3,derog_neg_emote_h,4,I hate black people.,hateful,black people,general,hate,hate,,,1,I hate [IDENTITY_P].
4,derog_neg_emote_h,5,I hate disabled people.,hateful,disabled people,general,hate,hate,,,1,I hate [IDENTITY_P].
...,...,...,...,...,...,...,...,...,...,...,...,...
3723,spell_leet_h,3897,Gay people are all complete 4ssholes.,hateful,gay people,general,4ssholes,asshole,1232.0,256.0,866,[IDENTITY_P] are all complete 4ssholes.
3724,spell_leet_h,3898,Black people are all complete 4ssholes.,hateful,black people,general,4ssholes,asshole,1233.0,256.0,866,[IDENTITY_P] are all complete 4ssholes.
3725,spell_leet_h,3899,Disabled people are all complete 4ssholes.,hateful,disabled people,general,4ssholes,asshole,1234.0,256.0,866,[IDENTITY_P] are all complete 4ssholes.
3726,spell_leet_h,3900,Muslims are all complete 4ssholes.,hateful,Muslims,general,4ssholes,asshole,1235.0,256.0,866,[IDENTITY_P] are all complete 4ssholes.


Load annotators data.

In [4]:
annotators_data = pd.read_csv(ANNOTATORS_DATA_PATH)

annotators_data

Unnamed: 0,annotator_id,n_annotated_texts,annotator_severity,annotator_severity_quantile,extreme_annotator_severity_quantile,extreme_annotator
0,20,14,-0.23,0.370947,0.9,False
1,23,13,-0.28,0.328145,0.9,False
2,28,15,-0.06,0.479896,0.9,False
3,38,12,0.10,0.601816,0.9,False
4,39,12,0.81,0.953307,0.9,True
...,...,...,...,...,...,...
766,11062,19,0.92,0.984436,0.9,True
767,11110,12,-0.35,0.276265,0.9,False
768,11111,21,0.37,0.769131,0.9,False
769,11122,13,0.26,0.704280,0.9,False


## Load fine-tuned model

In [5]:
tokenizer, _ = get_deberta_model(
    num_labels=2,
    device=device,
    model_dir=DEBERTA_MODEL_DIR,
    use_custom_head=False
)

use_pretrained_object = True

if use_pretrained_object:
    model = DebertaWithAnnotatorHeadsPretrained.from_pretrained(CHECKPOINT_PATH).to(device=device)
else:
    model = load_sepheads_model_safetensors(
        checkpoint_path=CHECKPOINT_PATH,
        annotator_ids=annotator_ids,
        deberta_model_dir=DEBERTA_MODEL_DIR,
        device=device
    )

model.eval();

# Subset annotators data to the annotators the model was trained on.
annotators_data = annotators_data[
    annotators_data['annotator_id'].isin(model.annotator_ids)
].reset_index(drop=True)

2025-05-06 09:19:19,164 - get_deberta_model - INFO - Instantiating DeBERTa tokenizer
2025-05-06 09:19:19,637 - get_deberta_model - INFO - Instantiating DeBERTa model with default classification head
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Inference on the HateCheck dataset

Select all (or a subset) of the extreme annotators and all (or a subset of randomly selected) non-extreme ones for inference.

In [13]:
# n_extreme_annotators = None
# n_non_extreme_anotators = 'all'

# if n_extreme_annotators is None:
#     selected_annotator_ids = annotators_data[annotators_data['extreme_annotator']]['annotator_id'].tolist()

# if n_non_extreme_anotators == 'all':
#     selected_annotator_ids += annotators_data[~annotators_data['extreme_annotator']]['annotator_id'].tolist()
# elif n_non_extreme_anotators == 'equal':
#     selected_annotator_ids += annotators_data[
#         ~annotators_data['extreme_annotator']
#     ].sample(frac=1.)['annotator_id'].iloc[:len(selected_annotator_ids)].tolist()
# elif isinstance(n_non_extreme_anotators, int):
#     selected_annotator_ids += annotators_data[
#         ~annotators_data['extreme_annotator']
#     ].sample(frac=1.)['annotator_id'].iloc[:n_non_extreme_anotators].tolist()
# else:
#     raise NotImplementedError("n_non_extreme_annotator must be 'all', 'equal' or an integer")

# selected_annotator_ids = pd.Series(selected_annotator_ids, name='annotator_id')

# selected_annotator_ids

In [7]:
# test_batch_size = 16

# text_encodings = {}

# for test_batch in tqdm(np.array_split(hatecheck_data, np.ceil(len(hatecheck_data) / test_batch_size))):
#     with torch.no_grad():
#         latent_rep_batch = model.deberta_encoder(**send_batch_to_device(
#             dict(
#                 **tokenizer(
#                     test_batch['test_case'].tolist(),
#                     padding='max_length',
#                     truncation=True,
#                     max_length=512,
#                     return_tensors='pt'
#                 ),
#                 # **{'annotator_ids': torch.LongTensor([model.annotator_ids[0]] * len(test_batch)).to(device=device)}
#             ),
#             device,
#             return_batch=True
#         ))[0]

#         latent_rep_batch = model.deberta_dropout(model.deberta_pooler(
#             latent_rep_batch
#         ))

#     text_encodings.update({
#         case_id: latent_rep_batch[i, ...]
#         for i, case_id in enumerate(test_batch['case_id'])
#     })

In [8]:
# with open('/data1/moscato/personalised-hate-boundaries-data/models/sepheads_model_training_test_subsampling_2/checkpoint_42903_hateckeck_text_encodings.pkl', 'wb') as f:
#     pickle.dump(text_encodings, f)

In [15]:
test_batch_size = 16

predictions = []

# Note: text encodings are indicized AS THE SAMPLES IN THE
#       HATECHECK DATASET.
text_encodings = []

for test_batch in tqdm(np.array_split(hatecheck_data, np.ceil(len(hatecheck_data) / test_batch_size))):
    with torch.no_grad():
        # Compute the latent representation of the text.
        latent_rep_batch = model.deberta_encoder(**send_batch_to_device(
            dict(
                **tokenizer(
                    test_batch['test_case'].tolist(),
                    padding='max_length',
                    truncation=True,
                    max_length=512,
                    return_tensors='pt'
                ),
            ),
            device,
            return_batch=True
        ))[0]

        latent_rep_batch = model.deberta_dropout(model.deberta_pooler(
            latent_rep_batch
        ))

        # Append text latent representations to the final list.
        text_encodings.append(latent_rep_batch)

        # For each latent representation in the batch and each
        # annotator, compute the predicted logits.
        for i, case_id in enumerate(test_batch['case_id']):
            for annotator_id in annotators_data['annotator_id']:
                predictions.append({
                    'case_id': case_id,
                    'annotator_id': annotator_id,
                    'text_encoding': latent_rep_batch[i, ...],
                    'logits': model.classification_heads[
                        model.annotator_ids.index(annotator_id)
                    ](latent_rep_batch[i, ...])
                })

predictions = pd.DataFrame(predictions)

  0%|          | 0/233 [00:00<?, ?it/s]

In [16]:
predictions

Unnamed: 0,case_id,annotator_id,text_encoding,logits
0,1,20,"[tensor(0.5236, device='cuda:0'), tensor(-0.01...","[tensor(0.2953, device='cuda:0'), tensor(-1.41..."
1,1,23,"[tensor(0.5236, device='cuda:0'), tensor(-0.01...","[tensor(2.2350, device='cuda:0'), tensor(-0.77..."
2,1,28,"[tensor(0.5236, device='cuda:0'), tensor(-0.01...","[tensor(1.1834, device='cuda:0'), tensor(-2.46..."
3,1,38,"[tensor(0.5236, device='cuda:0'), tensor(-0.01...","[tensor(1.3103, device='cuda:0'), tensor(-3.60..."
4,1,39,"[tensor(0.5236, device='cuda:0'), tensor(-0.01...","[tensor(1.3361, device='cuda:0'), tensor(-2.94..."
...,...,...,...,...
2031755,3901,11018,"[tensor(-0.1637, device='cuda:0'), tensor(2.05...","[tensor(0.7067, device='cuda:0'), tensor(0.605..."
2031756,3901,11062,"[tensor(-0.1637, device='cuda:0'), tensor(2.05...","[tensor(0.3051, device='cuda:0'), tensor(0.621..."
2031757,3901,11111,"[tensor(-0.1637, device='cuda:0'), tensor(2.05...","[tensor(-1.8385, device='cuda:0'), tensor(0.16..."
2031758,3901,11122,"[tensor(-0.1637, device='cuda:0'), tensor(2.05...","[tensor(-1.9693, device='cuda:0'), tensor(0.55..."


In [27]:
MODEL_BASE_DIR = '/'.join(CHECKPOINT_PATH.split('/')[:-2]) + '/'
LOADED_CHECKPOINT = int(CHECKPOINT_PATH.split('/')[-2].split('-')[-1])

TEXT_ENCODINGS_PATH = os.path.join(MODEL_BASE_DIR, f'checkpoint_{LOADED_CHECKPOINT}_hatecheck_text_encodings.pkl')
PREDICTED_LOGITS_PATH = os.path.join(MODEL_BASE_DIR, f'checkpoint_{LOADED_CHECKPOINT}_hatecheck_predicted_logits.pkl')
PREDICTIONS_CATALOG_PATH = os.path.join(MODEL_BASE_DIR, f'checkpoint_{LOADED_CHECKPOINT}_hatecheck_predictions_catalog.csv')
HATECHECK_PREDICTIONS_ANNOTATORS_DATA_PATH = os.path.join(MODEL_BASE_DIR, f'checkpoint_{LOADED_CHECKPOINT}_hatecheck_predictions_annotators_data.csv')

In [30]:
# text_encodings = torch.cat(text_encodings)

# with open(TEXT_ENCODINGS_PATH, 'wb') as f:
#     pickle.dump(text_encodings, f)

In [32]:
# predicted_logits = torch.cat([
#     logits[None, ...]
#     for logits in predictions['logits']
# ])

# with open(PREDICTED_LOGITS_PATH, 'wb') as f:
#     pickle.dump(predicted_logits, f)

In [33]:
# predictions[['case_id', 'annotator_id']].to_csv(
#     PREDICTIONS_CATALOG_PATH,
#     index=None
# )

In [35]:
# annotators_data[annotators_data['annotator_id'].isin(predictions['annotator_id'])].to_csv(
#     HATECHECK_PREDICTIONS_ANNOTATORS_DATA_PATH,
#     index=None
# )

Load the predictions.

In [36]:
with open(TEXT_ENCODINGS_PATH, 'rb') as f:
    text_encodings_loaded = pickle.load(f)

with open(PREDICTED_LOGITS_PATH, 'rb') as f:
    predicted_logits_loaded = pickle.load(f)

predictions_catalog = pd.read_csv(PREDICTIONS_CATALOG_PATH)
hatecheck_predictions_annotators_data = pd.read_csv(HATECHECK_PREDICTIONS_ANNOTATORS_DATA_PATH)

In [37]:
display(predictions_catalog)
display(hatecheck_predictions_annotators_data)

text_encodings.shape, predicted_logits.shape

Unnamed: 0,case_id,annotator_id
0,1,20
1,1,23
2,1,28
3,1,38
4,1,39
...,...,...
2031755,3901,11018
2031756,3901,11062
2031757,3901,11111
2031758,3901,11122


Unnamed: 0,annotator_id,n_annotated_texts,annotator_severity,annotator_severity_quantile,extreme_annotator_severity_quantile,extreme_annotator
0,20,14,-0.23,0.370947,0.9,False
1,23,13,-0.28,0.328145,0.9,False
2,28,15,-0.06,0.479896,0.9,False
3,38,12,0.10,0.601816,0.9,False
4,39,12,0.81,0.953307,0.9,True
...,...,...,...,...,...,...
540,11018,14,-0.57,0.153048,0.9,False
541,11062,19,0.92,0.984436,0.9,True
542,11111,21,0.37,0.769131,0.9,False
543,11122,13,0.26,0.704280,0.9,False


(torch.Size([3728, 768]), torch.Size([2031760, 2]))