In [1]:
from timeit import default_timer as timer
from datetime import timedelta

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import tqdm

In [3]:
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

Using device cuda


In [4]:
%load_ext autoreload
%autoreload 2

import data_handling as data
import preprocess as pp
from models import nvdm

In [5]:
# # Original data
# DATA_RAW_PATH = "./data/bds_1.txt"
# IDs, BDs = data.load_raw(DATA_RAW_PATH)

In [6]:
# Data that has already been preprocessed
# Generated by applying pp.preprocess_text() to each BD,
# then saved to a TSV
DATA_CLEAN_PATH = "./data/bds_1_clean.txt"
IDs_raw, BDs_raw = data.load_raw(DATA_CLEAN_PATH)

In [7]:
# Some entries have empty BDs, so filter those out
IDs = []
BDs = []
for iid, bd in zip(IDs_raw, BDs_raw):
    if len(bd) > 0:
        IDs.append(iid)
        BDs.append(bd)

print(len(IDs), len(BDs))

2034 2034


Following PyTorch's tutorial for data setup.
https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html

In [8]:
# Build frequency table
# (cleaned data joins tokens by space)
counter = Counter()
for desc in BDs:
    counter.update(desc.split(" "))

In [9]:
# PyTorch torchtext vocabulary converts tokens to indices and vice versa.
# Also has an '<unk>' for OOV words (might be useful later).
vocab = Vocab(counter,
              max_size=10000,
              min_freq=1,
              specials=['<unk>'])
print(len(vocab))
# actual is 70770 without max_size restriction

10001


In [10]:
class BDDataset(Dataset):
    """ Very simple dataset object. Stores all the passages.
    
    This is just for compatibility with PyTorch DataLoader.
    """
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [11]:
# "Preprocessing" function: just splits the text
# The file's text is already preprocessed.
def text_pipeline(text):
    return [vocab[token] for token in text.split(" ")]

def collate_batch(batch):
    """ Convert a batch of text (each a list of tokens) into appropriate torch tensors.
    
    Modification of https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html.
    We don't need labels.
    """
    # Offsets tells the model (which will use EmbeddingBag) where each text starts.
    text_list, offsets = [], [0]
    for _text in batch:
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))

    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return text_list.to(device), offsets.to(device)

In [12]:
# Create data loader to iterate over dataset in batches during training/evaluation
dataset = BDDataset(BDs)
batch_size = 64
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

In [13]:
# Training setup

# Total number of epochs
outer_epochs = 100

# Epochs for training the encoder/decoder on each alternation.
inner_epochs = 10

hidden_size = 500
num_topics = 10
model = nvdm.NVDM(len(vocab), hidden_size, num_topics, 1, device)
model = model.to(device)
model.train()

# Separate the encoder from decoder parameters when training in an alternating manner.
# Also including linear layers than output mu and log(sigma)
# (not the most elegant method but works)
optim_encoder = torch.optim.Adam(
    list(model.encoder.parameters()) +
    list(model.mu.parameters()) +
    list(model.log_sigma.parameters()),
    lr=0.0001)
optim_decoder = torch.optim.Adam(model.decoder.parameters(), lr=0.001)

start_time = timer()

for epoch in range(outer_epochs):

    # Train the encoder and decoder in turns,
    # fixing the parameters of the other every time.
    for switch in range(0, 2):
        # Author's code trains the decoder first,
        # not sure if the order matters.
        if switch == 0:
            optimizer = optim_encoder
        else:
            optimizer = optim_decoder
    
        # Do training
        for alt_epoch in range(inner_epochs):
            
            loss_sum = 0.0
            rec_sum = 0.0
            kl_sum = 0.0
            n = len(data_loader)
            
            for idx, (text, offsets) in enumerate(data_loader):
                text = text.to(device)
                offsets = offsets.to(device)
                
                optimizer.zero_grad()
                logits, loss_dict = model(text, offsets)
                loss = loss_dict["total"]
                loss.backward()
                
                optimizer.step()
                
                # For printing
                loss_sum += loss.item()
                rec_sum += loss_dict["rec"].item()
                kl_sum += loss_dict["kl"].item()

            model_str = "Enc" if switch == 0 else "Dec"
            print(f"[Time: {timedelta(seconds=timer() - start_time)}, Epoch {epoch + 1}, {model_str} {alt_epoch + 1}] Loss {loss_sum/n}, Rec {rec_sum/n}, KL {kl_sum/n}")

[Time: 0:00:06.763975, Epoch 1, Enc 1] Loss 56350.20568847656, Rec 56103.13659667969, KL 247.06936407089233
[Time: 0:00:13.298490, Epoch 1, Enc 2] Loss 56267.88000488281, Rec 56054.489013671875, KL 213.39070320129395
[Time: 0:00:19.994478, Epoch 1, Enc 3] Loss 56259.10107421875, Rec 56048.7607421875, KL 210.3403434753418


KeyboardInterrupt: 