## Required package
- spacy==2.2.4
    - pip install spacy==2.2.4
- download en_core_web_sm
    - python -m spacy download en_core_web_sm

In [29]:
import pandas as pd
import numpy as np
from ast import literal_eval
import os
import string
import itertools

import spacy

- util functions

In [26]:
SPECIAL_CHARACTERS = string.whitespace

def _contiguous_ranges(span_list):
    """
    Group character-level labels into range of intervals: [1, 2, 3, 5, 6, 7] -> [(1,3), (5,7)].
    Returns begin and end inclusive
    """
    output = []
    for _, span in itertools.groupby(enumerate(span_list), lambda p: p[1] - p[0]):
        span = list(span)
        output.append((span[0][1], span[-1][1]))
    return output


def _fix_spans(spans, text, special_characters=SPECIAL_CHARACTERS, collapse=False):
    """
    Applies minor edits to trim spans and remove singletons.
    If spans begin/end in the middle of a word, correct according to collapse strategy:
        If false, expand spans until word limits; if true collapse until word limits
    """
    cleaned_spans = []
    
    for begin, end in _contiguous_ranges(spans):
        # Remove special characters
        while text[begin] in special_characters and begin < end:
            begin += 1
        while text[end] in special_characters and begin < end:
            end -= 1
            
        # Keep full word
        while 0 < begin < end and text[begin - 1].isalnum():
            offset_move = 1 if collapse else -1
            begin += offset_move
        while len(text) - 1 > end > begin and text[end + 1].isalnum():
            offset_move = -1 if collapse else 1
            end += offset_move
            
        # Remove singletons (only one character)
        if end - begin > 1:
            cleaned_spans.extend(range(begin, end + 1))
            
    return cleaned_spans


- dataset class, can be converted to torch.dataset

In [32]:
class ToxicDataset():
    def __init__(self, data_path, split):
        """
        @param data_path: base_path of the data folder
        @param split: name of the csv file without .csv suffix
        """
        self.tokenizer = spacy.load("en_core_web_sm")

        self.original_sentences, self.original_spans, self.fixed_spans = \
            self.get_sentences_from_data_split(data_path, split)

        self.tokens, self.offsets, self.labels_ids = \
            self.preprocess_and_tokenize(self.original_sentences, self.fixed_spans)

    def __len__(self):
        return len(self.original_sentences)

    def __getitem__(self, index):
        tokens = self.tokens[index]
        offsets = self.offsets[index]
        label_ids = self.labels_ids[index]
        original_spans = self.original_spans[index]

        return tokens, label_ids, offsets, original_spans

    @staticmethod
    def get_sentences_from_data_split(data_path, split):
        """
        @param data_path: base_path of the data folder
        @param split: name of the csv file without .csv suffix
        """
        sentences = []
        original_spans = []
        fixed_spans = []
        data = pd.read_csv(os.path.join(data_path, split + '.csv'))
        
        for i in range(data.shape[0]):
            if split == 'tsd_test':
                span = fixed_span = []
            else:
                span = literal_eval(data['spans'][i])
                fixed_span = _fix_spans(span, data['text'][i])
            sentences.append(data['text'][i])
            original_spans.append(span)
            fixed_spans.append(fixed_span)

        return sentences, original_spans, fixed_spans

    def preprocess_and_tokenize(self, sentences, spans):
        """
        @param sentences: a list of posts, List[str]
        @param spans: a list of toxic spans, List[List[int]]
        """
        all_tokens = []
        all_offsets = []
        all_label_ids = []

        for sentence, span in zip(sentences, spans):
            tokenized = self.tokenizer(sentence)
            tokens = [token.text for token in tokenized]
            token_offset = [(token.idx, token.idx + len(token.text)) for token in tokenized]

            all_tokens.append(tokens)
            all_offsets.append(token_offset)
            all_label_ids.append([self.off2tox(offset, span) for offset in token_offset])

        return all_tokens, all_offsets, all_label_ids

    @staticmethod
    def off2tox(offsets, spans):
        """
        @param offsets: a tuple indicates the start and end position of the token in the sentence
        @param spans: toxic span label, List[int]
        """
        # Padded items
        if offsets == (0, 0):
            return 0
        toxicity = offsets[0] in spans
        return int(toxicity)

- use the tokens and label_ids for each post

In [33]:
dataset = ToxicDataset("data", "tsd_trial")

In [35]:
tokens, label_ids, offsets, original_spans = dataset[0]

In [39]:
print(tokens)
# print(offsets)
print(label_ids)
# print(original_spans)

['Because', 'he', "'s", 'a', 'moron', 'and', 'a', 'bigot', '.', 'It', "'s", 'not', 'any', 'more', 'complicated', 'than', 'that', '.']
[0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
