In [4]:
import torch
from models import skipgram
import pandas as pd
import torch.nn.functional as F
import wandb

run = wandb.init(project="word2vec")
artifact = run.use_artifact("model-weights:latest")
datadir = artifact.download()

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [5]:
vocab = pd.read_csv('./data/vocab.generated.csv', index_col='token')

def getIdFromToken(token: str):
    try:
        return int(vocab.at[token, 'id'])
    except:
        return int(vocab.at['unk', 'id'])

def getTokenFromId(id: float):
    return vocab[vocab['id'] == id].index

In [6]:

model = skipgram.Model(vocab.size + 1, skipgram.EMBEDDING_DIM)

state_dict = torch.load('./data/weights.generated.pt')
# model.load_state_dict(weights)
# model.eval()

print(state_dict)


OrderedDict({'embeddings.weight': tensor([[ 1.9269,  1.4873,  0.9007,  ..., -0.4879, -0.9138, -0.6581],
        [ 0.0780,  0.5258, -0.4880,  ...,  0.4880,  0.7846,  0.0286],
        [ 0.6531,  0.6823,  1.3485,  ...,  2.0001,  0.7429, -0.7462],
        ...,
        [-1.0039,  0.4795,  0.2109,  ..., -0.1502, -1.1138, -0.6855],
        [-0.5354, -1.8264,  1.7771,  ...,  0.0122,  0.5465, -1.7636],
        [ 0.1913,  1.1691,  0.8277,  ..., -1.0659, -0.4108, -0.3927]]), 'linear.weight': tensor([[-0.0745, -0.0141,  0.0439,  ...,  0.0650,  0.1096, -0.1003],
        [ 0.0433,  0.0672,  0.3224,  ...,  0.2404,  0.1880, -0.3250],
        [-0.1883,  0.2614,  0.3691,  ...,  0.3557,  0.2716,  0.1511],
        ...,
        [ 0.0066, -0.4087,  0.1351,  ..., -0.2987, -0.0707, -0.2673],
        [-0.1029,  0.2545,  0.1688,  ..., -0.0226, -0.0985, -0.0734],
        [-0.3080, -0.1655,  0.1879,  ...,  0.0040,  0.2891, -0.3028]]), 'linear.bias': tensor([0.1127, 0.1187, 0.0247,  ..., 0.0525, 0.1041, 0.0242])})

In [8]:
#find the n most similar words
def most_similar(word, n=5):
    word_idx = getIdFromToken(word)
    A =  state_dict['embeddings.weight'][word_idx].unsqueeze(0)
    word_similarities = []
    for i in range(len(vocab)):
        B =  model.embeddings.weight[i].unsqueeze(0)
        cosine_similarity = F.cosine_similarity(A, B, dim=1)
        word_similarities.append((getTokenFromId(i), cosine_similarity.item()))
    word_similarities = sorted(word_similarities, key=lambda x: x[1], reverse=True)
    return word_similarities[:n]

most_similar('dog')

[(Index(['1520'], dtype='object', name='token'), 0.5886226892471313),
 (Index(['yutu'], dtype='object', name='token'), 0.5071915984153748),
 (Index(['alarm'], dtype='object', name='token'), 0.5047330856323242),
 (Index(['susan'], dtype='object', name='token'), 0.5008740425109863),
 (Index(['modernist'], dtype='object', name='token'), 0.49624010920524597)]