In [1]:
import torch
from torch import nn
torch.manual_seed(0)
import random
random.seed(0)
from transformers import BertTokenizer, BertModel
from util import *
from loss_functions import *
from models import *
from loss_functions import *
import numpy as np
import sys
import time
import umap
import umap.plot
import json
import pandas as pd

In [2]:
def load_model(model, config_path, model_path):
    config = json.load(open(config_path,'r'))
    gen_config = config['generator']
    generator = model(gen_config['k_assignment'],
                      gen_config['token_dim'],
                      gen_config['hid_dim'],
                      gen_config['sense_dim'],
                      gen_config['modals'],
                      gen_config['sub_module_use_mlp'],
                      'cpu'
                     )
    generator.load_state_dict(torch.load(model_path))
    return generator

In [5]:
generator_mindist = load_model(Mapping_and_Centroids, 'configs/config1.json', 'checkpoints/model_mindist_1/generator.pth')
generator_mlp = load_model(Mapping_and_Centroids, 'configs/config2.json', 'checkpoints/model_mlp_1/generator.pth')

In [3]:
generator_mindist_sc = load_model(Mapping_and_Shared_Centroids, 'configs/config1_1.json', 'checkpoints/model_mindist_1_share_centroids/generator.pth')
generator_mlp_sc = load_model(Mapping_and_Shared_Centroids, 'configs/config2_1.json', 'checkpoints/model_mlp_1_share_centroids/generator.pth')

In [31]:
generator_mindist_sc1 = load_model(Mapping_and_Shared_Centroids, 'configs/config1_1.json', 'checkpoints/model_mindist_1_share_centroids_1/generator.pth')

In [6]:
test_sents = {modal: get_sents_from_file("../modal_sentences/test/{}.txt".format(modal)) for modal in generator_mindist.modals}

In [7]:
bert = BertModel.from_pretrained("bert-base-uncased", output_hidden_states=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [8]:
embedding_sentences = {modal: get_bert_embeddings(tokenizer,bert,sents,modal,'cpu') for modal, sents in test_sents.items()}

In [9]:
embedding_tensors = {modal: value[0] for modal, value in embedding_sentences.items()}
sentences = {modal: value[1] for modal, value in embedding_sentences.items()}

In [16]:
gen_out = generator_mindist_sc(embedding_tensors)
sense_embeddings = gen_out['embeddings_dict']
sense_class = gen_out['sense_class_dict']
centroids = gen_out['centroids']

centroids.shape

torch.Size([3, 300])

In [10]:
def plot(generator,
         embedding_tensors,
         sentences,
         share_centroids):
    gen_out = generator(embedding_tensors)
    sense_embeddings = gen_out['embeddings_dict']
    sense_class = gen_out['sense_class_dict']
    centroids = gen_out['centroids']
    
    _labels = list()
    _embeddings = list()
    _sentences = list()
    
    if share_centroids:
        _embeddings = centroids.detach().numpy()
        str_labels = ["sense_{}".format(label) for label in range(len(centroids))]
        _labels = str_labels
        _sentences += str_labels
    
    for modal in sense_embeddings.keys():
        embs = sense_embeddings[modal].detach().numpy()
        str_labels = ["{}_{}".format(modal,label) for label in sense_class[modal]]
        _labels += str_labels
        
        if not share_centroids:
            cents = centroids[modal].detach().numpy()
            str_labels = ["{}_centroid_{}".format(modal,i) for i in range(len(cents))]
            _labels += str_labels
        else:
            cents = centroids.detach().numpy()
            

        if len(_embeddings) == 0:
            _embeddings = embs
            if not share_centroids:
                _embeddings = np.append(_embeddings, cents,axis = 0)
        else:
            _embeddings = np.append(_embeddings, embs,axis = 0)
            if not share_centroids:
                _embeddings = np.append(_embeddings, cents,axis = 0)

        _sentences += sentences[modal]
        if not share_centroids:
            _sentences += str_labels



    hover_data = pd.DataFrame({'index':np.arange(len(_labels)),
                               'label':_labels,
                               'sentences':_sentences})
    umap.plot.output_notebook()
    mapper = umap.UMAP(n_neighbors=10).fit(_embeddings)
    p = umap.plot.interactive(mapper, labels=_labels, hover_data=hover_data, point_size=5)
    umap.plot.show(p)
    

In [41]:
plot(generator_mindist_sc1, embedding_tensors, sentences,True)

In [11]:
plot(generator_mlp_sc, embedding_tensors, sentences,True)

In [26]:
plot(generator_mindist_sc,embedding_tensors,sentences,True)

In [12]:
plot(generator_mindist,embedding_tensors,sentences,False)

In [13]:
plot(generator_mlp,embedding_tensors,sentences,False)