In [1]:
import numpy as np
import pandas as pd
from hlda_utils import *
from hlda_final import *

from sklearn.datasets import fetch_20newsgroups
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer, WordNetLemmatizer

In [2]:
###########################################
# LOAD & PREPROCESS 20 NEWSGROUPS
###########################################

dataset = fetch_20newsgroups(subset='train', remove=('headers','footers','quotes'))
all_docs = dataset.data
all_labels = dataset.target

print("Total raw documents in 20newsgroups (train):", len(all_docs))

_, _, vocab, word2idx, idx2word, corpus_indices = \
    full_preprocessing_pipeline(all_docs, all_labels, 
                                min_word_length=2, 
                                min_freq=100)

print(f"Vocabulary size after min_freq=100: {len(vocab)}")

Total raw documents in 20newsgroups (train): 11314
Documents after filtering: 10999
Vocabulary size: 1998
Preprocessing complete.
Vocabulary size after min_freq=100: 1998


In [3]:
###########################################
# 1) SELECT 500 SUITABLY-LONG DOCS
###########################################

selected_corpus = []
selected_count = 0
for doc in corpus_indices:
    if len(doc) >= 100:
        selected_corpus.append(doc)
        selected_count += 1
        if selected_count == 500:
            break

print(f"Number of selected documents: {len(selected_corpus)} (each >= 100 tokens).")


Number of selected documents: 500 (each >= 100 tokens).


## Synthetic data test

In [4]:
ground_truth_model = hLDA(
    corpus=selected_corpus,
    vocabulary=vocab,
    L=3,
    alpha=10.0,
    gamma=1.0,
    eta=0.1,  
    seed=42,
    verbose=False
)

time = ground_truth_model.gibbs_sampling(
    iterations=5000,
    display_interval=5000,
    top_n=5
)

Starting hierarchical LDA sampling...
Iteration 5000
Topic Node 0 (level=0, docs=500): to(2214), it(1679), is(1625), you(1553), that(1310)
    Topic Node 1 (level=1, docs=260): the(3467), of(1595), in(1084), and(1073), to(952)
        Topic Node 2 (level=2, docs=195): the(1943), of(1230), and(769), that(746), to(737)
        Topic Node 6 (level=2, docs=3): period(42), pp(37), play(30), power(29), pt(27)
        Topic Node 11 (level=2, docs=30): the(205), of(54), is(50), chip(49), and(49)
        Topic Node 14 (level=2, docs=10): the(85), of(85), and(79), turkish(31), presid(21)
        Topic Node 72 (level=2, docs=6): of(48), and(46), for(36), in(29), sequenc(24)
        Topic Node 286 (level=2, docs=13): the(26), run(22), program(21), have(18), window(18)
        Topic Node 8262 (level=2, docs=3): may(20), ma(16), june(13), april(9), me(9)
    Topic Node 3 (level=1, docs=166): the(2282), of(1048), and(644), in(573), to(550)
        Topic Node 4 (level=2, docs=82): the(973), and(690), 

In [8]:
doc_lengths = [len(doc) for doc in selected_corpus]
avg_length = int(np.mean(doc_lengths))
print(f"Average document length in selected corpus = {avg_length} tokens")

# Generate 100 synthetic documents using the ground truth model
num_synth_docs = 500
synth_corpus = ground_truth_model.generate_synthetic_corpus(
    num_docs=num_synth_docs,
    words_per_doc=avg_length,
    alpha_dir=10.0,   # or any other Dirichlet parameter you prefer
    rng=None          # uses the model's default RNG
)
print(f"Generated {len(synth_corpus)} synthetic documents, each ~{avg_length} words")


Average document length in selected corpus = 321 tokens
Generated 500 synthetic documents, each ~321 words


In [9]:
synthetic_model = hLDA(
    corpus=synth_corpus,
    vocabulary=vocab,
    L=3,
    alpha=10.0,
    gamma=1.0,
    eta=0.1,
    seed=42,
    verbose=False
)

time_synth = synthetic_model.gibbs_sampling(
    iterations=5000,
    display_interval=5000,
    top_n=5,
    show=True
)
print(f"Synthetic-data model trained in ~{time_synth} minutes.\n")

Starting hierarchical LDA sampling...
Iteration 5000
Topic Node 0 (level=0, docs=500): the(3603), to(2635), and(1782), is(1695), it(1647)
    Topic Node 1 (level=1, docs=433): the(3506), of(2216), and(1355), in(1268), to(1134)
        Topic Node 2 (level=2, docs=41): vote(123), your(73), will(65), be(65), the(58)
        Topic Node 3 (level=2, docs=25): the(156), of(130), and(115), turkish(51), presid(37)
        Topic Node 6 (level=2, docs=331): the(2607), is(1198), of(905), to(821), that(615)
        Topic Node 9 (level=2, docs=19): period(137), pp(111), pt(94), play(92), power(82)
        Topic Node 13 (level=2, docs=11): none(41), kill(18), of(18), were(17), attack(15)
        Topic Node 22 (level=2, docs=6): lo(21), lead(16), cub(15), pitcher(14), stl(13)
    Topic Node 4 (level=1, docs=8): ad(8), two(7), pressur(7), were(5), issu(5)
        Topic Node 23 (level=2, docs=8): the(83), good(53), veri(48), excel(47), cover(26)
    Topic Node 7 (level=1, docs=18): the(219), of(91), in(

In [10]:
print("\n--- Ground Truth Model Structure (brief) ---")
ground_truth_model.exhibit_topics(top_n=5, structure=True)

print("\n--- New Synthetic Model Structure (brief) ---")
synthetic_model.exhibit_topics(top_n=5, structure=True)


--- Ground Truth Model Structure (brief) ---
Topic Node 0 (level=0, children=4)
  Topic Node 1 (level=1, children=7)
    Topic Node 2 (level=2, children=0)
    Topic Node 6 (level=2, children=0)
    Topic Node 11 (level=2, children=0)
    Topic Node 14 (level=2, children=0)
    Topic Node 72 (level=2, children=0)
    Topic Node 286 (level=2, children=0)
    Topic Node 8262 (level=2, children=0)
  Topic Node 3 (level=1, children=8)
    Topic Node 4 (level=2, children=0)
    Topic Node 5 (level=2, children=0)
    Topic Node 15 (level=2, children=0)
    Topic Node 51 (level=2, children=0)
    Topic Node 1079 (level=2, children=0)
    Topic Node 6676 (level=2, children=0)
    Topic Node 8787 (level=2, children=0)
    Topic Node 8859 (level=2, children=0)
  Topic Node 7 (level=1, children=4)
    Topic Node 8 (level=2, children=0)
    Topic Node 9 (level=2, children=0)
    Topic Node 201 (level=2, children=0)
    Topic Node 8864 (level=2, children=0)
  Topic Node 12 (level=1, children=8)
  

In [11]:
doc_lengths = [len(doc) for doc in selected_corpus]
avg_length = int(np.mean(doc_lengths))
print(f"Average document length in selected corpus = {avg_length} tokens")

# Generate 100 synthetic documents using the ground truth model
num_synth_docs = 100
synth_corpus1 = ground_truth_model.generate_synthetic_corpus(
    num_docs=num_synth_docs,
    words_per_doc=avg_length,
    alpha_dir=10.0,   # or any other Dirichlet parameter you prefer
    rng=None          # uses the model's default RNG
)
print(f"Generated {len(synth_corpus)} synthetic documents, each ~{avg_length} words")


Average document length in selected corpus = 321 tokens
Generated 500 synthetic documents, each ~321 words


In [12]:
synthetic_model1 = hLDA(
    corpus=synth_corpus,
    vocabulary=vocab,
    L=3,
    alpha=10.0,
    gamma=1.0,
    eta=0.1,
    seed=42,
    verbose=False
)

time_synth1 = synthetic_model.gibbs_sampling(
    iterations=5000,
    display_interval=5000,
    top_n=5,
    show=True
)
print(f"Synthetic-data model trained in ~{time_synth} minutes.\n")

Starting hierarchical LDA sampling...
Iteration 5000
Topic Node 0 (level=0, docs=500): the(3786), to(2604), is(1768), and(1537), for(1497)
    Topic Node 1 (level=1, docs=430): the(3460), of(2385), and(1558), in(1429), to(1104)
        Topic Node 2 (level=2, docs=33): vote(124), will(68), be(61), your(59), the(51)
        Topic Node 3 (level=2, docs=24): of(143), and(108), the(104), turkish(49), presid(41)
        Topic Node 6 (level=2, docs=335): the(2518), is(1131), of(943), to(827), it(559)
        Topic Node 9 (level=2, docs=19): period(137), pp(111), pt(93), play(90), power(77)
        Topic Node 13 (level=2, docs=15): of(65), and(60), none(44), were(21), kill(20)
        Topic Node 22 (level=2, docs=4): lo(21), lead(15), cub(15), pitcher(14), stl(13)
    Topic Node 4 (level=1, docs=7): orient(6), two(5), howev(5), edit(5), is(4)
        Topic Node 23 (level=2, docs=7): the(58), good(49), excel(48), veri(45), of(33)
    Topic Node 7 (level=1, docs=16): the(182), of(94), and(62), i

In [13]:
print("\n--- Ground Truth Model Structure (brief) ---")
ground_truth_model.exhibit_topics(top_n=5, structure=True)

print("\n--- New Synthetic Model Structure (brief) ---")
synthetic_model1.exhibit_topics(top_n=5, structure=True)


--- Ground Truth Model Structure (brief) ---
Topic Node 0 (level=0, children=4)
  Topic Node 1 (level=1, children=7)
    Topic Node 2 (level=2, children=0)
    Topic Node 6 (level=2, children=0)
    Topic Node 11 (level=2, children=0)
    Topic Node 14 (level=2, children=0)
    Topic Node 72 (level=2, children=0)
    Topic Node 286 (level=2, children=0)
    Topic Node 8262 (level=2, children=0)
  Topic Node 3 (level=1, children=8)
    Topic Node 4 (level=2, children=0)
    Topic Node 5 (level=2, children=0)
    Topic Node 15 (level=2, children=0)
    Topic Node 51 (level=2, children=0)
    Topic Node 1079 (level=2, children=0)
    Topic Node 6676 (level=2, children=0)
    Topic Node 8787 (level=2, children=0)
    Topic Node 8859 (level=2, children=0)
  Topic Node 7 (level=1, children=4)
    Topic Node 8 (level=2, children=0)
    Topic Node 9 (level=2, children=0)
    Topic Node 201 (level=2, children=0)
    Topic Node 8864 (level=2, children=0)
  Topic Node 12 (level=1, children=8)
  

| **Model**       | **Total Nodes** | **Nodes at Level 2** | **Nodes at Level 3** | **No. of Documents** |
|-----------------|-----------------|----------------------|----------------------|----------------------|
| "Ground Truth"    | 32              | 4                    | 27                   | 500                  |
| Synthetic 1     | 18              | 5                    | 12                   | 500                  |
| Synthetic 2     | 30              | 6                    | 22                   | 100                  |

## Toy example

In [3]:
vocab = ["cat","dog","mouse","pizza","music","food","film","code"]
docs = [
    [0,0,1,1,6,7],  # "cat cat dog dog film code"
    [3,3,4,4,5,2],  # "pizza pizza music music food mouse"
    [0,3,4,1,7,7],  # mixture of cat,pizza,music,dog,code...
]

In [4]:
model = hLDA(
    corpus=docs,
    vocabulary=vocab,
    L=2,      # root + 1 child level
    alpha=5.0,
    gamma=1.0,
    eta=0.1,
    seed=42,
)

model.gibbs_sampling(iterations=20, display_interval=10, top_n=3, show=True)

Starting hierarchical LDA sampling...
Iteration 10
Topic Node 0 (level=0, docs=3): code(3), dog(3), mouse(1)
    Topic Node 4 (level=1, docs=2): music(3), pizza(3), food(1)
    Topic Node 5 (level=1, docs=1): cat(2), film(1), code(0)
Iteration 20
Topic Node 0 (level=0, docs=3): music(3), pizza(3), dog(3)
    Topic Node 4 (level=1, docs=3): code(3), cat(3), mouse(1)
Total topic nodes created = 2


0.0

In [5]:
print("\n--- Tree Structure (structure=True) ---")
model.exhibit_topics(structure=True)


--- Tree Structure (structure=True) ---
Topic Node 0 (level=0, children=1)
  Topic Node 4 (level=1, children=0)


In [6]:
syn_corpus = model.generate_synthetic_corpus(
    num_docs=3,          # create 3 new docs
    words_per_doc=6,     # each doc has 6 words
    alpha_dir=3.0,       # Dirichlet parameter for doc-level mixing
    rng=None             # uses model's default random state
)
print("\n--- Synthetic Docs ---")
for i, sdoc in enumerate(syn_corpus):
    text_tokens = [vocab[w] for w in sdoc]
    print(f"Doc {i}: {text_tokens}")


--- Synthetic Docs ---
Doc 0: ['cat', 'film', 'pizza', 'code', 'code', 'dog']
Doc 1: ['pizza', 'pizza', 'dog', 'music', 'food', 'cat']
Doc 2: ['film', 'food', 'code', 'music', 'mouse', 'music']
