# Test a fine-tuned SepHeads model on the HateCheck dataset

__Objective:__ load a SepHeads model fine-tuned on the Kumar dataset and evaluate it on the HateCheck dataset.

**Notes:**
- For each sample, we want to produce a prediction with each annotator head, in order to simulate each annotator's behaviour via the model.

In [1]:
import sys
import pandas as pd
import numpy as np
import pickle
from tqdm.notebook import tqdm
import datasets
import torch

sys.path.append('../modules/')

from model_utils import get_deberta_model
from models import DebertaWithAnnotatorHeadsPretrained
from pytorch_utils import send_batch_to_device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

In [2]:
DATASET_DIR = '/data1/shared_datasets/'

ANNOTATORS_DATA_PATH = '/data1/moscato/personalised-hate-boundaries-data/data/kumar_perspective_clean/annotators_data.csv'

DEBERTA_MODEL_DIR = '/data1/shared_models/'
CHECKPOINT_PATH = '/data1/moscato/personalised-hate-boundaries-data/models/sepheads_model_training_test_subsampling_2/checkpoint-42903/'

## Load data

Load HateCheck data.

Notes:
- The `case_id` field is NOT a progressive integer, but it identifies texts uniquely nevertheless.

In [3]:
hatecheck_data = datasets.load_dataset(
    path="Paul/hatecheck",
    cache_dir=DATASET_DIR
)['test'].to_pandas()

hatecheck_data

Unnamed: 0,functionality,case_id,test_case,label_gold,target_ident,direction,focus_words,focus_lemma,ref_case_id,ref_templ_id,templ_id,case_templ
0,derog_neg_emote_h,1,I hate women.,hateful,women,general,hate,hate,,,1,I hate [IDENTITY_P].
1,derog_neg_emote_h,2,I hate trans people.,hateful,trans people,general,hate,hate,,,1,I hate [IDENTITY_P].
2,derog_neg_emote_h,3,I hate gay people.,hateful,gay people,general,hate,hate,,,1,I hate [IDENTITY_P].
3,derog_neg_emote_h,4,I hate black people.,hateful,black people,general,hate,hate,,,1,I hate [IDENTITY_P].
4,derog_neg_emote_h,5,I hate disabled people.,hateful,disabled people,general,hate,hate,,,1,I hate [IDENTITY_P].
...,...,...,...,...,...,...,...,...,...,...,...,...
3723,spell_leet_h,3897,Gay people are all complete 4ssholes.,hateful,gay people,general,4ssholes,asshole,1232.0,256.0,866,[IDENTITY_P] are all complete 4ssholes.
3724,spell_leet_h,3898,Black people are all complete 4ssholes.,hateful,black people,general,4ssholes,asshole,1233.0,256.0,866,[IDENTITY_P] are all complete 4ssholes.
3725,spell_leet_h,3899,Disabled people are all complete 4ssholes.,hateful,disabled people,general,4ssholes,asshole,1234.0,256.0,866,[IDENTITY_P] are all complete 4ssholes.
3726,spell_leet_h,3900,Muslims are all complete 4ssholes.,hateful,Muslims,general,4ssholes,asshole,1235.0,256.0,866,[IDENTITY_P] are all complete 4ssholes.


Load annotators data.

In [5]:
annotators_data = pd.read_csv(ANNOTATORS_DATA_PATH)

annotators_data

Unnamed: 0,annotator_id,n_annotated_texts,mean_deviation,percent_lower,extreme_annotator
0,0,40,-0.080000,7.500000,False
1,1,60,-0.061667,5.000000,False
2,2,100,-0.041631,9.000000,False
3,3,60,-0.023333,6.666667,False
4,4,20,-0.010000,0.000000,False
...,...,...,...,...,...
17105,17105,20,-0.130000,25.000000,False
17106,17106,20,0.155000,5.000000,False
17107,17107,20,-0.145000,25.000000,False
17108,17108,20,0.205000,5.000000,False


## Load fine-tuned model

In [6]:
tokenizer, _ = get_deberta_model(
    num_labels=2,
    device=device,
    model_dir=DEBERTA_MODEL_DIR,
    use_custom_head=False
)

use_pretrained_object = True

if use_pretrained_object:
    model = DebertaWithAnnotatorHeadsPretrained.from_pretrained(CHECKPOINT_PATH).to(device=device)
else:
    model = load_sepheads_model_safetensors(
        checkpoint_path=CHECKPOINT_PATH,
        annotator_ids=annotator_ids,
        deberta_model_dir=DEBERTA_MODEL_DIR,
        device=device
    )

model.eval();

# Subset annotators data to the annotators the model was trained on.
annotators_data = annotators_data[
    annotators_data['annotator_id'].isin(model.annotator_ids)
].reset_index(drop=True)

2025-03-14 11:29:00,085 - get_deberta_model - INFO - Instantiating DeBERTa tokenizer
2025-03-14 11:29:00,648 - get_deberta_model - INFO - Instantiating DeBERTa model with default classification head
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Inference on the HateCheck dataset

Select all (or a subset) of the extreme annotators and all (or a subset of randomly selected) non-extreme ones for inference.

In [22]:
n_extreme_annotators = None
n_non_extreme_anotators = 'all'

if n_extreme_annotators is None:
    selected_annotator_ids = annotators_data[annotators_data['extreme_annotator']]['annotator_id'].tolist()

if n_non_extreme_anotators == 'all':
    selected_annotator_ids += annotators_data[~annotators_data['extreme_annotator']]['annotator_id'].tolist()
elif n_non_extreme_anotators == 'equal':
    selected_annotator_ids += annotators_data[
        ~annotators_data['extreme_annotator']
    ].sample(frac=1.)['annotator_id'].iloc[:len(selected_annotator_ids)].tolist()
elif isinstance(n_non_extreme_anotators, int):
    selected_annotator_ids += annotators_data[
        ~annotators_data['extreme_annotator']
    ].sample(frac=1.)['annotator_id'].iloc[:n_non_extreme_anotators].tolist()
else:
    raise NotImplementedError("n_non_extreme_annotator must be 'all', 'equal' or an integer")

selected_annotator_ids = pd.Series(selected_annotator_ids, name='annotator_id')

selected_annotator_ids

0         603
1         633
2        1856
3        2008
4        2980
        ...  
2878    15955
2879    15968
2880    16243
2881    16688
2882    16903
Name: annotator_id, Length: 2883, dtype: int64

In [214]:
# test_batch_size = 16

# text_encodings = {}

# for test_batch in tqdm(np.array_split(hatecheck_data, np.ceil(len(hatecheck_data) / test_batch_size))):
#     with torch.no_grad():
#         latent_rep_batch = model.deberta_encoder(**send_batch_to_device(
#             dict(
#                 **tokenizer(
#                     test_batch['test_case'].tolist(),
#                     padding='max_length',
#                     truncation=True,
#                     max_length=512,
#                     return_tensors='pt'
#                 ),
#                 # **{'annotator_ids': torch.LongTensor([model.annotator_ids[0]] * len(test_batch)).to(device=device)}
#             ),
#             device,
#             return_batch=True
#         ))[0]

#         latent_rep_batch = model.deberta_dropout(model.deberta_pooler(
#             latent_rep_batch
#         ))

#     text_encodings.update({
#         case_id: latent_rep_batch[i, ...]
#         for i, case_id in enumerate(test_batch['case_id'])
#     })

In [151]:
# with open('/data1/moscato/personalised-hate-boundaries-data/models/sepheads_model_training_test_subsampling_2/checkpoint_42903_hateckeck_text_encodings.pkl', 'wb') as f:
#     pickle.dump(text_encodings, f)

In [42]:
test_batch_size = 16

predictions = []

# Note: text encodings are indicized AS THE SAMPLES IN THE
#       HATECHECK DATASET.
text_encodings = []

for test_batch in tqdm(np.array_split(hatecheck_data, np.ceil(len(hatecheck_data) / test_batch_size))):
    with torch.no_grad():
        # Compute the latent representation of the text.
        latent_rep_batch = model.deberta_encoder(**send_batch_to_device(
            dict(
                **tokenizer(
                    test_batch['test_case'].tolist(),
                    padding='max_length',
                    truncation=True,
                    max_length=512,
                    return_tensors='pt'
                ),
                # **{'annotator_ids': torch.LongTensor([model.annotator_ids[0]] * len(test_batch)).to(device=device)}
            ),
            device,
            return_batch=True
        ))[0]

        latent_rep_batch = model.deberta_dropout(model.deberta_pooler(
            latent_rep_batch
        ))

        # Append text latent representations to the final list.
        text_encodings.append(latent_rep_batch)

        # For each latent representation in the batch and each
        # annotator, compute the predicted logits.
        for i, case_id in enumerate(test_batch['case_id']):
            for annotator_id in selected_annotator_ids:
                predictions.append({
                    'case_id': case_id,
                    'annotator_id': annotator_id,
                    'text_encoding': latent_rep_batch[i, ...],
                    'logits': model.classification_heads[
                        model.annotator_ids.index(annotator_id)
                    ](latent_rep_batch[i, ...])
                })

predictions = pd.DataFrame(predictions)

  return bound(*args, **kwds)


  0%|          | 0/233 [00:00<?, ?it/s]

In [44]:
predictions

Unnamed: 0,case_id,annotator_id,text_encoding,logits
0,1,603,"[tensor(-0.1653, device='cuda:0'), tensor(0.13...","[tensor(0.5429, device='cuda:0'), tensor(-0.23..."
1,1,633,"[tensor(-0.1653, device='cuda:0'), tensor(0.13...","[tensor(2.0611, device='cuda:0'), tensor(-0.68..."
2,1,1856,"[tensor(-0.1653, device='cuda:0'), tensor(0.13...","[tensor(0.8629, device='cuda:0'), tensor(-0.85..."
3,1,2008,"[tensor(-0.1653, device='cuda:0'), tensor(0.13...","[tensor(0.9548, device='cuda:0'), tensor(-1.07..."
4,1,2980,"[tensor(-0.1653, device='cuda:0'), tensor(0.13...","[tensor(-0.3418, device='cuda:0'), tensor(0.48..."
...,...,...,...,...
10747819,3901,15955,"[tensor(-0.1543, device='cuda:0'), tensor(-0.1...","[tensor(-0.9926, device='cuda:0'), tensor(2.41..."
10747820,3901,15968,"[tensor(-0.1543, device='cuda:0'), tensor(-0.1...","[tensor(-0.3213, device='cuda:0'), tensor(1.07..."
10747821,3901,16243,"[tensor(-0.1543, device='cuda:0'), tensor(-0.1...","[tensor(-0.1416, device='cuda:0'), tensor(0.04..."
10747822,3901,16688,"[tensor(-0.1543, device='cuda:0'), tensor(-0.1...","[tensor(-0.8853, device='cuda:0'), tensor(0.86..."


In [37]:
TEXT_ENCODINGS_PATH = '/data1/moscato/personalised-hate-boundaries-data/models/sepheads_model_training_test_subsampling_2/checkpoint_42903_hatecheck_text_encodings.pkl'
PREDICTED_LOGITS_PATH = '/data1/moscato/personalised-hate-boundaries-data/models/sepheads_model_training_test_subsampling_2/checkpoint_42903_hatecheck_predicted_logits.pkl'
PREDICTIONS_CATALOG_PATH = '/data1/moscato/personalised-hate-boundaries-data/models/sepheads_model_training_test_subsampling_2/checkpoint_42903_hatecheck_predictions_catalog.csv'
HATECHECK_PREDICTIONS_ANNOTATORS_DATA_PATH = '/data1/moscato/personalised-hate-boundaries-data/models/sepheads_model_training_test_subsampling_2/checkpoint_42903_hatecheck_predictions_annotators_data.csv'

In [63]:
# text_encodings = torch.cat(text_encodings)

# with open(TEXT_ENCODINGS_PATH, 'wb') as f:
#     pickle.dump(text_encodings, f)

In [64]:
predicted_logits = torch.cat([
    logits[None, ...]
    for logits in predictions['logits']
])

with open(PREDICTED_LOGITS_PATH, 'wb') as f:
    pickle.dump(predicted_logits, f)

In [65]:
predictions[['case_id', 'annotator_id']].to_csv(
    PREDICTIONS_CATALOG_PATH,
    index=None
)

In [68]:
annotators_data[annotators_data['annotator_id'].isin(predictions['annotator_id'])].to_csv(
    HATECHECK_PREDICTIONS_ANNOTATORS_DATA_PATH,
    index=None
)

Load the predictions.

In [71]:
with open(TEXT_ENCODINGS_PATH, 'rb') as f:
    text_encodings_loaded = pickle.load(f)

with open(PREDICTED_LOGITS_PATH, 'rb') as f:
    predicted_logits_loaded = pickle.load(f)

predictions_catalog = pd.read_csv(PREDICTIONS_CATALOG_PATH)
hatecheck_predictions_annotators_data = pd.read_csv(HATECHECK_PREDICTIONS_ANNOTATORS_DATA_PATH)

In [72]:
display(predictions_catalog)
display(hatecheck_predictions_annotators_data)

text_encodings.shape, predicted_logits.shape

Unnamed: 0,case_id,annotator_id
0,1,603
1,1,633
2,1,1856
3,1,2008
4,1,2980
...,...,...
10747819,3901,15955
10747820,3901,15968
10747821,3901,16243
10747822,3901,16688


Unnamed: 0,annotator_id,n_annotated_texts,mean_deviation,percent_lower,extreme_annotator
0,1,60,-0.061667,5.000000,False
1,2,100,-0.041631,9.000000,False
2,3,60,-0.023333,6.666667,False
3,14,60,0.075000,0.000000,False
4,15,80,0.071364,16.250000,False
...,...,...,...,...,...
2878,16935,20,-0.470000,80.000000,True
2879,16974,20,-0.650000,90.000000,True
2880,16992,20,-0.510000,60.000000,True
2881,17056,20,-0.460000,70.000000,True


(torch.Size([3728, 768]), torch.Size([10747824, 2]))