In [76]:
import torch
import torch.nn.functional as F
import os
import re

from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from datetime import datetime

seed = 265
torch.manual_seed(seed)

device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Training on device {device}.")

Training on device cpu.


In [77]:
TOKENIZER = get_tokenizer('basic_english')
PATH_GENERATED = './generated/'
MIN_FREQ = 100

def read_files(datapath='./data_train/'):
    files = os.listdir(datapath)
    files = [datapath + f for f in files if f.endswith('.txt')]

    texts = []
    for file in files:
        with open(file) as f:
            texts += f.readlines()
    return texts

def tokenize(texts, tokenizer=TOKENIZER):
    tokenized_text = []
    for text in texts:
        tokenized_text += tokenizer(text)
    return tokenized_text

def yield_tokens(texts, tokenizer=TOKENIZER):
    """
    Remove yield tokens from the text before tokenizing
    """

    # Remove words with digits, upper case, and multiple space 
    no_digits = '\w*[0-9]+\w*'
    no_names = '\w*[A-Z]+\w*'
    no_spaces = '\s+'

    for text in texts:
        text = re.sub(no_digits, ' ', text)
        text = re.sub(no_names, ' ', text)
        text = re.sub(no_spaces, ' ', text)
        yield tokenizer(text)

def count_freqs(words, vocab):
    freqs = torch.zeros(len(vocab), dtype=torch.int)
    for w in words:
        freqs[vocab[w]] += 1
    return freqs

def create_vocabulary(lines, min_freq=MIN_FREQ):
    """
    Create a vocabulary (list of known tokens) from a list of strings
    """
    vocab = build_vocab_from_iterator(yield_tokens(lines), min_freq=min_freq, specials=["<unk>"])
    vocab.append_token("i")  # Upper case words like 'I' were removed so we should add it back again.
    vocab.set_default_index(vocab["<unk>"])
    return vocab

In [78]:
# ----------------------- Tokenize texts -------------------------------

if os.path.isfile(PATH_GENERATED + "words_train.pt"):
    words_train = torch.load(PATH_GENERATED + "words_train.pt")
    words_val   = torch.load(PATH_GENERATED + "words_val.pt")
    words_test  = torch.load(PATH_GENERATED + "words_test.pt")
else:
    lines_books_train = read_files('./data_train/')
    lines_books_val   = read_files('./data_val/')
    lines_books_test  = read_files('./data_test/')

    words_train = tokenize(lines_books_train)
    words_val   = tokenize(lines_books_val)
    words_test  = tokenize(lines_books_test)
    
    torch.save(words_train, PATH_GENERATED + "words_train.pt")
    torch.save(words_val, PATH_GENERATED + "words_val.pt")
    torch.save(words_test, PATH_GENERATED + "words_test.pt")



# ----------------------- Create vocabulary ----------------------------

VOCAB_FNAME = "vocabulary.pt"
if os.path.isfile(PATH_GENERATED + VOCAB_FNAME):
    vocab = torch.load(PATH_GENERATED + VOCAB_FNAME)
else:
    vocab = create_vocabulary(lines_books_train, min_freq=MIN_FREQ)
    torch.save(vocab, PATH_GENERATED + VOCAB_FNAME)
    


# ------------------------ Quick analysis ------------------------------

VOCAB_SIZE = len(vocab)
freqs = count_freqs(words_train, vocab)
occurences = [(f.item(), w) for (f, w) in zip(freqs, vocab.lookup_tokens(range(VOCAB_SIZE)))]

In [79]:
n_print = 10
print("Total number of words in the training dataset:     ", len(words_train))
print("Total number of words in the validation dataset:   ", len(words_val))
print("Total number of words in the test dataset:         ", len(words_test))
print("Number of distinct words in the training dataset:  ", len(set(words_train)))
print("Number of distinct words kept (vocabulary size):   ", VOCAB_SIZE)

print(f"The {n_print} most occuring words:\n {occurences[:n_print]}")

Total number of words in the training dataset:      2684706
Total number of words in the validation dataset:    49526
Total number of words in the test dataset:          124152
Number of distinct words in the training dataset:   52105
Number of distinct words kept (vocabulary size):    1880
The 10 most occuring words:
 [(433907, '<unk>'), (182537, ','), (151278, 'the'), (123727, '.'), (82289, 'and'), (65661, 'of'), (62763, 'to'), (49230, 'a'), (41477, 'in'), (31052, 'that')]


In [80]:
CONTEXT_SIZE = 3

# ---------------- Define context / target pairs -----------------------
def create_dataset(text, vocab, context_size=CONTEXT_SIZE):
    """
    Create a pytorch dataset of context / target pairs from a text
    """
    
    # Transform each word to its index in the vocabulary.
    txt = [vocab[w] for w in text]

    n_text = len(text)
    contexts = []
    targets = []
    for i in range(n_text - context_size):
        
        t = txt[i + context_size]
        c = txt[i:i + context_size]
        
        targets.append(t) 
        contexts.append(torch.tensor(c).to(device=device))
            
    contexts = torch.stack(contexts)
    targets = torch.tensor(targets).to(device=device)
    return TensorDataset(contexts, targets)

In [81]:
def load_dataset(words, vocab, fname):
    """
    Load dataset if already generated, otherwise, create it and save it
    """
    
    if os.path.isfile(PATH_GENERATED + fname):
        dataset = torch.load(PATH_GENERATED + fname)
    else:
        dataset = create_dataset(words, vocab)
        torch.save(dataset, PATH_GENERATED + fname)
    return dataset

data_train = load_dataset(words_train, vocab, "data_train.pt")
data_val   = load_dataset(words_val, vocab, "data_val.pt")
data_test  = load_dataset(words_test, vocab, "data_test.pt")

In [82]:
class Word2Vec(nn.Module):
    
    def __init__(self, embedding, context_size=CONTEXT_SIZE):
        super().__init__()
        
        (vocab_size, embedding_dim) = embedding.weight.shape
        self.embedding = embedding

        self.fc1 = nn.Linear(embedding_dim*context_size, 128)
        self.fc2 = nn.Linear(128, vocab_size)

    def forward(self, x):
        out = self.embedding(x)
        out = F.relu(self.fc1(torch.flatten(out, 1)))        
        out = self.fc2(out)
        return out

In [86]:
def train(n_epochs, optimizer, model, loss_fn, train_loader, device=None):

    n_batch = len(train_loader)
    losses_train = []
    model.train()
    optimizer.zero_grad(set_to_none=True)

    for epoch in range(1, n_epochs + 1):

        loss_train = 0.0
        for contexts, targets in train_loader:

            contexts = contexts.to(device=device)
            targets = targets.to(device=device)

            outputs = model(contexts)

            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_train += loss.item()

        losses_train.append(loss_train / n_batch)

        if epoch == 1 or epoch % 5 == 0:
            print('{}  |  Epoch {}  |  Training loss {:.5f}'.format(
                datetime.now().time(), epoch, loss_train / n_batch))
    return losses_train


def compute_accuracy(model, loader, device=None):
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for contexts, targets in loader:
            contexts = contexts.to(device=device)
            targets = targets.to(device=device)

            outputs = model(contexts)
            _, predicted = torch.max(outputs, dim=1)
            total += len(targets)
            correct += int((predicted == targets).sum())

    acc =  correct / total
    return acc

In [84]:
batch_sizes = [64, 128, 256]
embedding_dims = [10, 16]

hparams = [{
    'batch_size': bs,
    'embedding_dim': em
 } for bs in batch_sizes for em in embedding_dims]

print(f"We are testing {len(hparams)} different hyper parameters.")

We are testing 6 different hyper parameters.


In [None]:
# DONT RUN I HAVE SAVED THE TRAINED MODELS IN GENERATED
models = []
embeddings = []
train_acc = []
val_acc = []

for param in hparams:
    print(f'Now training with parameters {param}')
    train_loader = DataLoader(data_train, batch_size=param['batch_size'], shuffle=True)
    val_loader   = DataLoader(data_val, batch_size=param['batch_size'], shuffle=True)

    embedding = nn.Embedding(VOCAB_SIZE, param['embedding_dim'])
    torch.manual_seed(seed)
    model = Word2Vec(embedding).to(device=device)

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()
    n_epochs=3
    train(n_epochs, optimizer, model, loss_fn, train_loader)

    models.append(model)
    embeddings.append(embedding)
    train_acc.append(compute_accuracy(model, train_loader))
    val_acc.append(compute_accuracy(model, val_loader))
    print()


In [69]:
torch.manual_seed(seed)

MODEL_FNAME = "model.pt"
EMBEDDING_DIM = 16

# Load the pretrained embedding 
if os.path.isfile("embedding.pt"):
    embedding = torch.load("embedding.pt").to(device=device)
else:
    embedding = nn.Embedding(len(vocab), EMBEDDING_DIM)
    torch.save(embedding, "embedding.pt")


batch_size = 512
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(data_val, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(data_test, batch_size=batch_size, shuffle=True)

model = Word2Vec(embedding)

if os.path.isfile(PATH_GENERATED + MODEL_FNAME):
    model = torch.load(PATH_GENERATED + MODEL_FNAME)
    model.to(device)
else:
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()
    n_epochs=10

    train(n_epochs, optimizer, model, loss_fn, train_loader)
    torch.save(model.to(device="cpu"), PATH_GENERATED + MODEL_FNAME)

acc_train = compute_accuracy(model, train_loader)
acc_val = compute_accuracy(model, val_loader)
print("Training Accuracy:     %.4f" %acc_train)
print("Validation Accuracy:   %.4f" %acc_val)

16:59:16.270479  |  Epoch 1  |  Training loss 4.28919
17:04:15.352456  |  Epoch 5  |  Training loss 3.87675
17:11:14.967903  |  Epoch 10  |  Training loss 3.81888
Training Accuracy:     0.2438
Validation Accuracy:   0.2309


In [70]:
model.embedding.weight.data[vocab['the']]

tensor([ 0.7452, -1.4874,  0.5123,  1.3052,  0.2934,  0.5553, -0.8712, -2.1692,
        -0.2182, -0.2197,  1.8711,  0.8342, -1.7075,  0.7678,  1.1717,  1.7013])

tensor([ 1.2513, -1.4227,  0.1993,  1.1488,  0.1734,  0.2339, -0.5957, -1.4216,
        -0.2747,  0.6735,  1.7830,  0.4373, -2.4252,  0.7154,  1.1046,  1.4459])

tensor([ 0.4724, -2.2209, -0.5557,  1.0608,  0.1002,  0.4285, -0.7487, -2.3987,
        -0.1100,  0.7012,  1.6068, -0.0888, -2.1806,  0.5736,  0.3631,  1.3329])