### Assignment 6
    Akhil Kanna Devarashetti

Question:

    Write a Pytorch version of the Word2vec/skip-gram displayed in Chapter 14 of d2l.  
    In particular, make DL24.py error free 
    Implement get_similar_tokens as an application of the word embedding model 
    (14.4.3 of d2l and also the last slide in lecture)


In [None]:
import collections
import random
import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim

In [None]:
raw_text = open('ptb/ptb.train.txt', "r").read()
#raw_text = raw_text[:600]

In [None]:
## Get the sentences and tokens (words)
sentences = [line.split() for line in raw_text.split('\n')]
tokens = [tk for line in sentences for tk in line]
num_tokens = len(tokens)

# Get unique tokens (words) with count > 10
counter = collections.Counter(tokens)
uniq_tokens = [token for token, freq in list(counter.items()) if counter[token] >= 0]  # Make counter 10

# Create hash map of the unique words and indices
idx_to_token, token_to_idx = uniq_tokens, dict()
for i in range(len(uniq_tokens)):
    token_to_idx[uniq_tokens[i]] = i

In [None]:
subsampled = []

for line in sentences:
    sub_sampled_line = []
    for token in line:
        random_number = random.uniform(0, 1)

        order_of_magnitude = round(math.log10(num_tokens))
        inverse_frequency = num_tokens / (10 ** order_of_magnitude) / counter[token]
        
        if random_number < inverse_frequency:
            sub_sampled_line.append(token)
    subsampled.append(sub_sampled_line)


In [None]:
corpus = [[token_to_idx.get(tk) for tk in line] for line in subsampled]
tokens = [tk for line in corpus for tk in line]
counter = collections.Counter(tokens)
sampling_weights = [counter[i]**0.75 for i in range(len(counter))]
population = list(range(len(sampling_weights)))
candidates = random.choices(population, sampling_weights, k=(10**order_of_magnitude))

In [None]:
max_window_size = 5
K = 5
j = 0
data = []
maxLen = 0
for line in corpus:
    if len(line) < 2:
        continue
    for i in range(len(line)):
        window_size = random.randint(1, max_window_size)
        indices = list(range(max(0, i - window_size), min(len(line), i + 1 + window_size)))
        indices.remove(i)
        for idx in indices:
            context = [line[idx] for idx in indices]
        neg = []
        while len(neg) < len(context) * K:
            ne = candidates[j]
            j += 1
            if j >= 10**order_of_magnitude:
                j = 0
            if ne not in context:
                neg.append(ne)
        data.append([line[i], context, neg])

In [None]:
max_len = max(len(c) + len(n) for _, c, n in data)
centers, contexts_negatives, labels = [], [], []
for center, context, negative in data:
    cur_len = len(context) + len(negative)
    centers += [center]
    contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
    labels += [[1] * len(context) + [0] * (max_len - len(context))]

In [None]:
class PTBdataset(torch.utils.data.Dataset):
    def __init__(self):
        super(PTBdataset).__init__()
        self.centers = np.array(centers).reshape(-1, 1)
        self.contexts_negatives = np.array(contexts_negatives)
        self.labels = np.array(labels)

    def __len__(self):
        return len(self.centers)

    def __getitem__(self, idx):
        return self.centers[idx], self.contexts_negatives[idx], self.labels[idx]

In [None]:
pdata = PTBdataset()
data_iter = torch.utils.data.DataLoader(pdata, batch_size=512, shuffle=True)

vocab_size = len(idx_to_token)
embed_size = 100

In [None]:
net = nn.Sequential(
    nn.Embedding(vocab_size, embed_size),
    nn.Embedding(vocab_size, embed_size))
loss = nn.BCELoss()
optimizer = optim.SGD(net.parameters(), 0.01)
m = nn.Sigmoid()

In [None]:
for epoch in range(500):
    for i, batch in enumerate(data_iter):
        center, context_negative, label = batch
        v = net[0](center.to(torch.int64))
        u = net[1](context_negative.to(torch.int64))
        pred = torch.tensordot(v, torch.transpose(u, 1, 2))
        l = loss(m(pred), label.to(torch.float32))
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        if (epoch + 1) % 50 == 0:
            print(epoch, i, float(l))

In [None]:
# Derived from the code in the link: https://d2l.ai/chapter_natural-language-processing-pretraining/similarity-analogy.html#finding-synonyms

def knn(W, x, k):
    cos = torch.matmul(W, x.reshape(-1,)) / (
        (torch.sum(W * W, dim=1) + 1e-9).sqrt() * torch.sum(x * x).sqrt())
    _, topk = torch.topk(cos, k=k)
    topk = topk.cpu().numpy()
    return topk, [cos[i].item() for i in topk]

def get_similar_tokens(query_token, k, embed):
    W = embed.weight.data
    x = W[token_to_idx[query_token]]
    topk, cos = knn(W, x, k+1)
    for i, c in zip(topk[1:], cos[1:]):
        print('cosine sim=%.3f: %s' % (c, (idx_to_token[i])))

get_similar_tokens('group', 3, net[0])

In [None]:
words_for_similarity = ['chip']

for word in words_for_similarity:
    print(f"\nSimilarity for '{word}':")
    get_similar_tokens(word, 3, net[0])
