### Atis joint nmt

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import pandas as pd
import warnings
import os
import sys

sys.path.append("../")
warnings.filterwarnings("ignore")

Download atis dataset from [here](https://github.com/Microsoft/CNTK/tree/master/Examples/LanguageUnderstanding/ATIS/Data)

### Run NER model

In [2]:
import os


data_path = "/datadrive/JointSLU/data/"
train_path = os.path.join(data_path, "train_filtered.csv")
valid_path = os.path.join(data_path, "valid_filtered.csv")
model_dir = " /datadrive/models/multi_cased_L-12_H-768_A-12/"
init_checkpoint_pt = os.path.join("/datadrive/models/multi_cased_L-12_H-768_A-12/", "pytorch_model.bin")
bert_config_file = os.path.join("/datadrive/bert/multi_cased_L-12_H-768_A-12/", "bert_config.json")
vocab_file = os.path.join("/datadrive/bert/multi_cased_L-12_H-768_A-12/", "vocab.txt")

In [3]:
import torch
torch.cuda.set_device(0)
torch.cuda.is_available(), torch.cuda.current_device()

(True, 0)

#### Create data loaders

In [4]:
from modules import BertNerData as NerData

INFO:summarizer.preprocessing.cleaner:'pattern' package not found; tag filters are not available for English


In [5]:
data = NerData.create(train_path, valid_path, vocab_file, data_type="bert_uncased", is_cls=True)

In [6]:
len(data.train_dl.dataset), len(data.valid_dl.dataset)

(9445, 888)

In [7]:
len(data.label2idx), len(data.id2cls)

(154, 17)

In [8]:
sup_labels = list(pd.read_csv("/datadrive/JointSLU/data/slt_flt.csv").slots)
len(sup_labels)

106

#### Create Ner model

Set params of encoder and decoder as proposed [here](https://arxiv.org/pdf/1609.01454.pdf)

In [9]:
from modules.models.bert_models import BertBiLSTMAttnNMTJoint

In [10]:
model = BertBiLSTMAttnNMTJoint.create(len(data.label2idx), len(data.cls2idx),
                                      bert_config_file, init_checkpoint_pt, enc_hidden_dim=256)

In [11]:
model.decoder

NMTJointDecoder(
  (embedding): Embedding(154, 64)
  (lstm): LSTM(576, 256, batch_first=True)
  (attn): Linear(in_features=256, out_features=256, bias=True)
  (slot_out): Linear(in_features=512, out_features=154, bias=True)
  (loss): CrossEntropyLoss()
  (intent_loss): CrossEntropyLoss()
  (intent_out): Linears(
    (linears): ModuleList(
      (0): Linear(in_features=512, out_features=128, bias=True)
    )
    (output_linear): Linear(in_features=128, out_features=17, bias=True)
  )
)

#### Create learner

In [12]:
from modules import NerLearner

In [13]:
num_epochs = 100
learner = NerLearner(model, data,
                     best_model_path="/datadrive/models/atis/joint_nmt.cpt",
                     lr=0.01, clip=1.0, sup_labels=sup_labels,
                     t_total=num_epochs * len(data.train_dl))

INFO:root:Don't use lr scheduler...


In [None]:
learner.fit(num_epochs, target_metric='prec')

### Get best results

In [16]:
learner.load_model()

#### Get span results for valid ds (where train support > 3)

In [17]:
import pandas as pd
sup_slots = list(pd.read_csv("/datadrive/JointSLU/data/sup_slots.csv").sup_slots)

In [18]:
from modules.data.bert_data import get_bert_data_loader_for_predict

In [19]:
dl = get_bert_data_loader_for_predict(data_path + "valid_filtered.csv", learner)

In [20]:
preds, preds_cls = learner.predict(dl)

HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

In [21]:
from modules.train.train import validate_step


clf_report, clf_report_cls = validate_step(
    learner.data.valid_dl, learner.model, learner.data.id2label, learner.sup_labels, learner.data.id2cls)

HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

Mean IOB precision

In [23]:
import numpy as np


np.mean([float(line.split()[1]) for line in clf_report.split("\n")[2:-5] if int(line.split()[-1]) > 0])

0.9129245283018869

Span mean precision

In [24]:
from modules.utils.plot_metrics import get_bert_span_report


clf_report = get_bert_span_report(dl, preds)
np.mean([float(line.split()[1]) for line in clf_report.split("\n")[2:-5] if int(line.split()[-1]) > 0])

0.8206811594202899

Classification mean

In [25]:
np.mean([float(line.split()[1]) for line in clf_report_cls.split("\n")[2:-5] if int(line.split()[-1]) > 0])

0.8878125