In [257]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration, Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from datasets import load_dataset, concatenate_datasets, ClassLabel, Dataset
import numpy as np
import huggingface_hub
from sklearn.metrics import f1_score
import gc
import torch.nn as nn
from tqdm import tqdm
import pandas as pd

token = 'Nope'
huggingface_hub.login(token)

In [31]:
def add_language_column(dataset, language):
    # lang_map = {'en': 0, 'de': 1, 'am': 2, 'es': 3, 'ru': 4, 'zh': 5, 'ar': 6, 'uk': 7, 'hi': 8}
    return dataset.add_column("Language", [language] * len(dataset))

dataset = load_dataset('textdetox/multilingual_toxic_spans')

datasets_with_language = []
for lang, data in dataset.items():
    dataset_with_language = add_language_column(data, lang)
    datasets_with_language.append(dataset_with_language)

toxic_spans = concatenate_datasets(datasets_with_language)

toxic_spans = toxic_spans.filter(lambda example: example['Negative Connotations'] is not None)

synthdetoxm = load_dataset('s-nlp/synthdetoxm')['train']

In [134]:
def just_tokenize(examples):
    tokenized_input = tokenizer([examples['Sentence']], padding='max_length', max_length=128, return_tensors="pt", is_split_into_words=True)
    
    word_tokens = tokenized_input.tokens()
    word_ids = tokenized_input.word_ids()

    toxic_tokens = tokenizer([examples['Negative Connotations'].replace(',', ' ')], padding='max_length', max_length=128, return_tensors="pt", is_split_into_words=True).tokens()

    return {'Word tokens': word_tokens, 'Toxic tokens': toxic_tokens, 'input_ids': tokenized_input['input_ids'], 'attention_mask': tokenized_input['attention_mask']}

def align_labels_with_tokens(example):
    word_tokens = example['Word tokens']
    toxic_something = example['Toxic tokens']
    toxic_tokens = []
    for tk in toxic_something:
        if tk not in tokenizer.all_special_tokens:
            toxic_tokens.append(tk)
    
    token_classes = []
    for wt in word_tokens:
        if wt in toxic_tokens and wt!='▁' and wt!='' and wt!=' ':
            token_classes.append(1)
        else:
            token_classes.append(0)

    return {'labels': token_classes}


def go_squeeze(examples):
    input_ids = examples['input_ids'].squeeze()
    attention_mask = examples['attention_mask'].squeeze()
    labels = examples['labels'].squeeze()

    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}

toxic_spans = toxic_spans.map(
    just_tokenize,
    batched=False,
)

toxic_spans = toxic_spans.map(
    align_labels_with_tokens,
    batched=False,
)

toxic_spans.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

toxic_spans = toxic_spans.map(
    go_squeeze,
    batched=False,
)

Map:   0%|          | 0/8729 [00:00<?, ? examples/s]

In [217]:
# toxic_spans[0]

In [138]:
def preprocess_function(examples):
    inputs = []
    targets = []
    
    for tox, neu, lang in zip(examples['toxic_sentence'], examples['neutral_sentence'], examples['lang']):
        lang_map = {
            'de': 'german',
            'fr': 'french',
            'es': 'spanish',
            'ru': 'russian'
        }
        if tox:  # If toxic text is not empty
            inputs.append("Detoxify and return answer in " + lang_map[lang] + ": " + tox)
            targets.append(neu)
    
    model_inputs = tokenizer(inputs, padding='max_length', truncation=True, max_length=128, return_tensors="pt")
    
    # Tokenize target texts
    labels = tokenizer(targets, padding='max_length', truncation=True, max_length=128, return_tensors="pt").input_ids
    
    model_inputs["labels"] = labels
    
    return model_inputs

# Preprocess the dataset
# train_test = dataset['train'].train_test_split(test_size=0.2, shuffle=True, seed=66)
# train_dataset = train_test['train']
# val_dataset = train_test['test']

synthdetoxm = synthdetoxm.map(preprocess_function, batched=True)

synthdetoxm.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

Map:   0%|          | 0/16000 [00:00<?, ? examples/s]

In [252]:
toxic_spans

Dataset({
    features: ['Sentence', 'Negative Connotations', 'Language', 'Word tokens', 'Toxic tokens', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 8729
})

In [159]:
synthdetoxm

Dataset({
    features: ['toxic_sentence', 'neutral_sentence', 'lang', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 16000
})

In [259]:
tokenizer = AutoTokenizer.from_pretrained('bigscience/mt0-base')

class TwoLossesModel(MT5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        self.classification_head = nn.Linear(config.d_model, 2)  # Assuming binary classification

model = TwoLossesModel.from_pretrained("bigscience/mt0-base")

Some weights of TwoLossesModel were not initialized from the model checkpoint at bigscience/mt0-base and are newly initialized: ['classification_head.bias', 'classification_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [240]:
df_for_collator = Dataset.from_dict({'cls_input_ids':toxic_spans['input_ids'][:8000], 
                                     'cls_attention_mask': toxic_spans['attention_mask'][:8000], 
                                     'cls_labels': synthdetoxm['labels'][:8000], 
                                     'detox_input_ids': synthdetoxm['input_ids'][:8000], 
                                     'detox_attention_mask': synthdetoxm['attention_mask'][:8000],
                                     'detox_labels': synthdetoxm['labels'][:8000]})

In [180]:
def prepare_dataset(toxic_spans_data, synthdetoxm):
    df_for_collator = pd.DataFrame(np.array([toxic_spans_data['Sentence'][:8000], 
                                         toxic_spans_data['Negative Connotations'][:8000], 
                                         synthdetoxm['toxic_sentence'][:8000], 
                                         synthdetoxm['neutral_sentence'][:8000], 
                                         synthdetoxm['lang'][:8000]]).T,
                               columns=['cls_data', 'cls_labels', 'detox_data', 'detox_labels', 'detox_lang'])
    return Dataset.from_pandas(df_for_collator)


In [241]:
# dataset_prepared = prepare_dataset(toxic_spans_data, synthdetoxm)

train_test = df_for_collator.train_test_split(test_size=0.2, shuffle=True, seed=42)
train_dataset = train_test['train']
eval_dataset = train_test['test']

In [245]:
train_dataset.set_format(type='torch', columns=['cls_input_ids', 'cls_attention_mask', 'cls_labels', 'detox_input_ids', 'detox_attention_mask', 'detox_labels'])
eval_dataset.set_format(type='torch', columns=['cls_input_ids', 'cls_attention_mask', 'cls_labels', 'detox_input_ids', 'detox_attention_mask', 'detox_labels'])

In [321]:
class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Define your classification loss
        self.classification_loss_fn = nn.CrossEntropyLoss()

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # Unpack inputs
        cls_input_ids = inputs.get("cls_input_ids")
        cls_attention_mask = inputs.get("cls_attention_mask")
        cls_labels = inputs.get("cls_labels")
        detox_input_ids = inputs.get("detox_input_ids")
        detox_attention_mask = inputs.get("detox_attention_mask")
        detox_labels = inputs.get("detox_labels")

        # for i in [cls_input_ids, cls_attention_mask, cls_labels, detox_input_ids, detox_attention_mask, detox_labels]:
        #     print(i.shape)

        # Forward pass through the model for detoxification
        detox_loss = torch.tensor(0.0, device=detox_input_ids.device)
        if detox_labels is not None:
            detox_outputs = model(input_ids=detox_input_ids, attention_mask=detox_attention_mask, labels=detox_labels)
            detox_loss = detox_outputs.loss

        classification_loss = torch.tensor(0.0, device=detox_input_ids.device)
        if cls_labels is not None:
            encoder_outputs = model.encoder(input_ids=cls_input_ids, attention_mask=cls_attention_mask, return_dict=True)
            hidden_states = encoder_outputs.last_hidden_state

            # Mean pooling over the sequence length
            # pooled_states = hidden_states.mean(dim=1)  # [batch_size, d_model]
            # print(pooled_states.shape)
            classification_logits = model.classification_head(hidden_states)

            print(classification_logits.shape, cls_labels.shape)
            classification_loss = self.classification_loss_fn(classification_logits, cls_labels)

        # Combine the losses (you can use a weighted sum if needed)
        total_loss = seq2seq_loss + classification_loss

        if return_outputs:
            return total_loss, outputs

        return total_loss

# Example usage
training_args = Seq2SeqTrainingArguments(
    output_dir="./two_losses_results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    warmup_steps=50,
    logging_steps=20,
    save_total_limit=3,
    metric_for_best_model='loss',
    greater_is_better=False,
    remove_unused_columns=False,
    report_to="none",
    seed=42,
)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    # data_collator=data_collator,
)
# trainer.train()

In [322]:
trainer.train()

torch.Size([8, 128, 2]) torch.Size([8, 128])


RuntimeError: Expected target size [8, 2], got [8, 128]