In [None]:
! pip install pytorch-lightning

In [None]:
! pip install pytorch-nlp

In [7]:
import pytorch_lightning as pl
import torch
from typing import List
from tqdm import tqdm
import re
import numpy as np
from argparse import Namespace
import logging

from transformers import ElectraTokenizer, ElectraModel

import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer, LayerNorm
from torchnlp.metrics import get_accuracy, get_token_accuracy
from pytorch_lightning.metrics.functional import f1_score
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [8]:
class RasaIntentEntityDataset(torch.utils.data.Dataset):

    def __init__(self, markdown_lines: List[str], tokenizer, seq_len=128,):
        self.intent_dict = {}
        self.entity_dict = {}
        self.entity_dict["O"] = 0  # using BIO tagging

        self.dataset = []
        self.seq_len = seq_len

        intent_value_list = []
        entity_type_list = []

        current_intent_focus = ""

        text_list = []

        for line in tqdm(markdown_lines,desc="Organizing Intent & Entity dictionary in NLU markdown file ...",):
            if len(line.strip()) < 2:
                current_intent_focus = ""
                continue

            if "## " in line:
                if "intent:" in line:
                    intent_value_list.append(line.split(":")[1].strip())
                    current_intent_focus = line.split(":")[1].strip()
                else:
                    current_intent_focus = ""

            else:
                if current_intent_focus != "":
                    text = line[2:].strip().lower()

                    for type_str in re.finditer(r"\([a-zA-Z_1-2]+\)", text):
                        entity_type = (text[type_str.start() + 1 : type_str.end() - 1].replace("(", "").replace(")", ""))
                        entity_type_list.append(entity_type)

                    text = re.sub(r"\([a-zA-Z_1-2]+\)", "", text)  # remove (...) str
                    text = text.replace("[", "").replace("]", "")  # remove '[',']' special char

                    if len(text) > 0:
                        text_list.append(text.strip())


        #dataset tokenizer setting
        if "ElectraTokenizer" in str(type(tokenizer)):
            self.tokenizer = tokenizer
            self.pad_token_id = 0
            self.unk_token_id = 1
            self.eos_token_id = 3 #[SEP] token
            self.bos_token_id = 2 #[CLS] token

        else:
            raise ValueError('not supported tokenizer type')

        intent_value_list = sorted(intent_value_list)
        for intent_value in intent_value_list:
            if intent_value not in self.intent_dict.keys():
                self.intent_dict[intent_value] = len(self.intent_dict)

        entity_type_list = sorted(entity_type_list)
        for entity_type in entity_type_list:
            if entity_type + '_B' not in self.entity_dict.keys():
                self.entity_dict[str(entity_type) + '_B'] = len(self.entity_dict)
            if entity_type + '_I' not in self.entity_dict.keys():
                self.entity_dict[str(entity_type) + '_I'] = len(self.entity_dict)

        current_intent_focus = ""

        for line in tqdm(markdown_lines, desc="Extracting Intent & Entity in NLU markdown files...",):
            if len(line.strip()) < 2:
                current_intent_focus = ""
                continue

            if "## " in line:
                if "intent:" in line:
                    current_intent_focus = line.split(":")[1].strip()
                else:
                    current_intent_focus = ""
            else:
                if current_intent_focus != "":  # intent & entity sentence occur case
                    text = line[2:].strip().lower()

                    entity_value_list = []
                    for value in re.finditer(r"\[(.*?)\]", text):
                        entity_value_list.append(text[value.start() + 1 : value.end() - 1].replace("[","").replace("]",""))

                    entity_type_list = []
                    for type_str in re.finditer(r"\([a-zA-Z_1-2]+\)", text):
                        entity_type = (text[type_str.start() + 1 : type_str.end() - 1].replace("(","").replace(")",""))
                        entity_type_list.append(entity_type)

                    text = re.sub(r"\([a-zA-Z_1-2]+\)", "", text)  # remove (...) str
                    text = text.replace("[", "").replace("]", "")  # remove '[',']' special char

                    if len(text) > 0:
                        each_data_dict = {}
                        each_data_dict["text"] = text.strip()
                        each_data_dict["intent"] = current_intent_focus
                        each_data_dict["intent_idx"] = self.intent_dict[current_intent_focus]
                        each_data_dict["entities"] = []

                        for value, type_str in zip(entity_value_list, entity_type_list):
                            for entity in re.finditer(value, text):
                                entity_tokens = self.tokenize(value)

                                for i, entity_token in enumerate(entity_tokens):
                                    if i == 0:
                                        BIO_type_str = type_str + '_B'
                                    else:
                                        BIO_type_str = type_str + '_I'

                                    each_data_dict["entities"].append(
                                        {
                                            "start": text.find(entity_token, entity.start(), entity.end()),
                                            "end": text.find(entity_token, entity.start(), entity.end()) + len(entity_token),
                                            "entity": type_str,
                                            "value": entity_token,
                                            "entity_idx": self.entity_dict[BIO_type_str],
                                        }
                                    )


                        self.dataset.append(each_data_dict)

        
        print(f"Intents: {self.intent_dict}")
        print(f"Entities: {self.entity_dict}")

    def tokenize(self, text: str, skip_special_char=True):
        if "ElectraTokenizer" in str(type(self.tokenizer)):
            if skip_special_char:
                return self.tokenizer.tokenize(text)
            else:
                return [token.replace('#','') for token in self.tokenizer.tokenize(text)]
        else:
            raise ValueError('not supported tokenizer type')
            
    def encode(self, text: str, padding: bool = True, return_tensor: bool = True):
        tokens = self.tokenizer.encode(text)
        if type(tokens) == list:
            tokens = torch.tensor(tokens).long()
        else:
            tokens = tokens.long()

        if padding:
            if len(tokens) >= self.seq_len:
                tokens = tokens[: self.seq_len]
            else:
                pad_tensor = torch.tensor([self.pad_token_id] * (self.seq_len - len(tokens)))
            
                tokens = torch.cat((tokens, pad_tensor), 0)

        if return_tensor:
            return tokens
        else:
            return tokens.numpy()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        tokens = self.encode(self.dataset[idx]["text"])

        intent_idx = torch.tensor([self.dataset[idx]["intent_idx"]])

        entity_idx = np.array(self.seq_len * [0]) # O tag indicate 0(zero)

        for entity_info in self.dataset[idx]["entities"]:
            if "ElectraTokenizer" in str(type(self.tokenizer)):
                ##check whether entity value is include in splitted token
                for token_seq, token_value in enumerate(tokens):
                    # Consider [CLS](bos) token
                    if token_seq == 0:
                        continue

                    for entity_seq, entity_info in enumerate(self.dataset[idx]["entities"]):
                        if (self.tokenizer.convert_ids_to_tokens([token_value.item()])[0] in entity_info["value"]):
                            entity_idx[token_seq] = entity_info["entity_idx"]
                            break

        entity_idx = torch.from_numpy(entity_idx)

        return tokens, intent_idx, entity_idx, self.dataset[idx]["text"]

    def get_intent_idx(self):
        return self.intent_dict

    def get_entity_idx(self):
        return self.entity_dict

    def get_vocab_size(self):
        return self.tokenizer.vocab_size

    def get_seq_len(self):
        return self.seq_len

In [9]:
class EmbeddingTransformer(nn.Module):
    def __init__(
        self,
        backbone: None,
        vocab_size: int,
        seq_len: int,
        intent_class_num: int,
        entity_class_num: int,
        d_model=512,
        nhead=8,
        num_encoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
        activation="relu",
        pad_token_id: int = 0,
    ):
        super(EmbeddingTransformer, self).__init__()
        self.backbone = backbone
        self.seq_len = seq_len
        self.pad_token_id = pad_token_id

        if backbone is None:
            self.encoder = nn.TransformerEncoder(
                TransformerEncoderLayer(
                    d_model, nhead, dim_feedforward, dropout, activation
                ),
                num_encoder_layers,
                LayerNorm(d_model),
            )
        else:  # pre-defined model architecture use
            if backbone == "electra":
                self.encoder = ElectraModel.from_pretrained("google/electra-small-discriminator")

            d_model = self.encoder.config.hidden_size

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(self.seq_len, d_model)

        self.intent_feature = nn.Linear(d_model, intent_class_num)
        self.entity_feature = nn.Linear(d_model, entity_class_num)

    def forward(self, x):
        if self.backbone in ["electra"]:
            feature = self.encoder(x)

            if type(feature) == tuple:
                feature = feature[0]  # last_hidden_state (N,S,E)

            # first token in sequence used to intent classification
            intent_feature = self.intent_feature(feature[:, 0, :]) # (N,E) -> (N,i_C)

            # other tokens in sequence used to entity classification
            entity_feature = self.entity_feature(feature[:, :, :]) # (N,S,E) -> (N,S,e_C)

            return intent_feature, entity_feature

In [10]:
class DualIntentEntityTransformer(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()

        self.hparams = hparams
        if type(self.hparams) == dict:
            self.hparams = Namespace(**self.hparams)

        self.dataset = RasaIntentEntityDataset(
            markdown_lines=self.hparams.nlu_data,
            tokenizer=self.hparams.tokenizer,
        )

        self.model = EmbeddingTransformer(
            backbone=self.hparams.backbone,
            vocab_size=self.dataset.get_vocab_size(),
            seq_len=self.dataset.get_seq_len(),
            intent_class_num=len(self.dataset.get_intent_idx()),
            entity_class_num=len(self.dataset.get_entity_idx()),
            d_model=self.hparams.d_model,
            num_encoder_layers=self.hparams.num_encoder_layers,
            pad_token_id=self.dataset.pad_token_id
        )

        self.train_ratio = self.hparams.train_ratio
        self.batch_size = self.hparams.batch_size
        self.optimizer = self.hparams.optimizer
        self.intent_optimizer_lr = self.hparams.intent_optimizer_lr
        self.entity_optimizer_lr = self.hparams.entity_optimizer_lr

        self.intent_loss_fn = nn.CrossEntropyLoss()
        # reduce O tag class weight to figure out entity imbalance distribution
        self.entity_loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([0.1] + [1.0] * (len(self.dataset.get_entity_idx()) - 1)))

    def forward(self, x):
        return self.model(x)

    def prepare_data(self):
        train_length = int(len(self.dataset) * self.train_ratio)

        self.train_dataset, self.val_dataset = random_split(
            self.dataset, [train_length, len(self.dataset) - train_length],
        )

        self.hparams.intent_label = self.get_intent_label()
        self.hparams.entity_label = self.get_entity_label()
    
    def get_intent_label(self):
        self.intent_dict = {}
        for k, v in self.dataset.intent_dict.items():
            self.intent_dict[str(v)] = k
        return self.intent_dict 
    
    def get_entity_label(self):
        self.entity_dict = {}
        for k, v in self.dataset.entity_dict.items():
            self.entity_dict[str(v)] = k
        return self.entity_dict

    def train_dataloader(self):
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
        )
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
        )
        return val_loader

    def configure_optimizers(self):
        intent_optimizer = eval(
            f"{self.optimizer}(self.parameters(), lr={self.intent_optimizer_lr})"
        )
        entity_optimizer = eval(
            f"{self.optimizer}(self.parameters(), lr={self.entity_optimizer_lr})"
        )

        return (
            [intent_optimizer, entity_optimizer],
            # [StepLR(intent_optimizer, step_size=1),StepLR(entity_optimizer, step_size=1),],
            [
                ReduceLROnPlateau(intent_optimizer, patience=1),
                ReduceLROnPlateau(entity_optimizer, patience=1),
            ],
        )

    def training_step(self, batch, batch_idx, optimizer_idx):
        self.model.train()

        tokens, intent_idx, entity_idx, text = batch
        intent_pred, entity_pred = self.forward(tokens)

        intent_acc = get_accuracy(intent_pred.argmax(1).cpu(), intent_idx.cpu())[0]
        #entity_acc = get_token_accuracy(entity_pred.argmax(2), entity_idx, ignore_index=self.dataset.pad_token_id)[0]

        tensorboard_logs = {
            "train/intent/acc": intent_acc,
            #"train/entity/acc": entity_acc,
        }

        if optimizer_idx == 0:
            intent_loss = self.intent_loss_fn(intent_pred, intent_idx.squeeze(1))
            tensorboard_logs["train/intent/loss"] = intent_loss

            return {
                "loss": intent_loss,
                "log": tensorboard_logs,
            }

        if optimizer_idx == 1:
            entity_loss = self.entity_loss_fn(entity_pred.transpose(1, 2), entity_idx.long(),)
            tensorboard_logs["train/entity/loss"] = entity_loss

            return {
                "loss": entity_loss,
                "log": tensorboard_logs,
            }

    def validation_step(self, batch, batch_idx):
        self.model.eval()

        tokens, intent_idx, entity_idx, text = batch
        intent_pred, entity_pred = self.forward(tokens)
        

        intent_acc = get_accuracy(intent_pred.argmax(1).cpu(), intent_idx.cpu())[0]
        #entity_acc = get_token_accuracy(entity_pred.argmax(2), entity_idx, ignore_index=self.dataset.pad_token_id)[0]
        intent_f1 = f1_score(intent_pred.argmax(1), intent_idx)

        intent_loss = self.intent_loss_fn(intent_pred, intent_idx.squeeze(1))
        entity_loss = self.entity_loss_fn(entity_pred.transpose(1, 2), entity_idx.long(),)

        return {
            "val_intent_acc": torch.Tensor([intent_acc]),
            #"val_entity_acc": torch.Tensor([entity_acc]),
            "val_intent_f1": intent_f1,
            "val_intent_loss": intent_loss,
            "val_entity_loss": entity_loss,
            "val_loss": intent_loss + entity_loss,
        }

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_intent_acc = torch.stack([x["val_intent_acc"] for x in outputs]).mean()
        #avg_entity_acc = torch.stack([x["val_entity_acc"] for x in outputs]).mean()
        avg_intent_f1 = torch.stack([x["val_intent_f1"] for x in outputs]).mean()

        tensorboard_logs = {
            "val/loss": avg_loss,
            "val/intent_acc": avg_intent_acc,
            #"val/entity_acc": avg_entity_acc,
            "val/intent_f1": avg_intent_f1,
        }

        return {
            "val_loss": avg_loss,
            "val_intent_f1": avg_intent_f1,
            #"val_entity_acc": avg_entity_acc,
            "log": tensorboard_logs,
            "progress_bar": tensorboard_logs,
        }

In [11]:
model = None
intent_dict = {}
entity_dict = {}

class Inferencer:
    def __init__(self, checkpoint_path: str):
        self.model = DualIntentEntityTransformer.load_from_checkpoint(checkpoint_path)
        self.model.model.eval()

        self.intent_dict = {}
        for k, v in self.model.dataset.intent_dict.items():
            self.intent_dict[v] = k # str key -> int key

        self.entity_dict = {}
        for k, v in self.model.dataset.entity_dict.items():
            self.entity_dict[v] = k # str key -> int key

        logging.info("intent dictionary")
        logging.info(self.intent_dict)
        print()

        logging.info("entity dictionary")
        logging.info(self.entity_dict)

    def is_same_entity(self, i, j):
        # check whether XXX_B, XXX_I tag are same 
        return self.entity_dict[i][:self.entity_dict[i].rfind('_')].strip() == self.entity_dict[j][:self.entity_dict[j].rfind('_')].strip()

    def inference(self, text: str, intent_topk=5):
        if self.model is None:
            raise ValueError(
                "model is not loaded, first call load_model(checkpoint_path)"
            )

        text = text.strip().lower()

        # encode text to token_indices
        tokens = self.model.dataset.encode(text)
        intent_result, entity_result = self.model.forward(tokens.unsqueeze(0).cpu())

        # mapping intent result
        rank_values, rank_indicies = torch.topk(
            nn.Softmax(dim=1)(intent_result)[0], k=intent_topk
        )
        intent = {}
        intent_ranking = []
        for i, (value, index) in enumerate(
            list(zip(rank_values.tolist(), rank_indicies.tolist()))
        ):
            intent_ranking.append(
                {"confidence": value, "name": self.intent_dict[index]}
            )

            if i == 0:
                intent["name"] = self.intent_dict[index]
                intent["confidence"] = value

        # mapping entity result
        entities = []

        # except first & last sequnce token whcih indicate BOS or [CLS] token & EOS or [SEP] token
        _, entity_indices = torch.max((entity_result)[0][1:-1, :], dim=1)
        start_idx = -1

        #print ('tokens')
        #print (tokens)
        #print ('predicted entities')
        #print (entity_indices)

        entity_indices = entity_indices.tolist()[:len(text)]
        start_token_position = -1

        # except first sequnce token whcih indicate BOS or [CLS] token
        if type(tokens) == torch.Tensor:
            tokens = tokens.long().tolist()

        for i, entity_idx_value in enumerate(entity_indices):
            if entity_idx_value != 0 and start_token_position == -1:
                start_token_position = i
            elif start_token_position >= 0 and not self.is_same_entity(entity_indices[i-1],entity_indices[i]):
                end_token_position = i - 1

                #print ('start_token_position')
                #print (start_token_position)
                #print ('end_token_position')
                #print (end_token_position)

                # find start text position
                token_idx = tokens[start_token_position + 1]
                if "ElectraTokenizer" in str(
                    type(self.model.dataset.tokenizer)
                ):  # ElectraTokenizer
                    token_value = self.model.dataset.tokenizer.convert_ids_to_tokens([token_idx])[0].replace("#", "")

                if len(token_value.strip()) == 0:
                    start_token_position = -1
                    continue
                  
                start_position = text.find(token_value.strip())

                # find end text position
                token_idx = tokens[end_token_position + 1]
                if "ElectraTokenizer" in str(
                    type(self.model.dataset.tokenizer)
                ):  # ElectraTokenizer
                    token_value = self.model.dataset.tokenizer.convert_ids_to_tokens([token_idx])[0].replace("#", "")

                end_position = text.find(token_value.strip(), start_position) + len(token_value.strip())

                if self.entity_dict[entity_indices[i-1]] != "O": # ignore 'O' tag
                    entities.append(
                         {
                            "start": start_position,
                            "end": end_position,
                            "value": text[start_position:end_position],
                            "entity": self.entity_dict[entity_indices[i-1]][:self.entity_dict[entity_indices[i-1]].rfind('_')]
                        }
                    )
                      
                    start_token_position = -1

                if entity_idx_value == 0:
                    start_token_position = -1

        result = {
            "text": text,
            "intent": intent,
            "intent_ranking": intent_ranking,
            "entities": entities,
        }

        # print (result)

        return result

In [26]:
inferencer = Inferencer("lightning_logs/version_35/checkpoints/epoch=19.ckpt")
intent_entity = inferencer.inference("list of tasks finished with priority 0", intent_topk=5)

Organizing Intent & Entity dictionary in NLU markdown file ...: 100%|██████████| 991/991 [00:00<00:00, 195170.93it/s]
Extracting Intent & Entity in NLU markdown files...: 100%|██████████| 991/991 [00:00<00:00, 18692.91it/s]


Intents: {'affirmative': 0, 'anythingElse': 1, 'bye': 2, 'greet': 3, 'listProjects': 4, 'listTasks': 5, 'negative': 6, 'nothingElse': 7}
Entities: {'O': 0, 'priority_B': 1, 'priority_I': 2, 'state_B': 3, 'state_I': 4}



In [27]:
intent_entity

{'text': 'list of tasks finished with priority 0',
 'intent': {'name': 'listTasks', 'confidence': 0.8964384198188782},
 'intent_ranking': [{'confidence': 0.8964384198188782, 'name': 'listTasks'},
  {'confidence': 0.05061880871653557, 'name': 'listProjects'},
  {'confidence': 0.01707865111529827, 'name': 'negative'},
  {'confidence': 0.008030795492231846, 'name': 'anythingElse'},
  {'confidence': 0.0074590472504496574, 'name': 'greet'}],
 'entities': [{'start': 14, 'end': 22, 'value': 'finished', 'entity': 'state'},
  {'start': 37, 'end': 38, 'value': '0', 'entity': 'priority'}]}

In [29]:
intent = intent_entity['intent']['name']
entities = intent_entity['entities']
print('intent: ',intent)
if len(intent_entity['entities'])>0:
    for i in range(len(intent_entity['entities'])):
        print('entity: ',intent_entity['entities'][i]['entity'], '    value: ', intent_entity['entities'][i]['value'])

intent:  listTasks
entity:  state     value:  finished
entity:  priority     value:  0
