In [1]:
from tqdm.auto import tqdm
from torch import nn
import torch

# my imports
import sys
sys.path.insert(1, '../')
from src.data.load_data import load_tokenized_data
from src.models.roberta_toxicity_classifier import RTCModel

Some weights of the model checkpoint at SkolkovoInstitute/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [2]:
model = RTCModel()

model('you are amazing')

Some weights of the model checkpoint at SkolkovoInstitute/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[4.656709671020508, -4.911500453948975]

In [3]:
df = load_tokenized_data(path='../data/raw/filtered.tsv',
                         cache_path='../data/processed/tokenized_roberta.tsv',
                         tokenizer=model.tokenizer, 
                         flatten=True)

In [4]:
class ToxicDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe):
        self.raw_data = dataframe

        self.texts = dataframe['text'].tolist()
        self.targets = dataframe['toxicity'].tolist()

        self.inputs = []
        for input, target in zip(self.texts, self.targets):
            model_input = { 'input_ids': input, 'labels': target }
            self.inputs.append(model_input)

    def __getitem__(self, idx):
        return self.inputs[idx]

    def __len__(self):
        return len(self.inputs)

In [5]:
dataset = ToxicDataset(df)

In [7]:
count = 0
softmax = nn.Softmax(1)
for sample in tqdm(dataset, total=len(dataset)):
    if count > 10: break
    logits = model(torch.tensor(sample['input_ids']).reshape(1, -1)).logits
    output = softmax(logits)

    roberta_toxic = output[0, 1] > output[0, 0]
    label_toxic = sample['labels'] > 0.5
    count += 1

    if roberta_toxic ^ label_toxic:
        text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True)
        readable_output = output.detach().numpy().tolist()[0]
        print(f'`{text}`: output {readable_output} vs. label {sample["labels"]:0.2f}')
    

  0%|          | 0/1155554 [00:00<?, ?it/s]

`I'm not gonna have a child......with the same genetic disorder as me who's gonna die. L...`: output [0.8488573431968689, 0.15114262700080872] vs. label 0.95
`Briggs, what the hell is going on?`: output [0.5176685452461243, 0.48233142495155334] vs. label 0.84


* threshold: 0.5; discrepancy: 13.1%

In [7]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=model.collate_batch)

total_discrepancy = 0
total = 0
toxicity_threshold = 0.5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
with tqdm(dataloader, total=len(dataloader), desc='Evaluating') as pb:
    for samples in pb:
        logits = model(samples)

        roberta_toxic = logits[:, 1] > logits[:, 0]
        label_toxic = samples['labels'].to(device) > toxicity_threshold

        total_discrepancy += torch.sum(roberta_toxic ^ label_toxic).item()
        total += len(samples.input_ids)

        pb.set_postfix({'Discrepancy': total_discrepancy / total})

Evaluating:   0%|          | 0/36112 [00:00<?, ?it/s]

KeyboardInterrupt: 