# Training the SepHeads model on the MHS dataset

__Objective:__ train the SepHeads model (pre-trained DeBERTa text encoder, annotator-specific classification heads) on the MHS dataset

In [1]:
import os
import sys
from copy import deepcopy
import pandas as pd
import torch
import datasets
import transformers

sys.path.append('../../modules/')

from custom_logger import get_logger
from data_utils import subsample_dataset
from model_utils import get_deberta_model
from models import DebertaWithAnnotatorHeads, DebertaWithAnnotatorHeadsPretrained, DebertaWithAnnotatorHeadsPretrainedConfig
from training_metrics import compute_metrics_sklearn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger = get_logger('sepheads_model_training_mhs')

%load_ext autoreload
%autoreload 2

In [2]:
# Load data.
DATASET_NAME = 'mhs'

DATASET_PATHS = {
    'popquorn': '../data/samples/POPQUORN_offensiveness.csv',
    'kumar': {
        'train': '/data1/moscato/personalised-hate-boundaries-data/data/kumar_perspective_clean/kumar_processed_with_ID_and_full_perspective_clean_train.csv',
        # 'train':  '/data/milanlp/moscato/personal_hate_bounds_data/kumar_processed_with_ID_and_full_perspective_clean.csv',
        'test': '/data1/moscato/personalised-hate-boundaries-data/data/kumar_perspective_clean/kumar_processed_with_ID_and_full_perspective_clean_test.csv',
        'annotators_data': '/data1/moscato/personalised-hate-boundaries-data/data/kumar_perspective_clean/annotators_data.csv'
    },
    # Data used for the original training./
    # 'mhs': {
    #     'train': '/data1/moscato/personalised-hate-boundaries-data/data/measuring_hate_speech_data_clean/mhs_clean_train.csv',
    #     'test': '/data1/moscato/personalised-hate-boundaries-data/data/measuring_hate_speech_data_clean/mhs_clean_test.csv'
    # }
    # New data (more samples, 10 samples per annotator.
    'mhs': {
        'train': '/data1/moscato/personalised-hate-boundaries-data/data/measuring_hate_speech_data_clean/mhs_clean_train_10_samples_per_annotator.csv',
        'test': '/data1/moscato/personalised-hate-boundaries-data/data/measuring_hate_speech_data_clean/mhs_clean_test_10_samples_per_annotator.csv'
    }
}

logger.info(f'Loading data from: {DATASET_PATHS[DATASET_NAME]}')

training_data = pd.read_csv(DATASET_PATHS[DATASET_NAME]['train']).drop(columns=['extreme_annotator'])
test_data = pd.read_csv(DATASET_PATHS[DATASET_NAME]['test']).drop(columns=['extreme_annotator'])

annotator_ids = sorted(training_data['annotator_id'].unique())

logger.info(
    f'N annotators: {len(annotator_ids)} | N training samples: {len(training_data)}'
    f' | N test samples: {len(test_data)}'
)

2025-05-12 19:06:51,582 - sepheads_model_training_mhs - INFO - Loading data from: {'train': '/data1/moscato/personalised-hate-boundaries-data/data/measuring_hate_speech_data_clean/mhs_clean_train_10_samples_per_annotator.csv', 'test': '/data1/moscato/personalised-hate-boundaries-data/data/measuring_hate_speech_data_clean/mhs_clean_test_10_samples_per_annotator.csv'}
2025-05-12 19:06:51,640 - sepheads_model_training_mhs - INFO - N annotators: 2219 | N training samples: 21742 | N test samples: 9379


In [3]:
(
    training_data.groupby('annotator_id')['text_id'].count().min(),
    test_data.groupby('annotator_id')['text_id'].count().min()
)

(np.int64(7), np.int64(3))

In [4]:
training_data['toxic_score'].mean()

np.float64(0.35525710606199984)

In [5]:
# Instantiate the DeBERTa text encoder.
logger.info('Instantiating the SepHeads model')

num_labels = training_data['toxic_score'].unique().shape[0]

model_dir = '/data1/shared_models/'

logger.info(f'N labels found in training data: {num_labels}')

deberta_tokenizer, deberta_model = get_deberta_model(
    num_labels,
    model_dir,
    device,
    use_custom_head=False,
    pooler_out_features=768,  # Default: 768.
    pooler_drop_prob=0.0,  # Default: 0.0
    classifier_drop_prob=0.1,  # Default: 0.1
    use_fast_tokenizer=False
)

deberta_with_annotator_heads_model = DebertaWithAnnotatorHeads(
    deberta_encoder=deepcopy(deberta_model.deberta),
    deberta_pooler=deepcopy(deberta_model.pooler),
    deberta_dropout=deepcopy(deberta_model.dropout),
    num_labels=num_labels,
    annotator_ids=annotator_ids,
)

del deberta_model

2025-05-12 19:07:09,139 - sepheads_model_training_mhs - INFO - Instantiating the SepHeads model
2025-05-12 19:07:09,140 - sepheads_model_training_mhs - INFO - N labels found in training data: 2
2025-05-12 19:07:09,140 - sepheads_model_training_mhs - INFO - Instantiating DeBERTa tokenizer
2025-05-12 19:07:09,829 - sepheads_model_training_mhs - INFO - Instantiating DeBERTa model with default classification head
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
def tokenize_function(examples):
    return deberta_tokenizer(
        examples["text"],
        padding='max_length',
        truncation=True,
        max_length=512,
        # return_tensors='pt'
    )

In [7]:
# Create tokenized datasets.
logger.info('Creating tokenized datasets')

tokenized_training_data = (
    # Create datast object from the DataFrame.
    datasets.Dataset.from_dict(
        training_data[[
            'text',
            'toxic_score',
            'annotator_id'
        ]].rename(
            columns={
                'toxic_score': 'label',
                'annotator_id': 'annotator_ids',
            }
        )
        .to_dict(orient='list')
    )
    # Tokenize.
    .map(tokenize_function, batched=True)
    # Remove useless column.
    .remove_columns("text")
    .shuffle()
    .flatten_indices()
)

tokenized_test_data = (
    # Create datast object from the DataFrame.
    datasets.Dataset.from_dict(
        test_data[[
            'text',
            'toxic_score',
            'annotator_id'
        ]].rename(
            columns={
                'toxic_score': 'label',
                'annotator_id': 'annotator_ids',
            }
        )
        .to_dict(orient='list')
    )
    # Tokenize.
    .map(tokenize_function, batched=True)
    # Remove useless column.
    .remove_columns("text")
)

2025-05-12 19:07:26,128 - sepheads_model_training_mhs - INFO - Creating tokenized datasets


Map:   0%|          | 0/21742 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/21742 [00:00<?, ? examples/s]

Map:   0%|          | 0/9379 [00:00<?, ? examples/s]

In [8]:
EXPERIMENT_ID = 'sepheads_model_training_mhs_enlarged_dataset_1'
MODEL_OUTPUT_DIR = f'/data1/moscato/personalised-hate-boundaries-data/models/mhs/{EXPERIMENT_ID}/'
N_EPOCHS = 10

training_args = transformers.TrainingArguments(
    output_dir=MODEL_OUTPUT_DIR,
    eval_strategy="epoch",
    save_strategy="epoch",  # Options: 'no', 'epoch', 'steps' (requires the `save_steps` argument to be set though).
    save_total_limit=2,
    load_best_model_at_end=True,
    learning_rate=5e-5,
    per_device_train_batch_size=16,  # Default: 8.
    gradient_accumulation_steps=1,  # Default: 1.
    per_device_eval_batch_size=32,  # Default: 8.
    num_train_epochs=N_EPOCHS,
    warmup_ratio=0.0,  # For linear warmup of learning rate.
    metric_for_best_model="f1",
    push_to_hub=False,
    # label_names=list(roberta_classifier.config.id2label.keys()),
    logging_strategy='epoch',
    logging_first_step=True,
    logging_dir=f'../../tensorboard_logs/{EXPERIMENT_ID}/',
    # logging_steps=10,
    disable_tqdm=False
)

data_collator = transformers.DataCollatorWithPadding(tokenizer=deberta_tokenizer)

trainer = transformers.Trainer(
    model=deberta_with_annotator_heads_model,
    args=training_args,
    train_dataset=tokenized_training_data,
    eval_dataset=tokenized_test_data,
    data_collator=data_collator,
    tokenizer=deberta_tokenizer,
    compute_metrics=compute_metrics_sklearn,
)

  trainer = transformers.Trainer(
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [9]:
training_output = trainer.train()

logger.info('Saving config file in the checkpoints directories')

for checkpoint_dir in os.listdir(MODEL_OUTPUT_DIR):
    CHECKPOINT_PATH = os.path.join(MODEL_OUTPUT_DIR, checkpoint_dir)

    config = DebertaWithAnnotatorHeadsPretrainedConfig(
        num_labels=2,
        annotator_ids=[int(aid) for aid in annotator_ids],
        deberta_model_dir=model_dir
    )
    
    config.save_pretrained(CHECKPOINT_PATH)



Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

## Check: manually reproduce the metrics seen during training

In [15]:
import numpy as np
from tqdm.notebook import tqdm
import pickle
from sklearn.metrics import classification_report
from pytorch_utils import send_batch_to_device

In [7]:
os.listdir(MODEL_OUTPUT_DIR)

['checkpoint-1611',
 'checkpoint-1790',
 'checkpoint_1611_hatecheck_text_encodings.pkl',
 'checkpoint_1611_hatecheck_predicted_logits.pkl',
 'checkpoint_1611_hatecheck_predictions_catalog.csv',
 'checkpoint_1611_hatecheck_predictions_annotators_data.csv']

In [8]:
checkpoint_steps = 1611

CHECKPOINT_PATH = os.path.join(MODEL_OUTPUT_DIR, f'checkpoint-{checkpoint_steps}/')

**Note:** if loading the model as a `DebertaWithAnnotatorHeadsPretrained` object (see the `use_pretrained_object` option below), in case the training did not save a config object (e.g. if a `DebertaWithAnnotatorHeads` was used instead of a `DebertaWithAnnotatorHeadsPretrained`), then the config must be created and placed in the checkpoint's directory. If we do this, we need to make sure that the annoatator IDs are the same as the ones we used for training (i.e. that the data above was loaded exactly as during training).

In [42]:
# config = DebertaWithAnnotatorHeadsPretrainedConfig(
#     num_labels=2,
#     annotator_ids=[int(aid) for aid in annotator_ids],
#     deberta_model_dir=model_dir
# )

# config.save_pretrained(CHECKPOINT_PATH)

# # config = DebertaWithAnnotatorHeadsPretrainedConfig.from_pretrained(MODEL_DIR)

In [9]:
classifier_loaded = DebertaWithAnnotatorHeadsPretrained.from_pretrained(CHECKPOINT_PATH).to(device=device)

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
predicted_logits = []

# Loop over test samples and accumulate the predictions.
for row in tqdm(test_data.to_dict(orient='records')):
    with torch.no_grad():
        output = classifier_loaded(**send_batch_to_device(
            dict(
                **deberta_tokenizer(
                    row['text'],
                    padding='max_length',
                    truncation=True,
                    max_length=512,
                    return_tensors='pt'
                ),
                **{'annotator_ids': torch.LongTensor([row['annotator_id']]).to(device=device)}
            ),
            device,
            return_batch=True
        ))['logits']

    predicted_logits.append(output)

predicted_logits = torch.cat(predicted_logits).cpu().numpy()
predicted_toxic_score = np.argmax(predicted_logits, axis=-1)

  0%|          | 0/2257 [00:00<?, ?it/s]

In [17]:
PREDICTED_LOGITS_OUTPUT_PATH = '/data1/moscato/personalised-hate-boundaries-data/models/mhs/sepheads_model_training_mhs_test_2/checkpoint_1611_mhs_predicted_logits.pkl'

# with open(PREDICTED_LOGITS_OUTPUT_PATH, 'wb') as f:
#     pickle.dump(predicted_logits, f)

In [20]:
print(classification_report(
    y_pred=predicted_toxic_score,
    y_true=test_data['toxic_score']
))

              precision    recall  f1-score   support

           0       0.74      0.92      0.82      1211
           1       0.87      0.64      0.74      1046

    accuracy                           0.79      2257
   macro avg       0.81      0.78      0.78      2257
weighted avg       0.80      0.79      0.78      2257

