# This notebook contains a mix-mash of offenseval and Andrews current bert model

In [2]:
#pip install fire
#pip install torch
#pip install torchtext
#pip install transformers

In [3]:
#Necessary modules from offenseval are given below. Installs may be needed as above
%load_ext autoreload
%autoreload 2
import os
from datetime import datetime
import fire
import torch
from torchtext import data
import torch.nn as nn
import html
import re
from transformers import (
    AdamW, BertForSequenceClassification, BertTokenizer,
    get_constant_schedule_with_warmup
)

In [4]:
%load_ext autoreload
%autoreload 2
import os
from datetime import datetime
import fire
import torch
from torchtext import data
import torch.nn as nn
import pickle
from transformers import (
    AdamW, BertForSequenceClassification, BertTokenizer,
    get_constant_schedule_with_warmup
)
#The below functions are loaded in individually below
#from offenseval.nn import (
#    Tokenizer,
#    train, evaluate, train_cycle, save_model, load_model
#)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#I think the below mode is missing. Hopefully it doesn't cause further issues
#model, TEXT = load_model("../models/bert.uncased.sample.mean06.ft.pt", device)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


ModuleNotFoundError: No module named 'offenseval'

In [2]:
class Tokenizer:
    """
    Tokenizer for tweets based on BERT Tokenizer + NLTK's Tokenizer
    """
    def __init__(self, bert_tokenizer, html_unescape=True, max_len=128):
        """
        Arguments:
        ----------
        html_unescape: Boolean (default False)
            Use or not `html.unescape` on text before tokenizing
        """
        self._html_unescape = html_unescape
        self.bert_tokenizer = bert_tokenizer
        self.max_len = max_len

        self._patterns = {
            "<hour>": re.compile(r"\d{1,2}\:\d{2}"),
            "<year>": re.compile(r"(1(7|8|9)|2(0|1))\d\d"),
            "<num>": re.compile(r"\d+(\.)?\d*"),
            # @stephenhay's from https://mathiasbynens.be/demo/url-regex
            "<url>": re.compile(r"(https?|ftp)://[^\s/$.?#].[^\s]*"),
        }

    def replace_patterns(self, text):
        for repl, pattern in self._patterns.items():
            text = pattern.sub(repl, text)
        return text

    def tokenize(self, text):
        if self._html_unescape:
            text = html.unescape(text)

        text = self.replace_patterns(text)
        return self.bert_tokenizer.tokenize(text)[:self.max_len]

    def convert_tokens_to_ids(self, *args, **kwargs):
        return self.bert_tokenizer.convert_tokens_to_ids(*args, **kwargs)

In [1]:
def train(model, iterator, optimizer, criterion, get_target,
          scheduler=None, max_grad_norm=None, ncols=500):
    """
    Trains the model for one full epoch
    Arguments:
    model: torch.nn.Module
        Model to be trained
    iterator:
        An iterator over the train batches
    optimizer: torch.nn.optimizer
        An optimizer
    criterion:
        Loss function
    scheduler: (optional) A scheduler
        Scheduler that will be called (if given) after each call to `optimizer.step()`
    get_target: a function
        Function receiving a batch and returning the targets
    max_grad_norm: float (optional, default None)
        If not none, applies gradient clipping using the given norm
    """
    epoch_loss = 0
    epoch_acc = 0

    model.train()

    pbar = tqdm(enumerate(iterator), total=len(iterator), ncols=ncols)
    for i, batch in pbar:
        # Zero gradients first
        optimizer.zero_grad()
        # We assume we always get the length
        text, lens = batch.text
        #target = 1. * (batch.avg > 0.6)
        target = get_target(batch)

        predictions = model(text)
        if type(predictions) is tuple:
            # This is because of BERTSequenceClassifier, sorry!
            predictions = predictions[0]

        loss = criterion(predictions.view(-1), target)
        # Calculate gradients
        loss.backward()
        # Gradient clipping
        if max_grad_norm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        if scheduler:
            scheduler.step()

        # Calculate metrics
        prob_predictions = torch.sigmoid(predictions)
        preds = torch.round(prob_predictions).detach().cpu()
        acc = accuracy_score(preds, target.cpu())

        epoch_loss += loss.item()
        epoch_acc += acc.item()

        # Update Pbar
        lr = optimizer.param_groups[0]["lr"]

        desc = f"Loss {epoch_loss / (i+1):.3f} -- Acc {epoch_acc / (i+1):.3f} -- LR {lr*1e5:.3f}e-5"
        pbar.set_description(desc)

    return epoch_loss / len(iterator), epoch_acc / len(iterator)



def train_cycle(model, optimizer, criterion, scheduler,
                train_it, dev_it, epochs, get_target, model_path,
                monitor="f1", early_stopping_tolerance=5, ncols=100):
    """
    Arguments:
    monitor: "f1" or "loss"
        What to monitor for Early Stopping
    """

    if monitor not in {"loss", "f1"}:
        raise ValueError("Monitor should be 'loss' or 'f1'")


    pbar = tqdm(range(epochs), ncols=ncols)
    pbar.set_description("Epochs")

    epochs_without_improvement = 0
    best_report = None
    max_grad_norm = 1.0

    def improves_performance(best_report, report):
        if best_report is None:
            return True

        if monitor == "loss":
            if report.loss < best_report.loss:
                return True
            else:
                return False
        elif monitor == "f1":
            if report.macro_f1 > best_report.macro_f1:
                return True
            else:
                return False

    for epoch in range(epochs):
        print(f"\n\nEpoch {epoch}")
        try:
            train_loss, train_acc = train(
                model, train_it, optimizer, criterion, get_target=get_target,
                max_grad_norm=max_grad_norm, scheduler=scheduler, ncols=ncols
            )
            report = evaluate(
                model, dev_it, criterion, get_target=lambda batch: batch.subtask_a
            )

            desc = f'Train: Loss: {train_loss:.3f} Acc: {train_acc*100:.2f}%'
            desc += f'\nVal.' + str(report)

            print(desc)
            if improves_performance(best_report, report):
                best_report = report
                epochs_without_improvement = 0
                torch.save(model.state_dict(), model_path)
                print(f"Best model so far ({report}) saved at {model_path}")
            else:
                epochs_without_improvement += 1
                if epochs_without_improvement >= early_stopping_tolerance:
                    print("Early stopping")
                    break
        except KeyboardInterrupt:
            print("Stopping training!")
            break


def create_criterion(device, weight_with=None):
    """
    Creates a `torch.nn.BCEWithLogitsLoss`.
    If weight_with is not None, uses class weight for the positive class
    Arguments:
    ----------
    device: "cuda" or "cpu"
    weight_with: data.Dataset
    """
    if weight_with:
        y = [row.subtask_a for row in weight_with]

        class_weights = compute_class_weight('balanced', ['NOT', 'OFF'], y)

        # normalize it
        class_weights = class_weights / class_weights[0]
        criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([class_weights[1]]))
    else:
        criterion = nn.BCEWithLogitsLoss()

    criterion = criterion.to(device)
    return criterion

In [None]:
def save_model(model, TEXT, output_path):
    base, _ = os.path.splitext(output_path)
    vocab_path = f"{base}.vocab.pkl"

    torch.save(model, output_path)

    with open(vocab_path, "wb") as f:
        pickle.dump(TEXT, f)

    print(f"Model saved to {output_path}")
    print(f"Vocab saved to {vocab_path}")

def load_model(model_path, device):
    base, _ = os.path.splitext(model_path)
    vocab_path = f"{base}.vocab.pkl"

    try:
        with open(vocab_path, "rb") as f:
            TEXT = pickle.load(f)
    except FileNotFoundError as e:
        print(e)
        print("Returning null TEXT")
        TEXT = None

    model = torch.load(model_path, map_location=device)

    return model, TEXT

In [None]:
#Taken from  https://github.com/finiteautomata/offenseval2020/blob/master/offenseval/nn/models/bert_for_sequence.py
from transformers import BertPreTrainedModel

class BertSeqModel(nn.Module):
    def __init__(self, bert, dropout=0.1, num_labels=1):
        """
        Arguments:
        ---------
        bert: BertModel
        """
        super().__init__()
        self.bert = bert
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(
            bert.config.hidden_size,
            num_labels
        )


    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        adapter=None,
    ):
        """
        Adapter: a function
            Function to be applied between BERT and the linear layer
        """

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)

        if adapter:
            pooled_output = adapter(pooled_output)

        out = self.classifier(pooled_output)

        return out

In [None]:
#Taken from https://github.com/finiteautomata/offenseval2020/blob/master/offenseval/nn/models/bert_gru.py

In [None]:
class BERTGRUSequenceClassifier(nn.Module):

    """
    BERT + GRU model
    Inspired on Ben Trevett's implementation:
    https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/6%20-%20Transformers%20for%20Sentiment%20Analysis.ipynb
    """
    def __init__(self,
                 bert,
                 hidden_dim,
                 output_dim,
                 n_layers=1,
                 bidirectional=False,
                 finetune_bert=False,
                 dropout=0.2):

        super().__init__()

        self.bert = bert

        embedding_dim = bert.config.to_dict()['hidden_size']

        self.finetune_bert = finetune_bert

        self.rnn = nn.GRU(embedding_dim,
                          hidden_dim,
                          num_layers = n_layers,
                          bidirectional = bidirectional,
                          batch_first = True,
                          dropout = 0 if n_layers < 2 else dropout)

        self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, text):
        #text = [batch size, sent len]

        if not self.finetune_bert:
            with torch.no_grad():
                embedded = self.bert(text)[0]
        else:
            embedded = self.bert(text)[0]
        #embedded = [batch size, sent len, emb dim]
        _, hidden = self.rnn(embedded)

        #hidden = [n layers * n directions, batch size, emb dim]

        if self.rnn.bidirectional:
            hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))
        else:
            hidden = self.dropout(hidden[-1,:,:])

        #hidden = [batch size, hid dim]

        output = self.out(hidden)

        #output = [batch size, out dim]

        return output