In [None]:
from torchcrf import CRF
import torch
import torch.nn as nn
from torch_optimizer import Ranger
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup
from torch.utils.data import DataLoader, TensorDataset, random_split
from utils import *
from metrics import f1score
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from multiprocessing import cpu_count
from platform import system
from os import environ

environ["TOKENIZERS_PARALLELISM"] = "false"
pl.seed_everything(seed=42)

In [None]:
LEARNING_RATE = 2.5e-1
BATCH_SIZE = 128
WEIGHT_DECAY = 1e-2
EPOCHS = 50
MAX_LEN = None
CELL_TYPE = "lstm"
N_JOBS = cpu_count() if system() != "Windows" else 0

TAG2IDX = {'B': 0, 'I': 1, 'O': 2, 'E': 3, 'S': 4, '<': 5, '>': 6, '$': 7}

POS_TAGS2IDX = {'PAD_AUX': 0, 'ADJ': 1, 'ADP': 2, 'ADV': 3, 'AUX': 4, 
                'CONJ': 5, 'CCONJ': 6, 'DET': 7, 'INTJ': 8, 'NOUN': 9, 
                'NUM': 10, 'PART': 11, 'PRON': 12, 'PROPN': 13, 'PUNCT': 14, 
                'SCONJ': 15, 'SYM': 16, 'VERB': 17, 'X': 18, 'SPACE': 19}

In [None]:
class SEQ2SEQ_POS_TAGS_ENCODER(pl.LightningModule):
    def __init__(self, 
                 input_dim1, 
                 input_dim2,
                 cell_type="lstm",
                 embed_dim1=128, 
                 embed_dim2=32,
                 dropout=0.5, 
                 cell_dim1=128,
                 cell_dim2=128,
                 bidirectional=True, 
                 num_layers1=3,
                 num_layers2=1,
                 num_tags=len(TAG2IDX),
                 use_scheduler=True,
                 train_dataset=None,
                 val_dataset=None,
                 test_dataset=None):

        super().__init__()
        self.embedding1 = nn.Embedding(num_embeddings=input_dim1,
                                       embedding_dim=embed_dim1,
                                       padding_idx=TAG2IDX['$'])

        self.embedding2 = nn.Embedding(num_embeddings=input_dim2,
                                       embedding_dim=embed_dim2,
                                       padding_idx=POS_TAGS2IDX['PAD_AUX'])
        
        c = (2 if bidirectional else 1)
        if cell_type == "lstm":
            self.cell1 = nn.LSTM(input_size=embed_dim1, 
                                hidden_size=cell_dim1, 
                                dropout=dropout,
                                num_layers=num_layers1, 
                                bidirectional=bidirectional)

            self.cell2 = nn.LSTM(input_size=embed_dim2, 
                                hidden_size=cell_dim2, 
                                dropout=dropout,
                                num_layers=num_layers2, 
                                bidirectional=bidirectional)
        else:
            self.cell1 = nn.GRU(input_size=embed_dim1, 
                                hidden_size=cell_dim1, 
                                dropout=dropout,
                                num_layers=num_layers1, 
                                bidirectional=bidirectional)

            self.cell2 = nn.GRU(input_size=embed_dim2, 
                                hidden_size=cell_dim2, 
                                dropout=dropout,
                                num_layers=num_layers2, 
                                bidirectional=bidirectional)

        self.fc = nn.Linear(c*cell_dim1, num_tags)
        self.crf = CRF(num_tags=num_tags, batch_first=True)
        self.dropout = nn.Dropout(p=dropout)
        ## Hyperparameters ##
        self.use_scheduler = use_scheduler
        self.learning_rate = LEARNING_RATE
        self.weight_decay = WEIGHT_DECAY
        self.batch_size = BATCH_SIZE
        ## Datasets ##
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        ## steps ##
        if self.use_scheduler: 
            self.total_steps = len(train_dataset) // self.batch_size

    def train_dataloader(self):
        return DataLoader(self.train_dataset, 
                          batch_size=self.batch_size,
                          shuffle=True,
                          num_workers=N_JOBS,
                          drop_last=False)


    def val_dataloader(self):
        return DataLoader(self.val_dataset, 
                          batch_size=self.batch_size,
                          num_workers=N_JOBS,
                          drop_last=False)


    def test_dataloader(self):
        return DataLoader(self.test_dataset, 
                          batch_size=self.batch_size,
                          num_workers=N_JOBS,
                          drop_last=False)

    def predict_dataloader(self):
        return DataLoader(self.test_dataset, 
                          batch_size=self.batch_size,
                          num_workers=N_JOBS,
                          drop_last=False)


    def forward(self, input_ids1, input_ids2):
        out1, _ = self.cell1(self.embedding1(input_ids1))
        out2, _ = self.cell2(self.embedding2(input_ids2))
        out = out1 + out2
        out = F.relu(out)
        out = self.dropout(out)
        out = self.fc(out)
        return out


    def _shared_evaluation_step(self, batch, batch_idx):
        ids1, ids2, masks, lbls = batch
        emissions = self(ids1, ids2)
        loss = -self.crf(emissions, lbls, mask=masks)
        pred = self.crf.decode(emissions, mask=masks)
        r, p, f1 = f1score(lbls, pred)
        return loss, r, p, f1
     
        
    def training_step(self, batch, batch_idx):
        loss, r, p, f1 = self._shared_evaluation_step(batch, batch_idx)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_recall", r, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_precision", p, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_f1score", f1, on_step=False, on_epoch=True, prog_bar=True)
        return loss


    def validation_step(self, batch, batch_idx):
        loss, r, p, f1 = self._shared_evaluation_step(batch, batch_idx)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_recall", r, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_precision", p, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_f1score", f1, on_step=False, on_epoch=True, prog_bar=True)

    
    def test_step(self, batch, batch_idx):
        loss, r, p, f1 = self._shared_evaluation_step(batch, batch_idx)
        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_recall", r, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_precision", p, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_f1score", f1, on_step=False, on_epoch=True, prog_bar=True)


    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        ids1, ids2, masks, _ = batch
        return self.crf.decode(self(ids1, ids2), mask=masks)


    def configure_optimizers(self):           
        optimizer = Ranger(self.parameters(), 
                           lr=self.learning_rate,
                           weight_decay=self.weight_decay)

        if self.use_scheduler:
            scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer,
                                                        num_warmup_steps=1,
                                                        num_training_steps=self.total_steps)
            lr_scheduler = {
                'scheduler': scheduler, 
                'interval': 'epoch', 
                'frequency': 1
            }
            return [optimizer], [lr_scheduler]
        else:
            return [optimizer]

In [None]:
with open("../data/full_vocab_290818_tree_bank_tokenier.txt", mode="r", encoding="utf-8") as f:
    vocab = [s.strip() for s in f.readlines()]
    VOCAB2IDX = {v:k for (k, v) in enumerate(vocab)}

In [None]:
encoded_input, pos_tags, masks, extended_labels = get_encoded_input("../data/train_290818.txt", 
                                                                    tag2idx=TAG2IDX,
                                                                    vocab2idx=VOCAB2IDX,
                                                                    pos_tags2idx=POS_TAGS2IDX,
                                                                    return_pos_tags=True,
                                                                    maxlen=MAX_LEN)

L = len(extended_labels)

dataset = TensorDataset(torch.LongTensor(encoded_input),
                        torch.LongTensor(pos_tags),
                        torch.BoolTensor(masks),
                        torch.LongTensor(extended_labels))

train_sz, val_sz = L-int(0.1*L), int(0.1*L)
train_dataset, val_dataset = random_split(dataset, (train_sz, val_sz))                                                                                                                      

In [None]:
encoded_input, pos_tags, masks, extended_labels = get_encoded_input("../data/test_290818.txt", 
                                                                    tag2idx=TAG2IDX,
                                                                    vocab2idx=VOCAB2IDX,
                                                                    pos_tags2idx=POS_TAGS2IDX,
                                                                    return_pos_tags=True,
                                                                    maxlen=MAX_LEN)


test_dataset = TensorDataset(torch.LongTensor(encoded_input),
                             torch.LongTensor(pos_tags),
                             torch.BoolTensor(masks),
                             torch.LongTensor(extended_labels))

In [None]:
model = SEQ2SEQ_POS_TAGS_ENCODER(input_dim1=len(VOCAB2IDX),
                                 input_dim2=len(POS_TAGS2IDX),
                                 cell_type=CELL_TYPE,
                                 bidirectional=True,
                                 train_dataset=train_dataset,
                                 val_dataset=val_dataset,
                                 test_dataset=test_dataset,
                                 use_scheduler=True)

earlystopping_callback = EarlyStopping(monitor="val_f1score", 
                                       min_delta=1e-4, 
                                       patience=EPOCHS, 
                                       mode="max")

checkpoint_callback = ModelCheckpoint(dirpath="./",
                                      filename=f"seq2seq-{CELL_TYPE}-ner-with-pos-tags-encoder",
                                      save_top_k=1, 
                                      mode="max",
                                      monitor="val_f1score",
                                      save_weights_only=True)

trainer = pl.Trainer(accelerator="gpu",
                     max_epochs=EPOCHS,
                     precision=16,
                     log_every_n_steps=1,
                     callbacks=[earlystopping_callback,
                                checkpoint_callback])

In [None]:
trainer.fit(model)

In [None]:
model.load_state_dict(torch.load(f"./seq2seq-{CELL_TYPE}-ner-with-pos-tags-encoder.ckpt")["state_dict"])
trainer.test(model)