# Analyze Word2Vec Embeddings

In [1]:
import pickle
from pathlib import Path

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

from word2vec.train import Word2Vec_SkipGram

In [2]:
# Parameters
model_path_str = "./.model/best_model.ckpt"
vocab_path_str = "./.out/vocab.obj"
viz_words = 600

In [3]:
# Preprocess params
model_path = Path(model_path_str)
vocab_path = Path(vocab_path_str)

In [4]:
# Load the model
model_pl = Word2Vec_SkipGram.load_from_checkpoint(model_path)
model = model_pl.model
model  # display model

SkipGramModel(
  (in_embed): Embedding(12928, 300)
  (out_embed): Embedding(12928, 300)
)

In [5]:
# Load vocabulary
with open(vocab_path, "rb") as fvocab:
    vocab = pickle.load(fvocab)

itos = vocab.get_itos()
vocab.lookup_tokens([0,1,2,3,4,5,6,7,8,9])

['<unk>', 'the', ',', '.', 'of', 'and', 'in', 'to', 'a', '=']

In [6]:
embeddings = model.in_embed.weight.cpu().data.numpy()
embeddings.shape

(12928, 300)

In [7]:
tsne = TSNE()
embed_tsne = tsne.fit_transform(embeddings[:viz_words, :])

In [8]:
# fig, ax = plt.subplots(figsize=(16, 16))
# for idx in range(viz_words):
#     plt.scatter(*embed_tsne[idx, :], color='steelblue')
#     plt.annotate(itos[idx], (embed_tsne[idx, 0], embed_tsne[idx, 1]), alpha=0.7)

In [9]:
import plotly.express as px
import pandas as pd

df = pd.DataFrame(data={
    "x": embed_tsne[:, 0].tolist(),
    "y": embed_tsne[:, 1].tolist(),
    "token": itos[:viz_words]
})
fig = px.scatter(
    df, x="x", y="y", hover_data={"token"}, width=1200, height=1200
)
fig.show()