In [1]:
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer, WordNetLemmatizer
from collections import defaultdict

from hlda_utils import *
from hlda_final import HLDA_Node, HierarchicalLDA

In [2]:
# Fetch the newsgroups dataset (train subset)
newsgroups_data = fetch_20newsgroups(
    subset='train', 
    remove=('headers', 'footers', 'quotes')
)

# Select only the first 2000 documents
raw_docs = newsgroups_data.data[:2000]  # Slice to get only the first 2000 docs
labels = newsgroups_data.target[:2000]  # Slice to match the labels
target_names = newsgroups_data.target_names

# Print dataset information
print(f"Number of documents: {len(raw_docs)}")
print(f"Number of categories: {len(target_names)}")

Number of documents: 2000
Number of categories: 20


In [3]:
# Apply preprocessing pipeline
filtered_docs, filtered_labels, vocab, word2idx, idx2word, corpus = full_preprocessing_pipeline(
    raw_docs, 
    labels, 
    stop_words=None, 
    stemmer=PorterStemmer(), 
    lemmatizer=WordNetLemmatizer(),
    min_word_length=2, 
    min_freq=5
)

# Checking document lengths after filtering
doc_lengths = [len(doc) for doc in filtered_docs]
print("")
print(f"After filtering docs:")
print(f"  Avg length: {np.mean(doc_lengths):.2f} tokens")
print(f"  Median length: {np.median(doc_lengths)}")
print(f"  Min length: {np.min(doc_lengths)}")
print(f"  Max length: {np.max(doc_lengths)}")

Number of documents after filtering empty ones: 1940
Number of labels after filtering: 1940

Vocabulary size (words with freq >= 5): 4805
Sample vocabulary words: ['aa', 'aaa', 'ab', 'abbrevi', 'abc', 'abil', 'abl', 'abolish', 'abort', 'abraham']
Preprecessing done

After filtering docs:
  Avg length: 90.41 tokens
  Median length: 42.0
  Min length: 1
  Max length: 4780


In [4]:
# Initialize the HLDA model
hlda_model = HierarchicalLDA(
    corpus=corpus,
    vocabulary=vocab,
    gamma = 1,
    eta = 0.01,
    alpha = 5,
    levels= 5
)

hlda_model.gibbs_sampling(iterations=2000,
                          topic_display_interval=400, 
                          top_n_words=8, 
                          show_word_counts=True)

Starting Hierarchical LDA sampling

10% done
20% done
*********************The 1 result**************************
topic=0 level=0 (docs=1940): one (442), mani (295), may (277), peopl (264), two (259), also (254), call (252), u (242)
    topic=1 level=1 (docs=790): would (380), know (235), one (232), think (212), go (211), like (203), say (200), could (195)
        topic=2 level=2 (docs=186): use (90), get (51), new (36), work (36), also (30), problem (27), set (25), system (25)
            topic=3 level=3 (docs=29): isbn (12), includ (11), power (11), acceler (9), book (8), technic (8), planetari (6), explor (6)
                topic=188 level=4 (docs=5): god (21), question (11), truth (9), say (8), human (7), argument (5), rule (5), deni (4)
                topic=265 level=4 (docs=5): space (34), propuls (27), satellit (18), lunar (18), rocket (16), design (15), fusion (14), refer (13)
                topic=303 level=4 (docs=6): god (13), christian (11), sabbath (6), work (5), serb (4

In [5]:
hlda_model.exhibit_nodes(8,True)

topic=0 level=0 (docs=1940): also (293), two (291), may (290), mani (272), u (225), one (223), ask (215), first (214)
    topic=1 level=1 (docs=778): would (384), peopl (266), know (229), go (219), one (203), say (190), like (189), could (189)
        topic=2 level=2 (docs=158): use (78), get (50), work (45), thank (25), time (25), need (25), would (24), new (21)
            topic=3 level=3 (docs=21): say (7), rule (7), go (7), think (6), rumor (4), troubl (4), earth (3), genocid (3)
                topic=188 level=4 (docs=4): god (21), question (12), truth (10), argument (6), human (5), natur (4), proof (4), contradict (4)
                topic=303 level=4 (docs=5): god (12), christian (7), jesu (4), ethnic (4), clean (4), serb (4), convert (3), follow (3)
                topic=449 level=4 (docs=8): flyer (6), pick (6), cap (5), mike (5), trade (5), daigl (4), montreal (4), name (4)
                topic=5323 level=4 (docs=4): right (6), moon (4), mine (4), un (4), nation (4), whatev 

In [4]:
## Trying a samller tree which is harder to branch out.
hlda_model2 = HierarchicalLDA(
    corpus=corpus,
    vocabulary=vocab,
    gamma = 0.05,
    eta = 0.01,
    alpha = 10,
    levels=3
)

hlda_model2.gibbs_sampling(iterations=2000,
                          topic_display_interval=400, 
                          top_n_words=5, 
                          show_word_counts=True)

Starting Hierarchical LDA sampling

10% done
20% done
*********************The 1 result**************************
topic=0 level=0 (docs=1940): also (395), mani (316), may (296), first (277), u (244)
    topic=1 level=1 (docs=430): would (291), one (255), get (202), know (200), like (191)
        topic=2 level=2 (docs=41): armenian (135), peopl (73), turkish (59), kill (40), russian (39)
        topic=3 level=2 (docs=7): armenian (87), u (72), peopl (61), said (54), say (44)
        topic=7 level=2 (docs=3): father (44), son (36), spirit (32), holi (22), creed (18)
        topic=10 level=2 (docs=4): shuttl (22), roll (11), launch (10), maneuv (9), space (9)
        topic=16 level=2 (docs=3): day (21), oxal (21), stone (20), kidney (18), vitamin (18)
        topic=17 level=2 (docs=18): tape (63), use (54), drive (48), driver (44), adaptec (41)
        topic=18 level=2 (docs=2): max (651), giz (141), bhj (123), ql (63), wm (63)
        topic=29 level=2 (docs=21): run (17), server (16), pr

In [7]:
hlda_model2.exhibit_nodes(8,True)

topic=0 level=0 (docs=1940): also (355), mani (306), year (261), may (259), first (252), seem (249), u (247), sinc (232)
    topic=1 level=1 (docs=332): would (200), one (183), like (163), use (161), get (145), know (127), time (110), well (110)
        topic=2 level=2 (docs=14): armenian (138), peopl (40), turkish (38), russian (37), genocid (33), armi (25), muslim (25), massacr (22)
        topic=3 level=2 (docs=7): armenian (87), said (71), u (70), peopl (68), say (64), one (49), start (44), apart (44)
        topic=7 level=2 (docs=3): father (44), son (36), spirit (32), holi (22), creed (18), proce (14), god (8), proceed (8)
        topic=17 level=2 (docs=8): tape (63), drive (46), adaptec (41), use (39), driver (39), problem (33), window (31), disk (25)
        topic=18 level=2 (docs=2): max (651), giz (141), bhj (123), ql (63), wm (63), ax (45), bj (39), gk (39)
        topic=29 level=2 (docs=11): server (11), program (11), request (10), segment (9), error (9), ncd (7), crash (7)

In [5]:
hlda_model3 = HierarchicalLDA(
    corpus=corpus,
    vocabulary=vocab,
    gamma = 0.01,
    eta = 0.01,
    alpha = 10,
    levels=3
)

hlda_model3.gibbs_sampling(iterations=2000,
                          topic_display_interval=400, 
                          top_n_words=5, 
                          show_word_counts=True)

Starting Hierarchical LDA sampling

10% done
20% done
*********************The 1 result**************************
topic=0 level=0 (docs=1940): two (292), point (276), new (271), first (251), also (249)
    topic=1 level=1 (docs=471): one (276), would (237), use (223), know (174), get (170)
        topic=2 level=2 (docs=28): armenian (205), turkish (106), genocid (82), peopl (75), muslim (51)
        topic=3 level=2 (docs=3): period (42), pp (37), play (34), power (33), scorer (24)
        topic=4 level=2 (docs=2): mv (58), ah (48), sq (35), q (32), zv (31)
        topic=6 level=2 (docs=24): avail (85), widget (63), version (58), includ (50), server (48)
        topic=10 level=2 (docs=6): doug (25), undefin (12), symbol (12), motif (10), librari (6)
        topic=21 level=2 (docs=2): max (651), giz (141), bhj (123), ql (63), wm (63)
        topic=22 level=2 (docs=9): brave (7), team (7), player (7), whatev (6), houston (5)
        topic=27 level=2 (docs=11): tape (68), drive (46), adapt

In [8]:
hlda_model3.exhibit_nodes(8,True)

topic=0 level=0 (docs=1940): first (306), year (291), new (279), may (270), right (266), point (261), also (244), post (233)
    topic=1 level=1 (docs=415): one (223), would (182), know (155), use (154), get (141), like (135), time (115), make (109)
        topic=2 level=2 (docs=22): armenian (196), turkish (83), genocid (70), peopl (53), muslim (40), govern (39), armenia (39), russian (37)
        topic=3 level=2 (docs=4): period (42), pp (37), play (34), power (33), scorer (24), pt (23), philadelphia (15), calgari (15)
        topic=4 level=2 (docs=2): mv (58), ah (48), sq (35), q (32), zv (31), hz (30), ri (29), xte (24)
        topic=6 level=2 (docs=31): avail (87), widget (66), version (59), includ (57), support (48), server (47), motif (43), use (41)
        topic=21 level=2 (docs=2): max (651), giz (141), bhj (123), wm (63), ql (63), ax (45), gk (39), bj (39)
        topic=22 level=2 (docs=14): year (16), brave (15), last (7), pitch (7), lopez (7), atlanta (6), hit (6), young (6

In [6]:
hlda_model4 = HierarchicalLDA(
    corpus=corpus,
    vocabulary=vocab,
    gamma = 0.05,
    eta = 0.01,
    alpha = 5,
    levels=3
)

hlda_model4.gibbs_sampling(iterations=2000,
                          topic_display_interval=400, 
                          top_n_words=5, 
                          show_word_counts=True)

Starting Hierarchical LDA sampling

10% done
20% done
*********************The 1 result**************************
topic=0 level=0 (docs=1940): also (373), first (317), may (285), make (284), time (280)
    topic=1 level=1 (docs=267): would (182), know (147), get (135), one (128), like (122)
        topic=2 level=2 (docs=31): armenian (211), peopl (114), said (72), u (71), one (64)
        topic=53 level=2 (docs=10): armenian (71), turkish (49), genocid (47), peopl (23), govern (22)
        topic=55 level=2 (docs=16): offer (21), ship (11), best (10), use (10), plea (9)
        topic=61 level=2 (docs=4): air (15), exhaust (11), pressur (8), intak (8), system (7)
        topic=63 level=2 (docs=9): game (29), disk (12), sne (11), save (6), trade (5)
        topic=66 level=2 (docs=12): run (10), compil (9), crash (7), gcc (6), server (6)
        topic=82 level=2 (docs=8): secur (12), inform (7), director (6), system (6), list (5)
        topic=87 level=2 (docs=11): jerk (14), wheeli (10), 

In [9]:
hlda_model4.exhibit_nodes(8,True)

topic=0 level=0 (docs=1940): one (450), also (425), peopl (363), time (345), first (311), mani (293), may (279), point (278)
    topic=1 level=1 (docs=231): would (145), get (116), like (97), want (84), one (83), know (80), think (73), work (64)
        topic=2 level=2 (docs=16): armenian (113), peopl (82), u (76), said (72), say (67), one (55), woman (53), went (49)
        topic=53 level=2 (docs=7): armenian (57), turkish (48), genocid (43), govern (22), peopl (21), armenia (17), muslim (16), million (14)
        topic=55 level=2 (docs=17): offer (17), best (13), ship (11), sale (8), condit (8), includ (7), speaker (7), paid (6)
        topic=61 level=2 (docs=3): air (15), exhaust (11), intak (8), pressur (8), flow (6), system (5), aftermarket (5), vacuum (4)
        topic=63 level=2 (docs=4): game (21), disk (12), sne (10), save (6), floppi (5), copi (4), play (4), super (4)
        topic=66 level=2 (docs=17): use (15), compil (13), problem (11), run (11), motif (10), librari (7), g

## Synthetic data generation

In [4]:
corpus = corpus[:100]
# Initialize the HLDA model
hlda_model5 = HierarchicalLDA(
    corpus=corpus,
    vocabulary=vocab,
    gamma = 1,
    eta = 0.01,
    alpha = 5,
    levels= 5
)
hlda_model5.gibbs_sampling(iterations=1000,
                          topic_display_interval=1000, 
                          top_n_words=8, 
                          show_word_counts= True)

Starting Hierarchical LDA sampling

10% done
20% done
30% done
40% done
50% done
60% done
70% done
80% done
90% done
100% done
*********************The 1 result**************************
topic=0 level=0 (docs=100): one (60), know (33), think (32), get (30), time (27), problem (23), mani (22), go (20)
    topic=1 level=1 (docs=29): use (24), would (19), system (11), could (9), run (7), control (6), anyon (5), data (5)
        topic=2 level=2 (docs=29): like (12), also (9), two (9), thank (8), first (8), well (7), year (7), peopl (7)
            topic=17 level=3 (docs=16): present (4), white (3), list (3), trajectori (3), network (3), repli (3), given (3), test (3)
                topic=2485 level=4 (docs=4): encrypt (4), class (4), weapon (4), enforc (3), strong (3), drug (3), restrict (3), clipper (3)
                topic=2513 level=4 (docs=4): request (6), fail (6), share (6), ncd (4), segment (4), record (3), memori (3), program (3)
                topic=4388 level=4 (docs=4): start

In [5]:
# Generate 100 synthetic documents, each 250 tokens, using alpha_levels=1.0
num_docs = 100
doc_length = 250
alpha_levels = 1.0
seed = 0

synthetic_docs, doc_paths = generate_synthetic_corpus(
    hlda_model=hlda_model5,
    num_docs=num_docs,
    doc_length=doc_length,
    seed=seed
)

print(f"Generated {len(synthetic_docs)} synthetic documents.\n")

Generated 100 synthetic documents.



In [6]:
# Display a few examples
for i, doc in enumerate(synthetic_docs[:3]):
    print(f"--- Synthetic Doc {i} ---")
    print(f"Path node indices: {doc_paths[i]}")
    print(f"First 20 tokens: {doc[:20]} ...\n")

--- Synthetic Doc 0 ---
Path node indices: [0, 9, 10, 3902, 4440]
First 20 tokens: [np.str_('absolut'), np.str_('thing'), np.str_('anoth'), np.str_('rose'), np.str_('differ'), np.str_('thing'), np.str_('note'), np.str_('glass'), np.str_('lead'), np.str_('anoth'), np.str_('chip'), np.str_('put'), np.str_('precis'), np.str_('person'), np.str_('interest'), np.str_('wrong'), np.str_('without'), np.str_('rang'), np.str_('probabl'), np.str_('rather')] ...

--- Synthetic Doc 1 ---
Path node indices: [0, 1, 2, 3267, 3341]
First 20 tokens: [np.str_('take'), np.str_('constant'), np.str_('reason'), np.str_('speed'), np.str_('without'), np.str_('street'), np.str_('lead'), np.str_('without'), np.str_('wide'), np.str_('month'), np.str_('see'), np.str_('tri'), np.str_('mani'), np.str_('thing'), np.str_('reason'), np.str_('get'), np.str_('sen'), np.str_('one'), np.str_('ca'), np.str_('know')] ...

--- Synthetic Doc 2 ---
Path node indices: [0, 9, 10, 11, 4082]
First 20 tokens: [np.str_('cours'), np.st