In [1]:
import ast
import csv
import random
import statistics
import sys

import sklearn
import spacy

import semeval2021
import fix_spans

import spans_detection
import numpy as np
import pandas as pd

def spans_to_ents(doc, spans, label):
  """Converts span indicies into spacy entity labels."""
  started = False
  left, right, ents = 0, 0, []
  for x in doc:
    if x.pos_ == 'SPACE':
      continue
    if spans.intersection(set(range(x.idx, x.idx + len(x.text)))):
      if not started:
        left, started = x.idx, True
      right = x.idx + len(x.text)
    elif started:
      ents.append((left, right, label))
      started = False
  if started:
    ents.append((left, right, label))
  return ents


def read_datafile(filename):
  """Reads csv file with python span list and text."""
  data = []
  with open(filename, encoding='utf8') as csvfile:
    reader = csv.DictReader(csvfile)
    count = 0
    for row in reader:
      fixed = fix_spans.fix_spans(
          ast.literal_eval(row['spans']), row['text'])
      data.append((fixed, row['text']))
  return data

def evaluate(dataset, toxic_list, model):
  test_texts = []
  for spans, text in dataset:
    test_texts.append(text)
        
  detected_spans = spans_detection.get_spans_from_dataset(test_texts, toxic_list)
  i = 0
    
  intersection_scores = []
    
  for spans, text in dataset:
    pred_spans = []
    intersection_spans = []
    doc = model(text)
    for ent in doc.ents:
      pred_spans.extend(range(ent.start_char, ent.start_char + len(ent.text)))
      intersection_spans.extend(range(ent.start_char, ent.start_char + len(ent.text)))

    # Get Ensemble intersection score
    intersection = set(intersection_spans).intersection(set(detected_spans[i]))
    intersection = sorted(list(intersection))
    score = semeval2021.f1(intersection, spans)
    intersection_scores.append(score)
    
    i += 1
    
  print('Average F1 Ensemble Intersection %g' % statistics.mean(intersection_scores))

def main():
  """Train and eval a spacy named entity tagger for toxic spans."""
  # Read training data
  print('Load training data')
  train = read_datafile('./data/tsd_train.csv')
    
  # Read trial data
  print('Load trial data')
  trial = read_datafile('./data/tsd_trial.csv')

  # Read trial data for test.
  print('Load test data')
  test = read_datafile('./data/tsd_test_with_ground_truth.csv')

  # Get the toxic word list
  file = open('Toxic words dictionary.txt', 'r', encoding="ISO-8859-1")
  content = file.read()
  file.close()
  toxic_list = content.split('\n')
  unique_toxic_list = np.unique(toxic_list)

  # Convert training data to Spacy Entities
  nlp = spacy.load("en_core_web_sm")

  print('Prepare training data')
  training_data = []
  for n, (spans, text) in enumerate(train):
    doc = nlp(text)
    ents = spans_to_ents(doc, set(spans), 'TOXIC')
    training_data.append((doc.text, {'entities': ents}))

  toxic_tagging = spacy.blank('en')
  toxic_tagging.vocab.strings.add('TOXIC')
  ner = nlp.create_pipe("ner")
  toxic_tagging.add_pipe(ner, last=True)
  ner.add_label('TOXIC')

  pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"]
  unaffected_pipes = [
      pipe for pipe in toxic_tagging.pipe_names
      if pipe not in pipe_exceptions]

  print('Training phase')
  with toxic_tagging.disable_pipes(*unaffected_pipes):
    toxic_tagging.begin_training()
    for iteration in range(25):
      random.shuffle(training_data)
      losses = {}
      batches = spacy.util.minibatch(
          training_data, size=spacy.util.compounding(
              4.0, 32.0, 1.001))
      for batch in batches:
        texts, annotations = zip(*batch)
        toxic_tagging.update(texts, annotations, drop=0.5, losses=losses)
      print("Losses", losses)

  # Evaluation //////////////////////////////   
  print('')
  print('Evaluation of the test data')
  evaluate(test, unique_toxic_list, toxic_tagging)


if __name__ == '__main__':
  main()

Load training data
Load trial data
Load test data
Prepare training data
Training phase
Losses {'ner': 31736.713823175938}

Evaluation test
Average F1 Ensemble Intersection 0.633024
