# ProdLDA implementation on newsgroup dataset

## Loading libraries

In [None]:
# Load libraries
import logging
import pyro
import torch
import pandas as pd
import numpy as np

import torch.nn.functional as F

import matplotlib.pyplot as plt

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer

from zzz_utils import *
from prod_lda import *

logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)

pyro.clear_param_store()
pyro.set_rng_seed(1)

## Loading and process the data

In [None]:
news = fetch_20newsgroups(subset='all')
vectorizer = CountVectorizer(max_df=0.5, min_df=30, stop_words='english')
docs = torch.from_numpy(vectorizer.fit_transform(news['data']).toarray())

vocab = pd.DataFrame(columns=['word', 'index'])
vocab['word'] = vectorizer.get_feature_names()
vocab['index'] = vocab.index

In [None]:
print('Dictionary size: %d' % len(vocab))
print('Corpus size: {}'.format(docs.shape))

## Testing variational inference

In [None]:
pyro.clear_param_store()
nTopics = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
docs = docs.float().to(device)
obj = fit_prod_lda(D = docs, nTopics = nTopics, nEpochs = 20, batch_size=32, lr = 0.01, seed = 123)

In [None]:
# plot ELBO losses
losses = obj['losses']

plt.figure(figsize=(5, 2))
plt.plot(losses)
plt.xlabel("SVI step")
plt.ylabel("ELBO loss");

# WordClouds

In [None]:
prodLDA = obj['prodLDA']
phi = prodLDA.get_phi()
fig, axs = plt.subplots(4, 3, figsize=(12, 12))
for n in range(phi.shape[0]):
    i, j = divmod(n, 3)
    plot_word_cloud(scale_zero_one(phi[n]), vocab, axs[i, j], 'Topic %d' % (n + 1))
axs[-1, -1].axis('off');

plt.show()