In [1]:
import numpy as np

np.random.seed(3288)

# Constructing nCRP tree

In [2]:
class Node:
    def __init__(self):
        self.children = {}  # Dictionary to store child nodes
        self.documents = 0  # Number of documents passing through this node
        self.words = {}     # Dictionary to store word assignments for each document

class nCRPTree:
    def __init__(self, alpha):
        self.root = Node()
        self.alpha = alpha  # Concentration parameter
        self.paths = {}  # Dictionary to keep track of paths for each document

    def sample_new_path(self, max_depth, document_id):
        """
        Sample a path through the tree using the nCRP.

        Parameters:
        - max_depth: Maximum depth of the tree (i.e. L-level tree)
        - document_id: ID of the document being sampled

        Returns:
        - path: A list of topics (node labels) representing the path through the tree
        """
        current_node = self.root
        current_node.documents += 1
        path = []  # Track the path as a list of node labels (not pointers)

        for level in range(1, max_depth):
            # Use the CRP function to sample a topic
            sampled_topic = self.CRP(current_node)
            path.append(sampled_topic)

            # Create new table if needed
            if sampled_topic not in current_node.children:
                current_node.children[sampled_topic] = Node()

            # Move to the next node in the path
            current_node = current_node.children[sampled_topic]
            current_node.documents += 1

        # Update the path for the document
        self.paths[document_id] = path
        return path

    def CRP(self, node):
        """
        Basic CRP process.

        Returns:
        - sampled_topic: Label of the sampled topic (Not the Node)
        """
        total_documents = node.documents  # Including the incoming document
        topic_probabilities = {}

        # If there are no tables, create a new table with probability 1
        if not node.children:
            return np.int64(1)  # The new table has a key = 1

        else:
            # Calculating the probability of joining each of the existing tables (topics)
            for topic, child_node in node.children.items():
                topic_probabilities[topic] = child_node.documents / (self.alpha + total_documents - 1)

            # Probability of creating a new table (topic)
            new_table_key = np.max(list(node.children.keys())) + 1
            topic_probabilities[new_table_key] = self.alpha / (self.alpha + total_documents - 1)

            topics = list(topic_probabilities.keys())
            probabilities = list(topic_probabilities.values())

            # Since probabilities sum to 1, normalization is not needed
            sampled_topic = np.random.choice(topics, p=probabilities)

            return sampled_topic

    def forget(self):
        """Reset the tree to its initial state and clear paths."""
        self.root = Node()
        self.paths = {}  # Clear all paths

    def initialise_tree(self, corpus, max_depth):
        """
        Initialise the tree by first randomly sampling a path using the nCRP
        and then randomly assigning words from each document to nodes along their respective paths.
        
        Parameters:
        - corpus: List of documents (each document is a list of words)
        - max_depth: Maximum depth of the tree
        
        The word assignments are directly modified in each node's 'words' attribute.
        """
        for doc_id, doc_words in enumerate(corpus):
            # Sample a path (get node labels) for the document
            path = self.sample_new_path(max_depth, doc_id)

            # Set pointers for the nodes along the path
            node_pointers = self._set_pointers(path)

            # Randomly assign each word in the document to a node along the path
            for word in doc_words:
                # Randomly choose a node pointer from the path to assign this word
                node_idx = np.random.randint(0, len(node_pointers))
                node_assigned = node_pointers[node_idx]

                # Assign the word to the node's 'words' dictionary
                if doc_id not in node_assigned.words:
                    node_assigned.words[doc_id] = []
                node_assigned.words[doc_id].append(word)

            # Delete the pointers after assignment
            del node_pointers

    def _set_pointers(self, path):
        """
        Traverse the path and set pointers to the nodes along the path.

        Parameters:
        - path: A list of node labels (topics) representing the path through the tree

        Returns:
        - node_pointers: A list of pointers to the actual nodes along the path
        """
        node_pointers = []
        current_node = self.root

        for topic in path:
            current_node = current_node.children[topic]
            node_pointers.append(current_node)

        return node_pointers

Testing nCRP process and initialisation

In [3]:
corpus = [
    ["apple", "banana", "orange", "grape"],              # Document 1
    ["cat", "dog", "fish", "bird", "hamster"],           # Document 2
    ["car", "bike", "bus", "train", "plane", "ship"],    # Document 3
    ["house", "building", "apartment", "cabin"],         # Document 4
    ["sun", "moon", "stars", "galaxy"],                  # Document 5
    ["river", "lake", "ocean", "sea"],                   # Document 6
    ["earth", "mars", "jupiter", "saturn", "venus"],     # Document 7
]

# Initialize the tree and set parameters
alpha = 1.0
tree = nCRPTree(alpha)
max_depth = 4

In [4]:
tree.initialise_tree(corpus, max_depth)

In [5]:
# Function to print all the paths for each document
def print_document_paths(tree):
    print("Document Paths:")
    for doc_id, path in tree.paths.items():
        print(f"Document {doc_id + 1} Path: {path}")

# Function to print word assignments across the tree, ensuring all nodes at each level are printed together
def print_tree_word_assignments(node, level=0, level_dict=None):
    if level_dict is None:
        level_dict = {}

    # Store word assignments for the current level
    if node.words:
        if level not in level_dict:
            level_dict[level] = []
        # Collect the word assignments for the current level
        word_assignment_str = []
        for doc_id, words in node.words.items():
            word_assignment_str.append(f"Document {doc_id + 1}: {words}")
        # Add the current node's word assignments to the level
        level_dict[level].append(' | '.join(word_assignment_str))

    # Recursively process child nodes
    for child_topic, child_node in node.children.items():
        print_tree_word_assignments(child_node, level + 1, level_dict)

    # When done traversing, print the levels in order (only at root level)
    if level == 0:
        for lvl in sorted(level_dict.keys()):
            for assignment in level_dict[lvl]:
                print(f"Level {lvl}: {assignment}")
        
# First, print all the paths for each document
print_document_paths(tree)

# Then, print word assignments at each node
print("\nWord assignments across the tree:")
print_tree_word_assignments(tree.root)

Document Paths:
Document 1 Path: [np.int64(1), np.int64(1), np.int64(1)]
Document 2 Path: [np.int64(1), np.int64(1), np.int64(1)]
Document 3 Path: [np.int64(1), np.int64(1), np.int64(1)]
Document 4 Path: [np.int64(1), np.int64(1), np.int64(2)]
Document 5 Path: [np.int64(1), np.int64(1), np.int64(2)]
Document 6 Path: [np.int64(1), np.int64(2), np.int64(1)]
Document 7 Path: [np.int64(2), np.int64(1), np.int64(1)]

Word assignments across the tree:
Level 1: Document 1: ['apple', 'grape'] | Document 2: ['cat', 'dog'] | Document 3: ['plane'] | Document 5: ['galaxy'] | Document 6: ['ocean']
Level 1: Document 7: ['earth', 'jupiter', 'saturn', 'venus']
Level 2: Document 1: ['banana', 'orange'] | Document 2: ['fish'] | Document 3: ['bus', 'train', 'ship'] | Document 4: ['building', 'apartment', 'cabin'] | Document 5: ['sun', 'moon']
Level 2: Document 6: ['river', 'sea']
Level 3: Document 2: ['bird', 'hamster'] | Document 3: ['car', 'bike']
Level 3: Document 4: ['house'] | Document 5: ['stars']
