In [34]:
import numpy as np
from collections import defaultdict
from math import log
from functools import lru_cache
from scipy.special import gammaln
from graphviz import Digraph

# Classes

In [35]:
class Node:
    """
    Represents a node in the hierarchical tree for hLDA.
    """
    def __init__(self, parent=None, level=0):
        self.children = {}
        self.documents = 0
        self.word_counts = defaultdict(int)
        self.total_words = 0
        self.parent = parent
        self.level = level

    def is_leaf(self):
        return len(self.children) == 0

    def add_child(self, topic_id):
        child_node = Node(parent=self, level=self.level + 1)
        self.children[topic_id] = child_node
        return child_node

    def remove_child(self, topic_id):
        if topic_id in self.children:
            del self.children[topic_id]

In [36]:
class nCRPTree:
    """
    Implements the hierarchical LDA model using the nested Chinese Restaurant Process.
    """
    def __init__(self, gamma, eta, num_levels, vocab, m=0.5, pi=1.0):
        self.root = Node()
        self.gamma = gamma
        self.eta = eta
        self.V = len(vocab)
        self.eta_sum = self.eta * self.V
        self.num_levels = num_levels
        self.paths = {}
        self.levels = {}
        self.document_words = {}
        self.m = m
        self.pi = pi
        self.vocab = vocab

    @lru_cache(maxsize=None)
    def cached_gammaln(self, x):
        return gammaln(x)

    def sample_ncrp_path(self, node):
        total_customers = node.documents
        topic_probabilities = {}

        # Existing children probabilities
        for topic_id, child in node.children.items():
            topic_probabilities[topic_id] = child.documents / (total_customers + self.gamma)

        # New child probability
        new_topic_key = max(node.children.keys(), default=0) + 1
        topic_probabilities[new_topic_key] = self.gamma / (total_customers + self.gamma)

        topics = list(topic_probabilities.keys())
        probs = np.array(list(topic_probabilities.values()))
        probs /= probs.sum()
        chosen = np.random.choice(topics, p=probs)
        is_new = (chosen not in node.children)
        return chosen, is_new

    def initialize_new_path(self, max_depth, document_id):
        current_node = self.root
        current_node.documents += 1
        path_nodes = [current_node]

        for level in range(1, max_depth):
            topic_id, is_new = self.sample_ncrp_path(current_node)
            if is_new:
                child_node = current_node.add_child(topic_id)
            else:
                child_node = current_node.children[topic_id]
            child_node.documents += 1
            path_nodes.append(child_node)
            current_node = child_node

        self.paths[document_id] = path_nodes
        return path_nodes

    def initialise_tree(self, corpus, max_depth):
        for doc_id, doc_words in enumerate(corpus):
            self.document_words[doc_id] = doc_words
            path_nodes = self.initialize_new_path(max_depth, doc_id)
            doc_levels = []
            num_levels = len(path_nodes)
            for w in doc_words:
                level = np.random.randint(0, num_levels)
                doc_levels.append(level)
                node = path_nodes[level]
                node.word_counts[w] += 1
                node.total_words += 1
            self.levels[doc_id] = doc_levels

    def add_document(self, document_id, path_nodes, level_word_counts):
        self.paths[document_id] = path_nodes
        for node in path_nodes:
            node.documents += 1
        for level, w_counts in level_word_counts.items():
            node = path_nodes[level]
            for w, cnt in w_counts.items():
                node.word_counts[w] += cnt
                node.total_words += cnt

    def remove_document(self, document_id):
        if document_id not in self.paths:
            return
        path_nodes = self.paths[document_id]
        doc_levels = self.levels[document_id]
        doc_words = self.document_words[document_id]

        # Decrement word counts
        for w, lvl in zip(doc_words, doc_levels):
            node = path_nodes[lvl]
            node.word_counts[w] -= 1
            if node.word_counts[w] == 0:
                del node.word_counts[w]
            node.total_words -= 1

        # Decrement document counts
        for node in path_nodes:
            node.documents -= 1

        # Prune empty leaves
        for node in reversed(path_nodes):
            if node.documents == 0 and node.is_leaf() and node.parent is not None:
                parent = node.parent
                remove_id = None
                for tid, cnode in parent.children.items():
                    if cnode == node:
                        remove_id = tid
                        break
                if remove_id is not None:
                    parent.remove_child(remove_id)

        del self.paths[document_id]
        del self.levels[document_id]
        del self.document_words[document_id]

    def get_level_word_counts(self, document_words, document_levels):
        level_word_counts = {}
        for w, lvl in zip(document_words, document_levels):
            if lvl not in level_word_counts:
                level_word_counts[lvl] = {}
            level_word_counts[lvl][w] = level_word_counts[lvl].get(w, 0) + 1
        return level_word_counts

    def level_likelihood(self, node, M):
        Nminus = node.word_counts
        sumNminus = node.total_words
        sumM = sum(M.values())
        sumN = sumNminus + sumM

        log_part1 = self.cached_gammaln(sumNminus + self.eta_sum)
        for w_count in Nminus.values():
            log_part1 -= self.cached_gammaln(w_count + self.eta)

        log_part2 = 0.0
        for w, n_w_minus in Nminus.items():
            n_w = n_w_minus + M.get(w, 0)
            log_part2 += self.cached_gammaln(n_w + self.eta)
        for w, mcount in M.items():
            if w not in Nminus:
                log_part2 += self.cached_gammaln(mcount + self.eta)

        log_part2 -= self.cached_gammaln(sumN + self.eta_sum)

        return log_part1 + log_part2

    def level_prior(self, parent, child_is_new, child_node=None):
        total_customers = parent.documents
        if child_is_new:
            return log(self.gamma) - log(total_customers + self.gamma)
        else:
            return log(child_node.documents) - log(total_customers + self.gamma)

    def sample_path_level(self, parent_node, level_word_counts, level_index):
        candidates = list(parent_node.children.items())  # (topic_id, node)
        new_topic_id = max(parent_node.children.keys(), default=0) + 1

        M = level_word_counts.get(level_index, {})

        # Existing children
        log_probs = []
        for topic_id, child_node in candidates:
            lp = self.level_prior(parent_node, False, child_node)
            lp += self.level_likelihood(child_node, M)
            log_probs.append((lp, topic_id, False))

        # New child
        fake_node = Node(parent=parent_node, level=parent_node.level + 1)
        lp_new = self.level_prior(parent_node, True)
        lp_new += self.level_likelihood(fake_node, M)
        log_probs.append((lp_new, new_topic_id, True))

        lps = [x[0] for x in log_probs]
        max_lp = max(lps)
        weights = np.exp([lp - max_lp for lp in lps])
        probs = weights / weights.sum()

        chosen_index = np.random.choice(len(probs), p=probs)
        chosen_lp, chosen_topic_id, chosen_is_new = log_probs[chosen_index]

        if chosen_is_new:
            child_node = parent_node.add_child(chosen_topic_id)
        else:
            child_node = parent_node.children[chosen_topic_id]

        return child_node

    def sample_path(self, document_id, document_words, document_levels):
        if document_id in self.paths:
            self.remove_document(document_id)

        level_word_counts = self.get_level_word_counts(document_words, document_levels)
        current_node = self.root
        path_nodes = [current_node]

        for ell in range(1, self.num_levels):
            child_node = self.sample_path_level(current_node, level_word_counts, ell)
            path_nodes.append(child_node)
            current_node = child_node

        self.add_document(document_id, path_nodes, level_word_counts)
        self.levels[document_id] = document_levels
        self.document_words[document_id] = document_words
        max_level = max(document_levels) if document_levels else 0
        assert max_level < self.num_levels, "Document level assignments exceed maximum tree depth."

    def compute_level_prior_probs(self, z_counts, max_z, m, pi):
        sum_ge = [0]*(max_z+1)
        running_sum = 0
        for level_idx in range(max_z, -1, -1):
            running_sum += z_counts[level_idx]
            sum_ge[level_idx] = running_sum

        probs = []
        for k in range(max_z+1):
            numerator = (m * pi + z_counts[k])
            denominator = (pi + sum_ge[k]) if sum_ge[k] > 0 else pi
            level_prob = numerator / denominator

            for j in range(k):
                numerator_j = ((1 - m) * pi + z_counts[j])
                denominator_j = (pi + sum_ge[j]) if sum_ge[j] > 0 else pi
                level_prob *= (numerator_j / denominator_j)
            probs.append(level_prob)

        return probs

    def sample_level_assignment_for_word(self, document_id, n):
        doc_words = self.document_words[document_id]
        doc_levels = self.levels[document_id]
        old_level = doc_levels[n]
        w = doc_words[n]

        path_nodes = self.paths[document_id]
        old_node = path_nodes[old_level]
        old_node.word_counts[w] -= 1
        if old_node.word_counts[w] == 0:
            del old_node.word_counts[w]
        old_node.total_words -= 1

        doc_levels[n] = -1
        z_counts = defaultdict(int)
        for lvl in doc_levels:
            if lvl >= 0:
                z_counts[lvl] += 1

        max_z = max(z_counts.keys()) if z_counts else 0
        level_range = list(range(max_z + 1))
        z_counts_list = [z_counts.get(k, 0) for k in level_range]
        prior_probs = self.compute_level_prior_probs(z_counts_list, max_z, self.m, self.pi)

        # Word likelihood for existing levels
        word_likelihoods = []
        for k in level_range:
            node = path_nodes[k]
            w_count = node.word_counts.get(w, 0)
            likelihood = (w_count + self.eta) / (node.total_words + self.eta_sum)
            word_likelihoods.append(likelihood)

        for i in range(len(prior_probs)):
            prior_probs[i] *= word_likelihoods[i]

        sum_existing = sum(prior_probs)
        leftover = 1.0 - sum_existing
        final_levels = level_range[:]
        final_probs = prior_probs[:]

        # If leftover > 0, consider going beyond max_z as per eq (3)
        if leftover > 1e-15:
            # We do a sequence of Bernoulli trials:
            # For each deeper level ell > max_z:
            # p(success) = (1-m)*p(w|...) with p(w|...) = eta/eta_sum for a new empty node
            ell = max_z + 1
            p_w_new = self.eta / self.eta_sum
            # We must decide step-by-step:
            # We know we are going beyond max_z, so attempt ell:
            # Bernoulli trial with p=(1-m)*p_w_new:
            # If success assign ell and stop.
            # If fail, try ell+1, until we run out of levels.
            # If we reach last level and still fail, assign the last tried level anyway.

            chosen_level = None
            while ell < self.num_levels:
                p_success = (1 - self.m)*p_w_new
                # Bernoulli trial:
                success = (np.random.rand() < p_success)
                if success:
                    chosen_level = ell
                    break
                else:
                    ell += 1

            if chosen_level is None:
                # If we fail all the way, assign the deepest level:
                chosen_level = self.num_levels - 1

            # Assign chosen_level now
            new_level = chosen_level
        else:
            # Just choose from existing levels
            total = sum_existing
            if total == 0:
                # rare fallback
                new_level = max_z
            else:
                probs = [p/total for p in final_probs]
                new_level = np.random.choice(final_levels, p=probs)

        # Update counts
        new_node = path_nodes[new_level]
        new_node.word_counts[w] = new_node.word_counts.get(w,0) + 1
        new_node.total_words += 1
        doc_levels[n] = new_level

    def sample_levels_for_document(self, document_id):
        doc_words = self.document_words[document_id]
        for n in range(len(doc_words)):
            self.sample_level_assignment_for_word(document_id, n)

    def gibbs_sampling(self, corpus, num_iterations, burn_in=100, thinning=10):
        self.initialise_tree(corpus, max_depth=self.num_levels)
        for it in range(num_iterations):
            for doc_id in range(len(corpus)):
                document_words = self.document_words[doc_id]
                document_levels = self.levels[doc_id]
                self.sample_path(doc_id, document_words, document_levels)
                self.sample_levels_for_document(doc_id)

            if (it + 1) % thinning == 0:
                print(f"Iteration {it + 1} completed.")

            if it + 1 == burn_in:
                print(f"Burn-in period of {burn_in} iterations completed.")
        print("Gibbs sampling completed.")

# Helper Functions

In [37]:
def get_top_words(node, vocab, top_n=5):
    word_counts = list(node.word_counts.items())
    word_counts.sort(key=lambda x: x[1], reverse=True)
    top_words = [w for w, count in word_counts[:top_n]]
    return top_words

def print_tree(node, vocab, level=0):
    indent = "  " * level
    print(f"{indent}Level {node.level}: docs={node.documents}, total_words={node.total_words}")
    top_words = get_top_words(node, vocab, top_n=5)
    print(f"{indent}  Top words: {top_words}")
    for child_id, child_node in node.children.items():
        print_tree(child_node, vocab, level+1)

def print_document_assignments(tree, doc_id):
    doc_words = tree.document_words[doc_id]
    doc_levels = tree.levels[doc_id]
    print(f"Document {doc_id}:")
    for w, lvl in zip(doc_words, doc_levels):
        print(f"  {w} -> level {lvl}")

# Visulisation

In [38]:
def visualize_tree_graphviz(node, vocab, graph=None, parent_id=None, node_id_counter=None, label_map=None):
    """
    Recursively traverse the tree and add nodes and edges to the Graphviz Digraph.

    Args:
        node (Node): The current node to visualize.
        vocab (list): List of vocabulary words.
        graph (Digraph, optional): The Graphviz Digraph object. Defaults to None.
        parent_id (int, optional): The ID of the parent node. Defaults to None.
        node_id_counter (list, optional): A single-element list acting as a mutable counter for node IDs. Defaults to None.
        label_map (dict, optional): Mapping from node IDs to labels. Defaults to None.

    Returns:
        tuple: (graph, current_node_id, label_map)
    """
    if graph is None:
        graph = Digraph(comment='nCRP Tree')
        graph.attr('node', shape='box', style='filled', color='lightblue')
        label_map = {}
    
    if node_id_counter is None:
        node_id_counter = [0]  # Initialize counter
    
    current_id = node_id_counter[0]
    
    # Create a label for the current node based on top words
    top_words = sorted(node.word_counts.keys(), key=lambda w: node.word_counts[w], reverse=True)[:3]
    label = f"Level {node.level}\nDocs: {node.documents}\nWords: {', '.join(top_words)}"
    label_map[current_id] = label
    graph.node(str(current_id), label=label)
    
    # Add edge from parent to current node
    if parent_id is not None:
        graph.edge(str(parent_id), str(current_id))
    
    # Traverse children
    for child_topic_id, child_node in node.children.items():
        node_id_counter[0] += 1  # Increment counter for the child
        child_id = node_id_counter[0]
        graph, node_id_counter, label_map = visualize_tree_graphviz(
            child_node, vocab, graph, parent_id=current_id, node_id_counter=node_id_counter, label_map=label_map
        )
    
    return graph, node_id_counter, label_map

def print_tree_graphviz(root, vocab, filename='ncrp_tree', view=False):
    """
    Generate and render the tree visualization using Graphviz.

    Args:
        root (Node): The root node of the tree.
        vocab (list): List of vocabulary words.
        filename (str, optional): Filename for the output. Defaults to 'ncrp_tree'.
        view (bool, optional): Whether to automatically open the visualization. Defaults to False.
    """
    graph, _, _ = visualize_tree_graphviz(root, vocab)
    graph.render(filename, view=view, format='png')
    print(f"Tree visualization saved as {filename}.png")

# Test

In [39]:
# Test Case 1: Small Synthetic Corpus
corpus = [
    ["apple", "banana", "apple"],
    ["banana", "banana", "apple"],
    ["apple", "grape", "grape"],
    ["grape", "banana", "apple"]
]
vocab = sorted(set(word for doc in corpus for word in doc))

tree = nCRPTree(gamma=1.0, eta=0.1, num_levels=3, vocab=vocab, m=0.5, pi=1.0)
tree.gibbs_sampling(corpus, num_iterations=20, burn_in=5, thinning=5)

print("=== Tree Structure After Gibbs Sampling ===")
print_tree(tree.root, vocab)

print("\n=== Document Assignments ===")
for doc_id in range(len(corpus)):
    print_document_assignments(tree, doc_id)

# Test Case 2: Single Document, Single Level
corpus_single = [["apple","banana","apple"]]
vocab_single = sorted(set(word for doc in corpus_single for word in doc))
tree_single = nCRPTree(gamma=1.0, eta=0.1, num_levels=1, vocab=vocab_single)
tree_single.gibbs_sampling(corpus_single, num_iterations=5, burn_in=2, thinning=2)
print("\n=== Single Document, Single Level Tree ===")
print_tree(tree_single.root, vocab_single)
print_document_assignments(tree_single, 0)

# Test Case 3: Check with Minimal Depth and Multiple Docs
corpus_small = [
    ["cat", "cat", "dog"],
    ["dog", "dog", "cat"]
]
vocab_small = sorted(set(word for doc in corpus_small for word in doc))
tree_small = nCRPTree(gamma=1.0, eta=0.1, num_levels=2, vocab=vocab_small)
tree_small.gibbs_sampling(corpus_small, num_iterations=10, burn_in=2, thinning=5)
print("\n=== Small Corpus, Two Levels ===")
print_tree(tree_small.root, vocab_small)
for doc_id in range(len(corpus_small)):
    print_document_assignments(tree_small, doc_id)

Iteration 5 completed.
Burn-in period of 5 iterations completed.
Iteration 10 completed.
Iteration 15 completed.
Iteration 20 completed.
Gibbs sampling completed.
=== Tree Structure After Gibbs Sampling ===
Level 0: docs=4, total_words=0
  Top words: []
  Level 1: docs=2, total_words=0
    Top words: []
    Level 2: docs=1, total_words=3
      Top words: ['banana', 'apple']
    Level 2: docs=1, total_words=3
      Top words: ['grape', 'apple']
  Level 1: docs=2, total_words=0
    Top words: []
    Level 2: docs=1, total_words=3
      Top words: ['apple', 'banana']
    Level 2: docs=1, total_words=3
      Top words: ['grape', 'banana', 'apple']

=== Document Assignments ===
Document 0:
  apple -> level 2
  banana -> level 2
  apple -> level 2
Document 1:
  banana -> level 2
  banana -> level 2
  apple -> level 2
Document 2:
  apple -> level 2
  grape -> level 2
  grape -> level 2
Document 3:
  grape -> level 2
  banana -> level 2
  apple -> level 2
Iteration 2 completed.
Burn-in period 

In [40]:
# Visualisation
print("=== Tree Structure After Gibbs Sampling ===")
print_tree_graphviz(tree.root, vocab, filename='tree_test_case_1', view=False)

=== Tree Structure After Gibbs Sampling ===
Tree visualization saved as tree_test_case_1.png


![Image Description](../src/tree_test_case_1.png)