<a href="https://colab.research.google.com/github/Isioman/Natural-Language-Processing-Project-Toxic-Spans-Detection/blob/main/Toxic_Spans_Detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install profanity-filter



In [None]:
import io
import os
import nltk 
import spacy
import string
import random
import itertools

import numpy as np
import pandas as pd
import torch as torch
import seaborn as sns
import tensorflow as tf
import matplotlib.pyplot as plt

from PIL import Image
from io import FileIO
from ast import literal_eval
from google.colab import auth
from nltk.corpus import stopwords
from collections import defaultdict
from googleapiclient.discovery import build
from profanity_filter import ProfanityFilter
from googleapiclient.http import MediaIoBaseDownload
from nltk.tokenize import sent_tokenize, word_tokenize

from wordcloud import WordCloud, STOPWORDS, ImageColorGenerator
from transformers import  BertTokenizerFast, BertForTokenClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix,precision_recall_fscore_support

%matplotlib inline

In [None]:
# Set seed values for all python libraries
def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)

In [None]:
# Set seed value of 42 for all python libraries
set_seed(42)

In [None]:
# Mount the google drive for reading and writing input output files
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Use GPU if available else use regular CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Credit: Bridget McInnes code of BERT_Examples.ipynb 
#  Get the files from the google drive
auth.authenticate_user()
drive_service = build('drive', 'v3')

# Get train data file
file_id = '16yiT1Wl4CPxt8-3DTYnXzZV1s2AcFxcT'  # Train Data file on the Google Drive
downloaded = io.FileIO("tsd_train.csv", 'w')
request = drive_service.files().get_media(fileId=file_id)
downloader = MediaIoBaseDownload(downloaded, request)
done = False
while done is False:
  status, done = downloader.next_chunk()
  print("Download {}%.".format(int(status.progress() * 100)))

# Get train data file
file_id = '1vyQHb3TJJ9daawHU5GBI49T9PfH12bQb'  # Trial Data file on the Google Drive
downloaded = io.FileIO("tsd_trail.csv", 'w')
request = drive_service.files().get_media(fileId=file_id)
downloader = MediaIoBaseDownload(downloaded, request)
done = False
while done is False:
  status, done = downloader.next_chunk()
  print("Download {}%.".format(int(status.progress() * 100)))

# Get test data file
file_id = '1w_CDQjo9Qd4URP293cpS-et6k4VdLLwW'  # Test Data file on the Google Drive
downloaded = io.FileIO("tsd_test.csv", 'w')
request = drive_service.files().get_media(fileId=file_id)
downloader = MediaIoBaseDownload(downloaded, request)
done = False
while done is False:
  status, done = downloader.next_chunk()
  print("Download {}%.".format(int(status.progress() * 100)))

Download 100%.
Download 100%.
Download 100%.


In [None]:
# Download stopwords and create a ProfanityFilter using Spacy
nltk.download('punkt')
nltk.download('stopwords')

nlp = spacy.load('en_core_web_sm')
profanity_filter = ProfanityFilter(nlps={'en': nlp})
nlp.add_pipe(profanity_filter.spacy_component, last=True)
stopwords = nlp.Defaults.stop_words

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [None]:
#Credit: https://github.com/ipavlopoulos/toxic_spans/blob/master/evaluation/fix_spans.py
#This method is provided the task organizers to extract contiguous ranges in the given span 
# E.g. [1, 2, 3, 5, 6, 7] -> [(1,3), (5,7)]
def contiguous_ranges(span_list):
    output = []
    for _, span in itertools.groupby(
        enumerate(span_list), lambda p: p[1] - p[0]):
        span = list(span)
        output.append((span[0][1], span[-1][1]))
    return output

In [None]:
#This method will perform minor edits by trimming the spans and removing singletons
#Credit: https://github.com/ipavlopoulos/toxic_spans/blob/master/evaluation/fix_spans.py
SPECIAL_CHARACTERS = string.whitespace
def fix_spans(spans, text, special_characters=SPECIAL_CHARACTERS):
  # print(spans)
  # print(text)
  cleaned = []
  for begin, end in contiguous_ranges(spans):
      while text[begin] in special_characters and begin < end:
          begin += 1
      while text[end] in special_characters and begin < end:
          end -= 1
      if end - begin > 1:
          cleaned.extend(range(begin, end + 1))
  return cleaned

In [None]:
#This method is used to invoke the fix_spans method to perform minor edits to the spans
def clean_spans(spans, posts):
  clean_span_list = list()
  for index, span in enumerate(spans):
    clean_span_list.append(fix_spans(span, posts[index]))
  return clean_span_list

In [None]:
# This method is used to compute the number of spans in the train, trial and test data
def number_of_spans(spans):
  empty_spans_count = 0
  single_spans_count = 0
  multi_spans_count = 0

  for index, span in enumerate(spans):
    if len(span) == 0:
      empty_spans_count += 1
    else:
      list_of_spans = contiguous_ranges(span)
      single_spans_count += len(list_of_spans) == 1
      multi_spans_count += len(list_of_spans) > 1

  return empty_spans_count, single_spans_count, multi_spans_count

In [None]:
# Method to load the training data and compute few exploratory metrics
def load_train_dataset():
  toxic_data = pd.read_csv('tsd_train.csv')
  toxic_data["spans"] = toxic_data.spans.apply(literal_eval)
  texts, spans = toxic_data["text"], toxic_data["spans"]

  #Put the text and spans in list
  toxic_text_list = texts.values.tolist()
  toxic_spans_list = spans.values.tolist()

  #Clean the spans to remove singletons and trimming spaces. Code provided by SemEval organizers
  cleaned_spans = clean_spans(toxic_spans_list, toxic_text_list)

  # Get number of spans
  empty_spans_count, single_spans_count, multi_spans_count = number_of_spans(toxic_spans_list)

  print('Total Training Samples:', len(toxic_text_list))
  print('Empty Spans:', empty_spans_count)
  print('Single Spans:', single_spans_count)
  print('Multi Spans:', multi_spans_count)
  print('*************************************************************')

  return toxic_text_list, cleaned_spans

In [None]:
# Method to load the trial data and compute few exploratory metrics
def load_trial_dataset():
  toxic_data = pd.read_csv('tsd_trail.csv')
  toxic_data["spans"] = toxic_data.spans.apply(literal_eval)
  texts, spans = toxic_data["text"], toxic_data["spans"]

  #Put the text and spans in list
  toxic_text_list = texts.values.tolist()
  toxic_spans_list = spans.values.tolist()

  #Clean the spans to remove singletons and trimming spaces. Code provided by SemEval organizers
  cleaned_spans = clean_spans(toxic_spans_list, toxic_text_list)

  # Get number of spans
  empty_spans_count, single_spans_count, multi_spans_count = number_of_spans(toxic_spans_list)

  print('Total Validation Samples:', len(toxic_text_list))
  print('Empty Spans:', empty_spans_count)
  print('Single Spans:', single_spans_count)
  print('Multi Spans:', multi_spans_count)
  print('*************************************************************')

  return toxic_text_list, cleaned_spans

In [None]:
# Method to load the test data and compute few exploratory metrics
def load_test_dataset():
  toxic_data = pd.read_csv('tsd_test.csv')
  toxic_data["spans"] = toxic_data.spans.apply(literal_eval)
  texts, spans = toxic_data["text"], toxic_data["spans"]

  #Put the text and spans in list
  toxic_text_list = texts.values.tolist()
  toxic_spans_list = spans.values.tolist()

  #Clean the spans to remove singletons and trimming spaces. Code provided by SemEval organizers
  cleaned_spans = clean_spans(toxic_spans_list, toxic_text_list)

  # Get number of spans
  empty_spans_count, single_spans_count, multi_spans_count = number_of_spans(toxic_spans_list)

  print('Total Test Samples:', len(toxic_text_list))
  print('Empty Spans:', empty_spans_count)
  print('Single Spans:', single_spans_count)
  print('Multi Spans:', multi_spans_count)
  print('*************************************************************')

  return toxic_text_list, cleaned_spans

In [None]:
# This method will compute the maximum length of a post in train, trial and test data
def max_post_length(toxic_posts):
  max_length = 0
  idx_max_len_post = 0
  for index, post in enumerate(toxic_posts):
    length = len(post)
    if length > max_length:
      max_length = length
      idx_max_len_post = index
  return idx_max_len_post, max_length

In [None]:
# Invoke method to load train, trial and test data
train_posts, train_spans = load_train_dataset()
trail_posts, trail_spans = load_trial_dataset()
test_posts, test_spans = load_test_dataset()

print('Train data max post length:', max_post_length(train_posts))
print('Trial data max post length:',max_post_length(trail_posts))
print('Test data max post length:',max_post_length(test_posts))

Total Training Samples: 7939
Empty Spans: 485
Single Spans: 5370
Multi Spans: 2084
*************************************************************
Total Validation Samples: 690
Empty Spans: 43
Single Spans: 448
Multi Spans: 199
*************************************************************
Total Test Samples: 2000
Empty Spans: 394
Single Spans: 1407
Multi Spans: 199
*************************************************************
Train data max post length: (212, 1000)
Trial data max post length: (324, 998)
Test data max post length: (713, 1000)


In [None]:
# This is used to collect all the toxic words in the train, trial and test data
def get_toxic_words_count(posts):
  toxic_dict = dict()
  for post in posts:
    sentences = sent_tokenize(post)
    for sentence in sentences:
      word_tokens = word_tokenize(sentence)
      for word in word_tokens:
        doc = nlp(word)
        is_toxic_word = doc._.is_profane
        if is_toxic_word:
          if word not in toxic_dict.keys():
            toxic_dict[word] = 1
          else:
            count = toxic_dict.get(word)
            toxic_dict[word] = count + 1   
  return toxic_dict
  print(toxic_dict)   

In [None]:
# Generate a word cloud for the most frequent toxic words
def toxic_word_cloud(toxic_word_dict):
  toxic_keys = toxic_words_dict.keys()

  wordcloud_stopwords = set(STOPWORDS)
  text = ' '.join([str(elem) for elem in toxic_keys])
  wordcloud = WordCloud(width = 1200, height = 800,stopwords = stopwords,).generate(text)
  plt.imshow(wordcloud, interpolation='bilinear')
  plt.axis("off")
  plt.tight_layout(pad = 0)
  plt.show()

In [None]:
# #Get toxic words for training data and generate word cloud
# train_toxic_words_dict = get_toxic_words_count(train_posts)
# toxic_word_cloud(train_toxic_words_dict)

# #Get toxic words for validation data and generate word cloud
# trail_toxic_words_dict = get_toxic_words_count(trail_posts)
# toxic_word_cloud(trail_toxic_words_dict)

# #Get toxic words for test data and generate word cloud
# test_toxic_words_dict = get_toxic_words_count(test_posts)
# toxic_word_cloud(test_toxic_words_dict)

In [None]:
# Define BERT Tokenizer and BERT Model
bertTokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
bertModel = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=2)
bertModel.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [None]:
#Tokenize Training, Validation and Testing data
train_posts_encodings = bertTokenizer(train_posts, return_offsets_mapping=True,  padding=True, truncation=True, return_tensors="pt")
trial_posts_encodings = bertTokenizer(trail_posts, return_offsets_mapping=True, padding=True, truncation=True, return_tensors="pt")
test_posts_encodings = bertTokenizer(test_posts, return_offsets_mapping=True, padding=True, truncation=True, return_tensors="pt")

In [None]:
# This method is used match the spans after BERT tokenizes words into sub-tokens
def encode_spans(encoding, span):
  labels = [0] * len(encoding.tokens)

  toxic_indices = set(span)
  for i, offset in enumerate(encoding.offsets):
    if offset == (0, 0):
      labels[i] = -100    
    else:
      for k in range(offset[0], offset[1]):
        if k in toxic_indices: 
          labels[i] = 1
          break
  
  return labels

In [None]:
# Update train, trial and test spans to match with the sub-tokens
train_spans_updated = [encode_spans(train_posts_encodings[i], span) for i, span in enumerate(train_spans)]
trial_spans_updated = [encode_spans(trial_posts_encodings[i], span) for i, span in enumerate(trail_spans)]
test_spans_updated = [encode_spans(test_posts_encodings[i], span) for i, span in enumerate(test_spans)]

In [None]:
#Define the class for Toxic Spans Dataset
class ToxicDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
# Crete ToxicDataset object for training, trial and test datasets
train_dataset = ToxicDataset(train_posts_encodings, train_spans_updated)
trial_dataset = ToxicDataset(trial_posts_encodings, trial_spans_updated)
test_dataset = ToxicDataset(test_posts_encodings, test_spans_updated)

In [None]:
# Offset Mappings are not passed to the BERT Model. So Remove them from the encodings
train_offset_mapping = train_posts_encodings.pop("offset_mapping") 
trial_offset_mapping = trial_posts_encodings.pop("offset_mapping")
test_offset_mapping = test_posts_encodings.pop("offset_mapping")

In [None]:
# This method will do the post-processing steps as defined in the modle
def post_process_predictions(predicted_offsets, posts_offset_mappings, posts_encodings):
    # Compute softmax output for all the predictions
    predictions = predicted_offsets.predictions.argmax(-1)
    pred_offsets_scores = predicted_offsets.predictions
  
    final_processed_char_offsets = list()

    # Iterate over each offset mapping, and its corresponding prediction
    for index, (offset, prediction) in enumerate(zip(posts_offset_mappings, predictions)):
        post_processed_offset = list()
        # Get tokens for each post
        offset_tokens = posts_encodings[index].tokens
          
        current_token_index = 1
        # Iterate over all the token to find the post processed character offsets
        while current_token_index < len(offset_tokens):
            # print('***********************************************************')
            # print('Current Token: {} being looked at index: {} :'.format(offset_tokens[current_token_index], current_token_index)) 
           
            # If the token being processed is '[SEP]' then break the loop
            if offset_tokens[current_token_index] == '[SEP]':
                # print('SEP token found')
                break
            char_offsets = list()
    
            # Get current and next tokens
            current_token = offset_tokens[current_token_index]
      
            # This if condition will handle the last token in the tokens list
            if current_token_index + 1 == len(offset_tokens):
                # print('Last token being monitored')
                current_token_prediction = prediction[current_token_index]
                previous_token_prediction = prediction[current_token_index - 1]
                if current_token_prediction == 1:  
                    if previous_token_prediction == 1:
                        prev_token_offset_values = get_last_offset_value(offset, current_token_index - 1)
                        char_offsets.extend(prev_token_offset_values)
    
                current_token_offset_values = get_offset_values_from_range(offset, current_token_index)
                break
            else:   # This else will handle all the tokens except the last one
                # print('Start current token index:', current_token_index)
                # print('Current Token: {} being looked at index: {} :'.format(current_token, current_token_index))     
                next_token = offset_tokens[current_token_index + 1]        
    
                # Get current and next token predictions
                previous_token_prediction = prediction[current_token_index - 1]
                current_token_prediction = prediction[current_token_index]
                next_token_prediction = prediction[current_token_index + 1]
    
                # check if next token starts with '##'
                is_current_token_starts_with_hash = current_token.startswith('##')
                is_next_token_starts_with_hash = next_token.startswith('##')
        
                # print('start token with hash:', is_current_token_starts_with_hash)
                # print('next token with hash:', is_next_token_starts_with_hash)
    
                if (not is_current_token_starts_with_hash) and (not is_next_token_starts_with_hash):
                    if current_token_prediction == 1 and next_token_prediction == 1:
                        # If Previous token prediction = 1 then get index value of the last element in the range
                        if previous_token_prediction == 1:
                            prev_token_offset_values = get_last_offset_value(offset, current_token_index - 1)
                            char_offsets.extend(prev_token_offset_values)
                            # print('Previous Token offset values:', prev_token_offset_values)
                          
                        current_token_offset_values = get_offset_values_from_range(offset, current_token_index)
                        next_token_offset_values = get_offset_values_from_range(offset, current_token_index + 1)
                        end_token_offset_values = get_last_offset_value(offset, current_token_index)
                        # print('End Token offset value:', end_token_offset_values)
    
                        char_offsets.extend(current_token_offset_values) 
                        char_offsets.extend(end_token_offset_values)
                        char_offsets.extend(next_token_offset_values)
                        # print('Consecute tokens offset values{[]},{[]},{[]}:', current_token_offset_values, end_token_offset_values, next_token_offset_values)
    
                        current_token_index = current_token_index + 2
    
                    elif current_token_prediction == 1 and next_token_prediction != 1:
                        if previous_token_prediction == 1:
                            prev_token_offset_values = get_last_offset_value(offset, current_token_index - 1)
                            char_offsets.extend(prev_token_offset_values)
    
                        current_token_offset_values = get_offset_values_from_range(offset, current_token_index)
                        char_offsets.extend(current_token_offset_values)
                        current_token_index = current_token_index + 1
                    else:
                        current_token_index = current_token_index + 1
    
                elif is_current_token_starts_with_hash or is_next_token_starts_with_hash:  
                    hash_char_offsets = list()  
                    hash_char_predictions = list()   
                    # print('Token with hash found:', next_token) 
    
                    if previous_token_prediction == 1:
                        prev_token_offset_values = get_last_offset_value(offset, current_token_index - 1)
                        hash_char_offsets.extend(prev_token_offset_values)
    
                    current_token_offset_value = get_offset_values_from_range(offset, current_token_index)
                    current_token_prediction = prediction[current_token_index]
                    hash_char_offsets.extend(current_token_offset_value)
                    hash_char_predictions.append(current_token_prediction)
    
                    # print('Hash Char offsets:',hash_char_offsets)
    
                    current_token_index = current_token_index + 1
                    while current_token_index < len(offset_tokens):
                        current_token = offset_tokens[current_token_index]
                        # print('In while loop token:', current_token)
                        if(current_token.startswith('##')):
                            current_token_offset_value = get_offset_values_from_range(offset, current_token_index)
                            current_token_prediction = prediction[current_token_index]
                            hash_char_offsets.extend(current_token_offset_value)
                            hash_char_predictions.append(current_token_prediction)
                            # print('Hash Char offsets:',hash_char_offsets)
                            current_token_index = current_token_index + 1
                        else:
                            break
                        # print('Predictions with ##:', hash_char_predictions)
                        if 1 in hash_char_predictions:
                            char_offsets.extend(hash_char_offsets)                  
    
            post_processed_offset.extend(char_offsets)  
        final_processed_char_offsets.append(post_processed_offset)  
        # print('Post Processed Offsets:', final_processed_char_offsets)
    return final_processed_char_offsets

In [None]:
# This method will compute the offset value in a given range
def get_offset_values_from_range(offset_mapping, index):
  offset_values = []
  first_offset = offset_mapping[index][0]
  second_offset = offset_mapping[index][1]

  # print('Offset being looked at:{},{}'.format(first_offset, second_offset))
  for i in range(first_offset, second_offset):
    offset_values.append(i)
  return offset_values

In [None]:
# This method will return the last offset value for a given offset mapping
def get_last_offset_value(offset_mapping, index):
  offset_values = []
  first_offset = offset_mapping[index][1]
  second_offset = offset_mapping[index + 1][0]

  # print('Offset being looked at:{},{}'.format(first_offset, second_offset))
  for i in range(first_offset, second_offset):
    offset_values.append(i)
  return offset_values

In [None]:
# This method will compute per post precision
def compute_per_post_precision(pred_spans, true_spans):
  pred_spans_length = len(pred_spans)
  true_spans_length = len(true_spans)
  
  # Return Precision = 1 if both length of predicted and true spans are of length 0 
  # else return Precision = 0
  if true_spans_length == 0:
    if pred_spans_length == 0:
      return 1.0
    else:
     return 0.0
  
  # Return zero precision when predicted spans length is zero
  if pred_spans_length == 0:
    return 0.0

  spans_intersection = set(pred_spans).intersection(set(true_spans))
  precision = len(spans_intersection) / pred_spans_length

  return float(precision)

In [None]:
# This method will compute per post recall value
def compute_per_post_recall(pred_spans, true_spans):
  pred_spans_length = len(pred_spans)
  true_spans_length = len(true_spans)
  
  # Return Recall = 1 if both length of predicted and true spans are of length 0 
  # else return Recall = 0
  if true_spans_length == 0:
    if pred_spans_length == 0:
      return 1.0
    else:
     return 0.0
  
  # Return zero precision when predicted spans length is zero
  if pred_spans_length == 0:
    return 0.0

  spans_intersection = set(pred_spans).intersection(set(true_spans))
  recall = len(spans_intersection) / true_spans_length

  return float(recall)

In [None]:
# This method will compute per post f1-score
def compute_per_post_f1_score(pred_spans, true_spans):
  precision = compute_per_post_precision(pred_spans, true_spans)
  recall = compute_per_post_recall(pred_spans, true_spans)

  pred_spans_length = len(pred_spans)
  true_spans_length = len(true_spans)

  if(true_spans_length == 0):
    return 1.0 if pred_spans_length == 0 else 0.0
  
  if pred_spans_length == 0:
    return 0.0

  pred_spans_set = set(pred_spans)
  true_spans_set = set(true_spans)

  spans_intersection = pred_spans_set.intersection(true_spans_set)
  spans_intersection_length = len(spans_intersection)

  f1_score = (2 * spans_intersection_length) / (pred_spans_length + true_spans_length)

  return float(f1_score)

In [None]:
# This method will compute system level precision, recall and f1-score values
def compute_all_evaluation_metrics(pred_spans, true_spans):
  total_precision = [compute_per_post_precision(pred_span, true_span) for pred_span, true_span in zip(pred_spans, true_spans)]
  total_recall = [compute_per_post_recall(pred_span, true_span) for pred_span, true_span in zip(pred_spans, true_spans)]
  total_f1_score = [compute_per_post_f1_score(pred_span, true_span) for pred_span, true_span in zip(pred_spans, true_spans)]

  mean_precision = np.mean(total_precision)
  mean_recall = np.mean(total_recall)
  mean_f1_score = np.mean(total_f1_score)
  
  return mean_precision, mean_recall, mean_f1_score  

In [None]:
# This method will compute the evaluation metrics for the predicted and true spans
def compute_metrics(pred_offsets, true_spans, offset_mappings, posts_encodings):  
  pred_spans = post_process_predictions(pred_offsets, offset_mappings, posts_encodings)
  precision, recall, f1_score = compute_all_evaluation_metrics(pred_spans, true_spans)

  return precision, recall, f1_score

In [None]:
plt.figure(dpi=600)
plt.rc('axes', labelsize=16)
plt.rc('font', size=13)   
#This method will plot the confusion matrix
def plot_confusion_matrices(encodings, predictions, labels):

  y_true, y_pred = [], []
  
  for i, (pred, gold) in enumerate(zip(predictions, labels)):
    sep_token = 1
    tokens = encodings[i].tokens
    #print('Tokens:', tokens)
    #print('Predictions:', pred)
    #print('True Spans:', gold)
    while tokens[sep_token] != '[SEP]':
      sep_token += 1
    y_true.extend(gold[1: sep_token])
    y_pred.extend(pred[1: sep_token])

  true_length = len(y_true);
  y_pred = y_pred[:true_length]
  # Normal confusion matrix
  cf_matrix = confusion_matrix(y_true, y_pred)
  labels = ['Neutral', 'Toxic']
  
  ax = plt.axes()
  sns_plot = sns.heatmap(cf_matrix / np.sum(cf_matrix), 
                        annot=True, fmt='.2%',
                        xticklabels=labels, yticklabels=labels, 
                        ax = ax,
                        cmap="YlGnBu")

  ax.set_xlabel('Predicted')
  ax.set_ylabel('Actual')

<Figure size 3600x2400 with 0 Axes>

In [None]:
# Configure Training Arguments for training the BERT Model
training_args = TrainingArguments(
  output_dir = '/drive/MyDrive/',
  num_train_epochs = 2,                 # total number of training epochs
  per_device_train_batch_size = 16,     # batch size per device during training
  per_device_eval_batch_size = 16,      # batch size for evaluation
  warmup_steps = 500,                   # number of warmup steps for learning rate scheduler
  weight_decay = 0.01,                  # strength of weight decay
  do_eval = True,                       # whether to run evaluation on the val set
  evaluation_strategy = "steps",        # evaluation is done (and logged) every logging_steps 
  learning_rate = 2e-5,                 # 5e-5 is default learning rate
  disable_tqdm = False,                 # remove tqdm statements to reduce clutter
)

In [None]:
# Configure the Training object for training the BERT Model
trainer = Trainer(
  model=bertModel,                  # configure the model that needs to be trained
  args=training_args,               # configure the training arguments
  train_dataset=train_dataset,      # Initialize with the training dataset object 
  eval_dataset=trial_dataset,       # Initialize with the trial dataset object
)

In [None]:
 #import torch
#torch.cuda.empty_cache()
#torch.cuda.memory_summary(device=None, abbreviated=False)

In [None]:
# Train the BERT Model
trainer.train()

In [None]:
# Evaluate the BERT Model on trial and test data
trial_predictions = trainer.predict(trial_dataset)
test_predictions = trainer.predict(test_dataset)

In [None]:
# Post-process the predictions
trial_processed_preds = post_process_predictions(trial_predictions, trial_offset_mapping, trial_posts_encodings)
test_processed_preds = post_process_predictions(test_predictions, test_offset_mapping, test_posts_encodings)

In [None]:
# Compute final metrics
trial_metrics = compute_metrics(trial_predictions, trail_spans, trial_offset_mapping, trial_posts_encodings)
test_metrics = compute_metrics(test_predictions, test_spans, test_offset_mapping, test_posts_encodings)

print("******Validation Evaluation Metrics**********")
print('Precision:{:.3f}'.format(trial_metrics[0]))
print('Recall:{:.3f}'.format(trial_metrics[1]))
print('F1-Score:{:.3f}'.format(trial_metrics[2]))

print("*********Test Evaluation Metrics*************")
print('Precision:{:.3f}'.format(test_metrics[0]))
print('Recall:{:.3f}'.format(test_metrics[1]))
print('F1-Score:{:.3f}'.format(test_metrics[2]))


In [None]:
#plot_confusion_matrices(test_posts_encodings, test_processed_preds, test_spans)

In [None]:
# Save Evaluation Metrics to File for trial data
with open('/drive/MyDrive/evaluation_metrics_trial_data.txt', "w") as writer:
  precision = trial_metrics[0]
  recall = trial_metrics[1]
  f1_score = trial_metrics[2]

  writer.write(f"Precision:{str(precision)} \n")
  writer.write(f"Recall: {str(recall)} \n")
  writer.write(f"F1_score: {str(f1_score)} \n")
writer.close()


In [None]:
# Save Evaluation Metrics to File for test data
with open('/drive/MyDrive/evaluation_metrics_test_data.txt', "w") as writer:
  precision = test_metrics[0]
  recall = test_metrics[1]
  f1_score = test_metrics[2]

  writer.write(f"Precision:{str(precision)} \n")
  writer.write(f"Recall: {str(recall)} \n")
  writer.write(f"F1_score: {str(f1_score)} \n")
writer.close()

In [None]:
# Save Predicted Spans for Trial data in a text file
with open('/drive/MyDrive/trial_predicted_spans.txt', "w") as writer:
  for pred in trial_processed_preds:
    pred_str = ' '.join([str(elem) for elem in pred])
    pred_str = "[" + pred_str + "]"
    writer.write(f"{str(pred_str)} \n")

In [None]:
# Save Predicted Spans for Test data in a text file
with open('/drive/MyDrive/test_predicted_spans.txt', "w") as writer:
  for pred in test_processed_preds:
    pred_str = ' '.join([str(elem) for elem in pred])
    pred_str = "[" + pred_str + "]"
    writer.write(f"{str(pred_str)} \n")

In [None]:
# Save Text of predicted spans for trial data
trial_pred_output_file = '/drive/MyDrive/trial_predicted_text.txt'
with open(trial_pred_output_file, "w") as fileWriter:
  for toxic_post, predicted_span in zip(trail_posts, trial_processed_preds):
    pred_span_ranges = contiguous_ranges(predicted_span)
    toxic_span_text = list()
    for span_range in pred_span_ranges:
      toxic_span_text.append(toxic_post[span_range[0]:span_range[1] + 1])
    fileWriter.write(f"Original Toxic Post:{toxic_post} \n")
    fileWriter.write(f"Predicted Toxic Span Offset:{predicted_span} \n")
    fileWriter.write(f"Predicted Toxic Text:{str(toxic_span_text)} \n")
    fileWriter.write(f"*************************************************\n")
  fileWriter.close()

In [None]:
# Save Text of predicted spans for test data
test_pred_output_file = '/drive/MyDrive/test_predicted_text.txt'
with open(test_pred_output_file, "w") as fileWriter:
  for toxic_post, predicted_span in zip(test_posts, test_processed_preds):
    pred_span_ranges = contiguous_ranges(predicted_span)
    toxic_span_text = list()
    for span_range in pred_span_ranges:
      toxic_span_text.append(toxic_post[span_range[0]:span_range[1] + 1])
    fileWriter.write(f"Original Toxic Post:{toxic_post} \n")
    fileWriter.write(f"Predicted Toxic Span Offset:{predicted_span} \n")
    fileWriter.write(f"Predicted Toxic Text:{str(toxic_span_text)} \n")
    fileWriter.write(f"*************************************************\n")
  fileWriter.close()


In [None]:
torch.cuda.empty_cache()