In [None]:
from collections import Counter
import numpy as np
import math
import random
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F

In [None]:
data_path = "data/text8.txt"

In [None]:
device = "cuda"
context_width = 3
seed = 42

In [None]:
with open(data_path) as file:
    data = file.read()

data = data.split(" ")
data = data[0:500000] # Using a smaller dataset

In [None]:
freq_threshold = 5

vocab = dict()
count = 0

freq = Counter(data)
for key, value in freq.items():
    if value >= freq_threshold:
        vocab[key] = count
        count += 1
    
inverse_vocab = {value: key for key, value in vocab.items()}

In [None]:
data = list(filter(lambda x: x in vocab.keys(), data))
freq = Counter(data)

In [None]:
normed_freq = {key: freq[key]/len(data) for (key, value) in vocab.items()}

In [None]:
def prob_dropping(frequency, t=5e-4):
    return 1.0 - math.sqrt(t/frequency)

prob_drop_word = [prob_dropping(normed_freq[word]) for word in data]
prob_drop_word = [prob_word if prob_word > 0 else 0 for prob_word in prob_drop_word]

In [None]:
random.seed(seed)
data = filter(lambda x: random.choices([False, True], weights=[x[1], 1-x[1]])[0], zip(data, prob_drop_word))
data = [pair[0] for pair in data]

In [None]:
class TextDataSet(Dataset):
    def __init__(self, data, vocab, context_width):
        self.data = data
        self.context_width = context_width
        self.vocab = vocab
        
    def get_context(self, idx):    
        first_index = max(0, idx - self.context_width)
        last_index = min(len(self.data), idx + self.context_width + 1)

        context = self.data[first_index:idx] + self.data[idx+1:last_index]
        context = [self.vocab[c] for c in context]
        
        return context
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        y = self.get_context(idx)
        
        x = self.data[idx]
        x = self.vocab[x]
        x = [x for _ in range(len(y))]
        
        x = np.array(x)
        y = np.array(y)
        
        return x, y

In [None]:
def collate_fn(batch):        
    x = [z[0] for z in batch]
    y = [z[1] for z in batch]

    x = np.concatenate(x)
    x = torch.from_numpy(x).long().to(device)
 
    y = np.concatenate(y)
    y = torch.from_numpy(y).long().to(device)

    return x, y

In [None]:
batch_size = 256
dataset = TextDataSet(data, vocab, context_width)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn)

In [None]:
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocab_size, embedding_size)
        self.output_layer = nn.Linear(embedding_size, vocab_size)
        
    def forward(self, x):
        x = self.embedding_layer(x)
        x = self.output_layer(x)
        x = F.log_softmax(x, dim=1)
        return x

In [None]:
torch.manual_seed(seed)
embedding_size = 50
model = Word2Vec(len(vocab), embedding_size).to(device)

In [None]:
num_epochs = 2
lr = 0.1

optimizer = optim.SGD(model.parameters(), lr=lr)

In [None]:
def get_closest_words(model, idx, num_closest=5):
    weights = model.embedding_layer.weight.detach().cpu().numpy()

    word_vector = weights[idx]
    word_vector = word_vector.reshape((1, word_vector.shape[0]))

    distance = cosine_similarity(word_vector, weights)

    arg_distance = distance.argsort()
    closes_arg = arg_distance[0, arg_distance.shape[1] - num_closest - 1: (arg_distance.shape[1] - 1)]

    return_list = []

    for i in reversed(closes_arg):
        return_list.append((inverse_vocab[i], distance[0, i]))
        
    return return_list

In [None]:
example_word = "men"

In [None]:
torch.manual_seed(seed)
for epoch in range(1, num_epochs+1):
    print(f"----- Epoch {epoch} -----")
    print(get_closest_words(model, vocab[example_word]))
    
    loss = 0.0
    for i, (x, y) in enumerate(train_loader):
        model.zero_grad()
        output = model(x)
        batch_loss = F.nll_loss(output, y)
        
        batch_loss.backward()
        optimizer.step()
        loss += batch_loss.item()
        
    print("Loss : {}".format(loss/len(train_loader)))