In [None]:
# !pip install --upgrade tensorflow-gpu==2.2 -q # must be at least 2.2 to use transformers
# !pip install pytorch_lightning -q
# !pip install transformers -q

In [None]:
from argparse import ArgumentParser
import pytorch_lightning as pl

from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from sklearn.utils import class_weight

import string
import pandas as pd
import torch
import numpy as np

from ipywidgets import *

from ast import literal_eval

from transformers import AutoModelForTokenClassification, AutoTokenizer, AutoConfig

In [None]:
tokenizer = AutoTokenizer.from_pretrained('TurkuNLP/bert-base-finnish-uncased-v1')

In [None]:
df = pd.read_csv('row_df_fi_labeled.csv', index_col=0)

# select rows which have labels
train_df = df[~df['label'].isna()].copy()

# rest is test
test_df = df.drop(index=train_df.index)

# convert serialized lists back into lists objects
train_df['label'] = train_df['label'].apply(literal_eval)

len(train_df)

In [None]:
train_df

In [None]:
int2label = dict(enumerate('OBI'))

label2int = {v: k for k, v in int2label.items()}

label2int

In [None]:
all_labels = np.array([tag for label in train_df['label'] for tag in label])

class_weights = class_weight.compute_class_weight(class_weight='balanced', classes=list('OBI'), y=all_labels)

class_weights

In [None]:
table = str.maketrans({c: f' {c} ' for c in string.punctuation})

def tokenize(s):
    return s.translate(table).strip().split()

In [None]:
class Ds(torch.utils.data.Dataset):
    def __init__(self, tokens, labels=None):
        self.tokens = tokens
        self.labels = labels
        
    def __getitem__(self, i):
        item = {'input_ids': self.tokens[i]}
        
        if self.labels is not None:
            item['labels'] = [label2int[lbl] for lbl in self.labels[i]]
                
        return item
    
    def __len__(self):
        return len(self.tokens)

In [None]:
def create_dataset(df, train=True):

    data_tokens = []
    data_labels = []

    for t in df.itertuples():
        text = t.answer
        word_labels = t.label
        
        if train:

            # get tokens using a simple whitespace tokenizer
            word_tokens = tokenize(text)

            # add O to labels and the corresponding [CLS] token to tokens
            labels = ['O']
            tokens = [tokenizer.cls_token_id]

            # use model tokenizer to repeat BIO tokens
            for word_token, word_label in zip(word_tokens, word_labels):
                word_tokens = tokenizer.encode(word_token, add_special_tokens=False)

                labels.extend(len(word_tokens) * word_label)
                tokens.extend(word_tokens)

            # add O and [SEP] tokens
            labels.append('O')
            tokens.append(tokenizer.sep_token_id)

            # tokenizing whitespace separated text should be the same as tokenizing full documents
            assert tokens == tokenizer.encode(text)

            # token and label sequences should have equal lengths
            assert len(labels) == len(tokens)

            data_tokens.append(tokens)
            data_labels.append(labels)    
            
        else:
            tokens = tokenizer.encode(text)
            data_tokens.append(tokens)
        
        
    return Ds(data_tokens, data_labels if train else None)

In [None]:
train_ds = create_dataset(train_df)

In [None]:
def collate_fn(examples):
    batch = tokenizer.pad(examples, padding='longest')
    
    if 'labels' in batch:

        # for some reason tokenizer.pad does not pad `labels` so let's do it manually
        label_sequences = [torch.tensor(x) for x in batch['labels']]
        batch['labels'] = torch.nn.utils.rnn.pad_sequence(label_sequences, batch_first=True, padding_value=-1)

    # we cannot pass return_tensors='pt' because stacking `labels` would fail
    # so let's also do this conversion manually
    batch['input_ids'] = torch.tensor(batch['input_ids'])
    batch['attention_mask'] = torch.tensor(batch['attention_mask'])    
    
    return batch

In [None]:
test_batch = collate_fn([train_ds[i] for i in range(8)])

test_config = AutoConfig.from_pretrained('TurkuNLP/bert-base-finnish-uncased-v1', num_labels=3)
test_model = AutoModelForTokenClassification.from_config(test_config)

In [None]:
out = test_model(**test_batch)
out[1].shape

In [None]:
class Model(pl.LightningModule):
    def __init__(self, lr, batch_size, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        config = AutoConfig.from_pretrained('TurkuNLP/bert-base-finnish-uncased-v1', num_labels=3)
        self.model = AutoModelForTokenClassification.from_config(config)

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

#     def training_step(self, batch, batch_idx):
#         loss, logits = self(**batch)
#         print('sums:', logits.argmax(-1).sum(dim=[0, 1]))
#         print('loss:', loss)
#         return loss

    def training_step(self, batch, batch_idx):
        
        assert (batch['labels'] != -1).sum() == batch['attention_mask'].sum()

        _, logits = self(**batch)
        
        attention_mask = batch['attention_mask']
        labels = batch['labels']



        # NOTE: this is noral transformers BERT loss which has been modified by adding class weights
        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float, device=logits.device))

            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.model.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))        
                
        print(f'loss {self.global_step}:', loss)

        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
    
    def train_dataloader(self):
        return DataLoader(train_ds, 
                          batch_size=args.batch_size, 
                          shuffle=True,
                          num_workers=4,
                          drop_last=True,
                          collate_fn=collate_fn)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--lr', type=float, default=1e-4)
        parser.add_argument('--batch_size', default=16, type=int)
        return parser


In [None]:
pl.seed_everything(1234)

# ------------
# args
# ------------
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = Model.add_model_specific_args(parser)
args = parser.parse_args([])

# ------------
# model
# ------------
model = Model(**vars(args))

# ------------
# training
# ------------
epochs = 3

trainer = pl.Trainer(logger=False,
                     min_epochs=epochs,
                     max_epochs=epochs,
                     gpus=0)
trainer.fit(model)

# ------------
# testing
# ------------

In [None]:
test_ds = create_dataset(test_df, train=False)

In [None]:
# test_ds = TestDataset(test_list)

# dl = DataLoader(test_ds, batch_size=32, collate_fn=test_collate)

dl = DataLoader(test_ds, batch_size=32, collate_fn=collate_fn)

In [None]:
model.eval()
None

In [None]:
style = """
<style>

.B {
    color: orange
}

.I {
    color: blue
}

.O {
    color: black
}

</style>
"""

In [None]:
htmls = []

# for test_batch in model.train_dataloader():
for test_batch in dl:
    pred = model(input_ids=test_batch['input_ids'], 
                 attention_mask=test_batch['attention_mask'])[0].argmax(-1)

    for i in range(len(pred)):
        x = tokenizer.convert_ids_to_tokens(test_batch['input_ids'][i])
        y = [int2label[bio_int.item()] for bio_int in pred[i]]
        
        pred_tokens = []

        for x_, y_ in zip(x, y):
            if x_ in ['[CLS]', '[SEP]', '[PAD]']:
                continue
            
            w = f'<span class="{y_}">{x_}</span>'
            
            pred_tokens.append(w)

        htmls.append(HTML(style + ' '.join(pred_tokens)))
        
    break

In [None]:
VBox(htmls)