In [1]:
from timeit import default_timer as timer
from datetime import timedelta

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import tqdm

In [3]:
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

Using device cuda


In [4]:
%load_ext autoreload
%autoreload 2

import data_handling as data
import preprocess as pp
from models import nvdm

In [27]:
# # Original data
# DATA_RAW_PATH = "./data/bds_1.txt"
# IDs, BDs = data.load_raw(DATA_RAW_PATH)

In [5]:
# Data that has already been preprocessed
# Generated by applying pp.preprocess_text() to each BD,
# then saved to a TSV
DATA_CLEAN_PATH = "./data/bds_1_clean.txt"
IDs_raw, BDs_raw = data.load_raw(DATA_CLEAN_PATH)

In [6]:
# Some entries have empty BDs, so filter those out
IDs = []
BDs = []
for iid, bd in zip(IDs_raw, BDs_raw):
    if len(bd) > 0:
        IDs.append(iid)
        BDs.append(bd)

print(len(IDs), len(BDs))

2034 2034


Following PyTorch's tutorial for data setup.
https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html

In [7]:
# Build frequency table
# (cleaned data joins tokens by space)
counter = Counter()
for desc in BDs:
    counter.update(desc.split(" "))

In [8]:
# PyTorch torchtext vocabulary converts tokens to indices and vice versa.
# Also has an '<unk>' for OOV words (might be useful later).
vocab = Vocab(counter,
              max_size=10000,
              min_freq=1,
              specials=['<unk>'])
print(len(vocab))
# actual is 70770 without max_size restriction

10001


In [9]:
class BDDataset(Dataset):
    """ Very simple dataset object. Stores all the passages.
    
    This is just for compatibility with PyTorch DataLoader.
    """
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [10]:
# "Preprocessing" function: just splits the text
# The file's text is already preprocessed.
def text_pipeline(text):
    return [vocab[token] for token in text.split(" ")]

def collate_batch(batch):
    """ Convert a batch of text (each a list of tokens) into appropriate torch tensors.
    
    Modification of https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html.
    We don't need labels.
    """
    # Offsets tells the model (which will use EmbeddingBag) where each text starts.
    text_list, offsets = [], [0]
    for _text in batch:
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))

    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return text_list.to(device), offsets.to(device)

In [11]:
# Create data loader to iterate over dataset in batches during training/evaluation
dataset = BDDataset(BDs)
batch_size = 64
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)
hidden_size = 500
num_topics = 10

In [None]:
# Training setup

# Total number of epochs
outer_epochs = 200

# Epochs for training the encoder/decoder on each alternation.
inner_epochs = 1

model = nvdm.NVDM(len(vocab), hidden_size, num_topics, 1, device)
model = model.to(device)
model.train()

# Separate the encoder from decoder parameters when training in an alternating manner.
# Also including linear layers than output mu and log(sigma)
# (not the most elegant method but works)
optim_encoder = torch.optim.Adam(
    list(model.encoder.parameters()) +
    list(model.mu.parameters()) +
    list(model.log_sigma.parameters()),
    lr=0.0001)
optim_decoder = torch.optim.Adam(model.decoder.parameters(), lr=0.001)

# Trains both the encoder and decoder at the same time.
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

start_time = timer()

for epoch in range(outer_epochs):

    # 1. Train the encoder and decoder in turns,
    #    fixing the parameters of the other every time.
#     for switch in range(0, 2):
#         # Author's code trains the decoder first,
#         # not sure if the order matters.
#         if switch == 0:
#             optimizer = optim_encoder
#         else:
#             optimizer = optim_decoder
    
    # 2. Train everything together
    if True:
        # Do training
        for alt_epoch in range(inner_epochs):
            
            loss_sum = 0.0
            rec_sum = 0.0
            kl_sum = 0.0
            n = len(data_loader)
            
            for idx, (text, offsets) in enumerate(data_loader):
                text = text.to(device)
                offsets = offsets.to(device)
                
                optimizer.zero_grad()
                logits, loss_dict = model(text, offsets, kl_weight=1.0)
                loss = loss_dict["total"]
                loss.backward()
                
                optimizer.step()
                
                # For printing
                loss_sum += loss.item()
                rec_sum += loss_dict["rec"].item()
                kl_sum += loss_dict["kl"].item()

            model_str = "All" # "Enc" if switch == 0 else "Dec"
            print(f"[Time: {timedelta(seconds=timer() - start_time)}, Epoch {epoch + 1}, {model_str} {alt_epoch + 1}] Loss {loss_sum/n}, Rec {rec_sum/n}, KL {kl_sum/n}")

Some observations:

- The original paper alternates between the encoder and decoder when training. i.e. It trains the decoder first for some (e.g. 10) iterations, fixing the encoders parameters. Then it trains the encoder, fixing the decoder's parameters. This is one epoch, which is repeated some number of times until convergence. However, this results in poor training performance: the KL is observed to fluctuate. The encoder and decoder are unable to jointly converge. By training them all together both the reconstruction loss and KL appear to go down.
- Right now we weight the reconstruction and KL losses equally: $L_{total} = L_{rec} + L_{KL}$. We could define a hyperparameter $\beta$ so that $L_{total} = L_{rec} + \beta L_{KL}$, which might help balance the two.

In [35]:
MODELSAVE_PATH = "./modelsaves/nvdm_200epochs.pt"
# torch.save(model.state_dict(), MODELSAVE_PATH)

model = nvdm.NVDM(len(vocab), hidden_size, num_topics, 1, device)
model.load_state_dict(torch.load(MODELSAVE_PATH))
model.eval()

NVDM(
  (embed_bow): EmbeddingBag(10001, 10001, mode=sum)
  (encoder): Sequential(
    (0): Linear(in_features=10001, out_features=500, bias=True)
    (1): Tanh()
    (2): Linear(in_features=500, out_features=500, bias=True)
    (3): Tanh()
  )
  (mu): Linear(in_features=500, out_features=10, bias=True)
  (log_sigma): Linear(in_features=500, out_features=10, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=10, out_features=10001, bias=True)
  )
)

In [36]:
# Extract the vocab-topic matrix (known as R in the paper).
# It has dimensions |V| x K: vocab size x number of topics
decoder = model.decoder[0]
weights = decoder.weight.data.detach().clone()
weights.size()

torch.Size([10001, 10])

In [37]:
# Look at some words
# manual KNN
from nltk.stem import PorterStemmer
PORTER_STEMMER = PorterStemmer()

# Set of words used in the original paper
candidates = ["weapons", "medical", "companies", "define", "israel", "book"]

for candidate in candidates:
    test_word = PORTER_STEMMER.stem(candidate)
    idx = vocab.stoi[test_word]
    print(test_word, idx)

    # Show top 10 most similar (based on cosine distance)
    sims = F.cosine_similarity(weights[idx].unsqueeze(0), weights)
    sim_vals, sim_idxs = torch.topk(sims, 15)

    # Show ith nearest word and its score.
    for i, v in zip(sim_idxs, sim_vals):
        print(f"{vocab.itos[i]}\t{v.item()}")
    
    print("-----------")

weapon 6165
weapon	1.0
command	0.9859694838523865
alarm	0.9724387526512146
catv	0.9687132239341736
nondisclosur	0.9681808948516846
harsh	0.9667993783950806
microwav	0.9648943543434143
dana	0.961925208568573
surg	0.959961473941803
newest	0.9583815932273865
warfar	0.9568072557449341
lightweight	0.9552127718925476
humid	0.9520108699798584
lockhe	0.9518946409225464
electromagnet	0.9515492916107178
-----------
medic 235
medic	1.0
health	0.9257354736328125
healthcar	0.9163874983787537
care	0.8989317417144775
afford	0.8853657841682434
bodi	0.870684027671814
hospit	0.851254403591156
patient	0.8428614735603333
often	0.8355170488357544
both	0.8289807438850403
treatment	0.8267985582351685
physician	0.8263426423072815
supplement	0.8239941596984863
fine	0.8207406997680664
age	0.8182672262191772
-----------
compani 3
compani	1.0
corpor	0.9685456156730652
As	0.9668803811073303
offic	0.9664745330810547
maintain	0.9538002610206604
sourc	0.9514853358268738
factor	0.9485275745391846
meet	0.94704395532608

In [41]:
# Look at most similar words per topic vector.

V, K = weights.size()
for i in range(K):
    print(f"Topic {i+1}")
    vals, idxs = torch.topk(torch.abs(weights[:, i]), 30)
    for i, v in zip(idxs, vals):
        print(f"{vocab.itos[i]}\t{v.item()}")
    print("------------")

Topic 1
iiot	1.5719436407089233
mifid	1.5622806549072266
sef	1.552869439125061
sdk	1.5519018173217773
ria	1.4965932369232178
modem	1.4920049905776978
guidewir	1.4848800897598267
interbodi	1.4842369556427002
cgm	1.4779086112976074
ghz	1.474395990371704
dexcom	1.469714879989624
plane	1.4693533182144165
clia	1.4622814655303955
interoper	1.4570252895355225
mbp	1.453446626663208
iot	1.4483671188354492
peek	1.4462167024612427
voip	1.4388659000396729
exoskeleton	1.4301104545593262
m2m	1.423481822013855
cpt	1.4216123819351196
cellfx	1.420039415359497
mridium	1.4180893898010254
bluetooth	1.4171274900436401
excim	1.412894368171692
saa	1.408990502357483
ott	1.4067158699035645
router	1.4054046869277954
oracl	1.400571346282959
labview	1.3922574520111084
------------
Topic 2
biotherapeut	2.1294736862182617
cd20	2.0815372467041016
dendrit	2.0779407024383545
lymphoblast	2.07721209526062
linker	2.0732641220092773
antigen	2.061290979385376
cellecti	2.0539138317108154
cd19	2.0537850856781006
microenviron

# Part 2 Full Evaluation

In [21]:
import os
import re

from nltk.stem import PorterStemmer


MODELSAVE_PATH = "./modelsaves"

models_k = dict()

for filename in os.listdir(MODELSAVE_PATH):
    
    num_topics = filename.split("_")[1][1:]
    num_topics = int(num_topics)
    
    model = nvdm.NVDM(len(vocab), hidden_size, num_topics, 1, device)
    model.load_state_dict(torch.load(os.path.join(MODELSAVE_PATH, filename)))
    model.eval()
    models_k[num_topics] = model


In [28]:
def analysis(model):
    PORTER_STEMMER = PorterStemmer()
    # Set of words used in the original paper
    candidates = ["weapons", "medical", "companies", "define", "israel", "book"]
    
    # Extract the vocab-topic matrix (known as R in the paper).
    # It has dimensions |V| x K: vocab size x number of topics
    decoder = model.decoder[0]
    weights = decoder.weight.data.detach().clone()

    for candidate in candidates:
        test_word = PORTER_STEMMER.stem(candidate)
        idx = vocab.stoi[test_word]
        print(test_word, idx)

        # Show top 10 most similar (based on cosine distance)
        sims = F.cosine_similarity(weights[idx].unsqueeze(0), weights)
        sim_vals, sim_idxs = torch.topk(sims, 15)

        # Show ith nearest word and its score.
        for i, v in zip(sim_idxs, sim_vals):
            print(f"{vocab.itos[i]}\t{v.item()}")

        print("-----------")
    
    V, K = weights.size()
    for i in range(K):
        print(f"Topic {i+1}")
        vals, idxs = torch.topk(torch.abs(weights[:, i]), 30)
        for i, v in zip(idxs, vals):
            print(f"{vocab.itos[i]}\t{v.item()}")
        print("------------")

In [30]:
analysis(models_k[5])

weapon 6165
weapon	1.0
headset	0.9978045225143433
diod	0.9977712631225586
fpga	0.9975942969322205
upcom	0.9969573020935059
microwav	0.9967873692512512
ergonom	0.9967591762542725
microdisplay	0.9961642026901245
ubiquitor	0.995185375213623
microprocessor	0.9949306845664978
avx	0.9949303269386292
waveguid	0.9942797422409058
ip	0.9927443861961365
sonar	0.99266117811203
labview	0.9923931956291199
-----------
medic 235
medic	1.0
univers	0.9583813548088074
devic	0.9522219300270081
rapid	0.9455239176750183
inc	0.9404319524765015
care	0.9309444427490234
hospit	0.9130054712295532
promot	0.9118832349777222
behavior	0.9088248014450073
version	0.9068402647972107
healthcar	0.8899935483932495
rapidli	0.8878748416900635
foundat	0.8871327042579651
still	0.8849836587905884
patient	0.8839438557624817
-----------
compani 3
compani	1.0
entir	0.9895139932632446
corpor	0.9875907897949219
maintain	0.9875349998474121
collect	0.9873965382575989
adopt	0.9861133694648743
As	0.9843692183494568
influenc	0.983251214