In [None]:
# Some checks

!ipython kernelspec list
import sys
print(sys.version)

In [None]:
# Imports 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader
from torch.nn import TransformerEncoder, TransformerEncoderLayer


from collections import defaultdict, Counter

import math
from typing import Tuple
import numpy as np
import pandas as pd
import pickle
import bcolz

import time
import random
import functools


from nltk.corpus import wordnet as wn
from bert_embedding import BertEmbedding
import nltk
from nltk.corpus import stopwords
from string import punctuation
from nltk import pos_tag, WordNetLemmatizer
from pprint import pprint
import string


import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina'
plt.style.use('seaborn')

In [None]:
# set seed to ensure the same initialization for every run

SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Utils Functions (Loading Data, Vocab, Numericalization)

In [None]:
def to_tensor(string_list):
    l_str = []
    for ele in string_list:
        if ele[0] == "[":
            l_str.append(ele[1:])
        else:
            if ele[-1] == "]":
                l_str.append(ele[:-1])
            else:
                l_str.append(ele)

    str_vec = " ".join(l_str)
    torch_labels = torch.tensor(list(map(float, str_vec.split(' '))), dtype=torch.float32)
    return torch_labels

In [None]:
def parse_data(file):
    """
    reads the stem word and the spatial tag of each token in the .csv file
    :param corpus_file:
    :param datafields:
    :return: List of training data of the form [[tokenized_sentence-1, spatial_tensors],
                                                [tokenized_sentence-1, spatial_tensors], ...]
    """
    with open(file, encoding='utf-8') as f:
        examples = []
        words = []
        lemmas = []
        synset_offset = []
        labels = []
        for line in f:
            line = line.strip()
            if not line:
                examples.append([lemmas, synset_offset, labels])
                words = []
                lemmas = []
                synset_offset = []
                labels = []
            else:
                columns = line.split()
                words.append(columns[0])
                lemmas.append(columns[1])
                synset_offset.append(columns[-6])
                lab = to_tensor(columns[-5:])
                labels.append(lab)
        return examples

In [None]:
def clean_untagged(data):
    original_data = data
    for entry in data:
        
        idx = [i for i, syn in enumerate(entry[1]) if syn == 'no-synset']

        # remove those from the data
        for s in reversed(idx):
            del entry[0][s]
            del entry[1][s]
            del entry[2][s]

    return original_data, data

In [None]:
def data_id(data):

    # data_collector = {"0": [[], []], "1": [[],[]], ...}
    data_collector = {}
    for i, instance in enumerate(data):
        data_collector[str(i)] = instance

    return data_collector

In [None]:
def load_vocab(data, embed_size=300):
    
    # insert all dataset vocabulary
    dataset_vocab = []
    if isinstance(data[0], str):
        dataset_vocab = data
    else:
        for instance in data:            
            dataset_vocab += instance[0]
        
    # print(len(dataset_vocab))
    
    # remove duplicates
    target_vocab = set(dataset_vocab)
    
    # generate weights matrix using glove
    matrix_len = len(target_vocab)
    
    weights_matrix = np.zeros((matrix_len, embed_size))
    
    words_found = 0

    for i, word in enumerate(target_vocab):
        #print(i, word)
        try:
            weights_matrix[i] = glove[word]
            #print(weights_matrix[i])
            words_found += 1
        except KeyError:
            weights_matrix[i] = np.random.normal(scale=0.6, size=(embed_size, ))
            #print(weights_matrix[i])
    #print(words_found)
    
    return target_vocab, weights_matrix

In [None]:
def numericalize(tokens_list, vocab):
    
    str2num = {word: index for index, word in enumerate(vocab)}
    num_list = []
    for token in tokens_list:
        num_list.append(str2num[token])
        
    return torch.tensor(num_list, dtype=torch.long)


In [None]:
# Preprocess Input sentence

# set of english stop words U set of punctuation
EN_STOPWORDS_PUNCT = set(stopwords.words('english')).union(set(string.punctuation))
WN_LEMMATIZER = WordNetLemmatizer()


def tags4wn(tag):
    """Penn Treebank tags: https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html
    Converts PennTreeBank tags to WN tags, e.g. n, a, v"""
    tag_conversion = {"NN": "n", # noun
                      "JJ": "a", # adjective
                      "VB": "v", # verb
                      "RB": "r"} # adverb
    # there are still many more tags
    try:
        # return the WN tags
        return tag_conversion[tag[:2]]
    except:
        # if no tag is found, treat the word as a noun
        return "n"
        # I think that in our case it is better to consider them all
        # return None

def preprocess(sentence):
    """Preprocesses a raw input sentence and return a list of each word with its POS tag."""
    # Tokenization
    tokenized_sentence = nltk.word_tokenize(sentence)
    # lowercase all words
    lower = [word.lower() for word in tokenized_sentence]
    # print(lower)
    # delete stop words and punctuation
    clean_sentence = [word for word in lower if word not in EN_STOPWORDS_PUNCT]
    # print(clean_sentence)
    # use wordNet Lemmatizer to do POS and then lemmatize
    pos_tagging = pos_tag(clean_sentence)
    # print(pos_tagging)
    # Lemmatize
    lemmatized_sentence = [(WN_LEMMATIZER.lemmatize(word, pos=tags4wn(tag)), tags4wn(tag)) for word, tag in pos_tagging]
    # print(lemmatized_sentence)

    return lemmatized_sentence


# Data Loader

### ----- Do it only once

In [None]:
path = "../data/test_transformer/"
# split the dataset into training, validation and testing
train_path = "train.csv"
validate_path = "validate.csv"
test_path = "test.csv"

In [None]:
# parse training data through path
data = parse_data(path + train_path)

In [None]:
# clean parsed data
orig, data = clean_untagged(data)

In [None]:
# create for each entry in dataset an ID
datasetID = data_id(data)

In [None]:
# store the dataset arranged by ID in a .pt file
# Saving and loading data to/from .pt
# save
#torch.save(datasetID, path + "pwngc_id.pt")


### ----- restart from here: 

In [None]:
datasetID = torch.load(path + "pwngc_id.pt")

In [None]:
# partition data in training/validation
splittings = {}


In [None]:
# split training and validation data by ID

N_train = 10
N_valid = 5
N_test = 5
# choose N training instances, randomly!
splittings["train"] = random.sample(list(datasetID), N_train)
splittings["validate"] = random.sample(list(set(datasetID) - set(splittings["train"])), N_valid)
# splittings

In [None]:
# Saving and loading data to/from .pt
# save
#torch.save(datasetID, path + "pwngc_id.pt")

In [2]:
# Dataset
class Dataset(torch.utils.data.Dataset):
    
    def __init__(self, list_ids, path2data):
        self.list_ids = list_ids
        self.path2data = path2data
        dataset = torch.load(self.path2data)

    def __len__(self):
        "Total Number of samples."

        return len(self.list_ids)

    def __getitem__(self, index):
        "Extracts one Example of data."

        id = self.list_ids[index]
        

        # data
        X = self.dataset[id][0]
        tag_y = self.dataset[id][1]
        y = self.dataset[id][2]

        return X, y, tag_y

NameError: name 'torch' is not defined

# Downloading Glove

In [None]:
glove_path = "./.vector_cache"

In [None]:
glove = pickle.load(open(f'{glove_path}/840B.300_glove.pkl', 'rb'))


In [None]:
target_VOCAB = np.load("WORDNET_VOCAB_exp_01.npy")

In [None]:
SPATIAL_TAGS = np.load("WORDNET_SPATIAL_TAGS_exp_01.npy")

# The Model

In [None]:
def create_emb_layer(weights_matrix, non_trainable=False):
    num_embeddings, embedding_dim = weights_matrix.shape
    emb_layer = nn.Embedding(num_embeddings, embedding_dim)
    weights_matrix = torch.from_numpy(weights_matrix)
    emb_layer.load_state_dict({'weight': weights_matrix})
    if non_trainable:
        emb_layer.weight.requires_grad = False

    return emb_layer, num_embeddings, embedding_dim

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: np.ndarray, dropout: float = 0.1, max_len: int = 5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [None]:
class TransformerEncoderRegressor(nn.Module):

    def __init__(self, weights_matrix:np.ndarray, 
                 ntoken: int, out_features:int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        
        super().__init__()
        
        self.model_type = 'Transformer'
        
        self.d_model = d_model
        
        self.weights_matrix = weights_matrix
        
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        
        # Embedding layer
        self.embedding, num_embeddings, embedding_dim = create_emb_layer(self.weights_matrix, True)
        
        # Multi-head attention mechanism is included in TransformerEncoderLayer
        # d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=<function relu>, 
        # layer_norm_eps=1e-05, batch_first=False, norm_first=False, device=None, dtype=None
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout) # activation
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers, norm=None)
        
        
#         padding_idx (int, optional) – If specified, the entries at padding_idx do not contribute to the gradient;
#         therefore, the embedding vector at padding_idx is not updated during training,
#         i.e. it remains as a fixed “pad”. For a newly constructed Embedding, the embedding vector at
#         padding_idx will default to all zeros, but can be updated to another value to be used as the padding vector.
        self.emb = nn.Embedding(ntoken, d_model) 
        self.out_features = out_features
        
        # Linear layer: returns the last hidden state of the encoder 
        self.fc = nn.Linear(d_model, embedding_dim)
        
        # No! Here I am just redoing fully connected connections
        # Linear Layer: affine transformation of last hidden layer into shape (1, embedding_dim)
        #self.context_vec = nn.Linear(d_model, embedding_dim)
        
        #self.decoder = nn.Linear(d_model, ntoken)
        
        # Now, I need to have a Linear space that takes the whole/subset dataframe as input, extracts its spatial_context_vec,
        # based on Glove-word-vector + spatial_point,
        # then calculates softmax on this distribution
        # choose the argmax
        # get its spatial tags
        # calculate distance loss between them
        # do backprop! 
        # Nx300 into Nx227733: matmul product of two matrices Nx300 and 300x227733 --> Nx227733
        # apply softmax to get the probabilities
        # apply argmax to get the maximum indices
        # use the indices to get the synset names as well as the mapping to coordinates
        # into Nx5: mapping to the coordinates
        
        self.output = nn.Linear(embedding_dim, 5)
        #self.wn_embeddings = nn.Linear(1, target_matrix.shape[0])

        self.init_weights()
        
#         weights_matrix = weights_matrix, 
#                                     ntoken= # false: 300,
#                                     out_features=5,
#                                     d_model=300,
#                                     d_hid=200,
#                                     nlayers=2,
#                                     nhead=2,
#                                     dropout=0.2
        
        
        # -------------------------------------

    def init_weights(self) -> None:
        "initialize weights using uniform distribution"
        initrange = 0.1
        self.emb.weight.data.uniform_(-initrange, initrange)
        # self.decoder.bias.data.zero_()
        # self.decoder.weight.data.uniform_(-initrange, initrange)
        
        #self.output.bias.data.zero_()
        #self.output.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        
        #src = self.encoder(src) * math.sqrt(self.d_model)
        src = torch.mul(self.emb(src), math.sqrt(self.d_model)) #? 1/sqrt!
#         print("Embedding", src.shape)
#         print('-' * 80)
        
        
        src = self.pos_encoder(src)
#         print("Positional Encoding", src.shape)
#         print('-' * 80)
        
        
        encoder_output = self.transformer_encoder(src) #, src_mask)
#         print("Encoder", encoder_output.shape)
#         # print(encoder_output)
#         print('-' * 80)
        
        
        linear_layer = self.fc(encoder_output)
#         print("Linear Layer", linear_layer.shape)
#         # print(linear_layer)
#         print('-' * 80)

        # calculate the sum/weighted sum/ ?? on the linear layer to get the context vector of size (1, embd_dim)
        context_vec = torch.sum(linear_layer, dim=1)
#         print("Final Context Vector", context_vec.shape)
#         # print(context_vec)
#         print('-' * 80)
        
        # regression output
        coordinates = self.output(context_vec)
#         print("Coordinates from Context Vector", coordinates.size())
#         # print(coordinates)
#         print('-'*80)
        return coordinates


# Geometric Loss

In [None]:
def coo2point(coo):
    # print(coo)
    l0 = coo[0]
    alpha = coo[1]
    alpha_rad = alpha * math.pi / 180
    l_i = coo[2]
    beta_i = coo[3]
    beta_i_rad = beta_i * math.pi / 180
    r = coo[4]
    
    # np.cos() and np.sin() take angles in radian as params
    center_pt = torch.tensor([l0 * math.cos(alpha_rad), l0 * math.sin(alpha_rad)], dtype=torch.float64, requires_grad=True)
    sense_pt = center_pt + torch.tensor([l_i * math.cos(alpha_rad + beta_i_rad),
                                     l_i * math.sin(alpha_rad + beta_i_rad)], dtype=torch.float64, requires_grad=True)
    return sense_pt, center_pt




def distance_loss(pred_pt, original_pt, include_r=False, pt_sphere=False):
    """
    Calculates the distance between two sense points, including radii.
    :param pred_pt:
    :param original_pt:
    :param include_r: if set to true, include radius in the distance. 
                      It gives more freedom/tolerance degrees to the loss function. 
                      Loss is satisfied once the predicted point is part of original point.
    :return:
    """
        
    # original_pt = torch.from_numpy(original_pt)
    # print("original point", type(original_pt), original_pt)
    
    r1 = pred_pt[-1]
    r2 = original_pt[-1]

    pred_sense, pred_center = coo2point(pred_pt)
    orig_sense, orig_center = coo2point(original_pt)
    
    
    loss = torch.linalg.norm(torch.sub(pred_sense, orig_sense)) - r2
    
    # very strong assumption for the words that are not sense-tagged
    # If I want more tolerance, I could neglect those tokens from the beginning
    if torch.all(torch.eq(original_pt, torch.zeros(original_pt.size(0))), dim=0):
        return loss
    
    if pt_sphere:
        dist = torch.linalg.norm(torch.sub(pred_sense, orig_sense)) + r2
        return dist

    
    if include_r:
        
        tolerant_loss = r1 + loss - r2
    
        if tolerant_loss < 0:
            tolerant_loss = 0.0
        
#         if r1 > r2: #case the predicted radius is bigger than actual one
#             tolerant_loss = torch.abs(torch.sub(r1, r2))
           
        return tolerant_loss
    
    else:
        return loss 
   


def geometric_loss(pred_list, label_list, include_r=False):
    
    # assert that the two lists must be of equal size
    pred_size = pred_list.size()[0]
    lab_size = label_list.size()[0]
    assert pred_size == lab_size
    
    sentence_loss = 0.0
    
    # sum over all the tokens in the sentence
    for i in range(pred_size):
        sentence_loss += distance_loss(pred_list[i], label_list[i], include_r)
        
    return sentence_loss

# Sense Inference

In [None]:
def is_contained(pred, sphere_coo, compare_spheres=False):

    pt, word = coo2point(pred)
    sphere_sense, sphere_center = coo2point(sphere_coo)

    pt_rad = pred[-1]
    sphere_rad = sphere_coo[-1] # in angles
    
    
    
    if compare_spheres == False:
        contained = (pt[0] - sphere_sense[0])**2 + (pt[1] - sphere_sense[1])**2 <= sphere_rad**2
    else:
        contained = pt_rad + torch.linalg.norm(pt - sphere_sense) - sphere_rad <= 0

    if contained:
        return True
    else:
        return False
    


def vicinity_matrix(spatial_params, target_vocab: np.ndarray, spatial_tags: np.ndarray, k=5):#, include_sphere=True, include_r=True) -> [str]:
    """
    Projects the predicted spatial parameters into the embedding space.
    Returns the synsets in the vicinity of the projected point.
    :param spatial_params:
    :return: Vicinity matrix, synsets dict
    """
    N = len(spatial_tags)
    
    #convert spatial_tags to tensor
    spatial_tags = torch.from_numpy(spatial_tags)
    
    synsets = {} # sort from most specific to most general
    
    indices = {}

    sense_pt, center_pt = coo2point(spatial_params)
    
    # ----------------------------------------------------------------------------------------------------------------
    # Prepare distance and containment calculations
    # ----------------------------------------------------------------------------------------------------------------
    
    # distance calculations
    dist_spheres = torch.empty(N) 
    dist_pt_sphere = torch.empty(N) 
    dist_pts = torch.empty(N)
    
    for i, tag in enumerate(spatial_tags):
        dist_spheres[i] = distance_loss(spatial_params, tag, include_r=True)
        dist_pt_sphere[i] = distance_loss(spatial_params, tag, pt_sphere=True)
        dist_pts[i] = distance_loss(spatial_params, tag, include_r=False)
    
    # containment calculations
    full_contained = torch.empty(N) 
    part_contained = torch.empty(N)
    disconnected = torch.empty(N) # handles points only
    
    for j, tag in enumerate(spatial_tags):
        full_contained[j] = is_contained(spatial_params, tag, compare_spheres=True)
        part_contained[j] = distance_loss(spatial_params, tag, include_r=True) > 0
        disconnected[j] = ~ is_contained(spatial_params, tag, compare_spheres=True) # reverse the True <----> False
    
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Initialize the Vicinity Matrix
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    # row=3, col=3, topk=2, 2 indicates the column of indices and the distances
    vicinity_matrix = torch.zeros((3,3, k, 2))
    
    ####################################################################################################################
    # # Full contained + min dist between sense points
    ####################################################################################################################
    
#     print("True elements")
    true_indices1 = (full_contained == True).nonzero(as_tuple=True)[0]
#     print(true_indices1)
    
    if true_indices1.size(0) != 0:
        dist1 = torch.index_select(dist_pts, 0, true_indices1)
#         print("dist1", dist1)
#         print("k = ", k)
        # sort in ascending order
        # select top k 
        sort_dist1, sort_indices = torch.topk(dist1, k, largest=False)  
#         print("SORTING", sort_dist1, sort_indices)
        synsets1 = np.take(target_vocab, sort_indices, 0)
        synsets["A"] = [synsets1, sort_dist1]
        indices["A"] = sort_indices
        # index, distance (without synsets because this would result in conflicts for torch.tensor that do not support str)
        vicinity_matrix[2][0] = torch.stack((sort_indices, sort_dist1), dim=1)
    else:
        pass
    
    
    ####################################################################################################################
    # # Partially contained + min dist between sense points
    ####################################################################################################################
    true_indices2 = (part_contained == True).nonzero(as_tuple=True)[0]
#     print("True Indices 2", true_indices2)
    
    if true_indices2.size(0) != 0:
        dist1 = torch.index_select(dist_pts, 0, true_indices2)
        # sort in ascending order
        # select top k 
        sort_dist2, sort_indices2 = torch.topk(dist1, k, largest=False)     
        synsets2 = np.take(target_vocab, sort_indices2, 0)
#         print("synset 2", synsets2)
        synsets["B"] = [synsets2, sort_dist2]
        indices["B"] = sort_indices2
        # index, distance (without synsets because this would result in conflicts for torch.tensor that do not support str)
        vicinity_matrix[2][1] = torch.stack((sort_indices2, sort_dist2), dim=1)
    else:
        pass

    ####################################################################################################################
    # # Disconnected + min dist between spheres/point2sphere/sense points ---> acts as Nearest neighbor
    ####################################################################################################################
    # get indices, where disconnected is true
    true_indices3 = (disconnected == True).nonzero(as_tuple=True)[0]
#     print("True Indices 3", true_indices3)

    if true_indices3.size(0) != 0:
        # get the distances at those indices
        dist_spheres3 = torch.index_select(dist_spheres, 0, true_indices3)
        dist_pt_sphere3 = torch.index_select(dist_pt_sphere, 0, true_indices3)
        dist_pts3 = torch.index_select(dist_pts, 0, true_indices3)

        # sort-select top k minimum distances
        sort_dist_spheres3, sort_sph_indices3 = torch.topk(dist_spheres3, k, largest=False)
        sort_dist_pt_sphere3, sort_pt_sph_indices3 = torch.topk(dist_pt_sphere3, k, largest=False)
        sort_dist_pts3, sort_pts_indices3 = torch.topk(dist_pts3, k, largest=False)

        # get their corresponding synsets
        synsets30 = np.take(target_vocab, sort_sph_indices3, 0)
        #print("synset30", synsets30)
        synsets["C"] = [synsets30, sort_dist_spheres3]
        indices["C"] = sort_sph_indices3
        
        synsets31 = np.take(target_vocab, sort_pt_sph_indices3, 0)
        synsets["D"] = [synsets31, sort_dist_pt_sphere3]
        indices["D"] = sort_pt_sph_indices3
        
        synsets32 = np.take(target_vocab, sort_pts_indices3, 0)
        synsets["E"] = [synsets32, sort_dist_pts3]
        indices["E"] = sort_pts_indices3
        
        # insert them into the vicinity matrix    
        vicinity_matrix[0][3] = torch.stack((sort_sph_indices3, sort_dist_spheres3), dim=1)
        vicinity_matrix[1][3] = torch.stack((sort_pt_sph_indices3, sort_dist_pt_sphere3), dim=1)
        vicinity_matrix[2][3] = torch.stack((sort_pts_indices3, sort_dist_pts3), dim=1)  
    


#     # get the spheres, where the point/point+radius is contained/overlaping/near

#     # 1. check if the predicted point is contained in some sense
#     contained = torch.empty(N)
    
#     for i, tag in enumerate(spatial_tags):
#         contained[i] = is_contained(spatial_params, tag, compare_spheres=include_sphere)
    
#     # 2. For those synsets, which is the nearest synset point
#     #use distance() to calculate distance between centers
#     distances = torch.empty(N)
#     for i, tag in enumerate(spatial_tags):
#         distances[i] = distance_loss(spatial_params, tag, include_r=include_r)
    
#     # sort dist--> indices
#     # check if for those distances the containment is true
#     # if true: choose the one having min_dist as sense
#     # top k senses must be stored in a dict 
    
#     # check if for those distances the containment is false, then, only the radius is falsly predicted (not priority now)
#     # if false and min_dist: choose it as potential sense
    
    

#     # 3. If None of the synsets apply to that word sense
#     # use sphere_dist to find the nearest sphere (most general synset), and assign it to that synset
#     # (this maybe good for rare senses)
#     # acts as a second chance
#     rare_contained = torch.empty(N)
#     rare_distances = torch.empty(N)
#     for i, tag in enumerate(spatial_tags):
#         rare_contained[i] = is_contained(spatial_params, tag, compare_spheres=False) #only consider sense point
#         rare_distances[i] = distance_loss(spatial_params, tag, include_r=False)


    return indices, vicinity_matrix, synsets

def decode_key(key, mtx):
    if key == "A":
        return mtx[2, 0]
    if key == "B":
        return mtx[2, 1]
    if key == "C":
        return mtx[0, 2]
    if key == "D":
        return mtx[1, 2]
    if key == "E":
        return mtx[2, 2]
    

def label_in_vicinity(vicinity_matrix, vicinity_synsets, target_vocab, spatial_tags, true_label):
    
    checked_synsets = []
    contained = []
    checks = 0
    predicted = []
    distances = []
    
    in_vicinity = False
    associated_syn = []
    
    # true label is either one of the possibilities [word, synset] or a randomly chosen one
    
    # induce subset of word-synset name 
    
    #spatial_tags = torch.from_numpy(spatial_tags)
    #idx_label = (spatial_tags == true_label).nonzero(as_tuple=True)[0]
    # transform to numpy to 
    true_label = np.array(true_label, dtype=np.float64)
    # keep spatial tag an np.ndarray
    rounded_l = np.round(true_label, decimals=2)
    
    if np.all(rounded_l == np.zeros(5)): #true_label): #torch.all(torch.eq(rounded_l, true_label)):
        in_vicinity = False #True
        associated_syn.append('no-synset')
        return in_vicinity, associated_syn
    
    try:
        # detecting the true label from the spatial_tags
        idx = [[np.array_equal(rounded_l, tag) for tag in spatial_tags].index(True)]
#         print("Found {} matching word-synset tags.".format(len(idx)))
        word_synset = target_vocab[idx] #list of list 
#         print("Matching word-synset", word_synset)
        # check if word_synset is within the vicinity matrix
        if len(word_synset) != 0:
            for e in word_synset:
                for key, val in vicinity_synsets.items():
#                     print("Searching in vicinity ... ")

#                     print("Checking if true label is in vicinity ...")
                    checked_synsets.append(e)
                    is_there = e[1] in val[:, 1]
                    checks += 1
                    contained.append(is_there)
                    
#                     print("1")
#                     print(checked_synsets)
#                     print(checks)
#                     print(contained)
                    
                    if is_there:
#                         print("The main true label <{}> is in the vacinity of the predicted tag.".format(e))
                        idx_e = np.where(val[:, 1] == e[1])
                        predicted.append(val[idx_e])
#                         print("Predicted 1: ", predicted)
                        distances.append(decode_key(key, vicinity_matrix)[idx_e][1])
#                         print("Distances 1: ", distances)
                    else:
#                         print("The main true label is not in vicinity ... ")
                        distances.append('no-distance')
#                         print("Searching if alternative true label synsets are in vicinity ... ")
                    # induce all the word-synset tuples that have same synset as true label.
                    # This double check is necessary since I choose the spatial tags in the training data randomly sometimes.
                    # get indices of all word-synsets sharing same synset (not same word)
                    ix = np.where(target_vocab == [_, e[1]])[0] # add [0] to indicate only the row index, not the column
#                     print("Indices ", ix)
                    if len(ix) != 0:
                        pos_syn = target_vocab[ix]
                        
#                         print("Possible synsets: ", pos_syn)
#                         print(target_vocab[:10])
                        for t in pos_syn:
                            checks += 1
                            checked_synsets.append(t)
                            is_near = t[1] in val[:, -1]
                            contained.append(is_near)
#                             print("2")
#                             print(checked_synsets)
#                             print(checks)
#                             print(contained)
                            if is_near == True:                                    
#                                 print("... The word-synset <{}> is in the vicinity of the predicted tag.".format(t))
                                idx_t = np.where(val[:, -1] == t[1])
                                predicted.append(val[idx_t])
#                                 print("Predicted 2: ", predicted)
                                distances.append(decode_key(key, vicinity_matrix)[idx_t][1])
#                                 print("Distances 2: ", distances)
                            else:
                                distances.append('no-distance')
                    else: 
                        print("... There are no other possibilites for word-synset <{}>".format(e))
                            
        else:
            print("Cannot find the suitable synset of this spatial tag!")

        
    except ValueError as ve:
        print(ve)
#         print("Found no index for the true label. Something went wrong ...")
#         print("Comparing <true label = {}> with <rounded label = {}>".format(true_label, rounded_l))
    
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Statistics
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
#     print("~" * 80)
#     print("Statistics")
#     print("~" * 80)
    
#     print("Predicted Spatial Tag = ", spatial_params)
#     print("Checked Spatial Tag(s) ; contained? ; Predicted ; distances = ({}):".format(len(checked_synsets)))
    for s, c, p, d in zip(checked_synsets, contained, predicted, distances):
        print(s, ";", c, ";", "\n", p, ";", d)
        print("-"*100)
        
#     print("True Spatial Tag(s) is in vicinity of predicted tag: ", contained)
    contained_idx = np.where(np.array(contained) == True)
    
#     print("contained_idx", contained_idx)
#     print("checked_idx", np.array(checked_synsets)[contained_idx])
#     print("slice", np.array(checked_synsets)[:, 1])
#     print("check_slice", np.array(checked_synsets)[:, 1][contained_idx])

    if len(contained_idx[0]) > 0:
#         print()
#         print(contained_idx)
        only_syn = set(np.array(checked_synsets)[contained_idx])#[:, 1])
        associated_syn.append(only_syn)
#         print("True Sense Tag(s) = ({}) --> ".format(len(only_syn)), only_syn)
#         print("Prediction is correct!")
        in_vicinity = True
#         print("Distance(predicted_sense, nearest_true_sense) = ({}): ".format(len(np.array(predicted)[contained_idx])))
#         for p, d in zip(np.array(predicted), distances):
#               print(p, d)
              
    else:
#         print("Prediction is false ..")
#         print("All synsets in the vicinity of the predicted tag are not true senses ..")
#         print("Please check manually if the synsets in the vicinity are generalizations of the true labels.")
        in_vicinity = False
        associated_syn.append("no-synset")
    
    
    return in_vicinity, associated_syn
    

In [None]:
def count_parameters(model):
    'Counts the parameters of the model to allow comparision between different models.'
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Training / Validation

In [None]:
# I need to split all training data beforehand

class RegTagger:
    
    def __init__(self, use_cuda, device):
        self.use_cuda = use_cuda
        self.device = device
        torch.backends.cudnn.benchmark = True

        
    def train(self, batch_size: int, num_workers: int, max_epochs: int, 
              splittings: dict, path2data: str, data: list, embed_size: int,
              target_vocab: list, spatial_tags: list,
              k=5,
              d_model=300, d_hid=200, nlayers=2, nhead=2, dropout=0.2,
              lr=5.0, gamma=0.95,
              shuffle=True):
        
        # create batches
        
        # parameters
        params = {'batch_size': batch_size, #64,
                  'shuffle': shuffle,
                  'collate_fn': lambda x: x,
                  'num_workers': num_workers} #6} #set 0 if training on Windows machine

        # Training and validation data generators
        training_set = Dataset(splittings['train'], path2data)
        training_generator = torch.utils.data.DataLoader(training_set, **params)

        validation_set = Dataset(splittings['validate'], path2data)
        validation_generator = torch.utils.data.DataLoader(validation_set, **params)

        # -------------------------------------------------
    
        # history to store the losses
        history = defaultdict(list)

        VOCAB, weights_matrix = load_vocab(data, embed_size=embed_size)

        # target_VOCAB
        # SPATIAL_TAGS

    

        #######################################################################################################################
        #        Count sentences and number of words in training and validation datasets to normalize the loss
        #######################################################################################################################
        nb_words_training = 0
        nb_train_sentences = 0
        nb_words_validation = 0

        for batch in training_generator:
            for sentence, label, syn in batch:
                nb_train_sentences += 1
                nb_words_training += len(sentence)

        for batch in validation_generator:
            for sentence, label, syn in batch:
                nb_words_validation += len(sentence)

#         print("Count results:")
#         print("nb_words_training = {}".format(nb_words_training))
#         print("nb_train_sentences = {}".format(nb_train_sentences))
#         print("nb_words_validation = {}".format(nb_words_validation))

#         print(params["batch_size"])
        n_batches = np.ceil(nb_train_sentences / batch_size)
#         print("ceiling", n_batches)

        mean_words = nb_words_training / n_batches
#         print("mean_words", mean_words)


        self.model = TransformerEncoderRegressor(weights_matrix = weights_matrix, 
                                            ntoken= len(VOCAB), #300,
                                            out_features=5,
                                            d_model=d_model,
                                            d_hid=d_hid,
                                            nlayers=nlayers,
                                            nhead=nhead,
                                            dropout=dropout)
        self.model.to(self.device)

        # ---------------------------------------------------------------------
        #                       Optimizer
        # ---------------------------------------------------------------------
        # criterion = nn.CrossEntropyLoss()
        criterion = nn.MSELoss()
#             lr = 5.0  # learning rate
        optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=gamma)
        # -------


        # Loop over epochs
        for epoch in range(max_epochs):

            t0 = time.time()

            loss_sum = 0

            self.model.train()

            # for transformer
            scheduler.step()


            print("Training ...")
            # Training
            for batch in training_generator:
#                 print("New Batch for Training")
#                 print("#" * 100)

                for local_batch, local_labels, local_synsets in batch:

                    # Transform list(<string>) to Tensor(<Tensor>)
#                     print("Input Sentence:")
#                     print(local_batch)
                    input_words = local_batch
                    local_batch = numericalize(local_batch, VOCAB)
#                     print(type(local_batch), local_batch)


                    # Transform List(<Tensor>) to Tensor(<Tensor>)
                    # I have labels of same length --> this should be no problem for Tensor
                    local_labels = torch.stack(local_labels)
#                     print("Labels:")
#                     print(local_synsets)
#                     print(type(local_labels), len(local_labels), type(local_labels[0]))
#                     print(local_labels)

                    # Transfer to GPU
                    local_batch, local_labels = local_batch.to(self.device), local_labels.to(self.device)

                    # Model computations
                    # out outputs the indices of wordnet database
                    out = self.model(local_batch)
#                     print("Model's Output")
#                     print(type(out), out.shape)
                    # print(out)
                    # predicted synsets
#                     print("Current Predictions based on vacinity of prediction")
#                     print("*" * 100)
#                     print("*" * 100)


                    # ntokens = len(VOCAB)#300
                    loss = geometric_loss(out, local_labels) / mean_words
                    # criterion(out.view(-1), local_labels.view(-1))
#                     print("Loss")
#                     print(type(loss), loss.size())
#                     print(loss)

                    optimizer.zero_grad()
                    loss.backward()
                    # I added this
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
                    # ---
                    optimizer.step()
                    loss_sum += loss.item()
#                     print("Loss Sum", loss_sum)


                    train_loss = loss_sum / len(local_batch)
                    history['train_loss'].append(train_loss)
#                     print(history)
#                     print(len(history['train_loss']))


            # Evaluate on the validation set.
            # evaluate every 1 step:

            print("Validation ...")
            vloss_sum = 0
            if epoch % 1 == 0:

                correct_sense = 0
                sense_accuracy = 0

                # set model to eval mode to ignore updating the weights of the model
                self.model.eval()

                # do not calculate gradients while evaluating
                with torch.set_grad_enabled(False):

                    for batch in validation_generator:
                        print("New Batch for Validation")
                        print("#" * 100)

                        for local_batch, local_labels, local_synsets in batch:

                            # Transform list(<string>) to Tensor(<Tensor>)
                            print("Input Sentence")
                            print(local_batch)
                            input_words = local_batch
                            local_batch = numericalize(local_batch, VOCAB)
        #                     print(type(local_batch), local_batch)


                            # Transform List(<Tensor>) to Tensor(<Tensor>)
                            # I have labels of same length --> this should be no problem for Tensor
                            local_labels = torch.stack(local_labels)
                            print("Labels:")
                            print(local_synsets)
        #                     print("Labels")
        #                     print(type(local_labels), len(local_labels), type(local_labels[0]))
        #                     print(local_labels)

                            # Transfer to GPU
                            local_batch, local_labels = local_batch.to(self.device), local_labels.to(self.device)

                            # Model computations
                            # out outputs the indices of wordnet database
                            out = self.model(local_batch)

                            # During validation and testing, I want to be less strict.
                            # So, if a point resides within the label sphere, the sense is correctly identified.
                            loss = geometric_loss(out, local_labels, include_r=True)

                            vloss_sum += loss.item()                  

                            validation_loss = vloss_sum / len(local_batch)
                            history['validation_loss'].append(validation_loss)
#                             print(history)
#                             print(len(history['validation_loss']))

                            correct_sense_batch = 0
#                             print("Initializing the corrext sense batch = {}".format(correct_sense_batch))

                            true_pred = []
                            predicted_synsets = []

                            for i, word_tag in enumerate(out):
#                                 print("i = ", i)
#                                 print("+"*150)
#                                 print("word_tag = ", word_tag.size())
#                                 print(word_tag)
#                                 print("+"*150)

                                vindices, vmat, vsyn = vicinity_matrix(spatial_params=word_tag,
                                                               target_vocab=target_vocab[:100],
                                                               spatial_tags=spatial_tags[:100], k=k)
#                                 print("Vicinity Matrix-Synsets: {}".format(vsyn))


                                in_vic, pred_syn = label_in_vicinity(vicinity_matrix=vmat, vicinity_synsets=vsyn,
                                                           target_vocab=target_vocab[:100], 
                                                           spatial_tags=spatial_tags[:100], true_label=local_labels[i])
                                
                                true_pred.append(in_vic)
                                predicted_synsets.append(pred_syn)
                                
#                                 print("In Vicinity? --> {}".format(in_vic))
#                                 print("Predicted synsets --> {}".format(pred_syn))

                                if in_vic==True:
                                    correct_sense += 1
                                    correct_sense_batch += 1

                            print(true_pred)
                            print(predicted_synsets)
                        
                            batch_acc = correct_sense_batch / len(local_batch)
                            history["sense_accuracy"].append(batch_acc)
#                             print("correct sense batch ({}) / local_batch ({}) = {}".format(correct_sense_batch, len(local_batch), batch_acc))


                        t1 = time.time()
                        print(f'Epoch {epoch}: train loss = {train_loss:.4f}, batch accuracy: {batch_acc:.4f}, time = {t1-t0:.4f}')


                sense_accuracy = correct_sense / nb_words_validation

                print("The sense accuracy on the validation set is {} %".format(sense_accuracy * 100))
                
        # **************************************************************************************************************
        # Plot Histogram 
        # **************************************************************************************************************
        data1 = history["train_loss"] 
        data2 = history["sense_accuracy"]

        fig, ax1 = plt.subplots()

        color = 'tab:red'
        ax1.set_xlabel('time (s)')
        ax1.set_ylabel('loss', color=color)
        ax1.plot(data1, color=color)
        ax1.tick_params(axis='y', labelcolor=color)

        ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

        color = 'tab:blue'
        ax2.set_ylabel('accuracy', color=color)  # we already handled the x-label with ax1
        ax2.plot(data2, color=color)
        ax2.tick_params(axis='y', labelcolor=color)

        fig.tight_layout()  # otherwise the right y-label is slightly clipped
        plt.show()

        return history
    
    
    # assuming the sentence is already splitted into tokens, e.g. ['fall', 'in', 'catastrophes']
    def test(self, testing_data, path, data, batch_size, num_workers, target_vocab, spatial_tags, k=5, shuffle=True):
        
        # parameters
        params = {'batch_size': batch_size, #64,
                  'shuffle': shuffle,
                  'collate_fn': lambda x: x,
                  'num_workers': num_workers} #6} #set 0 if training on Windows machine

        # Training and validation data generators
        testing_set = Dataset(testing_data, path)
        testing_generator = torch.utils.data.DataLoader(testing_set, **params)
        
        # ------
        # Count words in sentence to calculate accuracy
        # ------
        nb_words_testing = 0

        for batch in testing_generator:
            for sentence, label, syn in batch:
                nb_words_testing += len(sentence)
                
        # --------------------------
        VOCAB, weights_matrix = load_vocab(data, embed_size=embed_size)


        # ---------------------------  
        # testing
        # ---------------------------
        correct_sense = 0
        sense_accuracy = 0
        
        t0 = time.time()

        # set model to eval mode to ignore updating the weights of the model
        self.model.eval()

        # do not calculate gradients while evaluating
        with torch.set_grad_enabled(False):

            for batch in testing_generator:
                print("Batches for testing")
                print("#" * 100)

                for local_batch, local_labels, local_synsets in batch:

                    # Transform list(<string>) to Tensor(<Tensor>)
                    print("Input Sentence")
                    print(local_batch)
                    input_words = local_batch
                    local_batch = numericalize(local_batch, VOCAB)
#                     print(type(local_batch), local_batch)


                    # Transform List(<Tensor>) to Tensor(<Tensor>)
                    # I have labels of same length --> this should be no problem for Tensor
                    local_labels = torch.stack(local_labels)
                    print("Labels:")
                    print(local_synsets)
#                     print("Labels")
#                     print(type(local_labels), len(local_labels), type(local_labels[0]))
#                     print(local_labels)

                    # Transfer to GPU
                    local_batch, local_labels = local_batch.to(self.device), local_labels.to(self.device)

                    # Model computations
                    # out outputs the indices of wordnet database
                    out = self.model(local_batch)

                    # During validation and testing, I want to be less strict.
                    # So, if a point resides within the label sphere, the sense is correctly identified.
                    loss = geometric_loss(out, local_labels, include_r=True)

                    vloss_sum += loss.item()                  

                    validation_loss = vloss_sum / len(local_batch)
                    history['testing_loss'].append(validation_loss)
#                             print(history)
#                             print(len(history['validation_loss']))

                    correct_sense_batch = 0
#                             print("Initializing the corrext sense batch = {}".format(correct_sense_batch))

                    true_pred = []
                    predicted_synsets = []

                    for i, word_tag in enumerate(out):

                        vindices, vmat, vsyn = vicinity_matrix(spatial_params=word_tag,
                                                       target_vocab=target_vocab,
                                                       spatial_tags=spatial_tags, k=k)
#                                 print("Vicinity Matrix-Synsets: {}".format(vsyn))

        
                        in_vic, pred_syn = label_in_vicinity(vicinity_matrix=vmat, vicinity_synsets=vsyn,
                                                   target_vocab=target_vocab, 
                                                   spatial_tags=spatial_tags, true_label=local_labels[i])
        
                        print("In Vicinity? --> {}".format(in_vic))
                        print("Predicted synsets --> {}".format(pred_syn))
            
                        true_pred.append(in_vic)
                        predicted_synsets.append(pred_syn)
                        

                        if in_vic:
                            correct_sense += 1
                            correct_sense_batch += 1

                    print(true_pred)
                    print(predicted_synsets)
                    
                    batch_acc = correct_sense_batch / len(local_batch)
                    history["sense_accuracy"].append(batch_acc)
#                             print("correct sense batch ({}) / local_batch ({}) = {}".format(correct_sense_batch, len(local_batch), batch_acc))

                    
                t1 = time.time()
                print(f'batch accuracy: {batch_acc:.4f}, time = {t1-t0:.4f}')


        sense_accuracy = correct_sense / nb_words_validation

        print("The sense accuracy on the testing set is {} %".format(sense_accuracy * 100))

        
        
    
    def tag(self, sentence, embed_size, target_vocab, spatial_tags, k):
        print("Initial Input: ", sentence)
        
        if isinstance(sentence, str):
            # preprocess the sentence, such that the lemmatized sentence is returned
            lemm_sentence = preprocess(sentence)
            tokens = list(map(lambda x: x[0], lemm_sentence))
            
        if isinstance(sentence, list):
            lst2str = " ".join(sentence)
            lemm_sentence = preprocess(lst2str)
            tokens = list(map(lambda x: x[0], lemm_sentence))
            
            
        
        N = len(tokens)
        tags = '?' * N
        print("Lemmatized Sentence: ", tokens)
        
        #print(tags)
        
        data = tokens
        
        # words embeddings
        vocab, wmat = load_vocab(data, embed_size)
        
        # numericalize words
        num_data = numericalize(data, vocab)
        
        num_data = num_data.to(self.device)
        
        out = self.model(num_data)
        
        distances = []
        predicted_synsets = []

        for i, word_tag in enumerate(out):

            vindices, vmat, vsyn = vicinity_matrix(spatial_params=word_tag,
                                           target_vocab=target_vocab,
                                           spatial_tags=spatial_tags, k=k)
#                                 print("Vicinity Matrix-Synsets: {}".format(vsyn))
            
            predicted_synsets.append(vsyn)
    
            distances.append(vmat)
            #distances.append(decode_key(vsyn.keys(), vmat))
            
        for i in range(N):
            print(data[i], "\t", tags[i], "\t", predicted_synsets[i].items())
            print()
            
        
        return predicted_synsets

        

     

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
#torch.backends.cudnn.benchmark = True


T = RegTagger(use_cuda=use_cuda, device=device)
T.train(batch_size=5, num_workers=0, max_epochs=10,
        splittings=splittings, path2data=path+"pwngc_id.pt", data=data, embed_size=300,
        target_vocab=target_VOCAB, spatial_tags=SPATIAL_TAGS)

In [None]:
sentence = ['Hundred', 'babies', 'are', 'one', 'years', 'old', '.']
sentag = T.tag(sentence, 300, target_VOCAB[:100], SPATIAL_TAGS[:100], 5)
