<div align='left' style="width:29%;overflow:hidden;">
<a href='http://inria.fr'>
<img src='https://github.com/lmarti/jupyter_custom/raw/master/imgs/inr_logo_rouge.png' alt='Inria logo' title='Inria logo'/>
</a>
</div>

# Hierarchical Topic Modeling

> In this notebook we're going to expand our previous topic modeling approaches in order to model hierarchic topics.
Even though a single level topic modeling is helpful to go over the vast amount of papers in the CORD-19 dataset, our hypothesis is that a hierarchical topic modeling will provide a much more easier way to sway through the papers.

The following references have caught our attention:

- https://radimrehurek.com/gensim/models/hdpmodel.html: an implementation of Hierarchical Dirichlet Processes (HDP) using the topic modelling library `gensim`.
- https://github.com/joewandy/hlda: an implementation of Hierarchical Latent Dirichlet Allocation (hLDA) in Python.
- https://datascience.stackexchange.com/questions/128/latent-dirichlet-allocation-vs-hierarchical-dirichlet-process: a comparison between LDA and HDP.
- https://developer.squareup.com/blog/inferring-label-hierarchies-with-hlda/: a write up about Square's experience using hLDA to hierarchically classify customer support articles.
- https://www.hindawi.com/journals/sp/2017/4382348/: a journal article.

Preliminary notes:

- LDA models documents as Dirichlet mixtures of a fixed number of topics, which are modeled as Dirichlet mixtures of words.
- hLDA is an adaptation of LDA that models topics as a mixture of a new, distinct level of topics.
- HDP main difference with respect to LDA is that the number of topics isn't an hyperparamenter, but is discarded because it doesn't build a hierarchical topic structure.

First, we'll try a hierarchical topic modeling using hLDA and then we'll compare it to a manual LDA hierarchical modeling.

## Hierarchical Latent Dirichlet Allocation (hLDA)

This technique was presented in the 2004 NeurIPS paper "Hierarchical Topic Models and the Nested Chinese Restaurant Process" by David Blei et al. available at: https://papers.nips.cc/paper/2466-hierarchical-topic-models-and-the-nested-chinese-restaurant-process.pdf.

A quick Google search yields at least two implementations:

- https://github.com/blei-lab/hlda: implemented in C by the original authors. Last commit was in 2014.
- https://github.com/joewandy/hlda: implemented in Python. Last commit was in 2017.

We'll use the second one since it publishes a Jupyter Notebook with an example using the library.
First, we'll install it.

In [None]:
%load_ext autoreload
%autoreload 2

[33mYou are using pip version 19.0.3, however version 20.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [None]:
!pip install -q -r requirements.txt

Now we'll proceed to load the CORD-19 dataset papers.

In [None]:
from pathlib import Path

from risotto.references import load_papers_from_metadata_file, paper_as_markdown


CORD19_DATASET_FOLDER = Path("./datasets/CORD-19-research-challenge")

papers, _ = load_papers_from_metadata_file(CORD19_DATASET_FOLDER)

We've loaded the papers on memory.
Now, we'll process them in order to produce text strings with their contents.

In [None]:
from risotto.lda import process_papers_file_contents


docs = process_papers_file_contents(papers)

The last preprocessing step is to vectorize each paper.
We'll represent them using the CountVectorizer `scikit-learn` implementation.
We purposefully don't use representations such as `tf-idf` because the LDA algorithm takes care of the document frequency normalization of tokens.

In [None]:
from sklearn.feature_extraction.text import CountVectorizer
from fastprogress.fastprogress import progress_bar

from risotto.lda import tokenizer


def get_hlda_corpus(docs, max_vocab_size=2**13):
    count_vectorizer = CountVectorizer(
        tokenizer=tokenizer,
        lowercase=True,
        max_features=max_vocab_size,
    )
    count_vectorizer.fit(docs)
    vocab = count_vectorizer.vocabulary_
    docs_tokenized = []
    docs_tokens_idxs = []
    for doc in progress_bar(docs):
        tokens = [token.lower() for token in tokenizer(doc)]
        idxs = []
        for token in tokens:
            if token in vocab:
                idxs.append(vocab[token])
        docs_tokenized.append(tokens)
        docs_tokens_idxs.append(idxs)
    vocab_list = count_vectorizer.get_feature_names()
    return docs_tokenized, docs_tokens_idxs, vocab_list

docs_tokenized, docs_tokens_idxs, vocab = get_hlda_corpus(docs)



`docs_tokenized` is a list of lists with the tokenization of each paper.
`docs_tokens_idxs` is a list of lists with the vocabulary indeces of each paper token.
Finally, `vocab` is the list with the vocabulary used.

In [None]:
import random

from risotto.hlda_sampler import HierarchicalLDA


hlda = HierarchicalLDA(
    corpus=random.sample(docs_tokens_idxs, int(len(docs_tokens_idxs) * 0.1)),
    vocab=vocab,
    alpha=10,
    gamma=1,
    eta=0.1,
    seed=0,
    verbose=True,
    num_levels=3,
)
hlda.estimate(
    num_samples=500,
    display_topics=50,
    n_words=5,
    with_weights=False,
)

HierarchicalLDA sampling

.................................................. 50
topic=0 level=0 (documents=3888): disease, study, high, include, increase, 
    topic=1 level=1 (documents=1195): virus, infection, viral, human, coronavirus, 
        topic=2 level=2 (documents=58): virus, host, rna, species, bat, 
        topic=3 level=2 (documents=214): patient, covid-19, respiratory, clinical, pneumonia, 
        topic=7 level=2 (documents=156): assay, detection, virus, test, method, 
        topic=8 level=2 (documents=191): protein, cell, bind, domain, antibody, 
        topic=12 level=2 (documents=62): sample, virus, sequence, detect, detection, 
        topic=15 level=2 (documents=299): health, public, outbreak, pandemic, care, 
        topic=28 level=2 (documents=40): sequence, genome, viral, gene, database, 
        topic=35 level=2 (documents=15): und, die, der, des, ist, 
        topic=36 level=2 (documents=67): cell, apoptosis, pathway, signal, activation, 
        topic=40 leve

.................................................. 100
topic=0 level=0 (documents=3888): disease, study, include, increase, analysis, 
    topic=1 level=1 (documents=1143): virus, infection, viral, human, influenza, 
        topic=2 level=2 (documents=53): virus, host, species, bat, rna, 
        topic=3 level=2 (documents=195): patient, covid-19, clinical, respiratory, pneumonia, 
        topic=7 level=2 (documents=135): assay, detection, test, sample, method, 
        topic=8 level=2 (documents=178): protein, cell, bind, domain, fusion, 
        topic=12 level=2 (documents=60): sample, sequence, virus, strain, detect, 
        topic=15 level=2 (documents=284): health, public, outbreak, care, pandemic, 
        topic=28 level=2 (documents=44): sequence, genome, viral, rna, gene, 
        topic=35 level=2 (documents=12): und, die, der, des, sars, 
        topic=36 level=2 (documents=72): cell, apoptosis, pathway, signal, activation, 
        topic=40 level=2 (documents=20): protein, ex

.................................................. 150
topic=0 level=0 (documents=3888): disease, study, include, increase, control, 
    topic=1 level=1 (documents=1135): virus, infection, viral, human, influenza, 
        topic=2 level=2 (documents=57): host, virus, species, bat, rna, 
        topic=3 level=2 (documents=192): patient, covid-19, clinical, pneumonia, respiratory, 
        topic=7 level=2 (documents=128): assay, detection, method, sample, test, 
        topic=8 level=2 (documents=179): protein, cell, bind, domain, peptide, 
        topic=12 level=2 (documents=57): sample, sequence, detect, isolate, species, 
        topic=15 level=2 (documents=290): health, public, outbreak, care, pandemic, 
        topic=28 level=2 (documents=45): sequence, genome, viral, rna, gene, 
        topic=35 level=2 (documents=11): und, die, der, des, sars, 
        topic=36 level=2 (documents=71): cell, apoptosis, pathway, signal, activation, 
        topic=40 level=2 (documents=16): protein,

.................................................. 200
topic=0 level=0 (documents=3888): disease, study, include, analysis, new, 
    topic=1 level=1 (documents=1129): virus, infection, viral, human, respiratory, 
        topic=2 level=2 (documents=58): virus, host, species, bat, rna, 
        topic=3 level=2 (documents=186): patient, covid-19, clinical, pneumonia, day, 
        topic=7 level=2 (documents=127): assay, detection, sample, test, method, 
        topic=8 level=2 (documents=177): protein, cell, bind, domain, peptide, 
        topic=12 level=2 (documents=61): sample, sequence, detect, species, canine, 
        topic=15 level=2 (documents=287): health, public, outbreak, care, pandemic, 
        topic=28 level=2 (documents=47): sequence, genome, viral, datum, database, 
        topic=35 level=2 (documents=12): und, die, der, des, ist, 
        topic=36 level=2 (documents=72): cell, apoptosis, pathway, signal, expression, 
        topic=40 level=2 (documents=16): protein, expre

.................................................. 250
topic=0 level=0 (documents=3888): disease, study, include, increase, control, 
    topic=1 level=1 (documents=1112): virus, infection, viral, human, respiratory, 
        topic=2 level=2 (documents=53): virus, host, species, bat, rna, 
        topic=3 level=2 (documents=180): patient, covid-19, clinical, pneumonia, case, 
        topic=7 level=2 (documents=123): assay, detection, method, sample, test, 
        topic=8 level=2 (documents=171): protein, cell, bind, domain, peptide, 
        topic=12 level=2 (documents=62): sample, sequence, virus, isolate, species, 
        topic=15 level=2 (documents=290): health, public, outbreak, care, pandemic, 
        topic=28 level=2 (documents=48): sequence, genome, viral, rna, gene, 
        topic=35 level=2 (documents=10): und, die, der, des, ist, 
        topic=36 level=2 (documents=73): cell, apoptosis, pathway, signal, protein, 
        topic=40 level=2 (documents=15): protein, expressio

.................................................. 300
topic=0 level=0 (documents=3888): disease, study, include, control, model, 
    topic=1 level=1 (documents=1094): virus, infection, viral, human, coronavirus, 
        topic=2 level=2 (documents=47): host, virus, species, bat, rna, 
        topic=3 level=2 (documents=175): patient, covid-19, clinical, respiratory, pneumonia, 
        topic=7 level=2 (documents=123): assay, detection, test, sample, method, 
        topic=8 level=2 (documents=171): protein, cell, bind, domain, peptide, 
        topic=12 level=2 (documents=61): sample, sequence, detect, isolate, species, 
        topic=15 level=2 (documents=284): health, public, outbreak, care, pandemic, 
        topic=28 level=2 (documents=46): sequence, genome, viral, rna, gene, 
        topic=35 level=2 (documents=13): und, die, der, des, sars, 
        topic=36 level=2 (documents=73): cell, apoptosis, pathway, signal, activation, 
        topic=40 level=2 (documents=18): protein, 

.................................................. 350
topic=0 level=0 (documents=3888): disease, study, include, control, infectious, 
    topic=1 level=1 (documents=1109): virus, infection, viral, human, respiratory, 
        topic=2 level=2 (documents=49): host, species, virus, bat, rna, 
        topic=3 level=2 (documents=175): patient, covid-19, clinical, pneumonia, respiratory, 
        topic=7 level=2 (documents=120): assay, detection, sample, test, method, 
        topic=8 level=2 (documents=169): protein, cell, bind, domain, peptide, 
        topic=12 level=2 (documents=67): sample, sequence, detect, isolate, species, 
        topic=15 level=2 (documents=285): health, public, outbreak, care, pandemic, 
        topic=28 level=2 (documents=49): sequence, genome, viral, rna, datum, 
        topic=35 level=2 (documents=14): und, die, der, des, sars, 
        topic=36 level=2 (documents=74): cell, apoptosis, pathway, signal, protein, 
        topic=40 level=2 (documents=15): protei

.................................................. 400
topic=0 level=0 (documents=3888): disease, study, include, model, control, 
    topic=1 level=1 (documents=1098): virus, infection, viral, human, respiratory, 
        topic=2 level=2 (documents=50): host, virus, species, bat, evolution, 
        topic=3 level=2 (documents=173): patient, covid-19, pneumonia, clinical, case, 
        topic=7 level=2 (documents=123): assay, detection, test, sample, pcr, 
        topic=8 level=2 (documents=167): protein, cell, bind, domain, membrane, 
        topic=12 level=2 (documents=56): sample, sequence, detect, isolate, species, 
        topic=15 level=2 (documents=278): health, public, outbreak, care, pandemic, 
        topic=28 level=2 (documents=52): sequence, genome, datum, gene, approach, 
        topic=35 level=2 (documents=10): und, die, der, des, sars, 
        topic=36 level=2 (documents=76): cell, apoptosis, pathway, signal, induce, 
        topic=40 level=2 (documents=16): protein, ex

.................................................. 450
topic=0 level=0 (documents=3888): disease, study, include, control, increase, 
    topic=1 level=1 (documents=1101): virus, infection, viral, human, respiratory, 
        topic=2 level=2 (documents=46): host, virus, species, bat, rna, 
        topic=3 level=2 (documents=168): patient, covid-19, clinical, pneumonia, respiratory, 
        topic=7 level=2 (documents=118): assay, detection, test, method, sample, 
        topic=8 level=2 (documents=164): protein, cell, bind, domain, membrane, 
        topic=12 level=2 (documents=64): sample, sequence, detect, isolate, species, 
        topic=15 level=2 (documents=282): health, public, outbreak, care, pandemic, 
        topic=28 level=2 (documents=50): sequence, genome, viral, rna, datum, 
        topic=35 level=2 (documents=16): und, die, der, des, ist, 
        topic=36 level=2 (documents=76): cell, apoptosis, pathway, signal, protein, 
        topic=40 level=2 (documents=16): protein,

.................................................. 500
topic=0 level=0 (documents=3888): disease, study, include, control, high, 
    topic=1 level=1 (documents=1082): virus, infection, viral, human, respiratory, 
        topic=2 level=2 (documents=53): host, virus, species, bat, evolution, 
        topic=3 level=2 (documents=168): patient, covid-19, pneumonia, clinical, respiratory, 
        topic=7 level=2 (documents=120): assay, detection, method, sample, test, 
        topic=8 level=2 (documents=168): protein, cell, bind, domain, membrane, 
        topic=12 level=2 (documents=58): sample, sequence, species, detect, isolate, 
        topic=15 level=2 (documents=273): health, public, outbreak, care, pandemic, 
        topic=28 level=2 (documents=48): sequence, genome, viral, datum, rna, 
        topic=35 level=2 (documents=11): und, die, der, des, sars, 
        topic=36 level=2 (documents=74): cell, apoptosis, pathway, signal, expression, 
        topic=40 level=2 (documents=17): pr

Sampling a 10% of the total papers results in a sub-dataset of about 3.888 papers.
The number of topics of each level is determined by the Chinese Restaurant Process and can be influenced by tweaking the `alpha` and `gamma` hyperparameters.
Training the model on the 10% sample took about an hour for each 50 iterations.

To avoid spending time retraining the model, it'll be dumped to be able to load it in posterior experiments.

In [None]:
import pickle

# Dump the model
with open("hlda.pkl", "wb") as dump_file:
    pickle.dump(hlda, dump_file)

Now, we'll load the dumped model.

In [None]:
import pickle

# Load the model
with open("hlda.pkl", "rb") as dump_file:
    hlda = pickle.load(dump_file)

## Manual Hierarchical LDA

In this section we'll attempt to manually build a hierarchical topic model.
Essentially, using the same number of topics found by the hLDA technique at `level=1`, we'll model topics using the standard LDA.
Afterwards, for each group of documents of the first level topics, we'll run a new LDA topic modelling step.

In [None]:
from sklearn.feature_extraction.text import CountVectorizer

from risotto.lda import tokenizer


def get_lda_corpus(docs, max_vocab_size=2**13):
    count_vectorizer = CountVectorizer(
        tokenizer=tokenizer,
        lowercase=True,
        max_features=max_vocab_size,
    )
    vectorized_docs = count_vectorizer.fit_transform(docs)
    return vectorized_docs, count_vectorizer

vectorized_docs, count_vectorizer = get_lda_corpus(docs)

In [None]:
from sklearn.decomposition import LatentDirichletAllocation

lda = LatentDirichletAllocation(
    n_components=8,
    verbose=2,
    n_jobs=4,
)
lda = lda.fit(vectorized_docs)

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   17.0s remaining:   17.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   18.1s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.


iteration: 1 of max_iter: 10


[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   13.7s remaining:   13.7s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   14.6s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.


iteration: 2 of max_iter: 10


[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   10.7s remaining:   10.7s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   12.0s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.


iteration: 3 of max_iter: 10


[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    9.3s remaining:    9.3s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   10.4s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.


iteration: 4 of max_iter: 10


[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    7.4s remaining:    7.4s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    8.5s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.


iteration: 5 of max_iter: 10


[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    7.2s remaining:    7.2s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    7.8s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.


iteration: 6 of max_iter: 10


[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    6.7s remaining:    6.7s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    7.3s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.


iteration: 7 of max_iter: 10


[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    6.2s remaining:    6.2s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    6.9s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.


iteration: 8 of max_iter: 10


[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    6.2s remaining:    6.2s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    6.9s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.


iteration: 9 of max_iter: 10


[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    6.1s remaining:    6.1s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    6.6s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.


iteration: 10 of max_iter: 10


[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    5.6s remaining:    5.6s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    6.3s finished


The following cell will print the most relevant tokens of each modelled component.

In [None]:
from risotto.lda import print_topic_words

print_topic_words(lda, count_vectorizer, 5)

Topic 0: patient respiratory infection virus influenza
Topic 1: virus viral infection host human
Topic 2: de cat la en disease
Topic 3: drug activity disease development new
Topic 4: cell mouse infection protein expression
Topic 5: protein virus sequence vaccine gene
Topic 6: sample virus detection assay test
Topic 7: health disease outbreak model case


Now, we'll build the groups of papers belonging to the different modelled topics.

In [None]:
from collections import defaultdict

from scipy.sparse import vstack
from fastprogress.fastprogress import progress_bar


def group_docs_by_topics(model, vectorized_docs):
    docs_classified = lda.transform(vectorized_docs)
    docs_topics = docs_classified.argmax(1)
    clustered_docs = defaultdict(list)

    for vectorized_doc, topic_idx in zip(vectorized_docs, docs_topics):
        clustered_docs[topic_idx].append(vectorized_doc)

    stacked_clustered_docs = {}
    for topic_idx, docs_list in clustered_docs.items():
        stacked_clustered_docs[topic_idx] = vstack(docs_list)
    
    return stacked_clustered_docs

grouped_docs = group_docs_by_topics(lda, vectorized_docs)

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    5.3s remaining:    5.3s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    6.0s finished


In [None]:
grouped_docs

{6: <3264x8192 sparse matrix of type '<class 'numpy.int64'>'
 	with 233098 stored elements in Compressed Sparse Row format>,
 1: <4462x8192 sparse matrix of type '<class 'numpy.int64'>'
 	with 293323 stored elements in Compressed Sparse Row format>,
 7: <9418x8192 sparse matrix of type '<class 'numpy.int64'>'
 	with 544812 stored elements in Compressed Sparse Row format>,
 0: <5166x8192 sparse matrix of type '<class 'numpy.int64'>'
 	with 350595 stored elements in Compressed Sparse Row format>,
 5: <4973x8192 sparse matrix of type '<class 'numpy.int64'>'
 	with 352836 stored elements in Compressed Sparse Row format>,
 4: <4877x8192 sparse matrix of type '<class 'numpy.int64'>'
 	with 355218 stored elements in Compressed Sparse Row format>,
 3: <4549x8192 sparse matrix of type '<class 'numpy.int64'>'
 	with 264729 stored elements in Compressed Sparse Row format>,
 2: <2173x8192 sparse matrix of type '<class 'numpy.int64'>'
 	with 86607 stored elements in Compressed Sparse Row format>}

Then, for each paper group, we'll run LDA on them.

In [None]:
print_topic_words(lda, count_vectorizer, 5)

models = {}
for topic_idx, group_docs in progress_bar(grouped_docs.items()):
    print(f"Topic ID #{topic_idx}; documents = {group_docs.shape[0]}")
    
    models[topic_idx] = LatentDirichletAllocation(
        n_components=4,
        verbose=0,
        n_jobs=4,
    )
    models[topic_idx] = models[topic_idx].fit(group_docs)
    
    print_topic_words(models[topic_idx], count_vectorizer, 5)
    print("\n", end="")

Topic 0: patient respiratory infection virus influenza
Topic 1: virus viral infection host human
Topic 2: de cat la en disease
Topic 3: drug activity disease development new
Topic 4: cell mouse infection protein expression
Topic 5: protein virus sequence vaccine gene
Topic 6: sample virus detection assay test
Topic 7: health disease outbreak model case


Topic ID #6; documents = 3264
Topic 0: assay detection virus sample test
Topic 1: calve diarrhea group pedv pig
Topic 2: virus air study particle result
Topic 3: virus sample disease infection study

Topic ID #1; documents = 4462
Topic 0: viral protein response virus infection
Topic 1: virus cell protein viral replication
Topic 2: coronavirus respiratory mers-cov infection vaccine
Topic 3: virus human bat disease host

Topic ID #7; documents = 9418
Topic 0: patient infection care hospital respiratory
Topic 1: model case epidemic datum covid-19
Topic 2: virus disease human infection transmission
Topic 3: health disease public infectious system

Topic ID #0; documents = 5166
Topic 0: influenza virus respiratory test infection
Topic 1: patient infection respiratory syndrome acute
Topic 2: respiratory infection virus child viral
Topic 3: patient study group clinical covid-19

Topic ID #5; documents = 4973
Topic 0: protein antibody bind epitope sars-cov
Topic 1: sequence gene virus genome a