## Introduction

The goal of this tutorial is to implement a Variational Autoencoder (VAE) for Topic Models. The aim is to give you sense of: 


*   How topic models can be implemented under Variational Autoencoder (VAE)
*   How the "*reparametrization trick*" enables backpropogation through latent variables


Frist, we need to import neccesary packages:

In [122]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
import torch.nn.functional as F
import math
import os
import string
import numpy as np
import random
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
from collections import OrderedDict
from tqdm.notebook import tqdm

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [123]:
###############
# Torch setup #
###############
print('Torch version: {}, CUDA: {}'.format(torch.__version__, torch.version.cuda))
cuda_available = torch.cuda.is_available()

if not torch.cuda.is_available():
  print('WARNING: You may want to change the runtime to GPU for faster training!')
  DEVICE = 'cpu'
else:
  DEVICE = 'cuda:0'

#########################
# Some helper functions #
#########################
def fix_seed(seed=None):
  """Sets the seeds of random number generators."""
  if seed is None:
    # Take a random seed
    seed = time.time()
  seed = int(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  return seed

fix_seed(1234)

Torch version: 1.7.0+cu101, CUDA: 10.1


1234

## Data Preprocessing
### Download dataset

We experiment on a standard news corpora: the ***20NewsGroups*** and download it using scikit-learn. This dataset consists of 20k news articles classified into 20 topics.

In [124]:
from sklearn.datasets import fetch_20newsgroups

train_news_group = fetch_20newsgroups(subset='train')
test_news_group = fetch_20newsgroups(subset='test')

train_data = train_news_group['data']
test_data = test_news_group['data']

print("Size of training data:", len(train_data))
print("Size of test data:", len(test_data))
print("All topics:", train_news_group.target_names)

Size of training data: 11314
Size of test data: 7532
All topics: ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']


### Preprocess Dataset

In this section, we define the functions to do conventional preprocessing and build the vocabulary.

In [125]:
def preprocess(samples):
    output = []
    for item in samples:
        words = item.replace('\n', '').strip().lower().split(' ')
        punctuations = (string.punctuation).replace("'", "")
        trans_table = str.maketrans('', '', punctuations)
        stripped_words = [word.translate(trans_table) for word in words]
        words = [str for str in stripped_words if str]
        words = [word for word in words if not word.isdigit()]
        words = [str for str in words if str]
        output.append(words)
    return output

train_prep = preprocess(train_data)
test_prep = preprocess(test_data)

In [126]:
def get_vocab(data):
    vocab = {}

    stops = set(stopwords.words('english'))
    ### -------------- TODO --------------- ###
    # remove stop words and count frequency of words
    for sent in data:
        for word in sent:
            if word in stops:
                continue
            if word in vocab:
                vocab[word] += 1
            else:
                vocab[word] = 1
    return vocab


vocab_total = get_vocab(train_prep + test_prep)
print("Total number of words in vocabulary:", len(vocab_total))
sorted(vocab_total.items(),key = lambda x:x[1],reverse = True)

vocab = vocab_total

'''
Here we filter vocabulary to save some training time, otherwise our model input dimension would be huge (V=350k+).
You can uncomment the line to include more words (around 52k words in vocabulary), which would help classifing the topics (Q4).
'''
vocab = {k:v for k,v in list(vocab_total.items())[:5000]}
# vocab = {k:v for k,v in vocab_total.items() if v > 3}
vocab_size = len(vocab)
print("Vocabulary size after filtering:", vocab_size)

word2idx = {k:n for n,(k,v) in enumerate(vocab.items())}

Total number of words in vocabulary: 356832
Vocabulary size after filtering: 5000


In [127]:
train_doc = [[word for word in doc if word in vocab] for doc in train_prep]
train_doc = [doc for doc in train_doc if len(doc) > 5]

test_doc = [[word for word in doc if word in vocab] for doc in test_prep]
test_doc = [doc for doc in test_doc if len(doc) > 5]

### Process Bag-of-words Inputs

Next we define multiple helper functions to create input batches. Our inputs are represented in bag-of-word (bow) where each article/document is represented with a vector of **V** elements. We will also do the batching in this section, so the inputs to the models would be in the dimension of *(batch_size, vocab_size)*.

In [128]:
from collections import Counter
def data_set(data_url):
    """process data input."""
    data = []
    word_count = []
    for words in data_url:
        word2freq = dict(Counter(words))
        doc = {}
        count = 0

        for word,freq in word2freq.items():
            doc[int(word2idx[word])] = freq
            count += freq

        if count > 0:
            data.append(doc)
            word_count.append(count)

    return data, word_count

In [129]:
def create_batches(data_size, batch_size, shuffle=True):
    """create a batch of indices."""
    batches = []
    ids = list(range(data_size))
    if shuffle:
        random.shuffle(ids)
    for i in range(data_size // batch_size):
        start = i * batch_size
        end = (i + 1) * batch_size
        batches.append(ids[start:end])
    # the batch of which the length is less than batch_size
    rest = data_size % batch_size
    if rest > 0:
        batches.append(ids[-rest:] + [-1] * (batch_size - rest))  # -1 as padding
    return batches

In [130]:
def fetch_data(data, count, idx_batch, vocab_size):
    """fetch input data by batch."""
    batch_size = len(idx_batch)
    data_batch = np.zeros((batch_size, vocab_size))
    count_batch = []
    mask = np.zeros(batch_size)
    indices = []
    values = []
    for i, doc_id in enumerate(idx_batch):
        if doc_id != -1:
            for word_id, freq in data[doc_id].items():
                data_batch[i, word_id] = freq
            count_batch.append(count[doc_id])
            mask[i]=1.0
        else:
            count_batch.append(0)
    return data_batch, count_batch, mask

### Question 1: Finish the neural structures of the VAE encoder and decoder, and the reparamerisation trick.

In [131]:
class TopicModel(nn.Module):
    def __init__(self, 
                 vocab_size,
                 input_size,
                 n_hidden,
                 n_topic, 
                 batch_size):
        super(TopicModel, self).__init__()

        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        self.n_topic = n_topic
        self.batch_size = batch_size
 
        ### -------------- TODO --------------- ###
        self.mu_layer = nn.Linear(n_hidden, n_topic)
        self.logsigm_layer = nn.Linear(n_hidden, n_topic)

        ### -------------- TODO --------------- ###
        self.encoder = nn.Sequential(nn.Linear(input_size, n_hidden),
                                     nn.ReLU(),
                                     nn.Linear(n_hidden, n_hidden),
                                     nn.ReLU())
        
        ### -------------- TODO --------------- ###
        self.decoder =  nn.Sequential(nn.Linear(n_topic, vocab_size),
                                      nn.LogSoftmax(dim=-1))


    def zero_bias(self,):
        self.mu_layer.bias.data.fill_(0.0)
        self.logsigm_layer.bias.data.fill_(0.0)
        
    def forward(self, input):

        # encoder forward
        doc_vec = self.encoder(input)
        mu = self.mu_layer(doc_vec)
        logsigm = self.logsigm_layer(doc_vec)
        
        # reparameterisation
        ### -------------- TODO --------------- ###
        eps = torch.normal(0,1, size=(self.batch_size, self.n_topic), device=DEVICE)
        z = mu + logsigm.exp()*eps

        # decoder forward
        logits = self.decoder(z)
        
        # recons
        ### -------------- TODO --------------- ###
        recons = -torch.sum(logits*input, dim=1)
        
        # kld
        ### -------------- TODO --------------- ###
        kld = -0.5 * torch.sum(1 - torch.square(mu) + 2*logsigm - torch.exp(2*logsigm), dim=1)
        
        loss = torch.mean(recons + kld)
        recons = torch.mean(recons)
        kld = torch.mean(kld)

        # print(loss, recons, kld)
        

        return loss, recons, kld

### Question 2: Finish the Training File

In [132]:
def main_train():
    num_epoch = 20
    batch_size = 64
    vocab_size = len(vocab)
    n_hidden = 256
    n_topic = 50
    learning_rate = 0.0001
    alternate_epochs = 5
    
    train_set, train_count = data_set(train_doc)
    
    ### -------------- TODO --------------- ###
    model = TopicModel(vocab_size=vocab_size, 
                       input_size=vocab_size, 
                       n_hidden=n_hidden, 
                       n_topic=n_topic, 
                       batch_size=batch_size)
    
    model.zero_bias()
    model.to(DEVICE)

    ### -------------- TODO --------------- ###
    optimizer_enc = torch.optim.Adam([{'params':model.encoder.parameters()}, 
                                      {'params':model.mu_layer.parameters()},
                                      {'params':model.logsigm_layer.parameters()}],
                                     lr = learning_rate,
                                     eps= 1e-8)
    optimizer_dec = torch.optim.Adam([{'params':model.decoder.parameters()}], 
                                     lr = learning_rate,
                                     eps= 1e-8)
    
    for epoch in range(num_epoch):
        train_batches = create_batches(len(train_set), batch_size, shuffle=True)
        model.train() 
        
        ### -------------- TODO --------------- ###
        # Question: why do we need two optimizers #
        # Answer: To help escape local minimum, but it might not be necessary.
        for switch in range(0, 2): 
            if switch == 0:
                optimizer = optimizer_dec
                print_mode = 'updating decoder'
            else:
                optimizer = optimizer_enc
                print_mode = 'updating encoder'
                
            loss_epoch = 0.0
            recons_epoch = 0.0
            kld_epoch = 0.0
            count = 0
    
            for i in range(alternate_epochs):
                                 
                for idx_batch in train_batches:
                    data_batch, count_batch, mask = fetch_data(train_set, train_count, idx_batch, vocab_size)
                    input = torch.from_numpy(data_batch).float().to(DEVICE)
                    loss, recons, kld = model(input)
                    
                    # optimize
                    optimizer.zero_grad()      
                    loss.backward()        
                    optimizer.step()        
                    loss_epoch += loss
                    recons_epoch += recons
                    kld_epoch += kld
                    count += 1

            print(f'Epoch {epoch}, loss={loss_epoch/count}, recons={recons_epoch/count}, kld={kld_epoch/count}')

    return model
    

In [133]:
model = main_train()

Epoch 0, loss=596.6555786132812, recons=596.5907592773438, kld=0.06485079973936081
Epoch 0, loss=582.9100952148438, recons=579.7217407226562, kld=3.187455892562866
Epoch 1, loss=550.35205078125, recons=546.1802978515625, kld=4.171688556671143
Epoch 1, loss=523.10986328125, recons=511.15966796875, kld=11.950454711914062
Epoch 2, loss=515.8670043945312, recons=501.99493408203125, kld=13.872203826904297
Epoch 2, loss=510.50341796875, recons=497.9371032714844, kld=12.566522598266602
Epoch 3, loss=506.73638916015625, recons=492.8454284667969, kld=13.891047477722168
Epoch 3, loss=503.03216552734375, recons=490.366455078125, kld=12.665414810180664
Epoch 4, loss=499.25836181640625, recons=486.4047546386719, kld=12.853713035583496
Epoch 4, loss=496.37567138671875, recons=483.6053466796875, kld=12.770222663879395
Epoch 5, loss=493.532470703125, recons=480.9125671386719, kld=12.620038032531738
Epoch 5, loss=491.2666320800781, recons=478.1276550292969, kld=13.13899040222168
Epoch 6, loss=488.98471

In [134]:
# save model to use in Q4
torch.save(model.state_dict(), "vae.pt")

### Question 3: Code qualitative analysis for topics (p(x|z))
Now that we have the VAE trained with 50 candidate topics, we can explore how the VAE model cluster words with similar topics together.

In the following section, you will also need to evaluate the perplexity of the VAE model.

In [142]:
#Add meta information (authors, time, geolocation etc.) to improve quality of the topics

associations = {
    'jesus': ['prophet', 'jesus', 'matthew', 'christ', 'worship', 'church'],
    'comp ': ['floppy', 'windows', 'microsoft', 'monitor', 'workstation', 'macintosh', 
              'printer', 'programmer', 'colormap', 'scsi', 'jpeg', 'compression'],
    'car  ': ['wheel', 'tire'],
    'polit': ['amendment', 'libert', 'regulation', 'president'],
    'crime': ['violent', 'homicide', 'rape'],
    'midea': ['lebanese', 'israel', 'lebanon', 'palest'],
    'sport': ['coach', 'hitter', 'pitch'],
    'gears': ['helmet', 'bike'],
    'nasa ': ['orbit', 'spacecraft'],
}
def identify_topic_in_line(line):
    topics = []
    for topic, keywords in associations.items():
        for word in keywords:
            if word in line:
                topics.append(topic)
                break
    return topics

def print_top_words(beta, feature_names, n_top_words=10):
    print('---------------Printing the Topics------------------')
    for i in range(len(beta)):
        line = " ".join([feature_names[j][0] for j in beta[i].argsort()[:-n_top_words - 1:-1]])
        topics = identify_topic_in_line(line)
        print('|'.join(topics))
        print('     {}'.format(line))
    print('---------------End of Topics------------------')


def print_perp(model):
    cost=[]
    model.eval()
    test_set, test_count = data_set(test_doc)
    test_batches = create_batches(len(test_set), 64)
    
    ### -------------- TODO --------------- ###
    loss_sum = 0
    word_count = 0
    for idx_batch in test_batches:
        data_batch, count_batch, mask = fetch_data(test_set, test_count, idx_batch, vocab_size)
        test_input = torch.from_numpy(data_batch).float().to(DEVICE)
        loss, recons, kld = model(test_input)

        loss_sum += (loss.item() * 64)
        word_count += np.sum(count_batch)

    ppl = np.exp(loss_sum / word_count)
    print('The approximated perplexity is: ', ppl)

In [None]:
# perplexity on test data
print_perp(model)

In [143]:
# model latent topics
emb = model.decoder[0].weight.data.cpu().numpy().T
print_top_words(emb, sorted(vocab.items(), key=lambda x:x[1]))

---------------Printing the Topics------------------

     grubb loney gripe jupiter's celica grubbsubject smog hill ulysses marlborough

     those'new adeos effi highprimers dangerous runsgameand pharisees used								 roundassuming autotheft

     war ago regularly center box goes truth office ask buy

     essence ab condition putting anywhere approximately ialines intend hill gravity

     dlecointgarnetacnsfsuedu rosenblattsubject sys ludicrous suspended noobvious lloyd piaget requirements galaxies	o

     lloyd csyphersuafhpuarkedu jupiter rodcfchpcom flynnuniversity piaget systems acceptance september celicagt

     classic intend launched 32in used								 gravity approximately rumors travel griffin

     rosenblattsubject brain andwingless it'shard titan's updating wrongness posting scaliness aroundan

     narrowed claim11 topicsthe messagex jap tocomprehend 35andy dansjdcgssmotcom theseason isclean

     theodore smith year's richardson escape convenient purchasing rod for	

### Question 4: Use Topics to do Classification
In this section, you will use both the article and the labels to train a topic classifier. Firstly, you may train a vanilla classifier, and you are likely to get around 83% validation accuracy with a vocabulary of 50k (the accuracy might be lower with a small vocabulary). Then you can use the pre-trained VAE encoder as the classifier encoder and fine-tune it to see what happens.

In [115]:
'''
Uncomment these lines to train the classifier on larger vocabulary
NOTE: if you want to use the pre-trained VAE encoder, the vocab size for the classifier should be the same as the VAE model
'''

# vocab = {k:v for k,v in vocab_total.items() if v > 3}
# vocab_size = len(vocab)
# print("Vocabulary size after filtering:", vocab_size)
# word2idx = {k:n for n,(k,v) in enumerate(vocab.items())}

Vocabulary size after filtering: 52199


In [103]:
def fetch_labelled_data(data, labels, idx_batch, vocab_size):
    """fetch input data and labels by batch."""
    batch_size = len(idx_batch)
    data_batch = np.zeros((batch_size, vocab_size))
    label_batch = []

    texts_batch = [data[i] for i in idx_batch]
    label_batch = [labels[i] for i in idx_batch]

    for i, text in enumerate(texts_batch):
        for word in text:
            if word in vocab:
                data_batch[i, word2idx[word]] += 1

    return data_batch, np.array(label_batch)

In [104]:
class VanillaClassifier(nn.Module):
    def __init__(self, input_size, n_hidden, n_class, dp):
        super().__init__()

        ### --------------------- TODO ----------------------- ###
        # construct a same encoder architecture as the VAE encoder 
        self.encoder = nn.Sequential(nn.Linear(input_size, n_hidden),
                                     nn.ReLU(),
                                     nn.Linear(n_hidden, n_hidden),
                                     nn.ReLU(),
                                    )
        self.dropout = nn.Dropout(dp)
        self.output = nn.Linear(n_hidden, n_class, bias=True)

    def forward(self, input):
        doc_vec = self.dropout(self.encoder(input))
        logits = self.output(doc_vec)
        return logits

class VAEClassifier(nn.Module):

    def __init__(self, vae, n_class, dp):
        super().__init__()

        self.encoder = vae.encoder
        self.vae_output = vae.n_hidden
        
        self.dropout = nn.Dropout(dp)
        self.output = nn.Linear(self.vae_output, n_class, bias=True)

    def forward(self, input):
        doc_vec = self.dropout(self.encoder(input))

        logits = self.output(doc_vec)

        return logits

In [105]:
def evaluate(classifier, idx_batches, data, labels, vocab_size, criterion):
    with torch.no_grad():
        total_loss = 0
        total_acc = 0
        val_count = 0

        for idx_batch in idx_batches:
            ### --------------------- TODO ----------------------- ###
            # compute validation loss and accuracy
            data_batch, label_batch = fetch_labelled_data(data, labels, idx_batch, vocab_size)

            input = torch.from_numpy(data_batch).float().to(DEVICE)
            target = torch.from_numpy(label_batch).to(DEVICE)

            pred = classifier(input)

            _, predictions = torch.max(pred, dim=1)
            loss = criterion(pred, target)

            total_acc += torch.mean((predictions == target).float())
            total_loss += loss.item()
            val_count += 1

    return total_loss/val_count, total_acc/val_count

In [108]:
def cls_train(train_data, train_labels, valid_data, valid_labels, vocab_size):
    num_epoch = 10
    batch_size = 64
    dropout = 0.1
    n_hidden = 64
    learning_rate = 0.0001
    
    # model.load_state_dict(torch.load("vae.pt"))
    # classifier = VAEClassifier(model, 20, dropout)
    classifier = VanillaClassifier(vocab_size, n_hidden, 20, dropout)

    classifier.to(DEVICE)

    optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr = learning_rate)
    
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(num_epoch):
        train_batches = create_batches(len(train_prep), batch_size, shuffle=True)
        valid_batches = create_batches(len(test_prep), batch_size, shuffle=False)
        
                
        train_loss = 0.0
        count = 0
        acc = 0
        classifier.train()              
        for idx_batch in train_batches:
            optimizer.zero_grad()
            ### --------------------- TODO ----------------------- ###
            # finish training loop
            data_batch, label_batch = fetch_labelled_data(train_data, train_labels, idx_batch, vocab_size)
            input = torch.from_numpy(data_batch).float().to(DEVICE)
            target = torch.from_numpy(label_batch).to(DEVICE)

            pred = classifier(input)

            _, predictions = torch.max(pred, dim=1)
            acc += torch.mean((predictions == target).float())

            loss = criterion(pred, target)
            
            # optimize 
            loss.backward()        
            optimizer.step()        
            train_loss += loss
            count += 1
        
        # validation
        classifier.eval()
        valid_loss, valid_acc = evaluate(classifier, valid_batches, valid_data, test_labels, vocab_size, criterion)
        print(f'Epoch {epoch},\ttrain_loss={train_loss/count:.3f},\ttrain_acc={acc/count:.3f},\tvalid_loss={valid_loss:.3f},\tvalid_acc={valid_acc:.3f}')

In [109]:
cls_train(train_prep, train_labels, test_prep, test_labels, vocab_size)

Epoch 0,	train_loss=2.964,	train_acc=0.139,	valid_loss=2.917,	valid_acc=0.258
Epoch 1,	train_loss=2.785,	train_acc=0.393,	valid_loss=2.699,	valid_acc=0.425
Epoch 2,	train_loss=2.484,	train_acc=0.513,	valid_loss=2.412,	valid_acc=0.514
Epoch 3,	train_loss=2.152,	train_acc=0.594,	valid_loss=2.135,	valid_acc=0.579
Epoch 4,	train_loss=1.856,	train_acc=0.653,	valid_loss=1.907,	valid_acc=0.613
Epoch 5,	train_loss=1.616,	train_acc=0.688,	valid_loss=1.741,	valid_acc=0.628
Epoch 6,	train_loss=1.438,	train_acc=0.714,	valid_loss=1.618,	valid_acc=0.639
Epoch 7,	train_loss=1.296,	train_acc=0.733,	valid_loss=1.526,	valid_acc=0.644
Epoch 8,	train_loss=1.173,	train_acc=0.755,	valid_loss=1.453,	valid_acc=0.645
Epoch 9,	train_loss=1.077,	train_acc=0.771,	valid_loss=1.399,	valid_acc=0.652


Since the training of VAE takes a long time, the vocabulary was truncated into 5000 words. Therefore, the benefit from VAE pre-training might not seem very obvious. You may try and train a new VAE with a larger vocabulary (e.g. 50k and train for 20 epochs), it would help in classifing the topics.

In addition, the validation accuracy might be low if your vocabulary size is only 5,000 (valid acc around 65%). The below code is just a baseline utilizing all the words, and it would achieve around 83% accuracy on the validation set.

In [110]:
from sklearn.feature_extraction.text import CountVectorizer

vectorizer = CountVectorizer()
train_features = vectorizer.fit_transform(train_data)
test_features = vectorizer.transform(test_data)

train_labels = np.array(train_news_group['target'])
test_labels = np.array(test_news_group['target'])

vocab_size = len(vectorizer.vocabulary_)
print(vocab_size)

130107


In [111]:
def fetch_labelled_data(features, labels, idx_batch, vocab_size=None):
    idxs = np.array(idx_batch)

    feature_batch = features[idxs, :].toarray()
    label_batch = labels[idxs]
    return feature_batch, label_batch

In [112]:
cls_train(train_features, train_labels, test_features, test_labels, vocab_size)

Epoch 0,	train_loss=2.844,	train_acc=0.261,	valid_loss=2.673,	valid_acc=0.510
Epoch 1,	train_loss=2.259,	train_acc=0.746,	valid_loss=2.130,	valid_acc=0.734
Epoch 2,	train_loss=1.587,	train_acc=0.880,	valid_loss=1.642,	valid_acc=0.793
Epoch 3,	train_loss=1.084,	train_acc=0.916,	valid_loss=1.334,	valid_acc=0.807
Epoch 4,	train_loss=0.773,	train_acc=0.938,	valid_loss=1.137,	valid_acc=0.821
Epoch 5,	train_loss=0.563,	train_acc=0.961,	valid_loss=1.012,	valid_acc=0.822
Epoch 6,	train_loss=0.426,	train_acc=0.970,	valid_loss=0.921,	valid_acc=0.827
Epoch 7,	train_loss=0.327,	train_acc=0.976,	valid_loss=0.862,	valid_acc=0.828
Epoch 8,	train_loss=0.258,	train_acc=0.983,	valid_loss=0.835,	valid_acc=0.828
Epoch 9,	train_loss=0.207,	train_acc=0.986,	valid_loss=0.797,	valid_acc=0.828
