In [5]:
import numpy as np
from collections import defaultdict
from math import log, exp
from scipy.special import gammaln

In [6]:
class Node:
    def __init__(self, parent=None, level=0):
        self.children = {}         # Dictionary to store child nodes {topic_id: Node}
        self.documents = 0         # Number of documents passing through this node
        self.word_counts = defaultdict(int)  # Word counts at this node
        self.total_words = 0       # Total number of words at this node
        self.parent = parent       # Parent node
        self.level = level         # Level in the tree

    def is_leaf(self):
        """Check if this node is a leaf node (no children)."""
        return len(self.children) == 0

    def add_child(self, topic_id):
        """Adds a child node with a given topic ID."""
        child_node = Node(parent=self, level=self.level + 1)
        self.children[topic_id] = child_node
        return child_node

    def remove_child(self, topic_id):
        """Removes a child node with a given topic ID."""
        if topic_id in self.children:
            del self.children[topic_id]


In [7]:
class nCRPTree:
    def __init__(self, alpha, gamma, eta, num_levels, vocab):
        """
        Initialize the nCRP tree.

        Parameters:
        - alpha: float, smoothing parameter for document-topic distributions.
        - gamma: float, concentration parameter for the nested CRP.
        - eta: float, smoothing parameter for topic-word distributions.
        - num_levels: int, maximum depth of the hierarchical tree.
        - vocab: list, vocabulary for the corpus.
        """
        self.root = Node()
        self.alpha = alpha          # Smoothing on doc-topic distributions
        self.gamma = gamma          # Concentration parameter for nCRP
        self.eta = eta              # Smoothing on topic-word distributions
        self.eta_sum = eta * len(vocab)
        self.vocab = vocab
        self.num_levels = num_levels
        self.paths = {}             # Dictionary to keep track of paths for each document
        self.levels = {}            # Mapping from document IDs to word-level assignments
        self.document_words = {}    # Mapping from document_id to list of words in the document

    def CRP(self, node):
        """
        Basic CRP process.

        Returns:
        - sampled_topic: Label of the sampled topic (Not the Node)
        """
        total_words = node.total_words  # Use total words at the current node
        topic_probabilities = {}

        # If there are no children, create a new topic
        if not node.children:
            return 1  # Create a new topic with key = 1

        # Calculating probabilities for existing topics
        for topic, child_node in node.children.items():
            topic_probabilities[topic] = child_node.total_words / (self.alpha + total_words)

        # Probability of creating a new topic
        new_topic_key = max(node.children.keys(), default=0) + 1
        topic_probabilities[new_topic_key] = self.alpha / (self.alpha + total_words)

        # Normalize probabilities
        topics = list(topic_probabilities.keys())
        probabilities = list(topic_probabilities.values())
        total_prob = sum(probabilities)
        probabilities = [p / total_prob for p in probabilities]  # Normalize

        # Ensure sum of probabilities is 1 (for debugging purposes)
        assert abs(sum(probabilities) - 1.0) < 1e-6, "Probabilities do not sum to 1"
        return np.random.choice(topics, p=probabilities)

    def initialize_new_path(self, max_depth, document_id):
        """
        Sample a path through the tree using the nCRP during initialization.

        Parameters:
        - max_depth: Maximum depth of the tree (i.e., L-level tree).
        - document_id: ID of the document being sampled.

        Returns:
        - path_nodes: A list of Node objects representing the path through the tree.
        """
        current_node = self.root
        current_node.documents += 1
        path_nodes = [current_node]  # Start with the root node

        for level in range(1, max_depth):
            # Use the CRP function to sample a topic
            sampled_topic = self.CRP(current_node)

            # Create new child if needed
            if sampled_topic not in current_node.children:
                child_node = current_node.add_child(sampled_topic)
            else:
                child_node = current_node.children[sampled_topic]

            child_node.documents += 1
            path_nodes.append(child_node)
            current_node = child_node

        self.paths[document_id] = path_nodes  # Store path as node pointers
        return path_nodes

    def initialise_tree(self, corpus, max_depth):
        """
        Initialize the tree by sampling a path for each document using the nCRP
        and assigning word counts to nodes along their respective paths.

        Parameters:
        - corpus: List of documents (each document is a list of word indices)
        - max_depth: Maximum depth of the tree
        """
        for doc_id, doc_words in enumerate(corpus):
            # Store document words for easy access later
            self.document_words[doc_id] = doc_words

            # Sample a path for this document
            path_nodes = self.initialize_new_path(max_depth, doc_id)

            # Randomly assign levels to words and update word counts
            doc_levels = []
            num_levels = len(path_nodes)
            for word in doc_words:
                level = np.random.randint(0, num_levels)
                doc_levels.append(level)
                node = path_nodes[level]
                node.word_counts[word] += 1
                node.total_words += 1

            # Store levels for this document
            self.levels[doc_id] = doc_levels

    def add_document(self, document_id, path_nodes, level_word_counts):
        """
        Add a document to the tree by updating path and word counts.

        Parameters:
        - document_id: int, ID of the document
        - path_nodes: list of Node objects representing the path
        - level_word_counts: dict mapping levels to word counts
        """
        self.paths[document_id] = path_nodes
        # Increment document counts along the path
        for node in path_nodes:
            node.documents += 1
        # Update word counts along the path
        for level, word_counts in level_word_counts.items():
            node = path_nodes[level]
            for word, count in word_counts.items():
                node.word_counts[word] += count
                node.total_words += count

    def remove_document(self, document_id):
        """
        Remove a document from its current path in the tree and update attributes.

        Parameters:
        - document_id: int, ID of the document to be removed.
        """
        if document_id not in self.paths:
            return  # Document is not in the tree, no action needed

        path_nodes = self.paths[document_id]  # List of node pointers
        doc_levels = self.levels[document_id]  # List of levels assigned to each word
        doc_words = self.document_words[document_id]

        # Decrement word counts at the nodes along the path
        for word, level in zip(doc_words, doc_levels):
            node = path_nodes[level]
            node.word_counts[word] -= 1
            if node.word_counts[word] == 0:
                del node.word_counts[word]
            node.total_words -= 1

        # Decrement document counts along the path
        for node in path_nodes:
            node.documents -= 1

        # Remove nodes that have zero documents and no children
        for node in reversed(path_nodes):
            if node.documents == 0 and node.is_leaf():
                parent = node.parent
                if parent:
                    # Find the topic_id corresponding to this child
                    topic_id_to_remove = None
                    for topic_id, child in parent.children.items():
                        if child == node:
                            topic_id_to_remove = topic_id
                            break
                    if topic_id_to_remove is not None:
                        parent.remove_child(topic_id_to_remove)

        # Remove document from paths, levels, and document_words
        del self.paths[document_id]
        del self.levels[document_id]
        del self.document_words[document_id]

    def compute_ncrp_prior(self, node, weight, node_weights):
        """
        Compute the nested CRP prior recursively for all paths from the given node.

        Parameters:
        - node: Node object, the starting node for the computation.
        - weight: float, cumulative log probability of the path up to the current node.
        - node_weights: dict, stores the computed log probabilities for all nodes.
        """
        total_customers = node.documents
        for topic_id, child_node in node.children.items():
            child_weight = weight + np.log(child_node.documents / (self.gamma + total_customers))
            self.compute_ncrp_prior(child_node, child_weight, node_weights)

        # Weight for creating a new path (topic) from this node
        new_path_weight = weight + np.log(self.gamma / (self.gamma + total_customers))
        node_weights[node] += new_path_weight  # Accumulate weight

    def compute_doc_likelihood(self, node, level_word_counts, weight, node_weights, level=0):
        """
        Compute the document likelihood for words at each level along the path.

        Parameters:
        - node: Node object, the starting node for likelihood computation.
        - level_word_counts: dict, mapping levels to word counts for the document.
        - weight: float, cumulative log likelihood up to the current node.
        - node_weights: dict, stores the computed likelihood values for nodes.
        - level: int, current level of the tree (default is 0).
        """
        node_weight = 0.0
        word_counts = level_word_counts.get(level, {})
        total_words = node.total_words
        for word, count in word_counts.items():
            word_count_at_node = node.word_counts.get(word, 0)
            for _ in range(count):
                node_weight += np.log((self.eta + word_count_at_node) / (self.eta_sum + total_words))

        total_weight = weight + node_weight
        node_weights[node] += total_weight
        # Recurse into children
        for child_node in node.children.values():
            self.compute_doc_likelihood(child_node, level_word_counts, total_weight, node_weights, level + 1)

    def compute_posterior_over_paths(self, node_weights):
        """
        Compute the posterior probabilities over paths based on prior and likelihood.

        Parameters:
        - node_weights: dict, combined prior and likelihood values for nodes.

        Returns:
        - nodes: list, the Node objects for which the posterior is computed.
        - probabilities: numpy array, normalized probabilities for the nodes.
        """
        nodes = list(node_weights.keys())
        weights = np.array(list(node_weights.values()))
        max_weight = np.max(weights)
        weights = np.exp(weights - max_weight)  # For numerical stability
        probabilities = weights / np.sum(weights)
        return nodes, probabilities

    def sample_new_path(self, nodes, probabilities):
        """
        Sample a new path based on posterior probabilities.

        Parameters:
        - nodes: list, Node objects to sample from.
        - probabilities: numpy array, probabilities corresponding to the nodes.

        Returns:
        - sampled_node: Node object, the sampled node.
        """
        sampled_node = np.random.choice(nodes, p=probabilities)
        # If the sampled node is not a leaf, create a new child under it
        if not sampled_node.is_leaf():
            new_topic_id = max(sampled_node.children.keys(), default=0) + 1
            sampled_node = sampled_node.add_child(new_topic_id)
        return sampled_node

    def sample_path(self, document_id, document_words, document_levels):
        """
        Sample a new path for a document.

        Parameters:
        - document_id: int, ID of the document
        - document_words: list, words in the document
        - document_levels: list, levels assigned to each word
        """
        # Remove the document from its current path
        if document_id in self.paths:
            self.remove_document(document_id)

        # Collect word counts per level
        level_word_counts = {}
        for word, level in zip(document_words, document_levels):
            level_word_counts.setdefault(level, {})
            level_word_counts[level][word] = level_word_counts[level].get(word, 0) + 1

        # Compute the nCRP prior
        node_weights = defaultdict(float)
        self.compute_ncrp_prior(self.root, 0.0, node_weights)

        # Compute the document likelihood for each node
        self.compute_doc_likelihood(self.root, level_word_counts, 0.0, node_weights)

        # Compute posterior over paths
        nodes, probabilities = self.compute_posterior_over_paths(node_weights)

        # Sample a new path
        sampled_node = self.sample_new_path(nodes, probabilities)

        # Build the path from root to the sampled node
        path_nodes = []
        node = sampled_node
        while node is not None:
            path_nodes.insert(0, node)
            node = node.parent

        # Add the document back to the tree
        self.add_document(document_id, path_nodes, level_word_counts)
        

In [8]:
def print_tree(node, indent=0):
    """
    Helper function to print the tree structure.

    Parameters:
    - node: Node object, starting node (default is root)
    - indent: int, current indentation level
    """
    prefix = "  " * indent
    if node.parent is None:
        print(f"{prefix}Root Node: Documents={node.documents}, Total Words={node.total_words}")
    else:
        # Find the topic ID
        parent = node.parent
        topic_id = None
        for tid, child in parent.children.items():
            if child == node:
                topic_id = tid
                break
        print(f"{prefix}Topic {topic_id}: Level={node.level}, Documents={node.documents}, Total Words={node.total_words}")
    for child in node.children.values():
        print_tree(child, indent + 1)

def sum_total_words(node):
    """Helper function to sum total_words in the tree."""
    total = node.total_words
    for child in node.children.values():
        total += sum_total_words(child)
    return total

def compute_total_words(node):
    """Helper function to compute total words from word_counts."""
    total = sum(node.word_counts.values())
    for child in node.children.values():
        total += compute_total_words(child)
    return total

def main():
    # Vocabulary and corpus
    vocab = ["a", "b", "c"]
    corpus = [
        [0, 1],    # Document 0
        [1, 2],    # Document 1
        [0, 2],    # Document 2
    ]  # Toy corpus with 3 documents
    tree = nCRPTree(alpha=1.0, gamma=1.0, eta=0.1, num_levels=3, vocab=vocab)

    print("Initializing tree...")
    tree.initialise_tree(corpus, max_depth=3)

    print("\nTree Structure After Initialization:")
    print_tree(tree.root)

    # Verify total_words consistency
    total_words = sum(len(doc) for doc in corpus)
    tree_total_words = sum_total_words(tree.root)
    print(f"\nTotal Words in Tree: {tree_total_words} (Expected: {total_words})")
    assert tree_total_words == total_words, "Total words mismatch after initialization!"

    print("\nRemoving document 1...")
    tree.remove_document(document_id=1)

    print("\nTree Structure After Document Removal:")
    print_tree(tree.root)

    # Verify total_words consistency
    total_words_after_removal = total_words - len(corpus[1])
    tree_total_words_after_removal = sum_total_words(tree.root)
    print(f"\nTotal Words in Tree: {tree_total_words_after_removal} (Expected: {total_words_after_removal})")
    assert tree_total_words_after_removal == total_words_after_removal, "Total words mismatch after removal!"

    print("\nAdding a new document (Document 3) with words [1, 2]:")
    new_doc = [1, 2]
    doc_id = len(corpus)
    # Initialize path for the new document
    path_nodes = tree.initialize_new_path(max_depth=3, document_id=doc_id)
    # Assign levels to words
    level_word_counts = {0: {1:1, 2:1}}  # Assigning both words to level 0
    # Add the document to the tree
    tree.add_document(document_id=doc_id, path_nodes=path_nodes, level_word_counts=level_word_counts)

    print("\nTree Structure After Adding Document 3:")
    print_tree(tree.root)

    # Verify total_words consistency
    total_words_final = total_words_after_removal + len(new_doc)
    tree_total_words_final = sum_total_words(tree.root)
    print(f"\nTotal Words in Tree: {tree_total_words_final} (Expected: {total_words_final})")
    assert tree_total_words_final == total_words_final, "Total words mismatch after adding a new document!"

    print("\nAttempting to remove a non-existent document (Document 999)...")
    tree.remove_document(document_id=999)

    print("\nTree Structure After Attempting to Remove Non-Existent Document:")
    print_tree(tree.root)

    # Verify total_words consistency remains unchanged
    tree_total_words_final_after_invalid_removal = sum_total_words(tree.root)
    print(f"\nTotal Words in Tree: {tree_total_words_final_after_invalid_removal} (Expected: {total_words_final})")
    assert tree_total_words_final_after_invalid_removal == total_words_final, "Total words changed after attempting to remove a non-existent document!"

    print("\nAll test cases passed successfully!")

if __name__ == "__main__":
    main()

Initializing tree...

Tree Structure After Initialization:
Root Node: Documents=3, Total Words=0
  Topic 1: Level=1, Documents=3, Total Words=2
    Topic 1: Level=2, Documents=1, Total Words=1
    Topic 2: Level=2, Documents=2, Total Words=3

Total Words in Tree: 6 (Expected: 6)

Removing document 1...

Tree Structure After Document Removal:
Root Node: Documents=2, Total Words=0
  Topic 1: Level=1, Documents=2, Total Words=1
    Topic 1: Level=2, Documents=1, Total Words=1
    Topic 2: Level=2, Documents=1, Total Words=2

Total Words in Tree: 4 (Expected: 4)

Adding a new document (Document 3) with words [1, 2]:

Tree Structure After Adding Document 3:
Root Node: Documents=4, Total Words=2
  Topic 1: Level=1, Documents=2, Total Words=1
    Topic 1: Level=2, Documents=1, Total Words=1
    Topic 2: Level=2, Documents=1, Total Words=2
  Topic 2: Level=1, Documents=2, Total Words=0
    Topic 1: Level=2, Documents=2, Total Words=0

Total Words in Tree: 6 (Expected: 6)

Attempting to remove 