In [1]:
import torch
import torch.utils.data.dataloader as dataloader
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import os
import re
import sys
import gc
import time
import copy
from tqdm import tqdm
from collections import Counter

In [2]:
text = []
for file in os.listdir('Holmes_Training_Data/'):
    with open(os.path.join('Holmes_Training_Data', file), 'r', errors='ignore') as f:
        text.extend(f.read().splitlines())

text = [x.replace('*', '') for x in text]
text = [re.sub('[^ \fA-Za-z0-9_]', '', x) for x in text]
text = [x for x in text if x != '']

In [3]:
raw_text = []
for x in text:
    raw_text.extend(x.split(' '))
raw_text = [x for x in raw_text if x != '']

In [4]:
del text
gc.collect()

0

In [5]:
vocab = set(raw_text)
vocab_size = len(vocab)
freqs = Counter(raw_text)

In [6]:
print(vocab_size)

369970


In [7]:
def make_context_vector(context, word_to_ix):
    idxs = [word_to_ix[w] for w in context]
    return torch.tensor(idxs, dtype=torch.long)


word_to_ix = {word: i for i, word in enumerate(vocab)}
ix_to_word = {i: word for i, word in enumerate(vocab)}
data = []
for i in range(2, len(raw_text) - 2):
    context = [raw_text[i - 2], raw_text[i - 1],
               raw_text[i + 1], raw_text[i + 2]]
    target = raw_text[i]
#     context_ids = make_context_vector(context, word_to_ix)
#     label = torch.tensor([word_to_ix[target]], dtype=torch.long)
    data.append((context, target))
print(data[:5])
del raw_text
del vocab
gc.collect()

[(['The', 'Project', 'Etext', 'of'], 'Gutenberg'), (['Project', 'Gutenberg', 'of', 'Prester'], 'Etext'), (['Gutenberg', 'Etext', 'Prester', 'John'], 'of'), (['Etext', 'of', 'John', 'by'], 'Prester'), (['of', 'Prester', 'by', 'John'], 'John')]


0

In [8]:
def neg_sample(num_samples, positives=[]):
    freqs_pow = torch.Tensor([freqs[ix_to_word[i]] for i in range(vocab_size)]).pow(0.75)
    dist = freqs_pow / freqs_pow.sum()
    w = np.random.choice(len(dist), (len(positives), num_samples), p=dist.numpy())
    if positives.is_cuda:
        return torch.tensor(w).to(device)
    else:
        return torch.tensor(w)

In [9]:
class CBOW(nn.Module):
    def __init__(self, vocab_size, embedding_dim, context_size):
        super(CBOW, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.embeddings.weight.data.uniform_(-0.5 / vocab_size, 0.5 / vocab_size)
    def forward(self, inputs, label):
        negs = neg_sample(5, label)
        u_embeds = self.embeddings(label).view(len(label), -1)
        v_embeds_pos = self.embeddings(inputs).mean(dim=1)
        v_embeds_neg = self.embeddings(negs).mean(dim=1)
        loss1 = torch.diag(torch.matmul(u_embeds, v_embeds_pos.transpose(0, 1)))
        loss2 = torch.diag(torch.matmul(u_embeds, v_embeds_neg.transpose(0, 1)))
        loss1 = -torch.log(1 / (1 + torch.exp(-loss1)))
        loss2 = -torch.log(1 / (1 + torch.exp(loss2)))
        loss = (loss1.mean() + loss2.mean())
        return(loss)

In [10]:
CONTEXT_SIZE = 2
batch_size = 8192 * 2
device = torch.device('cuda:0')
# device = torch.device('cpu')
losses = []
# loss_function = nn.NLLLoss()
# loss_function = NEGLoss(ix_to_word, word_freqs)
model = CBOW(vocab_size, embedding_dim=100,
             context_size=CONTEXT_SIZE*2)
model.to(device)
# model = torch.nn.DataParallel(model, device_ids=[0, 1]).cuda()
optimizer = optim.SGD(model.parameters(), lr=0.1)

In [11]:
data_iter = torch.utils.data.DataLoader(data, batch_size=batch_size, 
                                        shuffle=False)

In [12]:
for epoch in range(1):
    total_loss = torch.Tensor([0])
    for context, target in tqdm(data_iter):
        context_ids = []
        for i in range(len(context[0])):
            context_ids.append(make_context_vector([context[j][i] for j in range(len(context))], word_to_ix))
        context_ids = torch.stack(context_ids)
        context_ids = context_ids.to(device)
        model.zero_grad()
        label = make_context_vector(target, word_to_ix)
        label = label.to(device)
        loss = model(context_ids, label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    losses.append(total_loss)
    print('epoch %d loss %.4f' %(epoch, total_loss))
print(losses)

100%|██████████| 2522/2522 [17:09<00:00,  2.45it/s]
  0%|          | 0/2522 [00:00<?, ?it/s]

epoch 0 loss 3496.2087


100%|██████████| 2522/2522 [17:08<00:00,  2.45it/s]

epoch 1 loss 3496.2087
[tensor([3496.2087]), tensor([3496.2087])]



