In [27]:
# Adapted from a personal database of code, frankly I don't remember the reference
from sklearn.metrics import confusion_matrix
from data_loading import *

import numpy as np
import matplotlib.pyplot as plt
import itertools
from rich.progress import track
from vocabulary import Vocabulary
from ner_pretrain_noisycorpus import NERv1_PRE
from gensim.models import KeyedVectors
import torch
from torch.profiler import profile, record_function, ProfilerActivity

from typing import List, Any, Dict
from configuration import TRAINING_PATH, DEV_PATH, TEST_PATH, MODEL_PATH, UNK_WORD, TAG_DICT, PRETRAINED_PATH
from configuration import EMBEDDING_SIZE, HIDDEN_SIZE, BATCH_SIZE, EPOCHS, NUM_LAYERS, BIDIRECTIONAL, PRETRAINED

from IPython.display import HTML, display


In [28]:
def _cstr(s, color='black'):
    if s == ' ':
        return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>'
    else:
        return f'<text style=color:#000;background-color:{color}>{s} </text>'

# print html
def _print_color(t):
    display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))

# get appropriate color for value
def _get_clr(value):
    colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
            '#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
            '#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
            '#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
    value = int((value * 100) / 5)
    if value == len(colors): value -= 1  # fixing bugs...
    return colors[value]

def print_colourbar():
    color_range = torch.linspace(-2.5, 2.5, 20)
    to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
    _print_color(to_print)


In [29]:
print_colourbar()

In [30]:
def visualize_values(output_values, result_list):
    text_colours = []
    for i in range(len(output_values)):
        text = (result_list[i], _get_clr(output_values[i]))
        text_colours.append(text)
    _print_color(text_colours)

def plot_state(data, state, b, decoder):
    actual_data = decoder(data[b, :, :].numpy())
    seq_len = len(actual_data)
    seq_len_w_pad = len(state)
    for s in range(state.size(2)):
        states = torch.sigmoid(state[:, b, s])
        visualize_values(states[seq_len_w_pad - seq_len:], list(actual_data))

visualize_values([-1.00, -0.66, 0.13, 0.66, 1.00], ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC'])

In [36]:

import sys
import os
import torch.nn

module_path = os.path.abspath(os.path.join('..'))
module_path1 = os.path.abspath('/'.join(module_path.split('/')[:-1]))
if module_path1 not in sys.path:
    sys.path.append(module_path1+"/model")

print(module_path)
print(module_path1)

class StudentModel():

    # STUDENT: construct here your model
    # this class should be loading your weights and vocabulary

    def __init__(self, device):
        self.device = device
        self.dataset = DatasetNER(module_path1+'/'+TRAINING_PATH, verbose=False)
        self.weights = self.init_weights(PRETRAINED)
        self.vocab = self.init_vocab(self.dataset, self.weights)
        self.tags = TAG_DICT
        self.emb_size = EMBEDDING_SIZE
        self.hidden_size = HIDDEN_SIZE
        self.num_layers = NUM_LAYERS
        self.bidirectional = BIDIRECTIONAL
        self.num_classes = len(self.tags)
        self.model = NERv1_PRE(len(self.vocab), self.emb_size, self.hidden_size, self.num_layers, self.num_classes, self.bidirectional, device, self.weights)
        print("Model loaded")

    def init_weights(self, pretrained = False):
        if pretrained:
            weights = KeyedVectors.load(module_path1+'/'+PRETRAINED_PATH)
            print(f'Loaded pretrained embeddings at {PRETRAINED_PATH}')
            return weights
        else:
            return None

    def init_vocab(self, dataset, weights=None):
        if weights is not None:
            tokens = [weights.index_to_key[i] for i in range(weights.vectors.shape[0])]
            return Vocabulary(tokens, min_freq=1, unk_token=UNK_WORD)
        else:
            return Vocabulary(dataset.tokens, min_freq=1, unk_token=UNK_WORD)

    def expert_postprocess(self, pred_lbl):
        for sentence in pred_lbl:
            for i, token in enumerate(sentence):
                if token.startswith("I-"):
                    if i == 0:
                        sentence[i] = "B-" + token[2:]
                        print("I- tag at the beginning of sentence")
                    elif sentence[i-1] == "O":
                        sentence[i] = "B-" + token[2:]
                        print("I- tag without B- tag")
                    elif (sentence[i-1].startswith("I-") or sentence[i-1].startswith("B-")) and sentence[i-1][2:] != sentence[i][2:]:
                        sentence[i] = "I-" + sentence[i-1][2:]
                        print("I- tag with different tag")

    def predict(self, tokens: List[List[str]]) -> List[List[str]]:
        # STUDENT: implement here your predict function
        # remember to respect the same order of tokens!
        self.model.load_state_dict(torch.load(module_path1+'/'+MODEL_PATH, map_location=torch.device(self.device)))
        self.model.eval()
        sentences = []
        offsets = []
        for list in tokens:
            index_phrase = torch.tensor(Vocabulary.word_to_index(list, self.vocab), dtype=torch.int64)
            sentences.append(index_phrase)
            offsets.append(index_phrase.size(0))
        sentences = torch.cat(sentences)
        predicted_indexes = self.model(sentences)
        predicted_labels = [Vocabulary.index_to_tags(phrase.tolist(), self.tags) for phrase in torch.split(predicted_indexes.argmax(1), offsets)]
        for elem in predicted_indexes:
            visualize_values(torch.nn.functional.softmax(elem), 
            ["B-CORP", "B-CW", "B-GRP", "B-LOC", "B-PER", "B-PROD", "I-CORP", "I-CW", "I-GRP", "I-LOC", "I-PER", "I-PROD", "O"])
        self.expert_postprocess(predicted_labels)
        return predicted_labels


'''     def visualize_values(output_values, result_list):
            text_colours = []
            for i in range(len(output_values)):
                text = (result_list[i], _get_clr(output_values[i]))
                text_colours.append(text)
            _print_color(text_colours)

        def plot_state(data, state, b, decoder):
            actual_data = decoder(data[b, :, :].numpy())
            seq_len = len(actual_data)
            seq_len_w_pad = len(state)
            for s in range(state.size(2)):
                states = torch.sigmoid(state[:, b, s])
                visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))

        def plot_state(list_of_words, cell_state, list_of_13scores):
            visualize_values(list_of_words, str_with_color) '''

model = StudentModel("cpu")
print(model.predict(["God dog".split() ]))

/home/ant/Documents/nlp2022-hw1/hw1
/home/ant/Documents/nlp2022-hw1
Loaded pretrained embeddings at model/glove_pretrained_300
Model loaded


  visualize_values(torch.nn.functional.softmax(elem),


I- tag with different tag
[['B-PROD', 'I-PROD']]


In [29]:

def flat_list(l: List[List[Any]]) -> List[Any]:
    return [_e for e in l for _e in e]

def count(l: List[Any]) -> Dict[Any, int]:
    d = {}
    for e in l:
        d[e] = 1 + d.get(e, 0)
    return d

def visualize():
    valid_dataset = DatasetNER(module_path1+'/'+DEV_PATH, verbose=True) 

    labels_s = valid_dataset.labels
    tokens_s = valid_dataset.tokens

    predictions_s = []
    batch_size = 32

    model = StudentModel("cpu")
    for i in track(range(0, len(tokens_s), batch_size), description="Visualizing"):
        batch = tokens_s[i : i + batch_size]
        predictions_s += model.predict(batch)

    flat_labels_s = flat_list(labels_s)
    flat_predictions_s = flat_list(predictions_s)

    label_distribution = count(flat_labels_s)
    pred_distribution = count(flat_predictions_s)

    print(f"# instances: {len(flat_list(labels_s))}")

    keys = set(label_distribution.keys()) | set(pred_distribution.keys())
    for k in keys:
        print(
            f"\t# {k}: ({label_distribution.get(k, 0)}, {pred_distribution.get(k, 0)})"
        )

visualize()

# instances: 12751
	# I-PER: (329, 329)
	# I-LOC: (153, 129)
	# I-CW: (261, 232)
	# O: (10240, 10428)
	# B-CW: (170, 173)
	# B-LOC: (243, 229)
	# B-CORP: (133, 105)
	# I-PROD: (87, 81)
	# I-GRP: (377, 342)
	# B-GRP: (190, 180)
	# I-CORP: (119, 80)
	# B-PER: (300, 305)
	# B-PROD: (149, 138)
