# Bert Multilabel Classification for Toxic Comments

## Import

In [1]:
import os
import sys
import numpy as np
import pandas as pd
import sklearn
import string
import re

# Feature Extraction
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import metrics
# Cleaning
import nltk
from nltk.corpus import stopwords
import contractions
# Part of the Speech Tagging and Lemmatization
from nltk.corpus import wordnet
from nltk.corpus import brown
from nltk.stem import WordNetLemmatizer
import torch
from wordcloud import WordCloud
import seaborn as sns
from PIL import Image

# Models
import transformers
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader 
from transformers import BertModel
# Visualization
import matplotlib.pyplot as plt

stop = set(stopwords.words('english'))
nltk.download("stopwords")
nltk.download("wordnet")
nltk.download('brown')
plt.style.use('ggplot')

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Caverna\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\Caverna\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package brown to
[nltk_data]     C:\Users\Caverna\AppData\Roaming\nltk_data...
[nltk_data]   Package brown is already up-to-date!


## Import data e preprocessing

In [2]:
train_data = pd.read_csv('data/train.csv')
test_data = pd.read_csv('data/test.csv')

def cleaning(text):

    # Lower case
    text = text.lower()

    # Remove Contractions
    text = contractions.fix(text)

    # Remove Slangs
    text = slang_clean(text)

    # Remove special characters
    text = re.sub('\[.*?\]', '', text)

    # Remove links
    text = re.sub('https?://\S+|www\.\S+', '', text)

    # Remove html tags
    text = re.sub ('<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});', '', text)

    # Remove non ASCII characters
    text = text.encode("ascii", "ignore")
    text = text.decode()

    # Remove punctuation
    text = re.sub('<.*?>+', '', text)
    text = re.sub('[%s]' % re.escape(string.punctuation), '', text)

    # Remove words with numbers in them
    text = re.sub('\w*\d\w*', '', text)

    # Spellings correction
    #text = TextBlob(text).correct()

    # Remove stop words
    text = ' '.join(word for word in text.split(' ') if word not in stop)
    
    # Remove \n and \r
    text = re.sub(r'(\n)+', ' ', text)
    text = re.sub(r'(\r)+', '', text)

    # Remove starting and ending spaces
    text = re.sub(r'^\s+|\s+$', '', text)

    # Remove extra spaces
    text = re.sub(r'\s+', ' ', text)

    return text

def slang_clean(text):
        """
            Other manual text cleaning techniques
        """
        # Typos, slang and other
        sample_typos_slang = {
                                "w/e": "whatever",
                                "usagov": "usa government",
                                "recentlu": "recently",
                                "ph0tos": "photos",
                                "amirite": "am i right",
                                "exp0sed": "exposed",
                                "<3": "love",
                                "luv": "love",
                                "amageddon": "armageddon",
                                "trfc": "traffic",
                                "16yr": "16 year"
                                }

        # Acronyms
        sample_acronyms =  { 
                            "mh370": "malaysia airlines flight 370",
                            "okwx": "oklahoma city weather",
                            "arwx": "arkansas weather",    
                            "gawx": "georgia weather",  
                            "scwx": "south carolina weather",  
                            "cawx": "california weather",
                            "tnwx": "tennessee weather",
                            "azwx": "arizona weather",  
                            "alwx": "alabama weather",
                            "usnwsgov": "united states national weather service",
                            "2mw": "tomorrow"
                            }

        
        # Some common abbreviations 
        sample_abbr = {
                        "$" : " dollar ",
                        "€" : " euro ",
                        "4ao" : "for adults only",
                        "a.m" : "before midday",
                        "a3" : "anytime anywhere anyplace",
                        "aamof" : "as a matter of fact",
                        "acct" : "account",
                        "adih" : "another day in hell",
                        "afaic" : "as far as i am concerned",
                        "afaict" : "as far as i can tell",
                        "afaik" : "as far as i know",
                        "afair" : "as far as i remember",
                        "afk" : "away from keyboard",
                        "app" : "application",
                        "approx" : "approximately",
                        "apps" : "applications",
                        "asap" : "as soon as possible",
                        "asl" : "age, sex, location",
                        "atk" : "at the keyboard",
                        "ave." : "avenue",
                        "aymm" : "are you my mother",
                        "ayor" : "at your own risk", 
                        "b&b" : "bed and breakfast",
                        "b+b" : "bed and breakfast",
                        "b.c" : "before christ",
                        "b2b" : "business to business",
                        "b2c" : "business to customer",
                        "b4" : "before",
                        "b4n" : "bye for now",
                        "b@u" : "back at you",
                        "bae" : "before anyone else",
                        "bak" : "back at keyboard",
                        "bbbg" : "bye bye be good",
                        "bbc" : "british broadcasting corporation",
                        "bbias" : "be back in a second",
                        "bbl" : "be back later",
                        "bbs" : "be back soon",
                        "be4" : "before",
                        "bfn" : "bye for now",
                        "blvd" : "boulevard",
                        "bout" : "about",
                        "brb" : "be right back",
                        "bros" : "brothers",
                        "brt" : "be right there",
                        "bsaaw" : "big smile and a wink",
                        "btw" : "by the way",
                        "bwl" : "bursting with laughter",
                        "c/o" : "care of",
                        "cet" : "central european time",
                        "cf" : "compare",
                        "cia" : "central intelligence agency",
                        "csl" : "can not stop laughing",
                        "cu" : "see you",
                        "cul8r" : "see you later",
                        "cv" : "curriculum vitae",
                        "cwot" : "complete waste of time",
                        "cya" : "see you",
                        "cyt" : "see you tomorrow",
                        "dae" : "does anyone else",
                        "dbmib" : "do not bother me i am busy",
                        "diy" : "do it yourself",
                        "dm" : "direct message",
                        "dwh" : "during work hours",
                        "e123" : "easy as one two three",
                        "eet" : "eastern european time",
                        "eg" : "example",
                        "embm" : "early morning business meeting",
                        "encl" : "enclosed",
                        "encl." : "enclosed",
                        "etc" : "and so on",
                        "faq" : "frequently asked questions",
                        "fawc" : "for anyone who cares",
                        "fb" : "facebook",
                        "fc" : "fingers crossed",
                        "fig" : "figure",
                        "fimh" : "forever in my heart", 
                        "ft." : "feet",
                        "ft" : "featuring",
                        "ftl" : "for the loss",
                        "ftw" : "for the win",
                        "fwiw" : "for what it is worth",
                        "fyi" : "for your information",
                        "g9" : "genius",
                        "gahoy" : "get a hold of yourself",
                        "gal" : "get a life",
                        "gcse" : "general certificate of secondary education",
                        "gfn" : "gone for now",
                        "gg" : "good game",
                        "gl" : "good luck",
                        "glhf" : "good luck have fun",
                        "gmt" : "greenwich mean time",
                        "gmta" : "great minds think alike",
                        "gn" : "good night",
                        "g.o.a.t" : "greatest of all time",
                        "goat" : "greatest of all time",
                        "goi" : "get over it",
                        "gps" : "global positioning system",
                        "gr8" : "great",
                        "gratz" : "congratulations",
                        "gyal" : "girl",
                        "h&c" : "hot and cold",
                        "hp" : "horsepower",
                        "hr" : "hour",
                        "hrh" : "his royal highness",
                        "ht" : "height",
                        "ibrb" : "i will be right back",
                        "ic" : "i see",
                        "icq" : "i seek you",
                        "icymi" : "in case you missed it",
                        "idc" : "i do not care",
                        "idgadf" : "i do not give a damn fuck",
                        "idgaf" : "i do not give a fuck",
                        "idk" : "i do not know",
                        "ie" : "that is",
                        "i.e" : "that is",
                        "ifyp" : "i feel your pain",
                        "IG" : "instagram",
                        "iirc" : "if i remember correctly",
                        "ilu" : "i love you",
                        "ily" : "i love you",
                        "imho" : "in my humble opinion",
                        "imo" : "in my opinion",
                        "imu" : "i miss you",
                        "iow" : "in other words",
                        "irl" : "in real life",
                        "j4f" : "just for fun",
                        "jic" : "just in case",
                        "jk" : "just kidding",
                        "jsyk" : "just so you know",
                        "l8r" : "later",
                        "lb" : "pound",
                        "lbs" : "pounds",
                        "ldr" : "long distance relationship",
                        "lmao" : "laugh my ass off",
                        "lmfao" : "laugh my fucking ass off",
                        "lol" : "laughing out loud",
                        "ltd" : "limited",
                        "ltns" : "long time no see",
                        "m8" : "mate",
                        "mf" : "motherfucker",
                        "mfs" : "motherfuckers",
                        "mfw" : "my face when",
                        "mofo" : "motherfucker",
                        "mph" : "miles per hour",
                        "mr" : "mister",
                        "mrw" : "my reaction when",
                        "ms" : "miss",
                        "mte" : "my thoughts exactly",
                        "nagi" : "not a good idea",
                        "nbc" : "national broadcasting company",
                        "nbd" : "not big deal",
                        "nfs" : "not for sale",
                        "ngl" : "not going to lie",
                        "nhs" : "national health service",
                        "nrn" : "no reply necessary",
                        "nsfl" : "not safe for life",
                        "nsfw" : "not safe for work",
                        "nth" : "nice to have",
                        "nvr" : "never",
                        "nyc" : "new york city",
                        "oc" : "original content",
                        "og" : "original",
                        "ohp" : "overhead projector",
                        "oic" : "oh i see",
                        "omdb" : "over my dead body",
                        "omg" : "oh my god",
                        "omw" : "on my way",
                        "p.a" : "per annum",
                        "p.m" : "after midday",
                        "pm" : "prime minister",
                        "poc" : "people of color",
                        "pov" : "point of view",
                        "pp" : "pages",
                        "ppl" : "people",
                        "prw" : "parents are watching",
                        "ps" : "postscript",
                        "pt" : "point",
                        "ptb" : "please text back",
                        "pto" : "please turn over",
                        "qpsa" : "what happens", #"que pasa",
                        "ratchet" : "rude",
                        "rbtl" : "read between the lines",
                        "rlrt" : "real life retweet", 
                        "rofl" : "rolling on the floor laughing",
                        "roflol" : "rolling on the floor laughing out loud",
                        "rotflmao" : "rolling on the floor laughing my ass off",
                        "rt" : "retweet",
                        "ruok" : "are you ok",
                        "sfw" : "safe for work",
                        "sk8" : "skate",
                        "smh" : "shake my head",
                        "sq" : "square",
                        "srsly" : "seriously", 
                        "ssdd" : "same stuff different day",
                        "stfu" : "shut the fuck up",
                        "tbh" : "to be honest",
                        "tbs" : "tablespooful",
                        "tbsp" : "tablespooful",
                        "tfw" : "that feeling when",
                        "thks" : "thank you",
                        "tho" : "though",
                        "thx" : "thank you",
                        "tia" : "thanks in advance",
                        "til" : "today i learned",
                        "tl;dr" : "too long i did not read",
                        "tldr" : "too long i did not read",
                        "tmb" : "tweet me back",
                        "tntl" : "trying not to laugh",
                        "ttyl" : "talk to you later",
                        "u" : "you",
                        "u2" : "you too",
                        "u4e" : "yours for ever",
                        "utc" : "coordinated universal time",
                        "w/" : "with",
                        "w/o" : "without",
                        "w8" : "wait",
                        "wassup" : "what is up",
                        "wb" : "welcome back",
                        "wtf" : "what the fuck",
                        "wtg" : "way to go",
                        "wtpa" : "where the party at",
                        "wuf" : "where are you from",
                        "wuzup" : "what is up",
                        "wywh" : "wish you were here",
                        "yd" : "yard",
                        "ygtr" : "you got that right",
                        "ynk" : "you never know",
                        "zzz" : "sleeping bored and tired"
                        }
            
        sample_typos_slang_pattern = re.compile(r'(?<!\w)(' + '|'.join(re.escape(key) for key in sample_typos_slang.keys()) + r')(?!\w)')
        sample_acronyms_pattern = re.compile(r'(?<!\w)(' + '|'.join(re.escape(key) for key in sample_acronyms.keys()) + r')(?!\w)')
        sample_abbr_pattern = re.compile(r'(?<!\w)(' + '|'.join(re.escape(key) for key in sample_abbr.keys()) + r')(?!\w)')
        
        text = sample_typos_slang_pattern.sub(lambda x: sample_typos_slang[x.group()], text)
        text = sample_acronyms_pattern.sub(lambda x: sample_acronyms[x.group()], text)
        text = sample_abbr_pattern.sub(lambda x: sample_abbr[x.group()], text)
        
        return text


train_data['text_clean'] = train_data['comment_text'].apply(lambda x: cleaning(x))
test_data['text_clean'] = test_data['comment_text'].apply(lambda x: cleaning(x))


# Part of Speech Tagging
wordnet_map = {"N":wordnet.NOUN, 
               "V":wordnet.VERB, 
               "J":wordnet.ADJ, 
               "R":wordnet.ADV
              }
    
train_sents = brown.tagged_sents(categories='news')
t0 = nltk.DefaultTagger('NN')
t1 = nltk.UnigramTagger(train_sents, backoff=t0)
t2 = nltk.BigramTagger(train_sents, backoff=t1)

def pos_tag_wordnet(text, pos_tag_type="pos_tag"):
    """
        Create pos_tag with wordnet format
    """
    
    pos_tagged_text = t2.tag(text)
    
    # map the pos tagging output with wordnet output 
    pos_tagged_text = [(word, wordnet_map.get(pos_tag[0])) if pos_tag[0] in wordnet_map.keys() else (word, wordnet.NOUN) for (word, pos_tag) in pos_tagged_text ]
   
    return pos_tagged_text

# Lemmatization
def lemmatize_word(text):
    
    lemmatizer = WordNetLemmatizer()
    lemma = [lemmatizer.lemmatize(word, tag) for word, tag in text]
    return lemma

# Apply Pos Tagging
train_data['separated'] = train_data['text_clean'].apply(lambda x: [x for x in x.split()])
train_data['text_pos'] = train_data['separated'].apply(lambda x: pos_tag_wordnet(x)) 
test_data['separated'] = test_data['text_clean'].apply(lambda x: [x for x in x.split()])
test_data['text_pos'] = test_data['separated'].apply(lambda x: pos_tag_wordnet(x))

# Apply Lemmatization
train_data['text_lem_wpos'] = train_data['text_pos'].apply(lambda x: lemmatize_word(x))
train_data['text_lem'] = [' '.join(map(str,l)) for l in train_data['text_lem_wpos']]
test_data['text_lem_wpos'] = test_data['text_pos'].apply(lambda x: lemmatize_word(x))
test_data['text_lem'] = [' '.join(map(str,l)) for l in test_data['text_lem_wpos']]


header = ['text_lem', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

test_labels = pd.read_csv('data/test_labels.csv')
test_wlabels = test_data.merge(test_labels)
labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
for s in labels:
    test_wlabels = test_wlabels[test_wlabels[s] != -1]
train_data.to_csv('data/train_clean.csv', columns = header, index = False)
test_wlabels.to_csv('data/test_clean.csv', columns = header, index = False)



## Data analysis

In [None]:
mask = np.array(Image.open('./images/wikipedia_mask.jpg'))

def generate_wordcloud(df, clm):
    text = []
    comments = train_data.loc[df[clm] == 1]['text_clean']

    for c in comments:
        text.append(c) 
    words = ' '.join(text)
    return WordCloud(stopwords=stop, background_color='white', mask=mask, height=1500, width=1500).generate(words)

train_toxic = generate_wordcloud(train_data, 'toxic')
train_sev_toxic = generate_wordcloud(train_data, 'severe_toxic')
train_obscene = generate_wordcloud(train_data, 'obscene')
train_threat = generate_wordcloud(train_data, 'threat')
train_insult = generate_wordcloud(train_data, 'insult')
train_id_hate = generate_wordcloud(train_data, 'identity_hate')
train_general = WordCloud(stopwords=stop, background_color='white', height=1500, width=4500).generate(" ".join(train['text_clean']))

fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(45, 10), gridspec_kw = {'wspace':0.01, 'hspace':0.1})
axes.imshow(train_general)
axes.axis('off')
axes.set_title('General Word Cloud')

plt.show()
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(15, 10), gridspec_kw = {'wspace':0.01, 'hspace':0.1})


axes[0][0].imshow(train_toxic)
axes[0][0].axis('off')
axes[0][0].set_title('Toxic Word Cloud')
axes[0][0].set_aspect('equal')

axes[0][1].imshow(train_sev_toxic)
axes[0][1].axis('off')
axes[0][1].set_title('Severely Toxic Word Cloud')
axes[0][1].set_aspect('equal')

axes[0][2].imshow(train_obscene)
axes[0][2].axis('off')
axes[0][2].set_title('Obscene Word Cloud')
axes[0][2].set_aspect('equal')

axes[1][0].imshow(train_threat)
axes[1][0].axis('off')
axes[1][0].set_title('Threat Word Cloud')
axes[1][0].set_aspect('equal')

axes[1][1].imshow(train_insult)
axes[1][1].axis('off')
axes[1][1].set_title('Insult Word Cloud')
axes[1][1].set_aspect('equal')

axes[1][2].imshow(train_id_hate)
axes[1][2].axis('off')
axes[1][2].set_title('Identity Hate Word Cloud')
axes[1][2].set_aspect('equal')


plt.show()

In [4]:

print(f'Non tossico: {len(train_data[train_data["toxic"] == 0])}') #14227


labels =  ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
for s in labels:
    print(f'{s}: {len(train_data[train_data[s] == 1])}')

count = {}
for s in labels:
    count[s] = 0

for d in range(len(train_data)):
    for s in labels:
        if train_data['toxic'][d] == 0 and train_data[s][d] == 1:
            count[s] += 1

print("..........")
for k in count.keys():
    
    print(f'{k}: {count[k]}')
        

print('..........')

count_bad_ppl = 0
for d in range(len(train_data)):
    for s in labels:
        if train_data[s][d] == 1:
            count_bad_ppl += 1
            break
        
print(f'Bad people: {count_bad_ppl}')

count_good_ppl = 0
for d in range(len(train_data)):
        if train_data['toxic'][d] == 0 and train_data['severe_toxic'][d] == 0 and train_data['obscene'][d] == 0 and train_data['threat'][d] == 0 and train_data['insult'][d] == 0 and train_data['identity_hate'][d] == 0:
            count_good_ppl += 1
            

print(f'Good people: {count_good_ppl}')

Non tossico: 33381
toxic: 15294
severe_toxic: 1595
obscene: 8449
threat: 478
insult: 7877
identity_hate: 1405


KeyError: 0

In [None]:

train_serie = pd.Series(train_data['text_clean'],), 
for d in range(len(train_data)):
        if train_data['toxic'][d] == 0 and train_data['severe_toxic'][d] == 0 and train_data['obscene'][d] == 0 and train_data['threat'][d] == 0 and train_data['insult'][d] == 0 and train_data['identity_hate'][d] == 0:
            count_good_ppl += 1
            if count_good_ppl == 32450:
                break
            else:
                train_data = train_data['toxic'][d]

## Caricamento dati già puliti

In [3]:
train_data = pd.read_csv('data/train_clean.csv')
test_data = pd.read_csv('data/test_clean.csv')

In [14]:
len(train_data)

32450

## Data balance (subsampling)

In [16]:
#32450
from sklearn.model_selection import train_test_split
train_df, val_df = train_test_split(train_data, test_size=0.05)
labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
train_toxic = train_df[train_df[labels].sum(axis=1) > 0]
train_clean = train_df[train_df[labels].sum(axis=1) == 0]

train_df = pd.concat([
  train_toxic,
  train_clean.sample(20000)
])


In [15]:
# delete from dataframe the rows with nan values
test_data = test_data[~test_data['text_lem'].isna()]
test_data['text_lem'].isna().sum()
train_data = train_data[~train_data['text_lem'].isna()]
train_data['text_lem'].isna().sum()

0

## Setup torch e bert

In [17]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f'There are {torch.cuda.device_count()} GPU(s) available.')
    print('Device name:', torch.cuda.get_device_name(0))
    

else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
Device name: NVIDIA GeForce RTX 2080 Ti


In [9]:
type(train_df['text_lem'].tolist())

list

In [18]:
# PreProcessing for bert
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from transformers import AdamW, get_linear_schedule_with_warmup
import random
import time
from tqdm import tqdm



tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

def preprocessing_for_bert(data):
    
    '''
    Add special tokens to the start and end of each sentence.
    Pad & truncate all sentences to a single constant length.
    Explicitly differentiate real tokens from padding tokens with the “attention mask”.

    '''

    input_ids = []
    attention_masks = []

    for sent in data:
        encoded_sent = tokenizer.encode_plus(
            text = sent,
            add_special_tokens = True,
            max_length = 300,
            pad_to_max_length = True,
            truncation = True,
            return_attention_mask = True
        )
        input_ids.append(encoded_sent.get('input_ids'))
        attention_masks.append(encoded_sent.get('attention_mask'))
    
    # Convert the lists into tensors.
    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)

    return input_ids, attention_masks

class BertClassifier(nn.Module):
    def __init__(self, freeze_bert=False):
        super(BertClassifier, self).__init__()
        # Specify hidden size of BERT, hidden size of our classifier, and number of labels
        D_in, H, D_out = 768, 50, 6

        # Instantiate BERT model
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        # self.LSTM = nn.LSTM(D_in,D_in,bidirectional=True)
        # self.clf = nn.Linear(D_in*2,2)

        # Instantiate an one-layer feed-forward classifier
        self.classifier = nn.Sequential(nn.Linear(D_in, H), nn.Dropout(0.3), nn.Linear(H, D_out))
        

        # Freeze the BERT model
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
        
    def forward(self, input_ids, attention_mask):
        # Feed input to BERT
        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask)
        # Extract the last hidden state of the token `[CLS]` for classification task
        
        last_hidden_state_cls = outputs[0][:, 0, :]        
        logits = self.classifier(last_hidden_state_cls)

        return logits

def initialize_model(epochs=4):
    """Initialize the Bert Classifier, the optimizer and the learning rate scheduler.
    """
    # Instantiate Bert Classifier
    bert_classifier = BertClassifier(freeze_bert=False)

    # Tell PyTorch to run the model on GPU
    bert_classifier.to(device)

    # Create the optimizer
    optimizer = AdamW(bert_classifier.parameters(),
                      lr=1e-4,    # Default learning rate
                      eps=1e-8    # Default epsilon value
                      )

    # Total number of training steps
    total_steps = len(train_dataloader) * epochs

    # Set up the learning rate scheduler
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0, # Default value
                                                num_training_steps=total_steps)
    return bert_classifier, optimizer, scheduler

def set_seed(seed_value=42):
    """Set seed for reproducibility.
    """
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

def train(model, train_dataloader, val_dataloader=None, epochs=4, evaluation=False):
    """Train the BertClassifier model.
    """
    # Start training loop
    print("Start training...\n")
    best_accuracy = 0
    for epoch_i in range(epochs):
        # =======================================
        #               Training
        # =======================================
        # Print the header of the result table
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}")
        print("-"*70)

        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()

        # Reset tracking variables at the beginning of each epoch
        total_loss, batch_loss, batch_counts = 0, 0, 0

        # Put the model into the training mode
        model.train()
        step = 0
        # For each batch of training data...
        for batch in tqdm(train_dataloader):
            batch_counts +=1
            # Load batch to GPU
            b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)

            # Zero out any previously calculated gradients
            model.zero_grad()
            b_input_ids = b_input_ids.to(device,dtype=torch.long)
            b_attn_mask = b_attn_mask.to(device,dtype=torch.long)
            b_labels = b_labels.to(device,dtype=torch.float)
            # Perform a forward pass. This will return logits.
            logits = model(b_input_ids, b_attn_mask)

            logits = torch.sigmoid(logits)
            # Compute loss and accumulate the loss values
            loss = loss_fn(logits, b_labels)
            batch_loss += loss.item()
            total_loss += loss.item()

            # Perform a backward pass to calculate gradients
            loss.backward()

            # Clip the norm of the gradients to 1.0 to prevent "exploding gradients"
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # Update parameters and the learning rate
            optimizer.step()
            scheduler.step()

            # Print the loss values and time elapsed for every 20 batches
            if (step % 20 == 0 and step != 0) or (step == len(train_dataloader) - 1):
                # Calculate time elapsed for 20 batches
                time_elapsed = time.time() - t0_batch

                # Print training results
                #print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}")
                
                # Reset batch tracking variables
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()
           
        # Calculate the average loss over the entire training data
        avg_train_loss = total_loss / len(train_dataloader)

        print("-"*70)
        # =======================================
        #               Evaluation
        # =======================================
        if evaluation == True:
            # After the completion of each training epoch, measure the model's performance
            # on our validation set.
            val_loss, val_accuracy = evaluate(model, val_dataloader)

            # Print performance over the entire training data
            time_elapsed = time.time() - t0_epoch
            
            print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}")
            print("-"*70)
        print("\n")
        step += 1
        output_folder = "bert_classifier_multilabel"
        filename="./trained_models/"+str(output_folder)

        # create folder if it does not exist

        if not os.path.exists(filename):
            os.makedirs(filename)

        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            filename="./trained_models/"+str(output_folder)+"/best.pt"
            torch.save(model.state_dict(), filename)

        filename="./trained_models/"+str(output_folder)+"/last.pt"
        torch.save(model.state_dict(), filename)
    
    print("Training complete!")

def evaluate(model, val_dataloader):
    """After the completion of each training epoch, measure the model's performance
    on our validation set.
    """
    # Put the model into the evaluation mode. The dropout layers are disabled during
    # the test time.
    model.eval()

    # Tracking variables
    val_accuracy = []
    val_loss = []

    # For each batch in our validation set...
    for batch in val_dataloader:
        # Load batch to GPU
        b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)

        # Compute logits
        with torch.no_grad():
            logits = model(b_input_ids, b_attn_mask)

        # Compute loss
        logits = torch.sigmoid(logits)
        loss = loss_fn(logits, b_labels.float())
        val_loss.append(loss.item())

        # Get the predictions
        preds = logits
        # Calculate the accuracy rate
        correct_val = 0
        res = 0
        for i in range(6):
            res = 1 if preds[0,i]>0.5 else 0
            if res == b_labels[0,i]:
                correct_val += 1
        val_accuracy.append(correct_val/6)

    # Compute the average accuracy and loss over the validation set.
    val_loss = np.mean(val_loss)
    val_accuracy = np.mean(val_accuracy)

    return val_loss, val_accuracy

#Splitting

train_inputs, train_masks = preprocessing_for_bert(train_df['text_lem'].tolist())
val_inputs, val_masks = preprocessing_for_bert(val_df['text_lem'].tolist())


# Convert other data types to torch.Tensor
train_labels = torch.tensor(train_df[labels].to_numpy())
val_labels = torch.tensor(val_df[labels].to_numpy())
# For fine-tuning BERT, the authors recommend a batch size of 16 or 32.
batch_size = 16
# Create the DataLoader for our training set
train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)

# Create the DataLoader for our validation set
val_dataset = TensorDataset(val_inputs, val_masks, val_labels)
val_sampler = SequentialSampler(val_dataset)
val_dataloader = DataLoader(val_dataset, sampler=val_sampler, batch_size=batch_size)

# Specify loss function
loss_fn = nn.BCEWithLogitsLoss()
bert_classifier, optimizer, scheduler = initialize_model(epochs=2)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [20]:
import gc
torch.cuda.empty_cache()
gc.collect()

16

## Train e valutazione

In [21]:
loss_fn = nn.BCELoss()
train(bert_classifier, train_dataloader, val_dataloader, epochs=5, evaluation=True)

Start training...

 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------


100%|██████████| 2215/2215 [14:19<00:00,  2.58it/s]


----------------------------------------------------------------------
   1    |    -    |   0.171176   |  0.056176  |   0.98    |  923.54  
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------


100%|██████████| 2215/2215 [14:18<00:00,  2.58it/s]


----------------------------------------------------------------------
   2    |    -    |   0.123939   |  0.050089  |   0.98    |  924.27  
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------


100%|██████████| 2215/2215 [14:10<00:00,  2.60it/s]


----------------------------------------------------------------------
   3    |    -    |   0.102258   |  0.050089  |   0.98    |  914.36  
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------


100%|██████████| 2215/2215 [14:18<00:00,  2.58it/s]


----------------------------------------------------------------------
   4    |    -    |   0.102119   |  0.050089  |   0.98    |  924.47  
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------


100%|██████████| 2215/2215 [14:10<00:00,  2.61it/s]


----------------------------------------------------------------------
   5    |    -    |   0.102489   |  0.050089  |   0.98    |  914.96  
----------------------------------------------------------------------


Training complete!


In [27]:
import torch.nn.functional as F
def bert_predict(model, test_dataloader):
    """Perform a forward pass on the trained BERT model to predict probabilities
    on the test set.
    """
    # Put the model into the evaluation mode. The dropout layers are disabled during
    # the test time.
    model.eval()

    all_logits = []
    labels = []

    # For each batch in our test set...
    for batch in test_dataloader:
        # Load batch to GPU
        b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)

        # Compute logits
        with torch.no_grad():
            logits = model(b_input_ids, b_attn_mask)
        all_logits.append(torch.sigmoid(logits).cpu().detach().numpy().tolist())
        labels.append(b_labels)
    # Concatenate logits from each batch
    

    # Apply softmax to calculate probabilities
    #probs = F.softmax(all_logits, dim=1).cpu().numpy()
    return all_logits, labels

from sklearn.metrics import f1_score, roc_curve, auc

def evaluate_roc(probs, y_true):
    """
    - Print AUC and accuracy on the test set
    - Plot ROC
    @params    probs (np.array): an array of predicted probabilities with shape (len(y_true), 6)
    @params    y_true (np.array): an array of the true values with shape (len(y_true),)
    """
    preds = probs
    fpr, tpr, threshold = roc_curve(y_true, preds)
    roc_auc = auc(fpr, tpr)
    print(f'AUC: {roc_auc:.4f}')
    
    # Get accuracy over the test set
    y_pred = np.where(preds >= 0.5, 1, 0)
    accuracy = f1_score(y_true, y_pred)
    print(f'F1: {accuracy}')
    
    # Plot ROC AUC
    plt.title('Receiver Operating Characteristic')
    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()


In [27]:
for l in y_val:
        for i in range(6):
                if l[i] == -1:
                        print (" ci sono con -1 nelle labels")

In [24]:


test_inputs, test_masks = preprocessing_for_bert(test_data['text_lem'].tolist())
test_y = torch.tensor(test_data[labels].to_numpy())

# Create the DataLoader for our test set
test_dataset = TensorDataset(test_inputs, test_masks, test_y)
test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=16)


In [28]:
#best_model = BertClassifier()
#best_model.load_state_dict(torch.load('trained_models/bert_classifier_multilabel/best.pt'))
#best_model.to(device)y

probs, _ = bert_predict(bert_classifier, test_dataloader)

In [134]:
def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):
    '''
    Compute the Hamming score (a.k.a. label-based accuracy) for the multi-label case
    http://stackoverflow.com/q/32239577/395857
    
    Take in np.array for y_true and y_pred. E.g.
    y_true = np.array([[0,1,0],
                       [0,1,1],
                       [1,0,1],
                       [0,0,1]])

    y_pred = np.array([[0,1,1],
                       [0,1,1],
                       [0,1,0],
                       [0,0,0]])
    '''
    acc_list = []
    for i in range(y_true.shape[0]):
        set_true = set( np.where(y_true[i])[0] )
        set_pred = set( np.where(y_pred[i])[0] )
        #print('\nset_true: {0}'.format(set_true))
        #print('set_pred: {0}'.format(set_pred))
        tmp_a = None
        if len(set_true) == 0 and len(set_pred) == 0:
            tmp_a = 1
        else:
            tmp_a = len(set_true.intersection(set_pred))/\
                    float( len(set_true.union(set_pred)) )
        #print('tmp_a: {0}'.format(tmp_a))
        acc_list.append(tmp_a)
    return np.mean(acc_list)

In [175]:
sentence = "STFU stupid nigga, you have to suck my dick, l2p!!!"
sentence = cleaning(sentence)
#sentence = pos_tag_wordnet(sentence)
#sentence = lemmatize_word(sentence)

print(sentence)

stfu stupid nigga suck dick


In [None]:
test_sentence, test_sentence_masks = preprocessing_for_bert(sentence)
test_sentence_y = torch.tensor([])

sentence_dataset = TensorDataset(test_sentence, test_sentence_masks, test_y)
sentence_sampler = SequentialSampler(sentence_dataset)
sentence_dataloader = DataLoader(sentence_dataset, sampler=sentence_sampler, batch_size=16)

In [118]:
a = np.array([])
a = np.append(a, 1)
print(a)


TypeError: append() missing 1 required positional argument: 'values'

In [170]:
outputs = []
for p in range(len(probs)):
    for b in range(len(probs[p])):
        aux = []
        for l in range(6):
            aux.append(int(probs[p][b][l] >= 0.5))
        outputs.append(aux)
        



In [171]:

f1 = f1_score(test_label, outputs, average = 'weighted')
print("F1 score: ", f1)
hammingscore = hamming_score(test_label, outputs)

print("Hamming score: ", hammingscore)

F1 score:  0.6048898420042245
Hamming score:  0.8307334246963753
