In [5]:
import torch
import torch.nn as nn

EMBED_DIMENSION = 300 
EMBED_MAX_NORM = 1 

class SkipGram_Model(nn.Module):
    def __init__(self, vocab_size: int):
        super(SkipGram_Model, self).__init__()
        self.embeddings = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=EMBED_DIMENSION,
            max_norm=EMBED_MAX_NORM,
        )
        self.linear = nn.Linear(
            in_features=EMBED_DIMENSION,
            out_features=vocab_size,
        )

    def forward(self, inputs_):
        x = self.embeddings(inputs_)
        x = self.linear(x)
        return x

model=torch.load('my_model5.pth')
# get vocab from its file.json
import json
with open('vocab.json') as f:
    vocab = json.load(f)

In [6]:
import numpy as np
# embedding from first model layer
embeddings = list(model.parameters())[0]
embeddings = embeddings.cpu().detach().numpy()

# normalization
norms = (embeddings ** 2).sum(axis=1) ** (1 / 2)
norms = np.reshape(norms, (len(norms), 1))
embeddings_norm = embeddings / norms


In [7]:
def getWordByID(id):
    for word in vocab:
        if vocab[word] == id:
            return word
    return "Word not found"

def get_top_similar(word: str, topN: int = 5):
    if word not in vocab:
        return "Out of vocabulary word"
    word_id = vocab[word]
    if word_id == 0:
        print("Out of vocabulary word")
        return

    word_vec = embeddings_norm[word_id]
    word_vec = np.reshape(word_vec, (len(word_vec), 1))
    dists = np.matmul(embeddings_norm, word_vec).flatten()
    topN_ids = np.argsort(-dists)[1 : topN + 1]

    topN_dict = {}
    for sim_word_id in topN_ids:
        sim_word = getWordByID(sim_word_id)
        topN_dict[sim_word] = dists[sim_word_id]
    return word,topN_dict

In [8]:
listOfChec=["father","kids","beautiful","cant","most","bihday","acti","pretty","today","down","intelligent"]

for word in listOfChec:
    print(get_top_similar(word))


('father', {'dad': 0.9047119, 'papa': 0.8454022, 'fathersday': 0.8272533, 'daddy': 0.7851294, 'fathers': 0.78254426})
('kids', {'babies': 0.75217015, 'parents': 0.7090419, 'faces': 0.6878179, 'eyes': 0.6774586, 'customers': 0.6727046})
('beautiful', {'lovely': 0.8039284, 'friend': 0.79314685, 'singing': 0.7827179, 'gorgeous': 0.78039616, 'planned': 0.7487232})
('cant', {'cannot': 0.9375809, 'ugh': 0.7699217, 'see': 0.76709133, 'either': 0.6912188, 'watch': 0.6822595})
('most', {'greatest': 0.75844485, 'best': 0.7475693, 'spirit': 0.72895867, 'pageant': 0.7234706, 'candidates': 0.7214298})
('bihday', {'bday': 0.9237944, 'anniversary': 0.85013413, 'cake': 0.8446919, 'celebrating': 0.81943274, 'rg': 0.81838214})
Out of vocabulary word
('pretty', {'so': 0.8132702, 'boring': 0.73079556, 'too': 0.7271501, 'glad': 0.72213, 'ok': 0.7192018})
('today', {'tomorrow': 0.7445729, 'thursday': 0.72951555, 'sunny': 0.7254831, 'training': 0.7179226, 'staing': 0.71141094})
('down', {'musical': 0.6605348

In [9]:
def getSimilarity(word1,word2):
    if word1 not in vocab or word2 not in vocab:
        return 0
    emb1 = embeddings[vocab[word1]]
    emb2 = embeddings[vocab[word2]]
    emb1 = np.reshape(emb1, (len(emb1), 1))
    emb2 = np.reshape(emb2, (len(emb2), 1))
    emb1_norm = (emb1 ** 2).sum() ** (1 / 2)
    emb2_norm = (emb2 ** 2).sum() ** (1 / 2)
    emb1 = emb1 / emb1_norm
    emb2 = emb2 / emb2_norm
    return np.matmul(emb1.T, emb2).flatten()[0]

In [10]:
getSimilarity("sad","funny")

0.7285083

In [11]:
dataGiven=[]
with open ("SimLex-999/SimLex-999.txt", 'r') as f:
    for line in f:
        lineLis=line.split()
        if lineLis[0] == "word1":
            continue
        dataGiven.append([lineLis[0],lineLis[1],float(lineLis[3])])
dataGot=[]
for i in dataGiven:
    dataGot.append([i[0],i[1]])
for i in range(len(dataGot)):
    dataGot[i].append(getSimilarity(dataGot[i][0],dataGot[i][1])*10)

num=0
for i in range(len(dataGot)):
    if dataGot[i][2] != 0:
        print(dataGot[i][2],dataGiven[i][2],dataGot[i][0],dataGot[i][1])
        num+=1
print(num)

3.744574785232544 1.58 old new
6.310285329818726 8.77 hard difficult
6.845147609710693 0.95 hard easy
5.428016185760498 9.17 happy glad
7.355071306228638 9.58 stupid dumb
5.000467300415039 8.42 bad awful
6.3031744956970215 0.58 easy difficult
4.880278706550598 7.78 bad terrible
5.298185348510742 1.38 hard simple
4.094451069831848 9.57 insane crazy
3.523419499397278 0.95 happy mad
5.22808313369751 9.47 large huge
4.91813063621521 8.05 hard tough
4.935943186283112 6.83 new fresh
4.063326716423035 1.28 happy angry
4.803216457366943 9.4 simple easy
3.3346158266067505 0.87 old fresh
5.5012595653533936 0.72 weird normal
3.600437641143799 9.2 weird odd
7.285082936286926 0.95 sad funny
8.08781385421753 8.05 wonderful great
7.548066973686218 6.38 guilty ashamed
7.263550162315369 6.5 beautiful wonderful
2.9763615131378174 8.27 confident sure
5.14740526676178 9.55 large big
3.6645162105560303 3.17 strong proud
3.498377799987793 0.75 dumb intelligent
3.9724135398864746 0.35 bad great
4.70750331878

In [12]:
num=0
ME=0
for i in range(len(dataGot)):
    if dataGot[i][2] != 0:
        ME+=dataGot[i][2]-dataGiven[i][2]
        num+=1
ME/=num
print(abs(ME), num)

thresh=0.5
num=0
numData=0
for i in range(len(dataGot)):
    if dataGot[i][2] != 0:
        if dataGot[i][2] > thresh and dataGiven[i][2] > thresh:
            num+=1
        if dataGot[i][2] < thresh and dataGiven[i][2] < thresh:
            num+=1
        numData+=1
print(num/numData)



1.0615117245286196 295
0.9593220338983051
