In [None]:
from torchcrf import CRF
import torch
import spacy
import torch.nn as nn
from torch.optim import SGD
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from metrics import f1score as f1score_featurewise
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from multiprocessing import cpu_count
from platform import system
from os import environ
from sklearn.metrics import f1_score as f1score_elementwise
from pandas import read_csv
from ast import literal_eval
import numpy as np

environ["TOKENIZERS_PARALLELISM"] = "false"
nlp = spacy.load("en_core_web_sm")
pl.seed_everything(seed=42)

In [None]:
convert_to_one_hot = lambda values: np.eye(len(tag2idx), dtype=int)[values].tolist()

In [None]:
def post_pad_sequences(sequences, max_len=None, pad_value=0):
    longest = max([len(x) for x in sequences])
    if max_len is not None and max_len > longest:
        print("Unnecessary extra padding detected!\nSequences will be only be padded to the longest length")
    sequences = [(x + [pad_value]*(longest-len(x))) for x in sequences]
    return sequences if max_len is None or max_len >= longest else [x[:max_len] for x in sequences]

In [None]:
LEARNING_RATE = 2e-3
BATCH_SIZE = 16
EPOCHS = 25
N_JOBS = cpu_count() if system() != "Windows" else 0

tag2idx = {'B': 0, 'I': 1, 'O': 2, 'E': 3, 'S': 4, 'X': 5}
pos2idx = {"NOUN": 0, "VERB": 1, "ADJ": 2, "ADV": 3, "PRON": 4, "OTHER": 5}

In [None]:
class CRF4NER(pl.LightningModule):
    def __init__(self, 
                 eval_metric="feature-wise",
                 num_tags=len(tag2idx),
                 train_dataset=None,
                 val_dataset=None,
                 test_dataset=None):

        super().__init__()
        self.crf = CRF(num_tags=num_tags, batch_first=True)
        self.eval_metric = f1score_featurewise if eval_metric == "feature-wise" else f1score_elementwise
        ## Hyperparameters ##
        self.batch_size = BATCH_SIZE
        self.learning_rate = LEARNING_RATE
        ## Datasets ##
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset

    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, 
                          batch_size=self.batch_size,
                          shuffle=True,
                          num_workers=N_JOBS,
                          drop_last=False)


    def val_dataloader(self):
        return DataLoader(self.val_dataset, 
                          batch_size=self.batch_size,
                          num_workers=N_JOBS,
                          drop_last=False)


    def test_dataloader(self):
        return DataLoader(self.test_dataset, 
                          batch_size=self.batch_size,
                          num_workers=N_JOBS,
                          drop_last=False)
        
        
    def forward(self, input_ids):
        pass


    def _shared_evaluation_step(self, batch, batch_idx):
        ids, masks, lbls = batch
        loss = -self.crf(ids, lbls, mask=masks)
        pred = self.crf.decode(ids, mask=masks)
        r, p, f1 = self.eval_metric(lbls.tolist(), pred)
        return loss, r, p, f1
    
        
    def training_step(self, batch, batch_idx):
        loss, r, p, f1 = self._shared_evaluation_step(batch, batch_idx)
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        self.log("train_recall", r, on_epoch=True, prog_bar=True)
        self.log("train_precision", p, on_epoch=True, prog_bar=True)
        self.log("train_f1score", f1, on_epoch=True, prog_bar=True)
        return loss


    def validation_step(self, batch, batch_idx):
        loss, r, p, f1 = self._shared_evaluation_step(batch, batch_idx)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        self.log("val_recall", r, on_epoch=True, prog_bar=True)
        self.log("val_precision", p, on_epoch=True, prog_bar=True)
        self.log("val_f1score", f1, on_epoch=True, prog_bar=True)

    
    def test_step(self, batch, batch_idx):
        loss, r, p, f1 = self._shared_evaluation_step(batch, batch_idx)
        self.log("test_loss", loss, on_epoch=True, prog_bar=True)
        self.log("test_recall", r, on_epoch=True, prog_bar=True)
        self.log("test_precision", p, on_epoch=True, prog_bar=True)
        self.log("test_f1score", f1, on_epoch=True, prog_bar=True)


    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        ids, masks = batch 
        return self.crf.decode(ids, mask=masks)


    def configure_optimizers(self):
        return SGD(self.parameters(),
                   lr=self.learning_rate,
                   momentum=0.9,
                   nesterov=True)

In [None]:
data = read_csv("../../data/train_290818.txt", 
                sep=" ",
                header=None, 
                encoding="utf-8").values.tolist()

text = [literal_eval(words) for (words, _, _) in data]
labels = [[l.split('-')[0] for l in literal_eval(labels)] for (_, labels, _) in data]
pos = [[token.pos_ for token in nlp(' '.join(s))] for s in text]


pos_new = []
for sent in pos:
    temp = []
    for p in sent:
        if p == "NOUN" or p == "PROPN":
            temp.append("NOUN")
        elif p == "ADV" or p == "VERB" or p == "ADJ" or p == "PRON":
            temp.append(p)
        else:
            temp.append("OTHER")
    pos_new.append(temp)


primitive_lens = [len(x) for x in pos_new]
pos_new = post_pad_sequences(pos_new, pad_value="OTHER")
fl = len(pos_new[0])

encoded_input = [convert_to_one_hot([pos2idx[p] for p in sent]) for sent in pos_new]
attn_masks = [([1]*pl + [0]*(fl - pl)) for pl in primitive_lens]

labels = [(l + ['X']*(fl - len(l))) for l in labels]
extended_labels = [[tag2idx[l] for l in lbls] for lbls in labels]


input_ids_train, input_ids_val, attn_masks_train, attn_masks_val, extended_labels_train, extended_labels_val = train_test_split(encoded_input,
                                                                                                                                attn_masks,
                                                                                                                                extended_labels, 
                                                                                                                                test_size=0.1, 
                                                                                                                                shuffle=True) 

input_ids_train = torch.LongTensor(input_ids_train)
attn_masks_train = torch.BoolTensor(attn_masks_train)
extended_labels_train = torch.LongTensor(extended_labels_train)

input_ids_val = torch.LongTensor(input_ids_val)
attn_masks_val = torch.BoolTensor(attn_masks_val)
extended_labels_val = torch.LongTensor(extended_labels_val)

train_dataset = TensorDataset(input_ids_train, attn_masks_train, extended_labels_train)
val_dataset = TensorDataset(input_ids_val, attn_masks_val, extended_labels_val)                                                                                                                        

In [None]:
data = read_csv("../../data/test_290818.txt", 
                sep=" ",
                header=None, 
                encoding="utf-8").values.tolist()

text = [literal_eval(words) for (words, _, _) in data]
labels = [[l.split('-')[0] for l in literal_eval(labels)] for (_, labels, _) in data]
pos = [[token.pos_ for token in nlp(' '.join(s))] for s in text]


pos_new = []
for sent in pos:
    temp = []
    for p in sent:
        if p == "NOUN" or p == "PROPN":
            temp.append("NOUN")
        elif p == "ADV" or p == "VERB" or p == "ADJ" or p == "PRON":
            temp.append(p)
        else:
            temp.append("OTHER")
    pos_new.append(temp)


primitive_lens = [len(x) for x in pos_new]
pos_new = post_pad_sequences(pos_new, pad_value="OTHER")
fl = len(pos_new[0])

encoded_input = np.array([convert_to_one_hot([pos2idx[p] for p in sent]) for sent in pos_new])
attn_masks = [([1]*pl + [0]*(fl - pl)) for pl in primitive_lens]
labels = [(l + ['X']*(fl - len(l))) for l in labels]
extended_labels = [[tag2idx[l] for l in lbls] for lbls in labels]

input_ids_test = torch.LongTensor(encoded_input)
attn_masks_test = torch.BoolTensor(attn_masks)
extended_labels_test = torch.LongTensor(extended_labels)

test_dataset = TensorDataset(input_ids_test, attn_masks_test, extended_labels_test)                                                                                                                     

In [None]:
model = CRF4NER(train_dataset=train_dataset,
                val_dataset=val_dataset,
                test_dataset=test_dataset)

earlystopping_callback = EarlyStopping(monitor="val_f1score", 
                                       min_delta=1e-4, 
                                       patience=5, 
                                       mode="max")

checkpoint_callback = ModelCheckpoint(dirpath="./",
                                      filename="crf-ner-val-f1score",
                                      save_top_k=1, 
                                      mode="max",
                                      monitor="val_f1score",
                                      save_weights_only=True)

trainer = pl.Trainer(accelerator="gpu",
                     max_epochs=EPOCHS,
                     precision=16,
                     log_every_n_steps=1,
                     callbacks=[earlystopping_callback, checkpoint_callback])

In [None]:
trainer.fit(model)

In [None]:
model.load_state_dict(torch.load(f"./crf-ner-val-f1score.ckpt")["state_dict"])
trainer.test(model)