In [26]:
import numpy as np
import random
import json

from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from sklearn.datasets import fetch_20newsgroups

from hlda import nCRPTree
from utils import *

In [27]:
# Step 1: Fetch the 20 Newsgroups dataset
newsgroups = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))
raw_corpus = newsgroups.data

# Step 2: Define a preprocessing pipeline
def preprocess_corpus(documents, stop_words):
    """
    Preprocess the corpus by:
    - Lowercasing
    - Tokenizing
    - Removing non-alpha tokens
    - Removing stopwords
    """
    preprocessed = []
    for doc in documents:
        tokens = word_tokenize(doc.lower())
        tokens = [t for t in tokens if t.isalpha() and t not in stop_words]
        preprocessed.append(tokens)
    return preprocessed

# Create a set of English stopwords
stop_words = set(stopwords.words('english'))

# Preprocess the corpus
preprocessed_corpus = preprocess_corpus(raw_corpus, stop_words)

# Using only a subset of the corpus
subset_size = 100
subset_corpus = preprocessed_corpus[:subset_size]

# Build a vocabulary from the subset
vocab = sorted(set(word for doc in subset_corpus for word in doc))

# Step 3: Initialize and run hLDA with a limited number of iterations
tree = nCRPTree(
    gamma=1.0,
    eta=0.1,
    num_levels=20,   # Max_level
    vocab=vocab,
    m=0.5,
    pi=1.0
)

# Run a small number of Gibbs iterations for testing
num_iterations = 1000
burn_in = 100
thinning = 10

In [28]:
tree.gibbs_sampling(subset_corpus, num_iterations=num_iterations, burn_in=burn_in, thinning=thinning)

Gibbs sampling completed.


In [29]:
print("Sampling completed.")

# Step 4: Visualize the resulting tree
print_tree_graphviz(tree.root, vocab, filename='ncrp_tree_example', view=False)

Sampling completed.
Tree visualization saved as ncrp_tree_example.png


![Image Description](../image/ncrp_tree_example.png)

# Synthetic dataset

## Synthetic data generation

In [30]:
np.random.seed(42)
random.seed(42)

# A hierarchical topic tree with associated words
topic_tree = {
    "_words": {
        "abstract": 10, "concept": 8, "entity": 5, "system": 7, "structure": 6,
        "analysis": 5, "data": 5, "model": 5, "process": 5, "method": 4
    },
    "Science": {
        "_words": {
            "science": 10, "theory": 8, "experiment": 7, "research": 7, "hypothesis": 6,
            "laboratory": 5, "data": 5, "statistic": 4, "evidence": 4
        },
        "Physics": {
            "_words": {
                "force": 10, "energy": 8, "quantum": 7, "particle": 7, "field": 6, 
                "relativity": 5, "momentum": 5, "photon": 5, "wave": 5, "magnetism": 4
            },
            "Astrophysics": {
                "_words": {
                    "galaxy": 10, "cosmic": 8, "nebula": 7, "cosmology": 7, "darkmatter": 6, 
                    "supernova": 5, "telescope": 5, "exoplanet": 5, "gravity": 5, "orbit": 4
                }
            },
            "Particle Physics": {
                "_words": {
                    "hadron": 10, "boson": 8, "fermion": 7, "quark": 7, "collider": 6,
                    "neutrino": 5, "charm": 5, "lepton": 5, "spin": 5, "muon": 4
                }
            }
        },
        "Chemistry": {
            "_words": {
                "molecule": 10, "reaction": 8, "compound": 7, "acid": 7, "base": 6, 
                "electron": 5, "proton": 5, "catalyst": 5, "radical": 5, "solvent": 4
            },
            "Organic Chemistry": {
                "_words": {
                    "carbon": 10, "hydrocarbon": 8, "polymer": 7, "enzyme": 7, "synthesis": 6,
                    "aminoacid": 5, "peptide": 5, "aldehyde": 5, "ketone": 5, "ester": 4
                }
            },
            "Inorganic Chemistry": {
                "_words": {
                    "metal": 10, "alloy": 8, "mineral": 7, "oxide": 7, "salt": 6,
                    "ion": 5, "complex": 5, "crystal": 5, "cluster": 5, "silicate": 4
                }
            }
        },
        "Biology": {
            "_words": {
                "cell": 10, "gene": 8, "evolution": 7, "species": 7, "ecosystem": 6,
                "organism": 5, "metabolism": 5, "phylum": 5, "genome": 5, "bacteria": 4
            },
            "Microbiology": {
                "_words": {
                    "virus": 10, "fungus": 8, "plasmid": 7, "antibiotic": 7, "microbe": 6,
                    "biofilm": 5, "pathogen": 5, "yeast": 5, "protist": 5, "spore": 4
                }
            },
            "Neuroscience": {
                "_words": {
                    "neuron": 10, "synapse": 8, "cortex": 7, "axon": 7, "neurotransmitter": 6,
                    "gliacell": 5, "cognition": 5, "memory": 5, "perception": 5, "nervous": 4
                }
            }
        }
    },
    "Sports": {
        "_words": {
            "competition": 10, "team": 9, "player": 9, "coach": 7, "tournament": 6,
            "league": 5, "training": 5, "score": 5, "referee": 5, "stadium": 4
        },
        "Football": {
            "_words": {
                "ball": 10, "goal": 8, "pitch": 7, "tackle": 7, "striker": 6,
                "midfielder": 5, "defender": 5, "penalty": 5, "corner": 5, "foul": 4
            },
            "American Football": {
                "_words": {
                    "quarterback": 10, "touchdown": 8, "linebacker": 7, "helmet": 7, "endzone": 6,
                    "scrimmage": 5, "fumble": 5, "receiver": 5, "fieldgoal": 5, "blitz": 4
                }
            },
            "Soccer": {
                "_words": {
                    "football": 10, "soccer": 8, "worldcup": 7, "uefa": 7, "maradona": 6,
                    "ronaldo": 5, "messi": 5, "beckham": 5, "fifa": 5, "leaguecup": 4
                }
            }
        },
        "Basketball": {
            "_words": {
                "basket": 10, "dribble": 8, "hoop": 7, "court": 7, "dunk": 6,
                "rebound": 5, "assist": 5, "guard": 5, "forward": 5, "center": 4
            },
            "NBA": {
                "_words": {
                    "nba": 10, "lakers": 8, "bulls": 7, "celtics": 7, "knicks": 6,
                    "lebron": 5, "jordan": 5, "kobe": 5, "shaquille": 5, "warriors": 4
                }
            },
            "FIBA": {
                "_words": {
                    "fiba": 10, "olympics": 8, "international": 7, "coachclinic": 7, "eurobasket": 6,
                    "rebounddrill": 5, "pickandroll": 5, "zone": 5, "man2man": 5, "backboard": 4
                }
            }
        }
    },
    "Music": {
        "_words": {
            "note": 10, "melody": 9, "tune": 8, "rhythm": 7, "harmony": 6,
            "composition": 5, "instrument": 5, "rehearsal": 5, "performance": 5, "conductor": 4
        },
        "Rock": {
            "_words": {
                "guitar": 10, "band": 9, "amplifier": 8, "riff": 7, "drum": 6,
                "bass": 5, "vocalist": 5, "solo": 5, "tour": 5, "album": 4
            },
            "Heavy Metal": {
                "_words": {
                    "metallica": 10, "megadeth": 8, "ironmaiden": 7, "blackmetal": 7, "distortion": 6,
                    "headbang": 5, "thrash": 5, "riffage": 5, "growl": 5, "moshpit": 4
                }
            },
            "Punk Rock": {
                "_words": {
                    "punk": 10, "ramones": 8, "anarchy": 7, "hardcore": 7, "mosh": 6,
                    "diy": 5, "gutter": 5, "scene": 5, "fastbeat": 5, "underground": 4
                }
            }
        },
        "Jazz": {
            "_words": {
                "saxophone": 10, "improv": 9, "swing": 8, "bassline": 7, "piano": 6,
                "trumpet": 5, "brushes": 5, "ensemble": 5, "standard": 5, "club": 4
            },
            "Bebop": {
                "_words": {
                    "charlieparker": 10, "dizzygillespie": 8, "fasttempo": 7, "complexharmony": 7, "jam": 6,
                    "bebopscale": 5, "alteredchord": 5, "contrafact": 5, "sessions": 5, "birdland": 4
                }
            },
            "Smooth Jazz": {
                "_words": {
                    "groove": 10, "softsax": 8, "fusionsound": 7, "laidback": 7, "lounge": 6,
                    "chill": 5, "background": 5, "ambient": 5, "radio": 5, "crossover": 4
                }
            }
        }
    }
}


def get_children(topic_node):
    """Return the children topics of a given node (excluding the '_words' key)."""
    return {k:v for k,v in topic_node.items() if k != "_words"}

def is_leaf_topic(topic_node):
    """Check if the given node is a leaf (no children except '_words')."""
    children = get_children(topic_node)
    return len(children) == 0

def sample_path(topic_node):
    """
    Randomly sample a path from the given topic_node down to a leaf.
    We move down the tree until we reach a leaf node.
    """
    path = [topic_node]
    current_node = topic_node

    while True:
        children = get_children(current_node)
        if not children:
            # Leaf node
            break
        # Randomly pick a child topic
        child_topic = random.choice(list(children.keys()))
        current_node = children[child_topic]
        path.append(current_node)

    return path

def combine_word_distributions(path_nodes):
    """
    Combine the word distributions from all nodes along the path.
    Return a list of words and their corresponding probability distribution.
    """
    combined_counts = {}
    for node in path_nodes:
        for w, c in node["_words"].items():
            combined_counts[w] = combined_counts.get(w, 0) + c

    # Normalize to form a probability distribution
    total = sum(combined_counts.values())
    words = list(combined_counts.keys())
    probs = [combined_counts[w] / total for w in words]
    return words, probs

def generate_document(topic_tree, num_words=100):
    """
    Generate a single document.
    1. Sample a path from the root to a leaf.
    2. Combine word distributions along that path.
    3. Sample words to form the document.
    Return the document (list of words) and the path of topic node references.
    """
    # Start from the root topic tree
    path_nodes = sample_path(topic_tree)
    words, probs = combine_word_distributions(path_nodes)
    doc_words = np.random.choice(words, size=num_words, p=probs).tolist()
    return doc_words, path_nodes

def find_key_for_node(parent_node, child_node_ref):
    """Find the key in parent_node that corresponds to child_node_ref."""
    for k, v in parent_node.items():
        if k == "_words":
            continue
        if v is child_node_ref:
            return k
    return None

def get_path_labels(topic_tree, path_nodes):
    """
    Given the path_nodes (which are references to nested dictionaries),
    find the corresponding labels (keys) from the root to leaf.
    """
    labels = ["Root"]  # The top-level node is root (unnamed), we label it "Root"
    current_node = topic_tree
    for node_ref in path_nodes[1:]:
        key = find_key_for_node(current_node, node_ref)
        labels.append(key)
        current_node = node_ref
    return labels

In [31]:
# Parameters for synthetic corpus generation
num_documents = 50
num_words_per_doc = 100

documents = []
document_paths = []
for _ in range(num_documents):
    doc_words, path_nodes = generate_document(topic_tree, num_words=num_words_per_doc)
    documents.append(doc_words)
    path_labels = get_path_labels(topic_tree, path_nodes)
    document_paths.append(path_labels)

# Store results to a JSON file for future use
with open("synthetic_corpus_extended.json", "w") as f:
    json.dump({"documents": documents, "document_paths": document_paths}, f)

## Trying out the model

In [32]:
# Load synthetic data
with open("synthetic_corpus_extended.json", "r") as f:
    data = json.load(f)
documents = data["documents"]
document_paths = data["document_paths"]

In [33]:
# Build vocabulary from documents
vocab = list(set(word for doc in documents for word in doc))

# Initialize the nCRPTree instance
gamma = 1.0
eta = 0.1
num_levels = 4 
m = 0.5
pi = 1.0
tree = nCRPTree(gamma=gamma, eta=eta, num_levels=num_levels, vocab=vocab, m=m, pi=pi)

# Run Gibbs sampling
num_iterations = 2000
tree.gibbs_sampling(documents, num_iterations=num_iterations, burn_in=100, thinning=100)



Gibbs sampling completed.


In [34]:
# 1. Visualize the predicted hierarchy from the hLDA model
print_tree_graphviz(tree.root, vocab, filename='predicted_hierarchy', view=False)

# 2. Visualize the actual hierarchy:
def visualize_hierarchy_dict(hierarchy_dict, parent_id=None, graph=None, node_id_counter=None):
    """
    Visualize the actual hierarchy defined by a dictionary (similar to topic_tree).
    """
    if graph is None:
        graph = Digraph(comment='Actual Hierarchy')
        graph.attr('node', shape='box', style='filled', color='lightgreen')
        node_id_counter = [0]
    
    current_id = node_id_counter[0]
    node_id_counter[0] += 1

    # Extract words for label
    node_words = hierarchy_dict.get("_words", {})
    
    # Just show some words to make the node meaningful
    top_words = sorted(node_words.keys(), key=lambda w: node_words[w], reverse=True)[:3]
    
    # Attempt a label that shows a "topic" name if available (keys other than _words)
    children_keys = [k for k in hierarchy_dict.keys() if k != "_words"]
    label = "Node"
    if parent_id is None:
        label = "Root\nWords: " + ", ".join(top_words)
    else:
        # The parent's children keys may give us a hint. 
        # But here we don't have direct parent-child name association stored, 
        # as we did in generation. We can just show words.
        label = "Words: " + ", ".join(top_words)
    graph.node(str(current_id), label=label)

    # Add edge if not root
    if parent_id is not None:
        graph.edge(str(parent_id), str(current_id))

    # Recursively traverse children
    for child_key, child_node in hierarchy_dict.items():
        if child_key == "_words":
            continue
        # Child is a sub-dictionary
        graph, node_id_counter = visualize_hierarchy_dict(child_node, parent_id=current_id, graph=graph, node_id_counter=node_id_counter)

    return graph, node_id_counter

# Generate and save the actual hierarchy visualization
actual_graph, _ = visualize_hierarchy_dict(topic_tree)
actual_graph.render('actual_hierarchy', view=False, format='png')

print("Actual hierarchy saved as actual_hierarchy.png")

Tree visualization saved as predicted_hierarchy.png
Actual hierarchy saved as actual_hierarchy.png


In [35]:
doc_ids_to_visualize = random.sample(range(len(documents)), 3)

for doc_id in doc_ids_to_visualize:
    visualize_document_path(tree, doc_id, vocab, filename=f'doc_{doc_id}_path', view=False, top_n=5)

print("Selected document paths visualized. Check doc_<id>_path.png files.")


Document 23 path visualization saved as doc_23_path.png
Document 19 path visualization saved as doc_19_path.png
Document 15 path visualization saved as doc_15_path.png
Selected document paths visualized. Check doc_<id>_path.png files.
