In [1]:
from collections import Counter
import numpy as np

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F

In [3]:
data_path = "data/text8.txt"

In [4]:
with open(data_path) as file:
    data = file.read()

data = data.split(" ")

In [5]:
freq = Counter(data)

In [6]:
freq_threshold = 5

vocab = dict()
count = 0

for key, value in freq.items():
    if value > freq_threshold:
        vocab[key] = count
        count += 1
    
inverse_vocab = {value: key for key, value in vocab.items()}

In [7]:
normed_freq = {key: freq[key]/len(data) for (key, value) in vocab.items()}

In [8]:
def get_prob_word(context):
    # Need to add negative sampling
    return [1/len(context) for _ in range(len(context))]

In [9]:
device = "cuda"
context_width = 3

class TextDataSet(Dataset):
    def __init__(self, data, context_width):
        self.data = data
        self.context_width = context_width
        
    def sample_context(self, idx):    
        first_index = max(0, idx - self.context_width)
        last_index = min(len(self.data), idx + self.context_width + 1)

        context = self.data[first_index:idx] + self.data[idx+1:last_index]
        context = list(filter(lambda x: x in vocab.keys(), context))

        if len(context) == 0:
            return None

        context = np.random.choice(context, p=get_prob_word(context))
        return vocab[context]
        
    def __len__(self):
        return len(data)
    
    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.sample_context(idx)
        
        if x not in vocab.keys() or y is None:
            return None, None
        
        x = vocab[x]
        
        return x, y

In [10]:
def collate_fn(batch):
    batch = list(filter(lambda z: z[0] is not None, batch))
    
    x = [z[0] for z in batch]
    x = torch.tensor(x).to(device)
    
    y = [z[1] for z in batch]
    y = torch.tensor(y).to(device)
    
    return x, y

In [11]:
batch_size = 64
dataset = TextDataSet(data, context_width)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn)

In [12]:
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocab_size, embedding_size, sparse=True)
        self.output_layer = nn.Linear(embedding_size, vocab_size)
    
    def forward(self, x):
        x = self.embedding_layer(x)
        #x = F.relu(x)
        x = self.output_layer(x)
        return x

In [13]:
embedding_size = 100
model = Word2Vec(len(vocab), embedding_size).to(device)

In [14]:
num_epochs = 1
lr = 1e-3
momentum = 0.9

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

In [15]:
for epoch in range(num_epochs):
    loss = 0.0
    
    for _, (x, y) in enumerate(train_loader):
        model.zero_grad()
        output = model(x)
        batch_loss = F.nll_loss(output, y)
        
        batch_loss.backward()
        optimizer.step()
        loss += batch_loss
        break
        
    print(loss)

tensor(1.00000e-02 *
       4.8362, device='cuda:0')
