In [27]:
import torch
import torch.nn.functional as F
import os
import re

from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

seed = 265
torch.manual_seed(seed)

device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Training on device {device}.")

Training on device cpu.


In [16]:
TOKENIZER = get_tokenizer('basic_english')
PATH_GENERATED = './generated/'
MIN_FREQ = 100

def read_files(datapath='./data_train/'):
    files = os.listdir(datapath)
    files = [datapath + f for f in files if f.endswith('.txt')]

    texts = []
    for file in files:
        with open(file) as f:
            texts += f.readlines()
    return texts

def tokenize(texts, tokenizer=TOKENIZER):
    tokenized_text = []
    for text in texts:
        tokenized_text += tokenizer(text)
    return tokenized_text

def yield_tokens(texts, tokenizer=TOKENIZER):
    # Match any word containing digit
    no_digits = '\w*[0-9]+\w*'
    # Match word containing a uppercase 
    no_names = '\w*[A-Z]+\w*'
    # Match any sequence containing more than one space
    no_spaces = '\s+'

    for text in texts:
        text = re.sub(no_digits, ' ', text)
        text = re.sub(no_names, ' ', text)
        text = re.sub(no_spaces, ' ', text)
        yield tokenizer(text)

def count_freqs(words, vocab):
    freqs = torch.zeros(len(vocab), dtype=torch.int)
    for w in words:
        freqs[vocab[w]] += 1
    return freqs

def create_vocabulary(lines, min_freq=MIN_FREQ):
    """
    Create a vocabulary (list of known tokens) from a list of strings
    """
    # vocab contains the vocabulary found in the data, associating an index to each word
    vocab = build_vocab_from_iterator(yield_tokens(lines), min_freq=min_freq, specials=["<unk>"])
    # Since we removed all words with an uppercase when building the vocabulary, we skipped the word "I"
    vocab.append_token("i")
    # Value of default index. This index will be returned when OOV (Out Of Vocabulary) token is queried.
    vocab.set_default_index(vocab["<unk>"])
    return vocab

In [17]:
# ----------------------- Tokenize texts -------------------------------
# Load tokenized versions of texts if you have already generated it
# Otherwise, create it and save it
if os.path.isfile(PATH_GENERATED + "words_train.pt"):
    words_train = torch.load(PATH_GENERATED + "words_train.pt")
    words_val   = torch.load(PATH_GENERATED + "words_val.pt")
    words_test  = torch.load(PATH_GENERATED + "words_test.pt")
else:
    # Get lists of strings, one for each line in each .txt files in 'datapath' 
    lines_books_train = read_files('./data_train/')
    lines_books_val   = read_files('./data_val/')
    lines_books_test  = read_files('./data_test/')

    # List of words contained in the dataset
    words_train = tokenize(lines_books_train)
    words_val   = tokenize(lines_books_val)
    words_test  = tokenize(lines_books_test)
    
    torch.save(words_train, PATH_GENERATED + "words_train.pt")
    torch.save(words_val, PATH_GENERATED + "words_val.pt")
    torch.save(words_test, PATH_GENERATED + "words_test.pt")



# ----------------------- Create vocabulary ----------------------------
VOCAB_FNAME = "vocabulary.pt"
# Load vocabulary if you have already generated it
# Otherwise, create it and save it
if os.path.isfile(PATH_GENERATED + VOCAB_FNAME):
    vocab = torch.load(PATH_GENERATED + VOCAB_FNAME)
else:
    # Create vocabulary based on the words in the training dataset
    vocab = create_vocabulary(lines_books_train, min_freq=MIN_FREQ)
    torch.save(vocab, PATH_GENERATED + VOCAB_FNAME)
    


# ------------------------ Quick analysis ------------------------------
VOCAB_SIZE = len(vocab)
print("Total number of words in the training dataset:     ", len(words_train))
print("Total number of words in the validation dataset:   ", len(words_val))
print("Total number of words in the test dataset:         ", len(words_test))
print("Number of distinct words in the training dataset:  ", len(set(words_train)))
print("Number of distinct words kept (vocabulary size):   ", VOCAB_SIZE)

freqs = count_freqs(words_train, vocab)
print("occurences:\n", [(f.item(), w) for (f, w) in zip(freqs, vocab.lookup_tokens(range(VOCAB_SIZE)))])

Total number of words in the training dataset:      2684706
Total number of words in the validation dataset:    49526
Total number of words in the test dataset:          124152
Number of distinct words in the training dataset:   52105
Number of distinct words kept (vocabulary size):    1880
occurences:
 [(433907, '<unk>'), (182537, ','), (151278, 'the'), (123727, '.'), (82289, 'and'), (65661, 'of'), (62763, 'to'), (49230, 'a'), (41477, 'in'), (31052, 'that'), (37167, 'he'), (29046, 'was'), (26508, 'his'), (26354, 'it'), (20862, 'with'), (20159, 'had'), (19965, 'is'), (15692, 'not'), (16593, 'as'), (15705, 'on'), (14464, 'him'), (15317, 'for'), (15838, 'at'), (15952, 'you'), (13255, 'be'), (12698, 'her'), (12798, 's'), (11924, 'which'), (11808, '!'), (11740, 'all'), (10338, '?'), (10205, 'have'), (10405, 'from'), (13251, 'but'), (11464, 'this'), (9439, 'by'), (11496, 'they'), (8797, 'said'), (8800, 'are'), (11055, 'she'), (9537, 'one'), (8219, 'were'), (8564, 'who'), (8345, 'so'), (9409

In [18]:
# ------------------------ Define targets ------------------------------
def compute_label(w):
    """
    helper function to define MAP_TARGET
    
    - 0 = 'unknown word'
    - 1 = 'punctuation' (i.e. the '<unk>' token)
    - 2 = 'is an actual word'
    """
    if w in ['<unk>']:
        return 0
    elif w in [',', '.', '(', ')', '?', '!']:
        return 1
    else:
        return 2

# true labels for this task:
MAP_TARGET = {
    vocab[w]: compute_label(w) for w in vocab.lookup_tokens(range(VOCAB_SIZE))
}

# context size for this task 
CONTEXT_SIZE = 3

# ---------------- Define context / target pairs -----------------------
def create_dataset(
    text, vocab, 
    context_size=CONTEXT_SIZE, map_target=MAP_TARGET
):
    """
    Create a pytorch dataset of context / target pairs from a text
    """
    
    n_text = len(text)
    n_vocab = len(vocab)
    
    # Change labels if only a few target are kept, otherwise, each word is
    # associated with its index in the vocabulary
    if map_target is None:
        map_target = {i:i for i in range(n_vocab)}
    
    # Transform the text as a list of integers.
    txt = [vocab[w] for w in text]

    # Start constructing the context / target pairs...
    contexts = []
    targets = []
    for i in range(n_text - context_size):
        
        # Word used to define target
        t = txt[i + context_size]
        
        # Context before the target
        c = txt[i:i + context_size]
        
        targets.append(map_target[t])
        contexts.append(torch.tensor(c))
            
    # contexts of shape (N_dataset, context_size)
    # targets of shape  (N_dataset)
    contexts = torch.stack(contexts)
    targets = torch.tensor(targets)
    # Create a pytorch dataset out of these context / target pairs
    return TensorDataset(contexts, targets)

In [19]:
def load_dataset(words, vocab, fname):
    """
    Load dataset if already generated, otherwise, create it and save it
    """
    # If already generated
    if os.path.isfile(PATH_GENERATED + fname):
        dataset = torch.load(PATH_GENERATED + fname)
    else:
        # Create context / target dataset based on the list of strings
        dataset = create_dataset(words, vocab)
        torch.save(dataset, PATH_GENERATED + fname)
    return dataset

data_train = load_dataset(words_train, vocab, "data_train.pt")
data_val   = load_dataset(words_val, vocab, "data_val.pt")
data_test  = load_dataset(words_test, vocab, "data_test.pt")

In [20]:
class MyMLP(nn.Module):
    
    def __init__(self, embedding=None, context_size=CONTEXT_SIZE):
        super().__init__()
        
        (vocab_size, embedding_dim) = embedding.weight.shape
        # Instantiate an embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
            
        # Regular MLP
        self.fc1 = nn.Linear(embedding_dim*context_size, 128)
        self.fc2 = nn.Linear(128, 3)

    def forward(self, x):
        # x is of shape (N, context_size) but contains integers which can
        # be seen as equivalent to (N, context_size, vocab_size) since one hot
        # encoding is used under the hood
        out = self.embedding(x)
        # out is now of shape (N, context_size, embedding_dim)
        
        out = F.relu(self.fc1(torch.flatten(out, 1)))
        # out is now of shape (N, context_size*embedding_dim)
        
        out = self.fc2(out)
        return out

In [None]:
torch.manual_seed(seed)

# Load the pretrained embedding 
if os.path.isfile("embedding.pt"):
    embedding = torch.load("embedding.pt").to(device=device)
else:
    raise ValueError("Embedding not found at the given location")

MODEL_FNAME = "model.pt"

batch_size = 512
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(data_val, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(data_test, batch_size=batch_size, shuffle=True)

model = MyMLP(embedding)

if os.path.isfile(PATH_GENERATED + MODEL_FNAME):
    # Load the trained model
    model = torch.load(PATH_GENERATED + MODEL_FNAME)
    model.to(device)
else:
    # Or train the model...
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()
    n_epochs=30

    train(n_epochs, optimizer, model, loss_fn, train_loader)
    # ... and save it
    torch.save(model.to(device="cpu"), PATH_GENERATED + MODEL_FNAME)

acc_train = compute_accuracy(model, train_loader)
acc_val = compute_accuracy(model, val_loader)
print("Training Accuracy:     %.4f" %acc_train)
print("Validation Accuracy:   %.4f" %acc_val)