In [1]:
import itertools
import string
import os
import pandas as pd
import numpy as np
from ast import literal_eval

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader

- conda install -c conda-forge transformers
- conda install -c conda-forge huggingface_hub==0.2.1


In [12]:
from transformers import BertTokenizerFast

In [6]:
SPECIAL_CHARACTERS = string.whitespace

def _contiguous_ranges(span_list):
    """
    Group character-level labels into range of intervals: [1, 2, 3, 5, 6, 7] -> [(1,3), (5,7)].
    Returns begin and end inclusive
    """
    output = []
    for _, span in itertools.groupby(enumerate(span_list), lambda p: p[1] - p[0]):
        span = list(span)
        output.append((span[0][1], span[-1][1]))
    return output


def _fix_spans(spans, text, special_characters=SPECIAL_CHARACTERS, collapse=False):
    """
    Applies minor edits to trim spans and remove singletons.
    If spans begin/end in the middle of a word, correct according to collapse strategy:
        If false, expand spans until word limits; if true collapse until word limits
    """
    cleaned_spans = []
    
    for begin, end in _contiguous_ranges(spans):
        # Remove special characters
        while text[begin] in special_characters and begin < end:
            begin += 1
        while text[end] in special_characters and begin < end:
            end -= 1
            
        # Keep full word
        while 0 < begin < end and text[begin - 1].isalnum():
            offset_move = 1 if collapse else -1
            begin += offset_move
        while len(text) - 1 > end > begin and text[end + 1].isalnum():
            offset_move = -1 if collapse else 1
            end += offset_move
            
        # Remove singletons (only one character)
        if end - begin > 1:
            cleaned_spans.extend(range(begin, end + 1))
            
    return cleaned_spans

def get_sentences_from_data_split(data_path, split):
    """
    @param data_path: base path to load data
    @param split: tsd_train or tsd_trial or tsd_test
    return sentences (List[str]), original_spans (List[List[int]]), fixed_spans (List[List[int]])
    """
    sentences, original_spans, fixed_spans = [], [], []
    data = pd.read_csv(os.path.join(data_path, split + '.csv'))
    
    for i in range(data.shape[0]):
        if split == 'tsd_test':
            span = fixed_span = []
        else:
            span = literal_eval(data['spans'][i])
            fixed_span = _fix_spans(span, data['text'][i])
        
        sentences.append(data['text'][i])
        original_spans.append(span)
        fixed_spans.append(fixed_span)

    return sentences, original_spans, fixed_spans

In [7]:
data = pd.read_csv("data/tsd_trial.csv")

In [8]:
sentences, original_spans, fixed_spans = get_sentences_from_data_split("data", "tsd_trial")

## Tokenize and offset label
- BertTokenizer will return the subwords

In [24]:
def preprocess_and_tokenize(self, sentences, spans):
    all_token_ids = []
    all_offsets = []
    all_att_masks = []
    all_special_masks = []
    all_label_ids = []

    for sentence, span in zip(sentences, spans):
        # Pad to 512. All sentences in the dataset have a lower number of tokens.
        tokenized = self.tokenizer(sentence, padding='max_length', max_length=512, return_attention_mask=True,
                                   return_special_tokens_mask=True,
                                   return_offsets_mapping=True, return_token_type_ids=False)

        all_token_ids.append(tokenized['input_ids'])
        all_offsets.append(tokenized['offset_mapping'])
        all_att_masks.append(tokenized['attention_mask'])
        all_special_masks.append(tokenized['special_tokens_mask'])
        all_label_ids.append([self.off2tox(offset, span) for offset in tokenized['offset_mapping']])

    return all_token_ids, all_offsets, all_att_masks, all_special_masks, all_label_ids


def off2tox(offsets, spans):
    # Padded items
    if offsets == (0, 0):
        return 0
    toxicity = offsets[0] in spans
    return int(toxicity)

In [None]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

tokenized = tokenizer(sentences[0], return_attention_mask=True,
                                   return_special_tokens_mask=True,
                                   return_offsets_mapping=True, return_token_type_ids=False)

In [9]:
sentences[0]

"Because he's a moron and a bigot. It's not any more complicated than that."

In [29]:
fixed_spans[0]

[15, 16, 17, 18, 19, 27, 28, 29, 30, 31]

In [37]:
print(sentences[0][15:20], sentences[0][27:32])

moron bigot


In [40]:
print([sentences[0][i:j+1] for (i,j) in tokenized["offset_mapping"]])

['B', 'Because ', "he'", "'s", 's ', 'a ', 'moro', 'on ', 'and ', 'a ', 'bigo', 'ot.', '. ', "It'", "'s", 's ', 'not ', 'any ', 'more ', 'complicated ', 'than ', 'that.', '.', 'B']


In [26]:
[off2tox(offset, fixed_spans[0]) for offset in tokenized['offset_mapping']]

[0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [3]:
class ToxicDataset(Dataset):
    def __init__(self, data_path, split):
        self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

        self.original_sentences, self.original_spans, self.fixed_spans = \
            self.get_sentences_from_data_split(data_path, split)

        self.token_ids, self.offsets, self.att_masks, self.special_masks, self.labels_ids = \
            self.preprocess_and_tokenize(self.original_sentences, self.fixed_spans)

    def __len__(self):
        return len(self.original_sentences)

    def __getitem__(self, index):
        token_ids = self.token_ids[index]
        offsets = self.offsets[index]
        att_masks = self.att_masks[index]
        special_masks = self.special_masks[index]
        label_ids = self.labels_ids[index]
        original_spans = self.original_spans[index]
        # Add padding to original_spans, which is the only one that is not padded yet.
        # All span sets are shorter than 1024
        original_spans.extend([-1] * (1024 - len(original_spans)))

        # To Tensor
        token_ids = torch.tensor(token_ids, dtype=torch.long)
        offsets = torch.tensor(offsets, dtype=torch.long)
        att_masks = torch.tensor(att_masks, dtype=torch.long)
        special_masks = torch.tensor(special_masks, dtype=torch.long)
        label_ids = torch.tensor(label_ids, dtype=torch.long)
        original_spans = torch.tensor(original_spans, dtype=torch.long)

        return token_ids, att_masks, label_ids, offsets, original_spans, special_masks

    @staticmethod
    def get_sentences_from_data_split(data_path, split):
        sentences = []
        original_spans = []
        fixed_spans = []
        data = pd.read_csv(os.path.join(data_path, split + '.csv'))
        
        for i in range(data.shape[0]):
            if split == 'tsd_test':
                span = fixed_span = []
            else:
                span = literal_eval(data['spans'][i])
                fixed_span = _fix_spans(span, data['text'][i])
            sentences.append(data['text'][i])
            original_spans.append(span)
            fixed_spans.append(fixed_span)

        return sentences, original_spans, fixed_spans

    def preprocess_and_tokenize(self, sentences, spans):
        all_token_ids = []
        all_offsets = []
        all_att_masks = []
        all_special_masks = []
        all_label_ids = []

        for sentence, span in zip(sentences, spans):
            # Pad to 512. All sentences in the dataset have a lower number of tokens.
            tokenized = self.tokenizer(sentence, padding='max_length', max_length=512, return_attention_mask=True,
                                       return_special_tokens_mask=True,
                                       return_offsets_mapping=True, return_token_type_ids=False)

            all_token_ids.append(tokenized['input_ids'])
            all_offsets.append(tokenized['offset_mapping'])
            all_att_masks.append(tokenized['attention_mask'])
            all_special_masks.append(tokenized['special_tokens_mask'])
            all_label_ids.append([self.off2tox(offset, span) for offset in tokenized['offset_mapping']])

        return all_token_ids, all_offsets, all_att_masks, all_special_masks, all_label_ids

    @staticmethod
    def off2tox(offsets, spans):
        # Padded items
        if offsets == (0, 0):
            return 0
        toxicity = offsets[0] in spans
        return int(toxicity)