In [217]:
import pandas as pd
import pickle as pk
from tqdm import tqdm
import torch
import torch.nn as nn
import dictionary_corpus
from dictionary_corpus import Corpus
import numpy as np
from sklearn.decomposition import PCA, SparsePCA, KernelPCA, IncrementalPCA
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import os
out_dir = 'contextual_embeddings'
pd.set_option('display.max_rows', 30)
pd.set_option('display.max_columns', 30)
matplotlib.use('webagg')
types = ['c0','h0','c1','h1']
# load the model and the corpus
model = torch.load('hidden650_batch128_dropout0.2_lr20.0.pt',map_location=torch.device('cpu'))
corpus = Corpus('')
# print("Vocab size %d", ntokens)

In [240]:
def euclidean(x,y):
        dist = np.linalg.norm(x-y)
        return dist.round(2)
def cos(x,y):
    dist = np.dot(x,y)/(np.linalg.norm(x)*np.linalg.norm(y))
    return dist.round(2)

def modulated_word_heat(d0,tar:str):
    # assuming the input has 5 columns: tensor labels file prev target, the index is a unique integer
    df = d0.where(d0['target']==tar).dropna().drop(['labels','target'],axis=1)
    df['tmp'] = df.index
    df['id'] =  df['tmp'].apply(lambda x: str(x)+'_')+df['prev'] + df['file'].apply(lambda x: '_'+str(x))
    df = df.drop(['prev','file'],axis=1)
    for word in df.index:
        col = df.loc[word]['id']
        df[col] = df['tensors'].apply(lambda x: euclidean(x,df.loc[word]['tensors']))
    return df.drop(['tensors','tmp'],axis=1).set_index('id')

def all_mod_heat(d0,words: list):
    def check(x):
        if x in words:
            return x
        else:
            return np.nan
    # assuming the input has 5 columns: tensor labels file prev target, the index is a unique integer 
    df = d0.copy()
    df['tmp'] = df.index
    df['id'] =  df['tmp'].apply(lambda x: str(x)+'_')+df['prev']+' '+df['target'] + df['file'].apply(lambda x: '_'+str(x))
    df['sort'] = df['target']+df['file']
    df['v']=df['file'].apply(lambda x: int(x))
    
    df = df.sort_values(by='sort')
    df['target'] = df['target'].apply(lambda x: check(x))
    df = df.dropna()
#     print(d3.shape)
    df = df.drop(['prev','file','target','tmp','labels','sort','v'],axis=1)
    for idx in tqdm(df.index):
        col = df.loc[idx]['id']
        df[col] = df['tensors'].apply(lambda x: euclidean(x,df.loc[idx]['tensors']))
    return df.drop(['tensors'],axis=1).set_index('id')

In [330]:
embed_dict = {}
for i in types:
    with open(f'{out_dir}/all_sent_{i}','rb') as f:
        embed_dict[i] = pk.load(f)

def l2_normalisation(embed_dict,types=types):
    l2 = {}
    for t in types:
        l2[t] = np.linalg.norm(embed_dict[t]['tensors'].agg('mean'))
        embed_dict[t]['tensors'] = embed_dict[t]['tensors']/l2[t]
    return embed_dict
l2_normalised_embed_dict = l2_normalisation(embed_dict)

In [307]:
def single_type_word_dist(df,extraction_type):
    #four embedding is a dict key = c0 h0 c1 h1, val =  6 columns
    res = {}
    for word,dfx in df.groupby(by='target'):
        if word == 'went':
            continue
        dfa = dfx[dfx['file']=='1']
        dfb = dfx[dfx['file']=='2']
        x1 = dfa['tensors'].agg('mean')
        x2 = dfb['tensors'].agg('mean')
        res[word] = euclidean(x1,x2)
    file_dis = pd.DataFrame.from_dict(res, orient='index',columns = [f'file_dist_{extraction_type}']).reset_index().rename(columns={"index": "target"})
    return file_dis

In [309]:
def four_types_of_dist(embed_dict,types = types, normalised=False):
    for t in types:
        delta_df = single_type_word_dist(embed_dict[t],t)
        if t == 'c0':
            df = delta_df
        else:
            df = pd.merge(df,delta_df, on='target')
    if normalised == True:
        l2 = {}
        for t in types:
            l2[t] = np.linalg.norm(embed_dict[t]['tensors'].agg('mean'))
            df[f'file_dist_{t}'] = (df[f'file_dist_{t}']/l2[t]).apply(lambda x:round(x,2))
    return df

In [315]:
four_types_dist_all_words = four_types_of_dist(embed_dict, normalised=True).set_index('target')
four_types_dist_all_words

Unnamed: 0_level_0,file_dist_c0,file_dist_h0,file_dist_c1,file_dist_h1
target,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Switzerland,0.42,0.92,0.50,0.79
bank,0.66,1.34,1.28,2.10
book,1.19,2.12,1.57,1.98
books,1.10,2.04,1.44,1.83
can,1.04,1.94,1.26,1.64
...,...,...,...,...
thought,1.14,2.55,1.51,1.98
tomatoes,0.50,0.94,0.78,1.13
transistor,0.46,0.87,0.51,0.86
watch,1.17,2.42,1.91,2.32


# DF1 is emb of all data

In [301]:
df1 = embed_dicts['h1'].copy()
wh = modulated_word_heat(df1,'duck')
wh

Unnamed: 0_level_0,692_a_2,693_little_2,694_wild_2,695_a_2,696_odd_2,697_sitting_2,698_a_2,699_last_2,700_a_2,701_a_2,702_the_2,703_rubber_2,704_black_2,705_dead_2,706_a_2,...,1074_the_1,1075_the_1,1076_of_1,1077_crispy_1,1078_her_1,1079_whole_1,1080_and_1,1081_of_1,1082_roasted_1,1083_fried_1,1084_roast_1,1085_or_1,1086_the_1,"1087_,_1","1088_,_1"
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1
692_a_2,0.00,5.16,5.54,4.70,4.89,5.69,4.04,5.23,4.97,4.36,5.38,4.44,5.45,4.60,4.68,...,4.14,5.14,5.55,4.90,4.36,5.28,5.67,5.43,4.94,4.63,4.78,4.81,5.25,5.48,4.83
693_little_2,5.16,0.00,4.32,4.18,3.95,4.97,3.99,4.14,4.57,4.31,4.65,4.12,4.04,3.65,4.14,...,4.44,4.09,4.66,3.83,4.48,3.84,5.58,4.92,4.21,4.04,4.29,5.03,4.38,5.66,5.09
694_wild_2,5.54,4.32,0.00,4.83,4.82,5.44,4.62,4.58,4.88,4.64,4.54,4.92,4.57,4.23,4.74,...,5.11,4.82,4.69,4.50,5.27,4.39,5.91,5.06,4.72,4.96,4.77,4.93,4.45,5.77,5.00
695_a_2,4.70,4.18,4.83,0.00,4.03,4.55,3.37,4.33,3.57,3.41,4.76,4.05,4.66,3.36,2.50,...,4.81,4.55,4.56,4.14,4.63,4.21,6.04,4.87,4.48,4.70,4.79,5.11,4.52,5.51,5.12
696_odd_2,4.89,3.95,4.82,4.03,0.00,4.10,3.28,4.55,4.34,3.49,4.75,4.16,4.45,3.39,3.53,...,4.70,4.62,4.77,3.77,4.49,4.40,6.35,5.39,4.59,4.77,4.84,4.94,4.75,6.02,5.34
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1084_roast_1,4.78,4.29,4.77,4.79,4.84,5.45,4.33,5.16,5.30,4.69,5.26,4.46,5.01,4.56,4.87,...,4.64,4.50,5.46,4.49,4.85,5.03,5.53,5.01,4.48,4.14,0.00,4.75,4.55,5.38,4.57
1085_or_1,4.81,5.03,4.93,5.11,4.94,5.75,4.49,5.43,5.50,4.98,5.62,5.19,5.30,4.71,4.99,...,4.78,5.32,5.14,4.54,4.87,5.13,4.96,4.35,4.77,4.88,4.75,0.00,4.80,4.75,4.78
1086_the_1,5.25,4.38,4.45,4.52,4.75,5.69,4.68,5.06,5.03,4.64,4.44,4.82,4.64,4.57,4.36,...,4.38,4.58,4.93,4.48,4.55,4.63,5.50,4.14,4.63,4.48,4.55,4.80,0.00,5.41,5.03
"1087_,_1",5.48,5.66,5.77,5.51,6.02,6.25,5.39,5.96,5.94,5.79,6.03,5.73,6.26,5.65,5.75,...,5.39,5.86,5.60,5.36,5.48,5.76,5.38,4.46,5.64,5.18,5.38,4.75,5.41,0.00,4.02


In [302]:
mask1 = np.zeros_like(wh)
mask1[np.triu_indices_from(mask1)] = True

In [303]:
fig, ax = plt.subplots(figsize=(8, 6))# plot heatmap
sns.heatmap(wh,vmin=2,vmax=10,mask=mask1)
plt.title(f'heat within {3}')
plt.show()

# Single word experiment

In [344]:
target = 'moves'
four_types_target_embeddings = {}
for t in types:
    df = l2_normalised_embed_dict[t]
    four_types_target_embeddings[t] = df[df['target']==target].dropna()

In [345]:
wind_dist = four_types_dist_all_words.loc['moves']
wind_dist

file_dist_c0    1.12
file_dist_h0    2.49
file_dist_c1    1.52
file_dist_h1    2.06
Name: moves, dtype: float64

In [352]:
heat_data_list = []
for i in types:
    heat_data = all_mod_heat(four_types_target_embeddings[i],[target])
    heat_data_list.append(heat_data)

100%|██████████████████████████████████████████| 39/39 [00:00<00:00, 217.85it/s]
100%|██████████████████████████████████████████| 39/39 [00:00<00:00, 213.64it/s]
100%|██████████████████████████████████████████| 39/39 [00:00<00:00, 215.37it/s]
100%|██████████████████████████████████████████| 39/39 [00:00<00:00, 214.98it/s]


In [349]:
# heat_data_list[0]

In [348]:
mask_d = np.zeros_like(heat_data_list[0])
mask_d[np.triu_indices_from(mask_d)] = True
f, axes = plt.subplots(2,2,figsize=(6, 6))
axes[0][0].set_title(f'{target} c0')
axes[0][1].set_title(f'{target} h0')
axes[1][0].set_title(f'{target} c1')
axes[1][1].set_title(f'{target} h1')
sns.heatmap(heat_data_list[0],vmin=1,vmax=4,xticklabels=False,mask=mask_d,ax=axes[0][0])
sns.heatmap(heat_data_list[1],vmin=1,vmax=4,xticklabels=False, yticklabels=False, mask=mask_d,ax=axes[0][1])
sns.heatmap(heat_data_list[2],vmin=1,vmax=4,xticklabels=False, yticklabels=False, mask=mask_d,ax=axes[1][0])
sns.heatmap(heat_data_list[3],vmin=1,vmax=4,xticklabels=False, yticklabels=False, mask=mask_d,ax=axes[1][1])
plt.show()