In [1]:
from collections import Counter
import numpy as np

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F

In [3]:
data_path = "data/text8.txt"

In [4]:
with open(data_path) as file:
    data = file.read()

In [5]:
data = data.split(" ")

In [6]:
freq = Counter(data)

In [7]:
freq_threshold = 5

vocab = dict()
count = 0

for key, value in freq.items():
    if value > freq_threshold:
        vocab[key] = count
        count += 1
    
inverse_vocab = {value: key for key, value in vocab.items()}

In [8]:
vocab["the"]

15

In [9]:
context_width = 3

def get_prob_word(context):
    # Need to add negative sampling
    return [1/len(context) for _ in range(len(context))]

def sample_context(data, index):
    # Assume that index > context_width for now
    # Bug when index + context +1 is larger than list
    
    context = data[index-context_width:index+context_width+1]
    context.pop(3)
    context = list(filter(lambda x: x in vocab.keys(), context))
    context = np.random.choice(context, p=get_prob_word(context))
    return vocab[context]

In [10]:
sample_context(data, 10)

10

In [11]:
device = "cuda"

class TextDataSet(Dataset):
    def __init__(self, data, context_width):
        self.data = data
        self.context_width = context_width
        
    def __len__(self):
        return len(data)
    
    def __getitem__(self, idx):
        x = self.data[idx]
        
        if x not in vocab.keys():
            return None, None
        
        x = vocab[x]        
        y = sample_context(data, idx)
        
        return x, y
    
def collate_fn(batch):
    batch = list(filter(lambda z: z[0] is not None, batch))
    
    x = [z[0] for z in batch]
    x = torch.tensor(x).to(device)
    
    y = [z[1] for z in batch]
    y = torch.tensor(y).to(device)
    
    return x, y

In [12]:
batch_size=64
dataset = TextDataSet(data, context_width)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn)

In [13]:
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocab_size, embedding_size)
    
    def forward(self, x):
        return self.embedding_layer(x)

In [14]:
embedding_size = 50
model = Word2Vec(len(vocab), embedding_size).to(device)

In [15]:
num_epochs = 1
lr = 1e-3

optimizer = optim.Adam(model.parameters(), lr=lr)

In [16]:
for epoch in range(num_epochs):
    loss = 0.0
    
    for _, (x, y) in enumerate(train_loader):
        model.zero_grad()
        output = model(x)
        batch_loss = F.nll_loss(output, y)
        
        batch_loss.backward()
        optimizer.step()
        loss += batch_loss
        break
        
    print(loss)

tensor(1.00000e-02 *
       9.9813, device='cuda:0')
