<a href="https://colab.research.google.com/github/Isioman/Natural-Language-Processing-Project-Toxic-Spans-Detection/blob/main/Copy_of_Toxic_Spans_Detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.12.3-py3-none-any.whl (3.1 MB)
[K     |████████████████████████████████| 3.1 MB 11.2 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.1.0-py3-none-any.whl (59 kB)
[K     |████████████████████████████████| 59 kB 5.0 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 41.6 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 33.9 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 34.8 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Atte

In [None]:
import numpy as np # mathematical functions
import torch as torch 
import pandas as pd
import tensorflow as tf
from ast import literal_eval
import matplotlib.pyplot as plt
import string
import itertools
import nltk

from collections import defaultdict
from tqdm import tqdm, trange
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import  BertTokenizerFast, BertForTokenClassification, Trainer, TrainingArguments

In [None]:

from google.colab import drive
drive.mount('/drive')

Mounted at /drive


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
#Credit: https://github.com/ipavlopoulos/toxic_spans/blob/master/evaluation/fix_spans.py
#This method is provided the task organizers to extract contiguous ranges in the given span 
# E.g. [1, 2, 3, 5, 6, 7] -> [(1,3), (5,7)]
def contiguous_ranges(span_list):
    output = []
    for _, span in itertools.groupby(
        enumerate(span_list), lambda p: p[1] - p[0]):
        span = list(span)
        output.append((span[0][1], span[-1][1]))
    return output

In [None]:
#This method will perform minor edits by trimming the spans and removing singletons
#Credit: https://github.com/ipavlopoulos/toxic_spans/blob/master/evaluation/fix_spans.py
SPECIAL_CHARACTERS = string.whitespace
def fix_spans(spans, text, special_characters=SPECIAL_CHARACTERS):
    cleaned = []
    for begin, end in contiguous_ranges(spans):
        while text[begin] in special_characters and begin < end:
            begin += 1
        while text[end] in special_characters and begin < end:
            end -= 1
        if end - begin > 1:
            cleaned.extend(range(begin, end + 1))
    return cleaned

In [None]:
#This method is used to invoke the fix_spans method to perform minor edits to the spans
def clean_spans(spans, posts):
  clean_span_list = list()
  for index, span in enumerate(spans):
    clean_span_list.append(fix_spans(span, posts[index]))
  return clean_span_list

In [None]:
def number_of_spans(spans):
  empty_spans_count = 0
  single_spans_count = 0
  multi_spans_count = 0

  for index, span in enumerate(spans):
    if len(span) == 0:
      empty_spans_count += 1
    else:
      list_of_spans = contiguous_ranges(span)
      single_spans_count += len(list_of_spans) == 1
      multi_spans_count += len(list_of_spans) > 1

  return empty_spans_count, single_spans_count, multi_spans_count

In [None]:
def load_train_dataset():
  toxic_data = pd.read_csv('/content/train.csv')
  toxic_data["spans"] = toxic_data.spans.apply(literal_eval)
  texts, spans = toxic_data["text"], toxic_data["spans"]

  #Put the text and spans in list
  toxic_text_list = texts.values.tolist()
  toxic_spans_list = spans.values.tolist()

  #Clean the spans to remove singletons and trimming spaces. Code provided by SemEval organizers
  cleaned_spans = clean_spans(toxic_spans_list, toxic_text_list)

  # Get number of spans
  empty_spans_count, single_spans_count, multi_spans_count = number_of_spans(toxic_spans_list)

  print('Total Training Samples:', len(toxic_text_list))
  print('Empty Spans:', empty_spans_count)
  print('Single Spans:', single_spans_count)
  print('Multi Spans:', multi_spans_count)
  print('*************************************************************')

  return toxic_text_list[:1], cleaned_spans[:1]

In [None]:
def load_trial_dataset():
  toxic_data = pd.read_csv('/content/trial.csv')
  toxic_data["spans"] = toxic_data.spans.apply(literal_eval)
  texts, spans = toxic_data["text"], toxic_data["spans"]

  #Put the text and spans in list
  toxic_text_list = texts.values.tolist()
  toxic_spans_list = spans.values.tolist()

  #Clean the spans to remove singletons and trimming spaces. Code provided by SemEval organizers
  cleaned_spans = clean_spans(toxic_spans_list, toxic_text_list)

  # Get number of spans
  empty_spans_count, single_spans_count, multi_spans_count = number_of_spans(toxic_spans_list)

  print('Total Validation Samples:', len(toxic_text_list))
  print('Empty Spans:', empty_spans_count)
  print('Single Spans:', single_spans_count)
  print('Multi Spans:', multi_spans_count)
  print('*************************************************************')

  return toxic_text_list[:1], cleaned_spans[:1]

In [None]:
def load_test_dataset():
  toxic_data = pd.read_csv('/content/test.csv')
  toxic_data["spans"] = toxic_data.spans.apply(literal_eval)
  texts, spans = toxic_data["text"], toxic_data["spans"]

  #Put the text and spans in list
  toxic_text_list = texts.values.tolist()
  toxic_spans_list = spans.values.tolist()

  #Clean the spans to remove singletons and trimming spaces. Code provided by SemEval organizers
  cleaned_spans = clean_spans(toxic_spans_list, toxic_text_list)

  # Get number of spans
  empty_spans_count, single_spans_count, multi_spans_count = number_of_spans(toxic_spans_list)

  print('Total Test Samples:', len(toxic_text_list))
  print('Empty Spans:', empty_spans_count)
  print('Single Spans:', single_spans_count)
  print('Multi Spans:', multi_spans_count)
  print('*************************************************************')

  return toxic_text_list[:1], cleaned_spans[:1]

In [None]:
def max_post_length(toxic_posts):
  max_length = 0
  idx_max_len_post = 0
  for index, post in enumerate(toxic_posts):
    length = len(post)
    if length > max_length:
      max_length = length
      idx_max_len_post = index
  return idx_max_len_post, max_length

In [None]:
train_posts, train_spans = load_train_dataset()
trial_posts, trial_spans = load_trial_dataset()
test_posts, test_spans = load_test_dataset()

print('Train data max post length:', max_post_length(train_posts))
print('Trial data max post length:',max_post_length(trial_posts))
print('Test data max post length:',max_post_length(test_posts))

Total Training Samples: 7939
Empty Spans: 485
Single Spans: 5370
Multi Spans: 2084
*************************************************************
Total Validation Samples: 690
Empty Spans: 43
Single Spans: 448
Multi Spans: 199
*************************************************************
Total Test Samples: 2000
Empty Spans: 394
Single Spans: 1407
Multi Spans: 199
*************************************************************
Train data max post length: (0, 98)
Trial data max post length: (0, 74)
Test data max post length: (0, 156)


In [None]:
bertTokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
bertModel = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=2)
bertModel.to(device)

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForTokenClassification: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-u

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [None]:
  def compute_encoded_spans(post_encoding, span):
    encoded_spans = list()
    encoded_spans = len(post_encoding.offsets) * [0]
    for idx, toxic_offset in enumerate(post_encoding.offsets):
      first_offset = toxic_offset[0]
      second_offset = toxic_offset[1]
      if first_offset == 0 and second_offset == 0:
        encoded_spans[idx] = -100
      else:
       for offset in range(first_offset, second_offset):
         if offset in span:
           encoded_spans[idx] = 1
           break          
    return encoded_spans

In [None]:
#Update spans to match the sub-word tokens
def encode_spans(post_encodings, spans):
  encoded_spans = [compute_encoded_spans(post_encodings[index], span) for index, span in enumerate(spans)]
  return encoded_spans

In [None]:
#Tokenize Training, Validation and Testing data
train_posts_encodings = bertTokenizer(train_posts, return_offsets_mapping=True, padding=True, truncation=True)
trial_posts_encodings = bertTokenizer(trial_posts, return_offsets_mapping=True,padding=True, truncation=True)
test_posts_encodings = bertTokenizer(test_posts, return_offsets_mapping=True, padding=True, truncation=True)

trial_input_ids = trial_posts_encodings['input_ids']
attention_masks = trial_posts_encodings['attention_mask']

actual_tokens = bertTokenizer.decode(trial_input_ids[0])

print(trial_input_ids)
print(attention_masks)
print(actual_tokens)

[[101, 2138, 2002, 1005, 1055, 1037, 22822, 2239, 1998, 1037, 2502, 4140, 1012, 2009, 1005, 1055, 2025, 2151, 2062, 8552, 2084, 2008, 1012, 102]]
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
[CLS] because he's a moron and a bigot. it's not any more complicated than that. [SEP]


In [None]:
#Compute updated span encodings for training, trial and test data
encoded_train_spans = encode_spans(train_posts_encodings, train_spans)
encoded_trial_spans = encode_spans(trial_posts_encodings, trial_spans)
encoded_test_spans = encode_spans(test_posts_encodings, test_spans)

tf.print(len(encoded_train_spans))
tf.print(len(encoded_trial_spans))
tf.print(len(encoded_test_spans))

print(encoded_train_spans)
print(encoded_trial_spans)
print(encoded_test_spans)

print('Validation Span:', trial_spans)

1
1
1
[[-100, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -100]]
[[-100, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -100]]
[[-100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, -100]]
Validation Span: [[15, 16, 17, 18, 19, 27, 28, 29, 30, 31]]


In [None]:
#Credit: https://huggingface.co/transformers/custom_datasets.html#tok-ner
class ToxicSpansDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
#Converting the training, trial and test dataset into Pytorch Dataset object
train_dataset = ToxicSpansDataset(train_posts_encodings, encoded_train_spans)
trial_dataset = ToxicSpansDataset(trial_posts_encodings, encoded_trial_spans)
test_dataset = ToxicSpansDataset(test_posts_encodings, encoded_test_spans)

tf.print(len(train_dataset))
tf.print(len(trial_dataset))
tf.print(len(test_dataset))

1
1
1


In [None]:
# Offset mappings are needed only for encoding the spans. They are not needed for training the model
train_offset_mapping = train_posts_encodings.pop("offset_mapping") 
trial_offset_mapping = trial_posts_encodings.pop("offset_mapping")
test_offset_mapping = test_posts_encodings.pop("offset_mapping")

print(train_offset_mapping)
print(trial_offset_mapping)
print(test_offset_mapping)

print(len(trial_offset_mapping))

[[(0, 0), (0, 7), (8, 15), (16, 19), (20, 30), (31, 40), (41, 48), (49, 50), (51, 59), (60, 63), (64, 75), (76, 78), (79, 86), (86, 87), (87, 88), (88, 89), (89, 90), (91, 98), (0, 0)]]
[[(0, 0), (0, 7), (8, 10), (10, 11), (11, 12), (13, 14), (15, 18), (18, 20), (21, 24), (25, 26), (27, 30), (30, 32), (32, 33), (34, 36), (36, 37), (37, 38), (39, 42), (43, 46), (47, 51), (52, 63), (64, 68), (69, 73), (73, 74), (0, 0)]]
[[(0, 0), (0, 4), (4, 5), (5, 6), (7, 12), (12, 13), (14, 18), (19, 22), (23, 26), (27, 33), (33, 34), (35, 38), (39, 40), (41, 43), (44, 52), (53, 57), (58, 61), (62, 69), (70, 74), (75, 79), (80, 83), (84, 92), (92, 93), (94, 101), (102, 103), (103, 105), (106, 109), (110, 117), (118, 124), (124, 125), (126, 129), (129, 131), (131, 132), (133, 136), (136, 139), (140, 147), (147, 148), (149, 155), (155, 156), (0, 0)]]
1


In [None]:
#Create DataLoader objects for training and testing the model in batches
# train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=512, shuffle=True)

In [None]:
# Configuring Training Arguments for training the BERT Model
training_args = TrainingArguments(
  output_dir='/drive/MyDrive/results',
  num_train_epochs=1,                 # total number of training epochs
  per_device_train_batch_size=16,     # batch size per device during training
  per_device_eval_batch_size=16,      # batch size for evaluation
  warmup_steps=500,                   # number of warmup steps for learning rate scheduler
  weight_decay=0.01,                  # strength of weight decay
  logging_dir='/drive/MyDrive/logs',   # directory for storing logs
  logging_steps=200,                   # log and save weights for each logging_steps
  do_eval=True,                       # whether to run evaluation on the val set
  evaluation_strategy="steps",        # evaluation is done (and logged) every logging_steps 
  learning_rate=5e-5,                 # 5e-5 is default learning rate
  disable_tqdm=False,                  # remove tqdm statements to reduce clutter
)

In [None]:
# Trainer Object
trainer = Trainer(
  model=bertModel,                 # the instantiated Transformers model to be trained
  args=training_args,              # training arguments, defined above
  train_dataset=train_dataset,       
  eval_dataset=trial_dataset   
)

In [None]:
print('Training started......')
trainer.train()
print('Training completed.....')

***** Running training *****
  Num examples = 1
  Num Epochs = 1
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 1


Training started......


Step,Training Loss,Validation Loss




Training completed. Do not forget to share your model on huggingface.co/models =)




Training completed.....


In [None]:
trail_predictions = trainer.predict(trial_dataset)
test_predictions = trainer.predict(test_dataset)

***** Running Prediction *****
  Num examples = 1
  Batch size = 16


***** Running Prediction *****
  Num examples = 1
  Batch size = 16


In [None]:
(trial_posts_encodings[0].tokens)
trail_predictions.predictions.argmax(-1)


array([[1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1]])

In [None]:
def system_precision_recall_f1(toxic_char_preds, gold_char_offsets):
  def per_post_precision_recall_f1(predictions, gold):
    if len(gold) == 0:
        return [1.0, 1.0, 1.0] if len(predictions) == 0 else [0.0, 0.0, 0.0]

    if len(predictions) == 0:
        return [0.0, 0.0, 0.0]
    
    predictions_set = set(predictions)
    gold_set = set(gold)
    nom = len(predictions_set.intersection(gold_set))
    precision = nom / len(predictions_set)
    recall = nom / len(gold_set)
    f1_score = (2 * nom) / (len(predictions_set) + len(gold_set))

    return [float(precision), float(recall), float(f1_score)]
  
  # get the respective metrics per post
  precision_recall_f1_scores = [per_post_precision_recall_f1(toxic_offsets, gold_offsets) for toxic_offsets, gold_offsets in zip(toxic_char_preds, gold_char_offsets)]
  
  # compute average precision, recall and f1 score of all posts
  return np.array(precision_recall_f1_scores).mean(axis=0)

In [None]:
# For tokenizing sentences
nltk.download('punkt')
sentence_tokenizer = nltk.data.load('tokenizers/punkt/PY3/english.pickle')
def toxic_character_offsets_with_thresholding(post_num, tokens, offset_mapping, prediction, val_sentences_info, prediction_score, threshold):
  toxic_offsets = []
  scores = []
  n = len(tokens)
  i = 1           # start from 1 as 0th token is [CLS]
  while i < n:
    # stop looping after processing all post tokens
    if tokens[i] == '[SEP]':
      break

    cur_toxic = []
    # if previous token is also predicted toxic, then toxic phrase found
    if len(toxic_offsets) > 0 and toxic_offsets[-1] == offset_mapping[i-1][1] - 1:
      cur_toxic.extend([index for index in range(offset_mapping[i-1][1], offset_mapping[i][0])])
    
    # add the characters offsets of this head BPE

    # print('***********************************')
    first_offset = offset_mapping[i][0]
    second_offset = offset_mapping[i][1]
    # print('First Offset:', first_offset)
    # print('Second Offset:', second_offset)

    cur_toxic.extend([index for index in range(first_offset, second_offset)])
    cur_score = [(tokens[i], prediction_score[i].max())]
    cur_labels = [prediction[i]]

    # print('Current Toxic:', cur_toxic)
    # print('Current Labels:', cur_labels)
    
    # process all sub-tokens of the current head BPE

    print('***********************************')
    new_first_offset = offset_mapping[i][0]
    new_second_offset = offset_mapping[i][1]


    i += 1
    while i < n and '##' in tokens[i]:
      # print('New First Offset:', new_first_offset)
      # print('New Second Offset:', new_second_offset)
      # print('Token:', tokens[i])
      cur_toxic.extend([index for index in range(new_first_offset, new_second_offset)])
      cur_score.append((tokens[i], prediction_score[i].max()))
      cur_labels.append(prediction[i])
      i += 1   

    # print('New Current Toxic:', cur_toxic)
    # print('New Current Labels:', cur_labels)
    # print('Current Score:', cur_score)

    # word is predicted toxic if any sub-token is predicted toxic by model
    prediction_label = True if max(cur_labels) == 1 else False
    # prediction_label = True if min(cur_labels) == 1 else False

    # print('Prediction Label:', prediction_label)
    
    # include cur_toxic offsets if any of the sub-token confidence score is greater than threshold
    confidence_values = [score for _, score in cur_score]
    passed_threshold = True if max(confidence_values) >= threshold else False
    # passed_threshold = True if min(confidence_values) >= threshold else False

    # print('Confidence Values:', confidence_values)
    # print('Passed Threshold:', passed_threshold)

    # include to global toxic offsets list only if both predicted label and threshold criteria passes
    if prediction_label and passed_threshold:
      toxic_offsets.extend(cur_toxic)
      scores.extend(cur_score)
      print('Toxic offsets:', toxic_offsets)
      print('Scores:', scores)

  return toxic_offsets, scores

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
def character_offsets_with_thresholding(val_text_encodings, val_offset_mapping, predictions, val_sentences_info, prediction_scores, threshold=-float('inf')):
  return [toxic_character_offsets_with_thresholding(i, val_text_encodings[i].tokens, offset_mapping, prediction, val_sentences_info, prediction_scores[i], threshold) for i, (offset_mapping, prediction) in enumerate(zip(val_offset_mapping, predictions))]

In [None]:
def compute_metrics(pred, gold_char_offsets, val_offset_mapping, val_text_encodings, val_sentences_info, threshold=-float('inf')):
  # get the sub-token predictions made by the model
  predictions = pred.predictions.argmax(-1)
  prediction_scores = pred.predictions

  # retrieve the toxic character offsets of these predictions
  toxic_char_preds_object = character_offsets_with_thresholding(val_text_encodings, val_offset_mapping, predictions, val_sentences_info, prediction_scores, threshold)
  toxic_char_offsets = [span[0] for span in toxic_char_preds_object]

  print('Toxic Character offsets:', toxic_char_offsets)

  # compute the precision, recall and f1 score on the validation set
  precision, recall, f1 = system_precision_recall_f1(toxic_char_offsets, gold_char_offsets)

  return {
    'precision': precision,
    'recall': recall,
    'f1': f1
  }

In [None]:
val_sentences_info = []

val_results = compute_metrics(trail_predictions, trial_spans, trial_offset_mapping, trial_posts_encodings, val_sentences_info, threshold=-float('inf'))
print(val_results)

In [None]:
# predictions = trail_predictions.predictions.argmax(-1)

predictions = trail_predictions.predictions.argmax(-1)
trial_prediction_scores = trail_predictions.predictions
# trial_prediction_scores = trail_predictions.predictions
val_toxic_char_preds = character_offsets_with_thresholding(trial_posts_encodings, trial_offset_mapping, predictions, val_sentences_info, trial_prediction_scores, threshold=-float('inf'))

In [None]:
#post processing
def toxic_character_offsets_with_late_fusion(post_num, text, tokens, offset_mapping, prediction, val_sentences_info, prediction_score, threshold):
  # retrieve the sentence classifications for this post
  sentence_classifications = val_sentences_info[text]
  # split the text into sentences
  sentence_spans = sentence_tokenizer.span_tokenize(text)

  toxic_offsets = []
  scores = []
  n = len(tokens)
  i = 1           # start from 1 as 0th token is [CLS]
  while i < n:
    # stop looping after processing all post tokens
    if tokens[i] == '[SEP]':
      break

    cur_toxic = []
    # if previous token is also predicted toxic, then toxic phrase found
    if len(toxic_offsets) > 0 and toxic_offsets[-1] == offset_mapping[i-1][1] - 1:
      cur_toxic.extend([index for index in range(offset_mapping[i-1][1], offset_mapping[i][0])])
    
    # add the characters offsets of this head BPE
    cur_toxic.extend([index for index in range(offset_mapping[i][0], offset_mapping[i][1])])
    cur_score = [(tokens[i], prediction_score[i].max())]
    cur_labels = [prediction[i]]
    #print(cur_labels)
    
    # process all sub-tokens of the current head BPE
    i += 1
    while i < n and '##' in tokens[i]:
      cur_toxic.extend([index for index in range(offset_mapping[i][0], offset_mapping[i][1])])
      cur_score.append((tokens[i], prediction_score[i].max()))
      cur_labels.append(prediction[i])
      i += 1
    
    # word is predicted toxic if any sub-token is predicted toxic by model
    prediction_label = True if max(cur_labels) == 1 else False
    # prediction_label = True if min(cur_labels) == 1 else False

    
    # include cur_toxic offsets if any of the sub-token confidence score is greater than threshold
    confidence_values = [score for _, score in cur_score]
    passed_threshold = True if max(confidence_values) >= threshold else False
    # passed_threshold = True if min(confidence_values) >= threshold else False

    # ensure that at least one token is located in the toxic sentence
    toxic_sentence = False
    for idx in cur_toxic:
      for start_sen, end_sen in sentence_spans:
        if start_sen <= idx <= end_sen and sentence_classifications[f'({start_sen}, {end_sen})']['Pred'] == 1:
          toxic_sentence = True
          break

  
    # include to global toxic offsets list only if both predicted label and threshold criteria passes
    if prediction_label and passed_threshold and toxic_sentence:
      toxic_offsets.extend(cur_toxic)
      scores.extend(cur_score)
  

  return toxic_offsets, scores

In [None]:
def mt_dnn_post_character_offsets(post_num, labels, tokens, offsets):
  toxic_offsets = []
  n = len(tokens)
  i = 1           # start from 1 as 0th token is [CLS]
  while i < n:
    # stop looping after processing all post tokens
    if tokens[i] == '[SEP]':
      break

    cur_toxic = []
    # if previous token is also predicted toxic, then toxic phrase found
    if len(toxic_offsets) > 0 and toxic_offsets[-1] == offsets[i-1][1] - 1:
      cur_toxic.extend([index for index in range(offsets[i-1][1], offsets[i][0])])

    cur_toxic.extend([index for index in range(offsets[i][0], offsets[i][1])])
    cur_labels = [labels[i]]

    # process all sub-tokens of the current head BPE
    i += 1
    while i < n and '##' in tokens[i]:
      cur_toxic.extend([index for index in range(offsets[i][0], offsets[i][1])])
      cur_labels.append(labels[i])
      i += 1

    prediction_label = True if max(cur_labels) == 1 else False

    if prediction_label:
      toxic_offsets.extend(cur_toxic)

  return toxic_offsets


In [None]:
#print(toxic_offsets)

In [None]:
def character_offsets_with_late_fusion(texts, val_text_encodings, val_offset_mapping, predictions, val_sentences_info, prediction_scores, threshold):
  return [toxic_character_offsets_with_late_fusion(i, texts[i], val_text_encodings[i].tokens, offset_mapping, prediction, val_sentences_info, prediction_scores[i], threshold) for i, (offset_mapping, prediction) in enumerate(zip(val_offset_mapping, predictions))]



In [None]:
def mt_dnn_character_offsets(tokens, predictions, offsets):
  return [mt_dnn_post_character_offsets(i, labels, tokens[i], offsets[i]) for i, labels in enumerate(predictions)]