In [25]:
import numpy as np
from collections import defaultdict

# Constructing nCRP tree

In [26]:
class Node:
    total_nodes = 0
    last_node_id = 0

    def __init__(self, parent=None, level=0):
        self.node_id = Node.last_node_id
        Node.last_node_id += 1

        self.children = {}         # Dictionary to store child nodes
        self.documents = 0         # Number of documents passing through this node
        self.word_counts = {}      # Word counts at this node
        self.total_words = 0       # Total number of words at this node
        self.parent = parent       # Parent node
        self.level = level         # Level in the tree

    def is_leaf(self):
        """Check if this node is a leaf node (no children)."""
        return len(self.children) == 0

    def add_child(self, topic_id):
        """Adds a child node with a given topic ID."""
        child_node = Node(parent=self, level=self.level + 1)
        self.children[topic_id] = child_node
        Node.total_nodes += 1
        return child_node

    def remove_child(self, topic_id):
        """Removes a child node with a given topic ID."""
        if topic_id in self.children:
            del self.children[topic_id]
            Node.total_nodes -= 1
            

# Path Sampling in nCRPTree

Path sampling in the `nCRPTree` involves a sequence of steps to determine the most probable hierarchical path (sequence of nodes) for a given document. The process combines prior probabilities (based on the nested Chinese Restaurant Process, nCRP) and likelihood values (based on the words in the document) to sample a path probabilistically.

## Overview of Path Sampling
The main function responsible for path sampling is `sample_path`. It orchestrates the process by:
1. Removing the document from its current path (if already assigned).
2. Collecting word counts for each level of the document.
3. Computing the prior probabilities of paths using the nCRP.
4. Computing the likelihood of the document for each path.
5. Combining the prior and likelihood to compute posterior probabilities over all paths.
6. Sampling a new path based on the posterior probabilities.
7. Reassigning the document to the sampled path and updating the tree structure.

---

## Detailed Explanation of Each Step

### 1. **Remove Document from Current Path**
The document is removed from its current path in the tree using the `remove_document` function. This involves:
- Traversing the path assigned to the document and decrementing the `documents` count for each node.
- Removing nodes if they become empty (i.e., have zero documents).

This ensures the document does not bias the path probabilities during re-sampling.

---

### 2. **Collect Word Counts for Each Level**
The function collects word counts for the document at each level using the `document_words` and `document_levels` inputs. The result is a dictionary, `level_word_counts`, where:
- Keys represent levels in the tree.
- Values are dictionaries of word frequencies at each level.

This is necessary for computing the likelihood of the document given each path.

---

### 3. **Compute nCRP Prior**
The nested Chinese Restaurant Process (nCRP) prior is computed using the `compute_ncrp_prior` function. This calculates the log-probability of each path based on the hierarchical structure of the tree:
$$
p(\text{path}) = \prod_{\text{nodes in path}} \frac{n_i}{\gamma + n_{\text{parent}}}
$$
where:
- $n_i$ is the number of documents at the node.
- $\gamma$ is the concentration parameter.
- $n_{\text{parent}}$ is the total number of documents at the parent node.

Logarithms are used for numerical stability:
$$
\log p(\text{path}) = \sum_{\text{nodes in path}} \left[ \log(n_i) - \log(\gamma + n_{\text{parent}}) \right]
$$
The prior for creating a new path at a node is:
$$
\log p(\text{new path}) = \log(\gamma) - \log(\gamma + n_{\text{parent}})
$$

---

### 4. **Compute Document Likelihood**
The likelihood of the document given a path is computed using the `compute_doc_likelihood` function. The likelihood is based on the Dirichlet-multinomial distribution:
$$
p(w | z) \propto \prod_{i=1}^{\text{word count}} \frac{\eta + \text{word count in topic} + i - 1}{\eta_{\text{sum}} + \text{total words in topic} + i - 1}
$$
where:
- $\eta$ is the smoothing parameter for topic-word distributions.
- $\eta_{\text{sum}}$ is the total smoothing across all words in the vocabulary.

The log-likelihood for each node is computed recursively for all levels in the tree.

---

### 5. **Compute Posterior Over Paths**
The prior and likelihood are combined to compute the posterior probabilities for each path using the `compute_posterior_over_paths` function:
$$
\log p(\text{path} | w) = \log p(\text{path}) + \log p(w | \text{path})
$$
To prevent numerical underflow when exponentiating, the log-probabilities are normalized:
$$
\text{weights} = \exp(\log p(\text{path} | w) - \max(\log p(\text{path} | w)))
$$
The probabilities are then normalized to sum to 1:
$$
\text{probabilities} = \frac{\text{weights}}{\sum \text{weights}}
$$

---

### 6. **Sample a New Path**
Using the posterior probabilities, a new path is sampled with the `sample_new_path` function:
- A node is chosen probabilistically based on the computed probabilities.
- If the sampled node is not a leaf, a new child node is created to extend the path.

This step ensures that the tree can dynamically grow to accommodate new topics.

---

### 7. **Reassign Document to New Path**
The document is reassigned to the sampled path using the `add_document` function. This involves:
- Incrementing the `documents` count for each node in the new path.
- Updating the word counts at each level of the path to include the words in the document.

---

## Final Formula Summary
The probability of a path is computed as:
$$
p(\text{path} | w) \propto p(\text{path}) \cdot p(w | \text{path})
$$
where:
- $p(\text{path})$ is the nCRP prior.
- $p(w | \text{path})$ is the document likelihood.

The path is sampled using these posterior probabilities.

---

## Functions Involved
1. **`remove_document`**: Removes the document from its current path.
2. **`compute_ncrp_prior`**: Computes the log-prior for all paths using nCRP.
3. **`compute_doc_likelihood`**: Computes the log-likelihood of the document for each path.
4. **`compute_posterior_over_paths`**: Combines prior and likelihood to compute posterior probabilities.
5. **`sample_new_path`**: Samples a path based on the posterior probabilities.
6. **`add_document`**: Reassigns the document to the sampled path and updates the tree.

Each of these steps is crucial for ensuring that the path sampling process adheres to the hierarchical probabilistic model described by the nCRP and Dirichlet-multinomial distributions.

In [27]:
class nCRPTree:
    def __init__(self, alpha, gamma, eta, num_levels, vocab):
        """
        Initialize the nCRP tree.

        Parameters:
        - alpha: float, smoothing parameter for document-topic distributions.
        - gamma: float, concentration parameter for the nested CRP.
        - eta: float, smoothing parameter for topic-word distributions.
        - num_levels: int, maximum depth of the hierarchical tree.
        - vocab: list, vocabulary for the corpus.
        """
        self.root = Node()
        self.alpha = alpha  # Smoothing on doc-topic distributions
        self.gamma = gamma  # Concentration parameter for nCRP
        self.eta = eta      # Smoothing on topic-word distributions
        self.eta_sum = eta * len(vocab)
        self.vocab = vocab
        self.num_levels = num_levels
        self.paths = {}  # Dictionary to keep track of paths for each document
        self.document_leaves = {}  # Mapping from document IDs to their leaf nodes
        self.levels = {}  # Mapping from document IDs to word-level assignments

    def forget(self):
        """Reset the tree to its initial state and clear paths."""
        self.root = Node()
        self.paths = {}  # Clear all paths
    
    def CRP(self, node):
        """
        Basic CRP process.

        Returns:
        - sampled_topic: Label of the sampled topic (Not the Node)
        """
        total_documents = node.documents  # Including the incoming document
        topic_probabilities = {}

        # If there are no tables, create a new table with probability 1
        if not node.children:
            return np.int64(1)  # The new table has a key = 1

        else:
            # Calculating the probability of joining each of the existing tables (topics)
            for topic, child_node in node.children.items():
                topic_probabilities[topic] = child_node.documents / (self.alpha + total_documents - 1)

            # Probability of creating a new table (topic)
            new_table_key = np.max(list(node.children.keys())) + 1
            topic_probabilities[new_table_key] = self.alpha / (self.alpha + total_documents - 1)

            topics = list(topic_probabilities.keys())
            probabilities = list(topic_probabilities.values())

            # Since probabilities sum to 1, normalization is not needed
            sampled_topic = np.random.choice(topics, p=probabilities)

            return sampled_topic

    def initialize_new_path(self, max_depth, document_id):
        """
        Sample a path through the tree using the nCRP during initialization.

        Parameters:
        - max_depth: Maximum depth of the tree (i.e. L-level tree).
        - document_id: ID of the document being sampled.

        Returns:
        - path: A list of topics (node labels) representing the path through the tree.
        """
        current_node = self.root
        current_node.documents += 1
        path = []  # Track the path as a list of node labels (not pointers)

        for level in range(1, max_depth):
            # Use the CRP function to sample a topic
            sampled_topic = self.CRP(current_node)
            path.append(sampled_topic)

            # Create new table if needed
            if sampled_topic not in current_node.children:
                current_node.add_child(sampled_topic)

            current_node = current_node.children[sampled_topic]
            current_node.documents += 1

        self.paths[document_id] = path
        return path
    

    def initialise_tree(self, corpus, max_depth):
        """
        Initialise the tree by sampling a path for each document using the nCRP
        and assigning word counts to nodes along their respective paths.

        Parameters:
        - corpus: List of documents (each document is a list of words)
        - max_depth: Maximum depth of the tree
        """
        for doc_id, doc_words in enumerate(corpus):
            path = self.initialize_new_path(max_depth, doc_id)

            # Traverse the path and assign words to nodes along the path
            current_node = self.root
            current_node.documents += 1  # Increment document count at the root
            for level, topic_id in enumerate(path):
                if topic_id not in current_node.children:
                    current_node.add_child(topic_id)
                current_node = current_node.children[topic_id]

                # Assign word counts to this node
                current_node.documents += 1
                for word in doc_words:
                    current_node.word_counts[word] = current_node.word_counts.get(word, 0) + 1
                    current_node.total_words += 1

    def _set_pointers(self, path):
        """
        Traverse the path and set pointers to the nodes along the path.

        Parameters:
        - path: A list of node labels (topics) representing the path through the tree

        Returns:
        - node_pointers: A list of pointers to the actual nodes along the path
        """
        node_pointers = []
        current_node = self.root

        for topic in path:
            current_node = current_node.children[topic]
            node_pointers.append(current_node)

        return node_pointers
    
    def add_document(self, document_id, sampled_node, level_word_counts):
        path = []
        node = sampled_node
        while node is not None:
            path.insert(0, node)
            node.documents += 1
            node = node.parent
        self.paths[document_id] = path
        # Update word counts along the path
        for level, node in enumerate(path):
            word_counts = level_word_counts.get(level, {})
            for word, count in word_counts.items():
                node.word_counts[word] = node.word_counts.get(word, 0) + count
                node.total_words += count
    
    def remove_document(self, document_id):
        """
        Remove a document from its current path in the tree and update statistics.

        Parameters:
        - document_id: int, ID of the document to be removed.
        """
        if document_id not in self.paths:
            return  # Document is not in the tree, no action needed

        path = self.paths[document_id]  # Get the path for this document
        current_node = self.root

        # Iterate over the path and decrement counts
        for topic in path:
            # Move to the child node
            child_node = current_node.children.get(topic)
            if child_node is None:
                continue

            # Decrement the document count
            child_node.documents -= 1

            # Remove the node if it's empty
            if child_node.documents == 0:
                del current_node.children[topic]
                Node.total_nodes -= 1
            else:
                # Adjust word counts for this node
                if document_id in self.levels:
                    word_counts_at_level = self.levels[document_id].get(child_node.level, {})
                    for word, count in word_counts_at_level.items():
                        child_node.word_counts[word] -= count
                        if child_node.word_counts[word] == 0:
                            del child_node.word_counts[word]
                        child_node.total_words -= count

            # Move to the next node
            current_node = child_node

        # Remove the document from paths and levels
        del self.paths[document_id]
        if document_id in self.levels:
            del self.levels[document_id]
    
    def compute_ncrp_prior(self, node, weight, node_weights):
        """
        Compute the nested CRP prior recursively for all paths from the given node.

        Parameters:
        - node: Node object, the starting node for the computation.
        - weight: float, cumulative log probability of the path up to the current node.
        - node_weights: dict, stores the computed log probabilities for all nodes.
        """
        total_customers = node.documents
        for topic_id, child_node in node.children.items():
            child_weight = weight + np.log(child_node.documents / (self.gamma + total_customers))
            self.compute_ncrp_prior(child_node, child_weight, node_weights)
            
        # Weight for creating a new path from this node
        new_path_weight = weight + np.log(self.gamma / (self.gamma + total_customers))
        node_weights[node] = new_path_weight
    
    def compute_doc_likelihood(self, node, level_word_counts, weight, node_weights, level=0):
        """
        Compute the document likelihood for words at each level along the path.

        Parameters:
        - node: Node object, the starting node for likelihood computation.
        - level_word_counts: dict, mapping levels to word counts for the document.
        - weight: float, cumulative log likelihood up to the current node.
        - node_weights: dict, stores the computed likelihood values for nodes.
        - level: int, current level of the tree (default is 0).
        """
        node_weight = 0.0
        word_counts = level_word_counts.get(level, {})
        total_words = node.total_words
        for word, count in word_counts.items():
            word_count_at_node = node.word_counts.get(word, 0)
            for i in range(count):
                node_weight += np.log((self.eta + word_count_at_node + i) /
                                    (self.eta_sum + total_words + i))
        total_weight = weight + node_weight
        node_weights[node] += total_weight
        # Recurse into children
        for child_node in node.children.values():
            self.compute_doc_likelihood(child_node, level_word_counts, total_weight, node_weights, level + 1)
        
    def compute_posterior_over_paths(self, node_weights):
        """
        Compute the posterior probabilities over paths based on prior and likelihood.

        Parameters:
        - node_weights: dict, combined prior and likelihood values for nodes.

        Returns:
        - nodes: list, the Node objects for which the posterior is computed.
        - probabilities: numpy array, normalized probabilities for the nodes.
        """
        nodes = list(node_weights.keys())
        weights = np.array(list(node_weights.values()))
        max_weight = np.max(weights)
        weights = np.exp(weights - max_weight)  # For numerical stability
        probabilities = weights / np.sum(weights)
        return nodes, probabilities
    
    def sample_new_path(self, nodes, probabilities):
        """
        Sample a new path based on posterior probabilities.

        Parameters:
        - nodes: list, Node objects to sample from.
        - probabilities: numpy array, probabilities corresponding to the nodes.

        Returns:
        - sampled_node: Node object, the sampled node.
        """
        sampled_node = np.random.choice(nodes, p=probabilities)
        # If the sampled node is not a leaf, we need to create a new leaf from it
        if not sampled_node.is_leaf():
            new_topic_id = max(sampled_node.children.keys(), default=0) + 1
            sampled_node = sampled_node.add_child(new_topic_id)
        return sampled_node
                
    def sample_path(self, document_id, document_words, document_levels):
        # Remove the document from its current path
        if document_id in self.paths:
            self.remove_document(document_id)
        # Collect word counts per level
        level_word_counts = {}
        for word, level in zip(document_words, document_levels):
            level_word_counts.setdefault(level, {})
            level_word_counts[level][word] = level_word_counts[level].get(word, 0) + 1
        # Compute the nCRP prior
        node_weights = defaultdict(float)
        self.compute_ncrp_prior(self.root, 0.0, node_weights)
        # Compute the document likelihood for each node
        self.compute_doc_likelihood(self.root, level_word_counts, 0.0, node_weights)
        # Compute posterior over paths
        nodes, probabilities = self.compute_posterior_over_paths(node_weights)
        # Sample a new path
        sampled_node = self.sample_new_path(nodes, probabilities)
        # Add the document back to the tree
        self.add_document(document_id, sampled_node, level_word_counts)
        # Update the document's leaf node
        self.document_leaves[document_id] = sampled_node
    

In [28]:
# Vocabulary and corpus
vocab = ["a", "b", "c"]
corpus = [[0, 1], [1, 2], [0, 2]]  # Toy corpus with 3 documents
tree = nCRPTree(alpha=1.0, gamma=1.0, eta=0.1, num_levels=3, vocab=vocab)

print("Initializing tree...")
tree.initialise_tree(corpus, max_depth=3)

print("\nTree Structure After Initialization:")
print(f"Root Node: Documents={tree.root.documents}")
for topic_id, child_node in tree.root.children.items():
    print(f"  Topic {topic_id}: Level={child_node.level}, Documents={child_node.documents}, Total Words={child_node.total_words}")
    for sub_topic_id, sub_child_node in child_node.children.items():
        print(f"    Sub-Topic {sub_topic_id}: Level={sub_child_node.level}, Documents={sub_child_node.documents}, Total Words={sub_child_node.total_words}")

Initializing tree...

Tree Structure After Initialization:
Root Node: Documents=6
  Topic 1: Level=1, Documents=4, Total Words=4
    Sub-Topic 1: Level=2, Documents=4, Total Words=4
  Topic 2: Level=1, Documents=2, Total Words=2
    Sub-Topic 1: Level=2, Documents=2, Total Words=2


In [29]:
print("\nRemoving document 1...")
tree.remove_document(document_id=1)

print("\nTree Structure After Document Removal:")
print(f"Root Node: Documents={tree.root.documents}")
for topic_id, child_node in tree.root.children.items():
    print(f"  Topic {topic_id}: Level={child_node.level}, Documents={child_node.documents}, Total Words={child_node.total_words}")
    for sub_topic_id, sub_child_node in child_node.children.items():
        print(f"    Sub-Topic {sub_topic_id}: Level={sub_child_node.level}, Documents={sub_child_node.documents}, Total Words={sub_child_node.total_words}")


Removing document 1...

Tree Structure After Document Removal:
Root Node: Documents=6
  Topic 1: Level=1, Documents=4, Total Words=4
    Sub-Topic 1: Level=2, Documents=4, Total Words=4
  Topic 2: Level=1, Documents=1, Total Words=2
    Sub-Topic 1: Level=2, Documents=1, Total Words=2


In [30]:
new_document = [0, 1, 2]
print("\nSampling a path for a new document...")
tree.sample_path(document_id=3, document_words=new_document, document_levels=[0, 1, 2])

print("\nTree Structure After Sampling New Path:")
print(f"Root Node: Documents={tree.root.documents}")
for topic_id, child_node in tree.root.children.items():
    print(f"  Topic {topic_id}: Level={child_node.level}, Documents={child_node.documents}, Total Words={child_node.total_words}")
    for sub_topic_id, sub_child_node in child_node.children.items():
        print(f"    Sub-Topic {sub_topic_id}: Level={sub_child_node.level}, Documents={sub_child_node.documents}, Total Words={sub_child_node.total_words}")


Sampling a path for a new document...

Tree Structure After Sampling New Path:
Root Node: Documents=7
  Topic 1: Level=1, Documents=4, Total Words=4
    Sub-Topic 1: Level=2, Documents=4, Total Words=4
  Topic 2: Level=1, Documents=1, Total Words=2
    Sub-Topic 1: Level=2, Documents=1, Total Words=2
  Topic 3: Level=1, Documents=1, Total Words=1


In [31]:
print("\nComputing nCRP prior...")
node_weights = defaultdict(float)
tree.compute_ncrp_prior(tree.root, 0.0, node_weights)

print("\nComputed nCRP Priors:")
for node, weight in node_weights.items():
    print(f"Node {node.node_id} (Level {node.level}): Prior Weight={weight}")


Computing nCRP prior...

Computed nCRP Priors:
Node 2 (Level 2): Prior Weight=-2.525728644308255
Node 1 (Level 1): Prior Weight=-2.3025850929940455
Node 4 (Level 2): Prior Weight=-3.4657359027997265
Node 3 (Level 1): Prior Weight=-2.772588722239781
Node 5 (Level 1): Prior Weight=-2.772588722239781
Node 0 (Level 0): Prior Weight=-2.0794415416798357


In [32]:
print("\nComputing document likelihood...")
level_word_counts = {0: {0: 2}, 1: {1: 1}, 2: {2: 1}}  # Example word counts for a document
node_weights = defaultdict(float)
tree.compute_doc_likelihood(tree.root, level_word_counts, 0.0, node_weights)

print("\nComputed Document Likelihoods:")
for node, weight in node_weights.items():
    print(f"Node {node.node_id} (Level {node.level}): Likelihood Weight={weight}")


Computing document likelihood...

Computed Document Likelihoods:
Node 0 (Level 0): Likelihood Weight=-0.25802586286889284
Node 1 (Level 1): Likelihood Weight=-1.621330705764085
Node 2 (Level 2): Likelihood Weight=-2.984635548659277
Node 3 (Level 1): Likelihood Weight=-0.9956248059996718
Node 4 (Level 2): Likelihood Weight=-1.7332237491304507
Node 5 (Level 1): Likelihood Weight=-0.425079947532059


In [33]:
print("\nComputing posterior over paths...")
nodes, probabilities = tree.compute_posterior_over_paths(node_weights)

print("\nPosterior Probabilities:")
for node, prob in zip(nodes, probabilities):
    print(f"Node {node.node_id} (Level {node.level}): Probability={prob}")


Computing posterior over paths...

Posterior Probabilities:
Node 0 (Level 0): Probability=0.3478983400128264
Node 1 (Level 1): Probability=0.088997249770723
Node 2 (Level 2): Probability=0.02276673831344076
Node 3 (Level 1): Probability=0.16638616261483002
Node 4 (Level 2): Probability=0.07957599081578828
Node 5 (Level 1): Probability=0.2943755184723915
