# Evaluation of DISC on the Idioment dataset
This file contains an evaluation of the DISC model on the Idioment dataset and calculates the recall. At the end, there is a visualisation tool. 
Preparing the data and inference shouldn't take long (less that a minute), however the cell loading the data handler and the detector model could take some time.

When running for the first time, please ensure that you have a model checkpoint present in `DISC/checkpoints`. A checkpoint can be downloaded via this [Link](https://drive.google.com/file/d/1pGX1F03FYWymXcZ0hjJ7kYi3fySW5n54/view?usp=sharing). See also the README file.

## Imports

In [1]:
from IPython.display import display, HTML, clear_output
import torch
import numpy as np
from tqdm import tqdm
from src.utils.model_util import load_model_from_checkpoint
from src.model.read_comp_triflow import ReadingComprehensionDetector as DetectorMdl
from config import Config as config
from demo_helper.data_processor import DataHandler
from demo_helper.visualize_helper import simple_scoring_viz
import nltk
import pandas as pd
from os.path import join
import random
nltk.download("punkt")
nltk.download("averaged_perceptron_tagger")

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /home/basg/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/basg/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

In [2]:
# Use config.py to change data and model settings
data_handler = DataHandler(config)
detector_model= load_model_from_checkpoint(DetectorMdl, data_handler.config)


Loading Pre-trained Glove Embeddings...


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Loading model from /home/basg/Projects/idioms/idiom_recognition/DISC/checkpoints/ReadComp_magpie_random_cross_attn-glove-char-pos-tri_latest.mdl
=> loading checkpoint '/home/basg/Projects/idioms/idiom_recognition/DISC/checkpoints/ReadComp_magpie_random_cross_attn-glove-char-pos-tri_latest.mdl'
=> loaded checkpoint '/home/basg/Projects/idioms/idiom_recognition/DISC/checkpoints/ReadComp_magpie_random_cross_attn-glove-char-pos-tri_latest.mdl'


## Data loading and preparation

In [3]:
IDIOMENT_PATH = "../../idioment/"

idioms = pd.read_csv(join(IDIOMENT_PATH, "idioms.csv"))
sentences = pd.read_csv(join(IDIOMENT_PATH, "sentences.csv"))
print(sentences.columns)
print(sentences)

Index(['id', 'sentence', 'sentiment'], dtype='object')
        id                                           sentence sentiment
0        1  How much of the forecast was genuine and how m...     other
1        2  I did touch them one time you see but of cours...     other
2        3  We find that choice theorists admit that they ...     other
3        4              Well, here I am with an olive branch.  positive
4        5  Its rudder and fin were both knocked out, and ...  negative
...    ...                                                ...       ...
2516  5971  We were both rigid as enemies, longing to come...  negative
2517  5974  We wouldn't wish to do that, but nor would we ...     other
2518  5977       But in the next move tit for tat retaliates.  negative
2519  5978  By this point some of my readers will be up in...  negative
2520  5980  Members of that nobility had been prominent in...     other

[2521 rows x 3 columns]


In [4]:
sents = sentences["sentence"].to_list()
sents[:5]

['How much of the forecast was genuine and how much was fixed, it is a moot point.',
 'I did touch them one time you see but of course there was nothing doing, he wanted me.',
 'We find that choice theorists admit that they introduce a style of moral paternalism at odds with liberal values.',
 'Well, here I am with an olive branch.',
 'Its rudder and fin were both knocked out, and a four-foot-long gash in the shell meant even repairs on the bank were out of the question.']

In [5]:
# Need to make batches, otherwise it crashes when doing inference
batch_size = 16
batches = [sents[b:b + batch_size] for b in range(0, len(sents), batch_size)]
assert len(batches) == np.ceil(len(sents) / batch_size)
print(len(batches))

data = []
for b in batches:
    d = data_handler.prepare_input(b)
    data.append(d)

print(len(data))

158
158


In [6]:
type(data[0])

dict

## Inference

In [7]:
# If this gives error when using GPU (on wsl2), take a look at this: https://discuss.pytorch.org/t/libcudnn-cnn-infer-so-8-library-can-not-found/164661 
outputs = []
i = 0
for batch in tqdm(data):
    if i == 66 or i == 96 or i == 113:
        print("skipping this batch (it always gives an unknown CUDA error)")
        outputs.append([[], []])
    else:
        with torch.no_grad():
            ys_ = detector_model(batch)
            probs = torch.nn.functional.softmax(ys_, dim=-1)
        ys_ = ys_.cpu().detach().numpy()
        probs = probs.cpu().detach().numpy()
        # idiom_class_probs = probs[:, :, -1].tolist()
        predicts = np.argmax(ys_, axis=2)
        outputs.append([probs, predicts])
    i += 1

print(len(data))
print(len(outputs))

  0%|          | 0/158 [00:00<?, ?it/s]

 43%|████▎     | 68/158 [00:12<00:10,  8.38it/s]

skipping this batch (it always gives an unknown CUDA error)


 61%|██████    | 96/158 [00:16<00:07,  8.28it/s]

skipping this batch (it always gives an unknown CUDA error)


 73%|███████▎  | 115/158 [00:18<00:04,  9.98it/s]

skipping this batch (it always gives an unknown CUDA error)


100%|██████████| 158/158 [00:23<00:00,  6.60it/s]

158
158





## Evaluation

In [8]:
n_idioms_found = 0
n_sents = 0
# Iterate through the output batches
for probs, predicts in outputs:
    # Iterate through the sentences in this batch
    for s in predicts:
        n_sents += 1
        # 4 means idiomatic usage. So if there is at least one 4 in a sentence, an idiom was found
        if 4 in s:
            n_idioms_found += 1
            # TODO: find idiom and check whether it is correct (use the idioms.csv or something)
            

In [9]:
print("idioms found:", n_idioms_found)
print("total number of sentences:", n_sents)
print("recall:", n_idioms_found / n_sents)

idioms found: 1962
total number of sentences: 2473
recall: 0.7933683784876668


## Visualisation

In [10]:
code = "" # if number, go to that batch, "n" = next batch, "r" = random batch, "q" = quit, "" = nothing
b_nr = 0
prev_nr = None

while code != "q":
    # Display batch
    if b_nr != prev_nr:
        clear_output(wait=True)

        print("output batch number:", b_nr)
        probs, predicts = outputs[b_nr]
        idiom_class_probs = probs[:, :, -1].tolist()
        sentences_tkns = data[b_nr]['xs_bert'].cpu().detach().numpy().tolist()
        sentences_tkns = [data_handler.tokenizer.convert_ids_to_tokens(s) for s in sentences_tkns]
        for i in range(len(sentences_tkns)):
            s = simple_scoring_viz(sentences_tkns[i], idiom_class_probs[i], 'YlGn')
            display(HTML(s))

    # Get new batch
    code = input('Please enter command:\n"n" = next batch, "r" = random batch, "q" = quit, input a number to go to that batch').lower()
    prev_nr = b_nr
    if code == "n":
        b_nr = (b_nr + 1) % len(outputs)
    elif code == "r":
        b_nr = random.randrange(0, len(outputs))
    elif code.isnumeric():
        n = int(code)
        if n >= -len(outputs) and n < len(outputs):
            b_nr = n
        else:
            print("please enter a valid batch number")
    elif code == "q":
        print("quitting...")
    else:
        print("not a valid command")


output batch number: 80


quitting...
