In [14]:
import numpy as np
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from sklearn.datasets import fetch_20newsgroups

from hlda import nCRPTree, Node
from utils import *

In [15]:
# Step 1: Fetch the 20 Newsgroups dataset
newsgroups = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))
raw_corpus = newsgroups.data

# Step 2: Define a preprocessing pipeline
def preprocess_corpus(documents, stop_words):
    """
    Preprocess the corpus by:
    - Lowercasing
    - Tokenizing
    - Removing non-alpha tokens
    - Removing stopwords
    """
    preprocessed = []
    for doc in documents:
        tokens = word_tokenize(doc.lower())
        tokens = [t for t in tokens if t.isalpha() and t not in stop_words]
        preprocessed.append(tokens)
    return preprocessed

# Create a set of English stopwords
stop_words = set(stopwords.words('english'))

# Preprocess the corpus
preprocessed_corpus = preprocess_corpus(raw_corpus, stop_words)

# We have a very large corpus here. Let's use a small subset for demonstration:
subset_size = 50
subset_corpus = preprocessed_corpus[:subset_size]

# Build a vocabulary from the subset
vocab = sorted(set(word for doc in subset_corpus for word in doc))

# Step 3: Initialize and run hLDA with a limited number of iterations
tree = nCRPTree(
    gamma=1.0,
    eta=0.1,
    num_levels=10,   # You can adjust the number of levels
    vocab=vocab,
    m=0.5,
    pi=1.0
)

# Run a small number of Gibbs iterations for testing
num_iterations = 100
burn_in = 10
thinning = 10

In [16]:
tree.gibbs_sampling(subset_corpus, num_iterations=num_iterations, burn_in=burn_in, thinning=thinning)


Iteration 10 completed.
Burn-in period of 10 iterations completed.
Iteration 20 completed.
Iteration 30 completed.
Iteration 40 completed.
Iteration 50 completed.
Iteration 60 completed.
Iteration 70 completed.
Iteration 80 completed.
Iteration 90 completed.
Iteration 100 completed.
Gibbs sampling completed.


In [17]:
print("Sampling completed. Now generating the tree visualization...")

# Step 4: Visualize the resulting tree
print_tree_graphviz(tree.root, vocab, filename='ncrp_tree_example', view=False)
print("Visualization generated: ncrp_tree_example.png")

Sampling completed. Now generating the tree visualization...
Tree visualization saved as ncrp_tree_example.png
Visualization generated: ncrp_tree_example.png


![Image Description](../image/ncrp_tree_example.png)