In [1]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

import random
import numpy as np

from tqdm import tqdm, tqdm_notebook

from collections import Counter

from poutyne.framework import Model, Experiment, OptimizerPolicy, sgdr_phases

In [2]:
torch.manual_seed(1010101011)
random.seed(1010101011)

In [3]:
dataset = "../data/postprocessed/ds_blogs.txt"
# dataset = "testds.txt"

In [4]:
class DS(DataLoader):
    def __init__(self, corpus_path, window_size, min_occurences):
        self.corpus_path = corpus_path
        
        self.counter = Counter()
        self.token2idx = {
            "<null>": 0
        }
        
        self.window_size = window_size
        self.half_of_window_size = int(self.window_size / 2)
        self.min_occurences = min_occurences
        
        self.sentences_len = 0
        
        with open(corpus_path) as f:
            for sentence in tqdm(f, desc="Counting tokens"):
                self.sentences_len += 1
                sentence = sentence.replace("\n", "")
                for token in sentence.split():
                    self.counter[token.lower()] += 1
        ' build vocab '            
        self.vocab = set([token for token in self.counter.keys() if self.counter[token] >= self.min_occurences])
        
        ' build token2idx '
        for token in tqdm(self.vocab, desc="Building token2idx"):
            self.token2idx[token] = len(self.token2idx)
            
        print("Tokens: {}".format(len(self.token2idx)))
        
        self.ds_x = []
        self.ds_y = []
        for inputs, output in tqdm(self.build_ds(), desc="Building dataset"):
            inputs, output = self.numericalize(inputs, output)
            self.ds_x.append(inputs)
            self.ds_y.append(output)
    
    def build_ds(self):
        with open(self.corpus_path) as f:
            for sentence in f:
                sentence = sentence.replace("\n", "").lower()
                
                sent_splt = sentence.split()
                
                uniq_sentence = set(sent_splt)
                
                if uniq_sentence.issubset(self.vocab) and len(sent_splt) >= self.window_size:
                    splitted = sent_splt
                    sentence_splitted = ['<null>' for _ in range(self.half_of_window_size)] + \
                        splitted + ['<null>' for _ in range(self.half_of_window_size)]
                    
                    index = len(sentence_splitted)
                    while len(splitted) > 0:
                        token = splitted.pop()
                        
                        inputs_left = sentence_splitted[index - 2 * self.half_of_window_size - 1:index - 2 * self.half_of_window_size + self.half_of_window_size - 1]
                        inputs_right = sentence_splitted[index - 2 * self.half_of_window_size + self.half_of_window_size:index - 2 * self.half_of_window_size + self.half_of_window_size + self.half_of_window_size]
                        
                        index = index - 1

                        yield inputs_left + inputs_right, token
    
    def numericalize(self, inputs, output):
        return [self.token2idx[token] for token in inputs], self.token2idx[output]
        
    def __getitem__(self, index):
        return torch.tensor(self.ds_x[index]).long(), torch.tensor(self.ds_y[index]).long()
    
    def __len__(self):
        return len(self.ds_x)

In [5]:
lr = 1e-2
bs = 5000
seq_len = 7
target_vectors = 100
embedding_size=30
epochs = 5
device = torch.device("cuda:4")
minimal_token_occurences = 10

In [6]:
ds = DS(dataset, seq_len, minimal_token_occurences)

Counting tokens: 8257486it [00:49, 167305.13it/s]
Building token2idx: 100%|██████████| 95067/95067 [00:00<00:00, 1323890.39it/s]
Building dataset: 39823it [00:00, 59383.51it/s]

Tokens: 95068


Building dataset: 71289233it [04:36, 257794.72it/s]


In [7]:
DL = DataLoader(dataset=ds, batch_size=bs, num_workers=4, shuffle=False, pin_memory=True)

In [8]:
class Sense2VecCBOW(nn.Module):
    def __init__(self, vocab_size, embedding_size, vectors, sequence_length):
        super(Sense2VecCBOW, self).__init__()

        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.vectors = vectors

        self.embeddings = nn.Embedding(vocab_size, embedding_size)
        self.fc_in = nn.Linear(embedding_size * (sequence_length - 1), vectors)
        self.fc_out = nn.Linear(vectors, vocab_size)

        self.init_weights()

    def init_weights(self):
        init_range = 0.1
        self.fc_in.weight.data.uniform_(-init_range, init_range)
        self.fc_out.weight.data.uniform_(-init_range, init_range)

    def get_weights(self):
        return self.fc_out.weight.cpu().detach().tolist()

    def forward(self, x):
        x = self.embeddings(x)
        x = self.fc_in(x.reshape(len(x), -1))
        # x = torch.relu(x)
        x = self.fc_out(x)
        return x

In [9]:
model = Sense2VecCBOW(
        len(ds.token2idx),
        embedding_size,
        target_vectors,
        seq_len
    )

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr=lr)

In [11]:
policy = OptimizerPolicy(
    sgdr_phases(len(DL), epochs, (3e-2, 3e-4), 2)
)

In [12]:
experiment = Experiment(
    '../experiments/t6',
    model,
    optimizer=optimizer,
    loss_function=criterion,
    batch_metrics=['accuracy'],
    monitor_metric='acc',
    monitor_mode='max',
    device=device
)

In [None]:
experiment.train(DL, 
                 epochs=epochs,
#                  callbacks=[policy]
                )

Epoch 1/5 1107.05s Step 14258/14258: loss: 4.202043, acc: 32.395840
Epoch 1: acc improved from -inf to 32.39584, saving file to ../experiments/t6/checkpoint_epoch_1.ckpt
Epoch 2/5 1100.82s Step 14258/14258: loss: 4.038350, acc: 33.773256
Epoch 2: acc improved from 32.39584 to 33.77326, saving file to ../experiments/t6/checkpoint_epoch_2.ckpt
Epoch 3/5 1136.27s Step 14258/14258: loss: 4.014506, acc: 33.949107
Epoch 3: acc improved from 33.77326 to 33.94911, saving file to ../experiments/t6/checkpoint_epoch_3.ckpt
Epoch 4/5 ETA 923s Step 4416/14258: loss: 3.941869, acc: 35.0000008