<a href="https://colab.research.google.com/github/HamdanXI/nlp_adventure/blob/main/bert-paradetox-with-editOps.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets
!pip install transformers
!pip install transformers[torch]

In [None]:
from datasets import load_dataset

dataset = load_dataset("HamdanXI/paradetox_with_editOps")

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['en_toxic_comment', 'en_neutral_comment', 'edit_ops'],
        num_rows: 19744
    })
})

In [32]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def find_max_lengths(toxic_comments, neutral_comments, edit_operations):
    max_len_toxic = max(len(tokenizer.encode(comment)) for comment in toxic_comments)
    max_len_neutral = max(len(tokenizer.encode(comment)) for comment in neutral_comments)
    max_len_ops = max(len(tokenizer.encode(" ".join(ops))) for ops in edit_operations)
    return max_len_toxic, max_len_neutral, max_len_ops

In [21]:
toxic_comments = [item['en_toxic_comment'] for item in dataset['train']]
neutral_comment = [item['en_neutral_comment'] for item in dataset['train']]
edit_operations = []

In [29]:
for item in dataset['train']:
    ops_as_string = ' '.join([' '.join(op) for op in item['edit_ops']])
    edit_operations.append(ops_as_string)

In [33]:
max_len_toxic, max_len_neutral, max_len_ops = find_max_lengths(toxic_comments, neutral_comment, edit_operations)

print(f"Maximum length for toxic comments: {max_len_toxic}")
print(f"Maximum length for neutral comments: {max_len_neutral}")
print(f"Maximum length for edit operations: {max_len_ops}")

Maximum length for toxic comments: 35
Maximum length for neutral comments: 35
Maximum length for edit operations: 197


In [34]:
def preprocess_data(samples):
    processed_comments = []
    labels = []

    for sample in samples:
        toxic_comment = sample['en_toxic_comment']
        edit_ops = sample['edit_ops']

        words = toxic_comment.split()

        for operation in sorted(edit_ops, key=lambda op: int(op[2]), reverse=True):
            op_type, text, index = operation[:3]
            index = int(index)

            if op_type == "replace":
                words[index:index+len(text.split())] = ['[MASK]']
            elif op_type == "delete":
                del words[index:index+len(text.split())]
            elif op_type == "insert":
                words.insert(index, '[INSERT]')

        masked_comment = ' '.join(words)

        encoded_comment = tokenizer.encode_plus(
            masked_comment,
            max_length=35,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        label_ids = tokenizer.encode_plus(
            sample['en_neutral_comment'],
            max_length=197,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )['input_ids']

        processed_comments.append(encoded_comment)
        labels.append(label_ids)

    return processed_comments, labels

In [52]:
processed_comments, labels = preprocess_data(dataset["train"])

In [51]:
from transformers import BertForMaskedLM

model = BertForMaskedLM.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.weight', 'cls.seq_relationship.weight', 'bert.pooler.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [53]:
import torch
from torch.utils.data import DataLoader, TensorDataset

input_ids = torch.cat([c['input_ids'] for c in processed_comments], dim=0)
attention_mask = torch.cat([c['attention_mask'] for c in processed_comments], dim=0)
labels_prepared = torch.cat(labels, dim=0).squeeze()

labels_prepared = labels_prepared[:, :input_ids.size(1)]

dataset_tensor = TensorDataset(input_ids, attention_mask, labels_prepared)
loader = DataLoader(dataset_tensor, batch_size=8, shuffle=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
from transformers import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)
model.train()

for epoch in range(3):
    for batch in loader:
        b_input_ids, b_attention_mask, b_labels = batch
        b_input_ids = b_input_ids.to(device)
        b_attention_mask = b_attention_mask.to(device)
        b_labels = b_labels.to(device)

        outputs = model(b_input_ids, attention_mask=b_attention_mask, labels=b_labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print(f"Loss: {loss.item()}")

Loss: 16.14881706237793
Loss: 12.795496940612793
Loss: 17.108789443969727
Loss: 15.307631492614746
Loss: 14.40674114227295
Loss: 16.537792205810547
Loss: 15.448683738708496
Loss: 15.114798545837402
Loss: 14.580933570861816
Loss: 15.276163101196289
Loss: 15.468316078186035
Loss: 15.902769088745117
Loss: 15.4613037109375
Loss: 15.623305320739746
Loss: 16.15754508972168
Loss: 14.628803253173828
Loss: 14.474635124206543
Loss: 15.366050720214844
Loss: 15.661334991455078
Loss: 14.950448036193848
Loss: 14.652499198913574
Loss: 15.817520141601562
Loss: 16.234935760498047
Loss: 16.74676513671875
Loss: 13.299561500549316
Loss: 17.144176483154297
Loss: 15.955431938171387
Loss: 14.132960319519043
Loss: 14.591659545898438
Loss: 16.16320037841797
Loss: 15.052130699157715
Loss: 16.371063232421875
Loss: 15.799471855163574
Loss: 14.796860694885254
Loss: 16.4093074798584
Loss: 15.544042587280273
Loss: 16.86667251586914
Loss: 15.185837745666504
Loss: 15.290634155273438
Loss: 14.576803207397461
Loss: 15.5