In [None]:
! pip install pytorch-lightning

In [None]:
! pip install pytorch-nlp

In [None]:
! pip install transformers

In [13]:
from __future__ import unicode_literals, print_function, division

from collections import OrderedDict
from tqdm import tqdm
from typing import List

from transformers import ElectraTokenizer, ElectraModel

import torch
import numpy as np
import re

import os
import numpy as np
import json
import pandas as pd
from sklearn import metrics
from typing import List, Dict, Sequence
from sklearn import metrics

import glob
from pytorch_lightning.callbacks.base import Callback

from torch.nn import TransformerEncoder, TransformerEncoderLayer, LayerNorm
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau

from pytorch_lightning import Trainer
from argparse import Namespace

from torchnlp.metrics import get_accuracy, get_token_accuracy
from pytorch_lightning.metrics.functional import f1_score

import logging

In [14]:
class RasaIntentEntityDataset(torch.utils.data.Dataset):

    def __init__(self, markdown_lines: List[str], tokenizer, seq_len=128,):
        self.intent_dict = {}
        self.entity_dict = {}
        self.entity_dict["O"] = 0  # using BIO tagging

        self.dataset = []
        self.seq_len = seq_len

        intent_value_list = []
        entity_type_list = []

        current_intent_focus = ""

        text_list = []

        for line in tqdm(markdown_lines,desc="Organizing Intent & Entity dictionary in NLU markdown file ...",):
            if len(line.strip()) < 2:
                current_intent_focus = ""
                continue

            if "## " in line:
                if "intent:" in line:
                    intent_value_list.append(line.split(":")[1].strip())
                    current_intent_focus = line.split(":")[1].strip()
                else:
                    current_intent_focus = ""

            else:
                if current_intent_focus != "":
                    text = line[2:].strip().lower()

                    for type_str in re.finditer(r"\([a-zA-Z_1-2]+\)", text):
                        entity_type = (text[type_str.start() + 1 : type_str.end() - 1].replace("(", "").replace(")", ""))
                        entity_type_list.append(entity_type)

                    text = re.sub(r"\([a-zA-Z_1-2]+\)", "", text)  # remove (...) str
                    text = text.replace("[", "").replace("]", "")  # remove '[',']' special char

                    if len(text) > 0:
                        text_list.append(text.strip())


        #dataset tokenizer setting
        if "ElectraTokenizer" in str(type(tokenizer)):
            self.tokenizer = tokenizer
            self.pad_token_id = 0
            self.unk_token_id = 1
            self.eos_token_id = 3 #[SEP] token
            self.bos_token_id = 2 #[CLS] token

        else:
            raise ValueError('not supported tokenizer type')

        intent_value_list = sorted(intent_value_list)
        for intent_value in intent_value_list:
            if intent_value not in self.intent_dict.keys():
                self.intent_dict[intent_value] = len(self.intent_dict)

        entity_type_list = sorted(entity_type_list)
        for entity_type in entity_type_list:
            if entity_type + '_B' not in self.entity_dict.keys():
                self.entity_dict[str(entity_type) + '_B'] = len(self.entity_dict)
            if entity_type + '_I' not in self.entity_dict.keys():
                self.entity_dict[str(entity_type) + '_I'] = len(self.entity_dict)

        current_intent_focus = ""

        for line in tqdm(markdown_lines, desc="Extracting Intent & Entity in NLU markdown files...",):
            if len(line.strip()) < 2:
                current_intent_focus = ""
                continue

            if "## " in line:
                if "intent:" in line:
                    current_intent_focus = line.split(":")[1].strip()
                else:
                    current_intent_focus = ""
            else:
                if current_intent_focus != "":  # intent & entity sentence occur case
                    text = line[2:].strip().lower()

                    entity_value_list = []
                    for value in re.finditer(r"\[(.*?)\]", text):
                        entity_value_list.append(text[value.start() + 1 : value.end() - 1].replace("[","").replace("]",""))

                    entity_type_list = []
                    for type_str in re.finditer(r"\([a-zA-Z_1-2]+\)", text):
                        entity_type = (text[type_str.start() + 1 : type_str.end() - 1].replace("(","").replace(")",""))
                        entity_type_list.append(entity_type)

                    text = re.sub(r"\([a-zA-Z_1-2]+\)", "", text)  # remove (...) str
                    text = text.replace("[", "").replace("]", "")  # remove '[',']' special char

                    if len(text) > 0:
                        each_data_dict = {}
                        each_data_dict["text"] = text.strip()
                        each_data_dict["intent"] = current_intent_focus
                        each_data_dict["intent_idx"] = self.intent_dict[current_intent_focus]
                        each_data_dict["entities"] = []

                        for value, type_str in zip(entity_value_list, entity_type_list):
                            for entity in re.finditer(value, text):
                                entity_tokens = self.tokenize(value)

                                for i, entity_token in enumerate(entity_tokens):
                                    if i == 0:
                                        BIO_type_str = type_str + '_B'
                                    else:
                                        BIO_type_str = type_str + '_I'

                                    each_data_dict["entities"].append(
                                        {
                                            "start": text.find(entity_token, entity.start(), entity.end()),
                                            "end": text.find(entity_token, entity.start(), entity.end()) + len(entity_token),
                                            "entity": type_str,
                                            "value": entity_token,
                                            "entity_idx": self.entity_dict[BIO_type_str],
                                        }
                                    )


                        self.dataset.append(each_data_dict)

        
        print(f"Intents: {self.intent_dict}")
        print(f"Entities: {self.entity_dict}")

    def tokenize(self, text: str, skip_special_char=True):
        if "ElectraTokenizer" in str(type(self.tokenizer)):
            if skip_special_char:
                return self.tokenizer.tokenize(text)
            else:
                return [token.replace('#','') for token in self.tokenizer.tokenize(text)]
        else:
            raise ValueError('not supported tokenizer type')
            
    def encode(self, text: str, padding: bool = True, return_tensor: bool = True):
        tokens = self.tokenizer.encode(text)
        if type(tokens) == list:
            tokens = torch.tensor(tokens).long()
        else:
            tokens = tokens.long()

        if padding:
            if len(tokens) >= self.seq_len:
                tokens = tokens[: self.seq_len]
            else:
                pad_tensor = torch.tensor([self.pad_token_id] * (self.seq_len - len(tokens)))
            
                tokens = torch.cat((tokens, pad_tensor), 0)

        if return_tensor:
            return tokens
        else:
            return tokens.numpy()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        tokens = self.encode(self.dataset[idx]["text"])

        intent_idx = torch.tensor([self.dataset[idx]["intent_idx"]])

        entity_idx = np.array(self.seq_len * [0]) # O tag indicate 0(zero)

        for entity_info in self.dataset[idx]["entities"]:
            if "ElectraTokenizer" in str(type(self.tokenizer)):
                ##check whether entity value is include in splitted token
                for token_seq, token_value in enumerate(tokens):
                    # Consider [CLS](bos) token
                    if token_seq == 0:
                        continue

                    for entity_seq, entity_info in enumerate(self.dataset[idx]["entities"]):
                        if (self.tokenizer.convert_ids_to_tokens([token_value.item()])[0] in entity_info["value"]):
                            entity_idx[token_seq] = entity_info["entity_idx"]
                            break

        entity_idx = torch.from_numpy(entity_idx)

        return tokens, intent_idx, entity_idx, self.dataset[idx]["text"]

    def get_intent_idx(self):
        return self.intent_dict

    def get_entity_idx(self):
        return self.entity_dict

    def get_vocab_size(self):
        return self.tokenizer.vocab_size

    def get_seq_len(self):
        return self.seq_len

In [15]:
class RasaIntentEntityValidDataset(torch.utils.data.Dataset):
    
    def __init__(self, markdown_lines: List[str], tokenizer, seq_len=128,):
        self.intent_dict = {}
        self.entity_dict = {}
        self.entity_dict["O"] = 0  # using BIO tagging

        self.dataset = []
        self.seq_len = seq_len
        self.tokenizer = tokenizer

        intent_value_list = []
        entity_type_list = []

        current_intent_focus = ""

        text_list = []

        for line in tqdm(markdown_lines, desc="Organizing Intent & Entity dictionary in NLU markdown file ...",):
            if len(line.strip()) < 2:
                current_intent_focus = ""
                continue

            if "## " in line:
                if "intent:" in line:
                    intent_value_list.append(line.split(":")[1].strip())
                    current_intent_focus = line.split(":")[1].strip()
                else:
                    current_intent_focus = ""

            else:
                if current_intent_focus != "":
                    text = line[2:].strip()

                    for type_str in re.finditer(r"\([a-zA-Z_1-2]+\)", text):
                        entity_type = (text[type_str.start() + 1 : type_str.end() - 1].replace("(", "").replace(")", ""))
                        entity_type_list.append(entity_type)

                    text = re.sub(r"\([a-zA-Z_1-2]+\)", "", text)  # remove (...) str
                    text = text.replace("[", "").replace("]", "")  # remove '[',']' special char

                    if len(text) > 0:
                        text_list.append(text.strip())
        
        #dataset tokenizer setting
        if "ElectraTokenizer" in str(type(tokenizer)):
            self.tokenizer = tokenizer
            self.pad_token_id = 0
            self.unk_token_id = 1
            self.eos_token_id = 3 #[SEP] token
            self.bos_token_id = 2 #[CLS] token

        else:
            raise ValueError('not supported tokenizer type')



        intent_value_list = sorted(intent_value_list)
        for intent_value in intent_value_list:
            if intent_value not in self.intent_dict.keys():
                self.intent_dict[intent_value] = len(self.intent_dict)

        entity_type_list = sorted(entity_type_list)
        for entity_type in entity_type_list:
            if entity_type + '_B' not in self.entity_dict.keys():
                self.entity_dict[str(entity_type) + '_B'] = len(self.entity_dict)
            if entity_type + '_I' not in self.entity_dict.keys():
                self.entity_dict[str(entity_type) + '_I'] = len(self.entity_dict)

        current_intent_focus = ""

        for line in tqdm(markdown_lines, desc="Extracting Intent & Entity in NLU markdown files...",):
            if len(line.strip()) < 2:
                current_intent_focus = ""
                continue

            if "## " in line:
                if "intent:" in line:
                    current_intent_focus = line.split(":")[1].strip()
                else:
                    current_intent_focus = ""
            else:
                if current_intent_focus != "":  # intent & entity sentence occur case
                    text = line[2:]

                    entity_value_list = []
                    for value in re.finditer(r"\[(.*?)\]", text):
                        entity_value_list.append(text[value.start() + 1 : value.end() - 1].replace("[", "").replace("]", ""))

                    entity_type_list = []
                    for type_str in re.finditer(r"\([a-zA-Z_1-2]+\)", text):
                        entity_type = (text[type_str.start() + 1 : type_str.end() - 1].replace("(", "").replace(")", ""))
                        entity_type_list.append(entity_type)

                    text = re.sub(r"\([a-zA-Z_1-2]+\)", "", text)  # remove (...) str
                    text = text.replace("[", "").replace("]", "")  # remove '[',']' special char

                    if len(text) > 0:
                        each_data_dict = {}
                        each_data_dict["text"] = text.strip()
                        each_data_dict["intent"] = current_intent_focus
                        each_data_dict["intent_idx"] = self.intent_dict[current_intent_focus]
                        each_data_dict["entities"] = []

                        for value, type_str in zip(entity_value_list, entity_type_list):
                            for entity in re.finditer(value, text):
                                entity_tokens = self.tokenize(value)

                                for i, entity_token in enumerate(entity_tokens):
                                    if i == 0:
                                        BIO_type_str = type_str + '_B'
                                    else:
                                        BIO_type_str = type_str + '_I'

                                    each_data_dict["entities"].append(
                                        {
                                            "start": text.find(entity_token, entity.start(), entity.end()),
                                            "end": text.find(entity_token, entity.start(), entity.end()) + len(entity_token),
                                            "entity": type_str,
                                            "value": entity_token,
                                            "entity_idx": self.entity_dict[BIO_type_str],
                                        }
                                    )


                        self.dataset.append(each_data_dict)

        
        print(f"Intents: {self.intent_dict}")
        print(f"Entities: {self.entity_dict}")

    def tokenize(self, text: str, skip_special_char=True):
        if "ElectraTokenizer" in str(type(self.tokenizer)):
            if skip_special_char:
                return self.tokenizer.tokenize(text)
            else:
                return [token.replace('#','') for token in self.tokenizer.tokenize(text)]
        else:
            raise ValueError('not supported tokenizer type')
            
    def encode(self, text: str, padding: bool = True, return_tensor: bool = True):
        tokens = self.tokenizer.encode(text)
        if type(tokens) == list:
            tokens = torch.tensor(tokens).long()
        else:
            tokens = tokens.long()

        if padding:
            if len(tokens) >= self.seq_len:
                tokens = tokens[: self.seq_len]
            else:
                pad_tensor = torch.tensor([self.pad_token_id] * (self.seq_len - len(tokens)))
            
                tokens = torch.cat((tokens, pad_tensor), 0)

        if return_tensor:
            return tokens
        else:
            return tokens.numpy()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        tokens = self.encode(self.dataset[idx]["text"])

        intent_idx = torch.tensor([self.dataset[idx]["intent_idx"]])

        entity_idx = np.array(self.seq_len * [0]) # O tag indicate 0(zero)

        for entity_info in self.dataset[idx]["entities"]:
            if "ElectraTokenizer" in str(type(self.tokenizer)):
                ##check whether entity value is include in splitted token
                for token_seq, token_value in enumerate(tokens):
                    # Consider [CLS](bos) token
                    if token_seq == 0:
                        continue

                    for entity_seq, entity_info in enumerate(self.dataset[idx]["entities"]):
                        if (self.tokenizer.convert_ids_to_tokens([token_value.item()])[0] in entity_info["value"]):
                            entity_idx[token_seq] = entity_info["entity_idx"]
                            break

        entity_idx = torch.from_numpy(entity_idx)
        

        return tokens, intent_idx, entity_idx, self.dataset[idx]["text"]

    def get_intent_idx(self):
        return self.intent_dict

    def get_entity_idx(self):
        return self.entity_dict

    def get_vocab_size(self):
        return self.tokenizer.vocab_size

    def get_seq_len(self):
        return self.seq_len

In [16]:
#tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
#nlu_data_train = open("RASA_NLU_training_dataset.md", encoding="utf-8").readlines()
#dataset_train = RasaIntentEntityValidDataset(markdown_lines=nlu_data_train, tokenizer=tokenizer)

In [17]:
def confusion_matrix(pred: list, label:list, label_index:dict, file_name:str=None, output_dir='results'):
    cm = ConfusionMatrix(pred, label)
    cm.relabel(mapping=label_index)
    cm_matrix = cm.matrix
    cm_normalized_matrix = cm.normalized_matrix

    if file_name is None:
        file_name = 'confusion_matrix.json'
    
    normalized_file_name = file_name.replace('.', '_normalized.')

    if output_dir is not None:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, file_name), 'w') as fp:
            json.dump(cm_matrix, fp, indent=4)
        
        # with open(os.path.join(output_dir, normalized_file_name), 'w') as fp:
        #     json.dump(cm_normalized_matrix, fp, indent=4)
    
    return cm_matrix


def show_rasa_metrics(pred, label, labels=None, target_names=None, output_dir='results', file_name=None):

    output = metrics.classification_report(label, pred, labels=labels, target_names=target_names, output_dict=True)
                                           
    if file_name is None:
        file_name = 'reports.json'
        
    if output_dir is not None:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, file_name), 'w') as fp:
            json.dump(output, fp, indent=4)

    return output


def show_entity_metrics(pred, label,  output_dir='results', file_name=None):
    entity_metric = Entity_Matrics(label, pred)

    output = entity_metric.generate_report()
                                           
    if file_name is None:
        file_name = 'reports.json'
        
    if output_dir is not None:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, file_name), 'w') as fp:
            json.dump(output, fp, indent=4)

    return output

class Entity_Matrics:
    def __init__(self, sents_true_labels: Sequence[Sequence[Dict]], sents_pred_labels:Sequence[Sequence[Dict]]):
        self.sents_true_labels = sents_true_labels
        self.sents_pred_labels = sents_pred_labels 
        self.types = set(entity['entity'] for sent in sents_true_labels for entity in sent)
        self.confusion_matrices = {type: {'TP': 0, 'TN': 0, 'FP': 0, 'FN': 0} for type in self.types}
        self.scores = {type: {'p': 0, 'r': 0, 'f1': 0} for type in self.types}

    def cal_confusion_matrices(self) -> Dict[str, Dict]:
        """Calculate confusion matrices for all sentences."""
        for true_labels, pred_labels in zip(self.sents_true_labels, self.sents_pred_labels):
            for true_label in true_labels: 
                entity_type = true_label['entity']
                prediction_hit_count = 0 
                for pred_label in pred_labels:
                    if pred_label['entity'] != entity_type:
                        continue
                    if pred_label['start'] == true_label['start'] and pred_label['end'] == true_label['end'] and pred_label['value'] == true_label['value']: # TP
                        self.confusion_matrices[entity_type]['TP'] += 1
                        prediction_hit_count += 1
                    elif ((pred_label['start'] == true_label['start']) or (pred_label['end'] == true_label['end'])) and pred_label['value'] != true_label['value']: # boundry error, count FN, FP
                        self.confusion_matrices[entity_type]['FP'] += 1
                        self.confusion_matrices[entity_type]['FN'] += 1
                        prediction_hit_count += 1
                if prediction_hit_count != 1: # FN, model cannot make a prediction for true_label
                    self.confusion_matrices[entity_type]['FN'] += 1
                prediction_hit_count = 0 # reset to default

    def cal_scores(self) -> Dict[str, Dict]:
        """Calculate precision, recall, f1."""
        confusion_matrices = self.confusion_matrices 
        scores = {type: {'p': 0, 'r': 0, 'f1': 0} for type in self.types}
        
        for entity_type, confusion_matrix in confusion_matrices.items():
            if confusion_matrix['TP'] == 0 and confusion_matrix['FP'] == 0:
                scores[entity_type]['p'] = 0
            else:
                scores[entity_type]['p'] = confusion_matrix['TP'] / (confusion_matrix['TP'] + confusion_matrix['FP'])

            if confusion_matrix['TP'] == 0 and confusion_matrix['FN'] == 0:
                scores[entity_type]['r'] = 0
            else:
                scores[entity_type]['r'] = confusion_matrix['TP'] / (confusion_matrix['TP'] + confusion_matrix['FN']) 

            if scores[entity_type]['p'] == 0 or scores[entity_type]['r'] == 0:
                scores[entity_type]['f1'] = 0
            else:
                scores[entity_type]['f1'] = 2*scores[entity_type]['p']*scores[entity_type]['r'] / (scores[entity_type]['p']+scores[entity_type]['r'])  
        self.scores = scores

    def print_confusion_matrices(self):
        for entity_type, matrix in self.confusion_matrices.items():
            print(f"{entity_type}: {matrix}")

    def print_scores(self):
        for entity_type, score in self.scores.items():
            print(f"{entity_type}: f1 {score['f1']:.4f}, precision {score['p']:.4f}, recall {score['r']:.4f}")
        
    def cal_micro_avg(self):
        sum_TP = 0
        sum_FP = 0
        sum_FN = 0
        support = 0
        for k, v in self.confusion_matrices.items():
            sum_TP += v['TP']
            sum_FP += v['FP']
            sum_FN += v['FN']
            support += np.array(list(self.confusion_matrices[k].values())).sum().item()
        precision = sum_TP / (sum_TP + sum_FP)
        recall = sum_TP / (sum_TP + sum_FN)
        f1 = 2*(precision * recall / (precision + recall))
        self.micro_avg = dict()
        self.micro_avg['precision'] = precision
        self.micro_avg['recall'] = recall
        self.micro_avg['f1-score'] = f1
        self.micro_avg['support'] = support
    
    def cal_macro_avg(self):
        precision = []
        recall = []
        support = 0
        for k, v in self.scores.items():
            precision.append(v['p'])
            recall.append(v['r'])
        for k, v in self.confusion_matrices.items():
            support += np.array(list(self.confusion_matrices[k].values())).sum().item()
        precision = np.array(precision).mean()
        recall = np.array(recall).mean()
        f1 = 2*(precision * recall / (precision + recall))
        self.macro_avg = dict()
        self.macro_avg['precision'] = precision
        self.macro_avg['recall'] = recall
        self.macro_avg['f1-score'] = f1
        self.macro_avg['support'] = support
    
    def cal_weight_avg(self):
        tp = []
        fp = []
        fn = []
        weight = []
        support = 0
        for k, v in self.confusion_matrices.items():
            tp.append(v['TP'])
            fp.append(v['FP'])
            fn.append(v['FN'])
            weight.append(np.array(list(v.values())).sum().item())
            support += np.array(list(self.confusion_matrices[k].values())).sum().item()

        weight = np.array(weight) / np.array(weight).sum()
        tp = np.array(tp)
        fp = np.array(fp)
        fn = np.array(fn)
        precision = (weight * tp).sum() / ((weight * tp).sum() + (weight * fp).sum())
        recall = (weight * tp).sum() / ((weight * tp).sum() + (weight * fn).sum())
        f1 = 2*(precision * recall / (precision + recall))
        self.weight_avg = dict()
        self.weight_avg['precision'] = precision.item()
        self.weight_avg['recall'] = recall.item()
        self.weight_avg['f1-score'] = f1.item()
        self.weight_avg['support'] = support
    
    def generate_report(self):
        self.cal_confusion_matrices()
        self.cal_scores()
        self.cal_micro_avg()
        self.cal_macro_avg()
        self.cal_weight_avg()
        
        report = dict()
        for k, v in self.scores.items():
            report[k] = dict()
            report[k]['precision'] = v['p']
            report[k]['recall'] = v['r']
            report[k]['f1-score'] = v['f1']
            report[k]['support'] = np.array(list(self.confusion_matrices[k].values())).sum().item()
        report['micro avg'] = self.micro_avg
        report['macro avg'] = self.macro_avg
        report['weighted avg'] = self.weight_avg
        return report

In [18]:
class NERDecoder(object):

    def __init__(self, entity_dict:dict, tokenizer):
        self.entity_dict = entity_dict
        self.tokenizer = tokenizer
    
    def is_same_entity(self, i, j):
        # check whether XXX_B, XXX_I tag are same 
        return self.entity_dict[i][:self.entity_dict[i].rfind('_')].strip() == self.entity_dict[j][:self.entity_dict[j].rfind('_')].strip()

    def process(self, tokens, entity_indices, text):
        # mapping entity result
        entities = []
        start_idx = -1

        entity_indices = entity_indices.tolist()[:len(text)]
        start_token_position = -1

        # except first sequnce token whcih indicate BOS or [CLS] token
        if type(tokens) == torch.Tensor:
            tokens = tokens.long().tolist()

        for i, entity_idx_value in enumerate(entity_indices):
            if entity_idx_value != 0 and start_token_position == -1:
                start_token_position = i
            elif start_token_position >= 0 and not self.is_same_entity(entity_indices[i-1],entity_indices[i]):
                end_token_position = i - 1

                #print ('start_token_position')
                #print (start_token_position)
                #print ('end_token_position')
                #print (end_token_position)

                # find start text position
                token_idx = tokens[start_token_position + 1]
                if "ElectraTokenizer" in str(
                    type(self.tokenizer)
                ):  # ElectraTokenizer
                    token_value = self.tokenizer.convert_ids_to_tokens([token_idx])[0].replace("#", "")

                if len(token_value.strip()) == 0:
                    start_token_position = -1
                    continue
                    
                start_position = text.find(token_value.strip())

                # find end text position
                token_idx = tokens[end_token_position + 1]
                if "ElectraTokenizer" in str(
                    type(self.tokenizer)
                ):  # ElectraTokenizer
                    token_value = self.tokenizer.convert_ids_to_tokens([token_idx])[0].replace("#", "")

                end_position = text.find(token_value.strip(), start_position) + len(token_value.strip())

                if self.entity_dict[entity_indices[i-1]] != "O": # ignore 'O' tag
                    entities.append(
                        {
                            "start": start_position,
                            "end": end_position,
                            "value": text[start_position:end_position],
                            "entity": self.entity_dict[entity_indices[i-1]][:self.entity_dict[entity_indices[i-1]].rfind('_')]
                        }
                    )
                        
                    start_token_position = -1

                if entity_idx_value == 0:
                    start_token_position = -1


        return entities

In [19]:
def show_intent_report(dataset, pl_module, tokenizer, file_name=None, output_dir=None, cuda=True):
    ##generate rasa performance matrics
    # text = []
    preds = np.array([])
    targets = np.array([])
    logits = np.array([])
    label_dict = dict()
    pl_module.model.eval()
    for k, v in pl_module.intent_dict.items():
        label_dict[int(k)] = v
    dataloader = DataLoader(dataset, batch_size=32)

    for batch in tqdm(dataloader, desc="load intent dataset"):
        #dataset follows RasaIntentEntityDataset which defined in this package
        input_ids, intent_idx, entity_idx, text = batch
        model =  pl_module.model
        if cuda > 0:
            input_ids = input_ids.cuda()
            model = model.cuda()
        intent_pred, entity_pred = model.forward(input_ids)
        y_label = intent_pred.argmax(1).cpu().numpy()
        preds = np.append(preds, y_label)
        targets = np.append(targets, intent_idx.cpu().numpy())
        
        logit = intent_pred.detach().cpu()
        softmax = torch.nn.Softmax(dim=-1)
        logit = softmax(logit).numpy()
        logits = np.append(logits, logit.max(-1))
    
    preds = preds.astype(int)
    targets = targets.astype(int)

    labels = list(label_dict.keys())
    target_names = list(label_dict.values())
    
    report = show_rasa_metrics(pred=preds, label=targets, labels=labels, target_names=target_names, file_name=file_name, output_dir=output_dir)

def show_entity_report(dataset, pl_module, tokenizer, file_name=None, output_dir=None, cuda=True):
    
    ##generate rasa performance matrics
    text = []
    label_dict = dict()
    pl_module.model.eval()
    for k, v in pl_module.entity_dict.items():
        label_dict[int(k)] = v

    decoder = NERDecoder(label_dict, tokenizer)
    dataloader = DataLoader(dataset, batch_size=32)

    preds = list()
    targets = list()
    labels = set()

    for batch in tqdm(dataloader, desc="load entity dataset"):
        input_ids, intent_idx, entity_idx, token = batch
        text.extend(token)
        if cuda > 0:
            input_ids = input_ids.cuda()
        _, entity_result = pl_module.model.forward(input_ids)

        entity_result = entity_result.detach().cpu()
        _, entity_indices = torch.max(entity_result, dim=-1)



        for i in range(entity_idx.shape[0]):
            decode_original = decoder.process(input_ids[i].cpu().numpy(), entity_idx[i].numpy(), token[i])
            decode_pred = decoder.process(input_ids[i].cpu().numpy(), entity_indices[i].numpy(), token[i])
            targets.append(decode_original)
            preds.append(decode_pred)


    report = show_entity_metrics(pred=preds, label=targets, file_name=file_name, output_dir=output_dir)


In [20]:
class PerfCallback(Callback):
    def __init__(self, file_path=None, gpu_num=0, report_nm=None, output_dir=None, root_path=None):
        self.file_path = file_path
        if gpu_num > 0:
            self.cuda = True
        else:
            self.cuda = False
        self.report_nm = report_nm
        self.output_dir = output_dir
        
        if root_path is None:
            self.root_path = 'lightning_logs'
        else:
            self.root_path = os.path.join(root_path, 'lightning_logs')

    def on_train_end(self, trainer, pl_module):        
        if self.file_path is None:
            print("evaluate valid data")
            dataset = pl_module.val_dataset
            tokenizer = pl_module.dataset.tokenizer
        else:
            print("evaluate new data")
            tokenizer = pl_module.model.dataset.tokenizer
            self.nlu_data = open(self.file_path, encoding="utf-8").readlines()
            dataset = RasaIntentEntityValidDataset(markdown_lines=self.nlu_data, tokenizer=tokenizer)
                
        if self.output_dir is None:
            folder_path = [f for f in glob.glob(os.path.join(self.root_path, "**/"), recursive=False)]
            folder_path.sort()
            self.output_dir  = folder_path[-1]
        self.output_dir = os.path.join(self.output_dir, 'results')
        intent_report_nm = self.report_nm.replace('.', '_intent.')
        entity_report_nm = self.report_nm.replace('.', '_entity.')
        show_intent_report(dataset, pl_module, tokenizer, file_name=intent_report_nm, output_dir=self.output_dir, cuda=self.cuda)
        show_entity_report(dataset, pl_module, tokenizer, file_name=entity_report_nm, output_dir=self.output_dir, cuda=self.cuda)

In [21]:
class EmbeddingTransformer(nn.Module):
    def __init__(
        self,
        backbone: None,
        vocab_size: int,
        seq_len: int,
        intent_class_num: int,
        entity_class_num: int,
        d_model=512,
        nhead=8,
        num_encoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
        activation="relu",
        pad_token_id: int = 0,
    ):
        super(EmbeddingTransformer, self).__init__()
        self.backbone = backbone
        self.seq_len = seq_len
        self.pad_token_id = pad_token_id

        if backbone is None:
            self.encoder = nn.TransformerEncoder(
                TransformerEncoderLayer(
                    d_model, nhead, dim_feedforward, dropout, activation
                ),
                num_encoder_layers,
                LayerNorm(d_model),
            )
        else:  # pre-defined model architecture use
            if backbone == "electra":
                self.encoder = ElectraModel.from_pretrained("google/electra-small-discriminator")

            d_model = self.encoder.config.hidden_size

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(self.seq_len, d_model)

        self.intent_feature = nn.Linear(d_model, intent_class_num)
        self.entity_feature = nn.Linear(d_model, entity_class_num)

    def forward(self, x):
        if self.backbone in ["electra"]:
            feature = self.encoder(x)

            if type(feature) == tuple:
                feature = feature[0]  # last_hidden_state (N,S,E)

            # first token in sequence used to intent classification
            intent_feature = self.intent_feature(feature[:, 0, :]) # (N,E) -> (N,i_C)

            # other tokens in sequence used to entity classification
            entity_feature = self.entity_feature(feature[:, :, :]) # (N,S,E) -> (N,S,e_C)

            return intent_feature, entity_feature

In [22]:
class DualIntentEntityTransformer(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()

        self.hparams = hparams
        if type(self.hparams) == dict:
            self.hparams = Namespace(**self.hparams)

        self.dataset = RasaIntentEntityDataset(
            markdown_lines=self.hparams.nlu_data,
            tokenizer=self.hparams.tokenizer,
        )

        self.model = EmbeddingTransformer(
            backbone=self.hparams.backbone,
            vocab_size=self.dataset.get_vocab_size(),
            seq_len=self.dataset.get_seq_len(),
            intent_class_num=len(self.dataset.get_intent_idx()),
            entity_class_num=len(self.dataset.get_entity_idx()),
            d_model=self.hparams.d_model,
            num_encoder_layers=self.hparams.num_encoder_layers,
            pad_token_id=self.dataset.pad_token_id
        )

        self.train_ratio = self.hparams.train_ratio
        self.batch_size = self.hparams.batch_size
        self.optimizer = self.hparams.optimizer
        self.intent_optimizer_lr = self.hparams.intent_optimizer_lr
        self.entity_optimizer_lr = self.hparams.entity_optimizer_lr

        self.intent_loss_fn = nn.CrossEntropyLoss()
        # reduce O tag class weight to figure out entity imbalance distribution
        self.entity_loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([0.1] + [1.0] * (len(self.dataset.get_entity_idx()) - 1)))

    def forward(self, x):
        return self.model(x)

    def prepare_data(self):
        train_length = int(len(self.dataset) * self.train_ratio)

        self.train_dataset, self.val_dataset = random_split(
            self.dataset, [train_length, len(self.dataset) - train_length],
        )

        self.hparams.intent_label = self.get_intent_label()
        self.hparams.entity_label = self.get_entity_label()
    
    def get_intent_label(self):
        self.intent_dict = {}
        for k, v in self.dataset.intent_dict.items():
            self.intent_dict[str(v)] = k
        return self.intent_dict 
    
    def get_entity_label(self):
        self.entity_dict = {}
        for k, v in self.dataset.entity_dict.items():
            self.entity_dict[str(v)] = k
        return self.entity_dict

    def train_dataloader(self):
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
        )
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
        )
        return val_loader

    def configure_optimizers(self):
        intent_optimizer = eval(
            f"{self.optimizer}(self.parameters(), lr={self.intent_optimizer_lr})"
        )
        entity_optimizer = eval(
            f"{self.optimizer}(self.parameters(), lr={self.entity_optimizer_lr})"
        )

        return (
            [intent_optimizer, entity_optimizer],
            # [StepLR(intent_optimizer, step_size=1),StepLR(entity_optimizer, step_size=1),],
            [
                ReduceLROnPlateau(intent_optimizer, patience=1),
                ReduceLROnPlateau(entity_optimizer, patience=1),
            ],
        )

    def training_step(self, batch, batch_idx, optimizer_idx):
        self.model.train()

        tokens, intent_idx, entity_idx, text = batch
        intent_pred, entity_pred = self.forward(tokens)

        intent_acc = get_accuracy(intent_pred.argmax(1).cpu(), intent_idx.cpu())[0]
        #entity_acc = get_token_accuracy(entity_pred.argmax(2), entity_idx, ignore_index=self.dataset.pad_token_id)[0]

        tensorboard_logs = {
            "train/intent/acc": intent_acc,
            #"train/entity/acc": entity_acc,
        }

        if optimizer_idx == 0:
            intent_loss = self.intent_loss_fn(intent_pred, intent_idx.squeeze(1))
            tensorboard_logs["train/intent/loss"] = intent_loss

            return {
                "loss": intent_loss,
                "log": tensorboard_logs,
            }

        if optimizer_idx == 1:
            entity_loss = self.entity_loss_fn(entity_pred.transpose(1, 2), entity_idx.long(),)
            tensorboard_logs["train/entity/loss"] = entity_loss

            return {
                "loss": entity_loss,
                "log": tensorboard_logs,
            }

    def validation_step(self, batch, batch_idx):
        self.model.eval()

        tokens, intent_idx, entity_idx, text = batch
        intent_pred, entity_pred = self.forward(tokens)
        

        intent_acc = get_accuracy(intent_pred.argmax(1).cpu(), intent_idx.cpu())[0]
        #entity_acc = get_token_accuracy(entity_pred.argmax(2), entity_idx, ignore_index=self.dataset.pad_token_id)[0]
        intent_f1 = f1_score(intent_pred.argmax(1), intent_idx)

        intent_loss = self.intent_loss_fn(intent_pred, intent_idx.squeeze(1))
        entity_loss = self.entity_loss_fn(entity_pred.transpose(1, 2), entity_idx.long(),)

        return {
            "val_intent_acc": torch.Tensor([intent_acc]),
            #"val_entity_acc": torch.Tensor([entity_acc]),
            "val_intent_f1": intent_f1,
            "val_intent_loss": intent_loss,
            "val_entity_loss": entity_loss,
            "val_loss": intent_loss + entity_loss,
        }

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_intent_acc = torch.stack([x["val_intent_acc"] for x in outputs]).mean()
        #avg_entity_acc = torch.stack([x["val_entity_acc"] for x in outputs]).mean()
        avg_intent_f1 = torch.stack([x["val_intent_f1"] for x in outputs]).mean()

        tensorboard_logs = {
            "val/loss": avg_loss,
            "val/intent_acc": avg_intent_acc,
            #"val/entity_acc": avg_entity_acc,
            "val/intent_f1": avg_intent_f1,
        }

        return {
            "val_loss": avg_loss,
            "val_intent_f1": avg_intent_f1,
            #"val_entity_acc": avg_entity_acc,
            "log": tensorboard_logs,
            "progress_bar": tensorboard_logs,
        }

In [23]:
def train(
    file_path,
    # training args
    train_ratio=0.8,
    batch_size=None,
    optimizer="Adam",
    intent_optimizer_lr=1e-5,
    entity_optimizer_lr=2e-5,
    checkpoint_path=os.getcwd(),
    max_epochs=20,
    tokenizer_type="char",
    # model args
    # refered below link to optimize model
    # https://www.notion.so/A-Primer-in-BERTology-What-we-know-about-how-BERT-works-aca45feaba2747f09f1a3cdd1b1bbe16
    backbone=None,
    d_model=256,
    num_encoder_layers=2,
    **kwargs
):
    gpu_num = torch.cuda.device_count()

    if backbone is None:
        report_nm = "diet_{}_tokenizer_report.json".format(tokenizer_type)
    else:
        report_nm = "{}_report.json".format(backbone)

    if batch_size is not None:
        trainer = Trainer(
            default_root_dir=checkpoint_path,
            max_epochs=max_epochs,
            gpus=gpu_num,
            callbacks=[
                PerfCallback(
                    gpu_num=gpu_num, report_nm=report_nm, root_path=checkpoint_path
                )
            ],
        )
    else:
        trainer = Trainer(
            default_root_dir=checkpoint_path,
            max_epochs=max_epochs,
            gpus=gpu_num,
            auto_scale_batch_size='binsearch',
            callbacks=[
                PerfCallback(
                    gpu_num=gpu_num, report_nm=report_nm, root_path=checkpoint_path
                )
            ],
        )

    model_args = {}

    # training args
    model_args["max_epochs"] = max_epochs
    model_args["nlu_data"] = open(file_path, encoding="utf-8").readlines()
    model_args["train_ratio"] = train_ratio
    model_args["batch_size"] = batch_size
    model_args["optimizer"] = optimizer
    model_args["intent_optimizer_lr"] = intent_optimizer_lr
    model_args["entity_optimizer_lr"] = entity_optimizer_lr

    if backbone is None:
        model_args["tokenizer"] = tokenizer_type

    else:
        if backbone == "electra":
            model_args["tokenizer"] = ElectraTokenizer.from_pretrained("google/electra-small-discriminator")

    # model args
    model_args["backbone"] = backbone
    model_args["d_model"] = d_model
    model_args["num_encoder_layers"] = num_encoder_layers

    for key, value in kwargs.items():
        model_args[key] = value

    hparams = Namespace(**model_args)

    model = DualIntentEntityTransformer(hparams)

    trainer.fit(model)

In [24]:
 train("Intent_Entity_dataset.md",
       #training args
        train_ratio=0.8,
        batch_size=32,
        optimizer="Adam",
        intent_optimizer_lr=1e-5,
        entity_optimizer_lr=2e-5,
        checkpoint_path=os.getcwd(),
        max_epochs=20,
        backbone="electra",

        #model args
        num_encoder_layers=3
    )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
Organizing Intent & Entity dictionary in NLU markdown file ...: 100%|██████████| 991/991 [00:00<00:00, 104046.54it/s]
Extracting Intent & Entity in NLU markdown files...: 100%|██████████| 991/991 [00:00<00:00, 32986.18it/s]


Intents: {'affirmative': 0, 'anythingElse': 1, 'bye': 2, 'greet': 3, 'listProjects': 4, 'listTasks': 5, 'negative': 6, 'nothingElse': 7}
Entities: {'O': 0, 'priority_B': 1, 'priority_I': 2, 'state_B': 3, 'state_I': 4}



  | Name           | Type                 | Params
--------------------------------------------------------
0 | model          | EmbeddingTransformer | 21 M  
1 | intent_loss_fn | CrossEntropyLoss     | 0     
2 | entity_loss_fn | CrossEntropyLoss     | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..

load intent dataset:   0%|          | 0/7 [00:00<?, ?it/s][A
load intent dataset:  14%|█▍        | 1/7 [00:00<00:00,  7.71it/s][A

evaluate valid data



load intent dataset:  29%|██▊       | 2/7 [00:00<00:00,  7.70it/s][A
load intent dataset:  43%|████▎     | 3/7 [00:00<00:00,  7.88it/s][A
load intent dataset:  57%|█████▋    | 4/7 [00:00<00:00,  7.80it/s][A
load intent dataset:  71%|███████▏  | 5/7 [00:00<00:00,  7.79it/s][A
load intent dataset: 100%|██████████| 7/7 [00:00<00:00,  8.64it/s][A

load entity dataset:   0%|          | 0/7 [00:00<?, ?it/s][A
load entity dataset:  14%|█▍        | 1/7 [00:00<00:00,  7.69it/s][A
load entity dataset:  29%|██▊       | 2/7 [00:00<00:00,  7.69it/s][A
load entity dataset:  43%|████▎     | 3/7 [00:00<00:00,  7.61it/s][A
load entity dataset:  57%|█████▋    | 4/7 [00:00<00:00,  7.38it/s][A
load entity dataset:  71%|███████▏  | 5/7 [00:00<00:00,  7.37it/s][A
load entity dataset: 100%|██████████| 7/7 [00:00<00:00,  8.17it/s][A





In [None]:
model = None
intent_dict = {}
entity_dict = {}

class Inferencer:
    def __init__(self, checkpoint_path: str):
        self.model = DualIntentEntityTransformer.load_from_checkpoint(checkpoint_path)
        self.model.model.eval()

        self.intent_dict = {}
        for k, v in self.model.dataset.intent_dict.items():
            self.intent_dict[v] = k # str key -> int key

        self.entity_dict = {}
        for k, v in self.model.dataset.entity_dict.items():
            self.entity_dict[v] = k # str key -> int key

        logging.info("intent dictionary")
        logging.info(self.intent_dict)
        print()

        logging.info("entity dictionary")
        logging.info(self.entity_dict)

    def is_same_entity(self, i, j):
        # check whether XXX_B, XXX_I tag are same 
        return self.entity_dict[i][:self.entity_dict[i].rfind('_')].strip() == self.entity_dict[j][:self.entity_dict[j].rfind('_')].strip()

    def inference(self, text: str, intent_topk=5):
        if self.model is None:
            raise ValueError(
                "model is not loaded, first call load_model(checkpoint_path)"
            )

        text = text.strip().lower()

        # encode text to token_indices
        tokens = self.model.dataset.encode(text)
        intent_result, entity_result = self.model.forward(tokens.unsqueeze(0).cpu())

        # mapping intent result
        rank_values, rank_indicies = torch.topk(
            nn.Softmax(dim=1)(intent_result)[0], k=intent_topk
        )
        intent = {}
        intent_ranking = []
        for i, (value, index) in enumerate(
            list(zip(rank_values.tolist(), rank_indicies.tolist()))
        ):
            intent_ranking.append(
                {"confidence": value, "name": self.intent_dict[index]}
            )

            if i == 0:
                intent["name"] = self.intent_dict[index]
                intent["confidence"] = value

        # mapping entity result
        entities = []

        # except first & last sequnce token whcih indicate BOS or [CLS] token & EOS or [SEP] token
        _, entity_indices = torch.max((entity_result)[0][1:-1, :], dim=1)
        start_idx = -1

        #print ('tokens')
        #print (tokens)
        #print ('predicted entities')
        #print (entity_indices)

        entity_indices = entity_indices.tolist()[:len(text)]
        start_token_position = -1

        # except first sequnce token whcih indicate BOS or [CLS] token
        if type(tokens) == torch.Tensor:
            tokens = tokens.long().tolist()

        for i, entity_idx_value in enumerate(entity_indices):
            if entity_idx_value != 0 and start_token_position == -1:
                start_token_position = i
            elif start_token_position >= 0 and not self.is_same_entity(entity_indices[i-1],entity_indices[i]):
                end_token_position = i - 1

                #print ('start_token_position')
                #print (start_token_position)
                #print ('end_token_position')
                #print (end_token_position)

                # find start text position
                token_idx = tokens[start_token_position + 1]
                if "ElectraTokenizer" in str(
                    type(self.model.dataset.tokenizer)
                ):  # ElectraTokenizer
                    token_value = self.model.dataset.tokenizer.convert_ids_to_tokens([token_idx])[0].replace("#", "")

                if len(token_value.strip()) == 0:
                    start_token_position = -1
                    continue
                  
                start_position = text.find(token_value.strip())

                # find end text position
                token_idx = tokens[end_token_position + 1]
                if "ElectraTokenizer" in str(
                    type(self.model.dataset.tokenizer)
                ):  # ElectraTokenizer
                    token_value = self.model.dataset.tokenizer.convert_ids_to_tokens([token_idx])[0].replace("#", "")

                end_position = text.find(token_value.strip(), start_position) + len(token_value.strip())

                if self.entity_dict[entity_indices[i-1]] != "O": # ignore 'O' tag
                    entities.append(
                         {
                            "start": start_position,
                            "end": end_position,
                            "value": text[start_position:end_position],
                            "entity": self.entity_dict[entity_indices[i-1]][:self.entity_dict[entity_indices[i-1]].rfind('_')]
                        }
                    )
                      
                    start_token_position = -1

                if entity_idx_value == 0:
                    start_token_position = -1

        result = {
            "text": text,
            "intent": intent,
            "intent_ranking": intent_ranking,
            "entities": entities,
        }

        # print (result)

        return result

In [None]:
inferencer = Inferencer("lightning_logs/version_35/checkpoints/epoch=19.ckpt")
inferencer.inference("list of tasks", intent_topk=5)