In [1]:
!uv pip install "git+https://github.com/stanfordnlp/dspy/@main"

[2K[2mResolved [1m69 packages[0m [2min 297ms[0m[0m                                        [0m
[2mUninstalled [1m1 package[0m [2min 3ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 4ms[0m[0m git+https://github.com/stanfordnl[0m
 [31m-[39m [1mdspy[0m[2m==2.6.27[0m
 [32m+[39m [1mdspy[0m[2m==3.0.0b2 (from git+https://github.com/stanfordnlp/dspy/@6ead3da60268a54bf484fe6831a23d82a6d0f232)[0m


In [2]:
import dspy

In [None]:
dspy.__version__
lm = dspy.LM("groq/moonshotai/kimi-k2-instruct")# api_key = 'gsk_[REDACTED]')
#lm = dspy.LM("gpt-4.1")
dspy.configure(lm = lm)

say we want to do a chat that branches and grow like a tree all managed by the llm itself?
how might we do it?

The flow would be something like this:

user send prompt --> check for where best to attach it to the tree --> collect all messaged from that branch to the root --> sends that collect response --> append branch --> show user

In [5]:
import pydantic
class ScoreItem(pydantic.BaseModel):
    branchID: int
    reasoning: str
    rank: int
    relevance: float

class Score(dspy.Signature):
    """Score each conversation turn with rank, relevance, and comments."""
    user_prompt: str = dspy.InputField(desc = "user prompt for which we score conversation relevance")
    conversation_list: list[dict] = dspy.InputField(desc="List of conversation turns")
    scores_list: list[ScoreItem] = dspy.OutputField(desc="""List of ScoreItems, each with 'reasoning' (LLM's reasoning for that rank and score), 'rank' (1=highest), 'relevance' (0-1)
    conversation relevant to current user_prompt.
    return branchID to match to conversations.
    """)
    
scorer = dspy.Predict(Score)

In [6]:
conversations = [{"branchID": 1,"messages":[{"role":"system","content":"You are a concise physics tutor."},{"role":"user","content":"Why does a bicycle stay upright when moving but fall when still?"},{"role":"assistant","content":"Angular momentum of the wheels creates gyroscopic stability; at speed the wheels resist tilting. When stationary this stabilizing torque disappears, so gravity tips the bike over."}]},
{"branchID": 2,"messages":[{"role":"system","content":"You are a helpful culinary assistant."},{"role":"user","content":"Give me a quick vegan week-night pasta sauce."},{"role":"assistant","content":"Blend one ripe avocado, juice of ½ lemon, garlic clove, pinch of salt & pepper. Toss hot pasta with the creamy sauce and cherry tomatoes."}]},
{"branchID": 3,"messages":[{"role":"system","content":"You are a calm mindfulness coach."},{"role":"user","content":"I feel overwhelmed at work."},{"role":"assistant","content":"Pause. Close your eyes. Breathe in for 4, hold for 4, out for 6. Do this cycle three times, then choose one small task to complete. Momentum follows motion."},{"role":"user","content":"Thanks, I feel a bit lighter."}]},
{"branchID": 4,"messages":[{"role":"system","content":"You are a friendly travel advisor."},{"role":"user","content":"I have 48 h in Lisbon—what’s unmissable?"},{"role":"assistant","content":"Day 1: Alfama sunrise, tram 28, Pastéis de Belém. Day 2: LX Factory, ferry to Cacilhas for seafood sunset. Leave space to get lost."},{"role":"user","content":"Any budget tips?"},{"role":"assistant","content":"Walk instead of tuk-tuks, grab meals at Mercado da Ribeira, and many miradouros are free."}]},
{"branchID": 5,"messages":[{"role":"system","content":"You are a concise coding mentor."},{"role":"user","content":"Explain recursion in one sentence."},{"role":"assistant","content":"A function solves a problem by solving smaller versions of itself until hitting a trivial base case."},{"role":"user","content":"Show me a tiny Python example."},{"role":"assistant","content":"def fact(n): return 1 if n<2 else n*fact(n-1)\nprint(fact(5))  # 120"}]}]

In [7]:
scorer(user_prompt = "I am stressed", conversation_list = conversations)

Prediction(
    scores_list=[ScoreItem(branchID=3, reasoning="The conversation directly addresses the user's current emotional state of stress by offering mindfulness techniques and actionable advice for feeling overwhelmed, which aligns perfectly with the prompt 'I am stressed'.", rank=1, relevance=0.95), ScoreItem(branchID=1, reasoning="While this conversation is about physics and bicycles, it has no connection to the user's stated stress or emotional wellbeing, making it largely irrelevant to the current prompt.", rank=2, relevance=0.05), ScoreItem(branchID=2, reasoning="This is a cooking-related conversation about vegan pasta sauce, which has no bearing on the user's stress or emotional state, making it irrelevant to the current prompt.", rank=3, relevance=0.05), ScoreItem(branchID=4, reasoning="This travel advice conversation about Lisbon is completely unrelated to the user's expressed stress, offering no support or relevance to their emotional needs.", rank=4, relevance=0.05), Sc

we will iterativaly rank conversations in groups of 4 as in a tournaments

initially we will do both: 
1) check for all branches of 4 turns from root to 4th turn
2) also look at the last 4 turns of each branch tip.
3) present best candidate from both and check if the are part of the same path.
4) if so use that
5) else, explore that space to see where the conversation would be be attached. earlier is better then later, attach to root is ok!


In [None]:
from typing import Any
from pydantic.fields import FieldInfo

class PredictWithHistory(dspy.Module):
    def __init__(
        self,
        signature: str | type[dspy.Signature],
        **config: dict[str, Any],
    ):
        super().__init__()
        signature = dspy.ensure_signature(signature)
        extended_signature = signature.prepend(name="history",
                                               field=dspy.InputField(),
                                               type_=dspy.History)
        self.predict = dspy.Predict(extended_signature, **config)
        self.chat_history = []

    def forward(self, **kwargs):
        h = dspy.History(messages=self.chat_history)
        outputs = self.predict(history = h, **kwargs)
        self.chat_history.append({**kwargs, **outputs})
        return outputs


In [8]:
import dspy
import pydantic
from typing import List, Dict, Optional, Tuple, Any
from collections import defaultdict

class ConversationNode:
    """Represents a node in the conversation tree"""
    def __init__(self, branch_id: int, messages: List[Dict], parent: Optional['ConversationNode'] = None):
        self.branch_id = branch_id
        self.messages = messages
        self.parent = parent
        self.children: List['ConversationNode'] = []
        
    def get_path_to_root(self) -> List['ConversationNode']:
        """Get all nodes from this node to root"""
        path = []
        current = self
        while current:
            path.append(current)
            current = current.parent
        return list(reversed(path))
    
    def get_conversation_window(self, start_idx: int = 0, window_size: int = 4) -> List[Dict]:
        """Get a window of conversation from this branch"""
        return self.messages[start_idx:start_idx + window_size]
    
    def get_last_n_turns(self, n: int = 4) -> List[Dict]:
        """Get the last n turns of this branch"""
        return self.messages[-n:] if len(self.messages) >= n else self.messages

class AttachmentPoint(pydantic.BaseModel):
    branch_id: int
    attachment_index: int  # Where in the branch to attach (message index)
    reasoning: str
    score: float

class FindBestAttachment(dspy.Signature):
    """Find the best attachment point for a new conversation turn"""
    user_prompt: str = dspy.InputField(desc="The new user prompt to attach")
    candidate_points: List[Dict] = dspy.InputField(desc="List of potential attachment points with context")
    best_attachment: AttachmentPoint = dspy.OutputField(desc="The best attachment point with reasoning")

class ConversationTreeModule(dspy.Module):
    """DSPy module for managing tree-based conversation branching"""
    
    def __init__(self):
        super().__init__()
        self.scorer = dspy.Predict(Score)
        self.attachment_finder = dspy.Predict(FindBestAttachment)
        self.tree_root = None
        self.branches: Dict[int, ConversationNode] = {}
        
    def build_tree(self, conversations: List[Dict]):
        """Build the conversation tree from the provided conversations"""
        # For simplicity, assuming each conversation is a separate branch from root
        # In a real implementation, you'd parse parent-child relationships
        self.tree_root = ConversationNode(0, [], None)
        
        for conv in conversations:
            branch_id = conv["branchID"]
            messages = conv["messages"]
            node = ConversationNode(branch_id, messages, self.tree_root)
            self.tree_root.children.append(node)
            self.branches[branch_id] = node
    
    def get_4_turn_windows_from_root(self) -> List[Dict]:
        """Get all 4-turn windows starting from root for each branch"""
        windows = []
        for branch_id, node in self.branches.items():
            if len(node.messages) >= 4:
                window = node.get_conversation_window(0, 4)
                windows.append({
                    "branchID": branch_id,
                    "messages": window,
                    "window_type": "from_root"
                })
        return windows
    
    def get_last_4_turns_from_tips(self) -> List[Dict]:
        """Get the last 4 turns from each branch tip"""
        windows = []
        for branch_id, node in self.branches.items():
            window = node.get_last_n_turns(4)
            if window:  # Only include if there are messages
                windows.append({
                    "branchID": branch_id,
                    "messages": window,
                    "window_type": "from_tip"
                })
        return windows
    
    def check_same_path(self, branch_id1: int, branch_id2: int) -> bool:
        """Check if two branches are on the same path"""
        if branch_id1 == branch_id2:
            return True
        
        # Get paths to root for both branches
        node1 = self.branches.get(branch_id1)
        node2 = self.branches.get(branch_id2)
        
        if not node1 or not node2:
            return False
        
        path1 = set(n.branch_id for n in node1.get_path_to_root())
        path2 = set(n.branch_id for n in node2.get_path_to_root())
        
        # Check if one is ancestor of the other
        return branch_id1 in path2 or branch_id2 in path1
    
    def explore_attachment_space(self, user_prompt: str, branch_id: int) -> AttachmentPoint:
        """Explore where in the branch to attach the new conversation"""
        node = self.branches[branch_id]
        candidate_points = []
        
        # Check different attachment points in the branch
        for i in range(len(node.messages) + 1):
            context_before = node.messages[:i] if i > 0 else []
            context_after = node.messages[i:i+2] if i < len(node.messages) else []
            
            candidate_points.append({
                "branch_id": branch_id,
                "attachment_index": i,
                "context_before": context_before[-2:],  # Last 2 messages before
                "context_after": context_after,
                "position_description": f"After message {i}" if i > 0 else "At branch start"
            })
        
        # Also consider attaching to root
        candidate_points.append({
            "branch_id": 0,
            "attachment_index": 0,
            "context_before": [],
            "context_after": [],
            "position_description": "At conversation root (new branch)"
        })
        
        # Find best attachment point
        result = self.attachment_finder(
            user_prompt=user_prompt,
            candidate_points=candidate_points
        )
        
        return result.best_attachment
    
    def forward(self, user_prompt: str, conversations: List[Dict]) -> Dict[str, Any]:
        """Main forward pass to find best conversation branch and attachment point"""
        
        # Build tree if not already built
        if not self.branches:
            self.build_tree(conversations)
        
        # Step 1 & 2: Get windows from root and tips
        root_windows = self.get_4_turn_windows_from_root()
        tip_windows = self.get_last_4_turns_from_tips()
        
        # Score root windows
        root_scores = self.scorer(
            user_prompt=user_prompt,
            conversation_list=root_windows
        ).scores_list
        
        # Score tip windows  
        tip_scores = self.scorer(
            user_prompt=user_prompt,
            conversation_list=tip_windows
        ).scores_list
        
        # Step 3: Find best candidates from both
        best_root = max(root_scores, key=lambda x: x.relevance)
        best_tip = max(tip_scores, key=lambda x: x.relevance)
        
        # Step 4: Check if they're on the same path
        same_path = self.check_same_path(best_root.branchID, best_tip.branchID)
        
        if same_path:
            # Use the one with higher relevance
            best_branch_id = best_root.branchID if best_root.relevance >= best_tip.relevance else best_tip.branchID
            attachment_point = AttachmentPoint(
                branch_id=best_branch_id,
                attachment_index=len(self.branches[best_branch_id].messages),
                reasoning=f"Same path detected, using {'root' if best_root.relevance >= best_tip.relevance else 'tip'} window",
                score=max(best_root.relevance, best_tip.relevance)
            )
        else:
            # Step 5: Explore attachment space
            # Choose the branch with higher relevance to explore
            explore_branch_id = best_root.branchID if best_root.relevance >= best_tip.relevance else best_tip.branchID
            attachment_point = self.explore_attachment_space(user_prompt, explore_branch_id)
        
        # Compile the conversation context for the attachment
        if attachment_point.branch_id == 0:
            # New branch from root
            context_messages = []
        else:
            node = self.branches[attachment_point.branch_id]
            context_messages = node.messages[:attachment_point.attachment_index]
        
        return {
            "best_attachment": attachment_point,
            "context_messages": context_messages,
            "root_best": best_root,
            "tip_best": best_tip,
            "same_path": same_path
        }


In [9]:
# Initialize the module
tree_module = ConversationTreeModule()

# Your existing conversations
conversations = [
    {"branchID": 1,"messages":[{"role":"system","content":"You are a concise physics tutor."},{"role":"user","content":"Why does a bicycle stay upright when moving but fall when still?"},{"role":"assistant","content":"Angular momentum of the wheels creates gyroscopic stability; at speed the wheels resist tilting. When stationary this stabilizing torque disappears, so gravity tips the bike over."}]},
    {"branchID": 2,"messages":[{"role":"system","content":"You are a helpful culinary assistant."},{"role":"user","content":"Give me a quick vegan week-night pasta sauce."},{"role":"assistant","content":"Blend one ripe avocado, juice of ½ lemon, garlic clove, pinch of salt & pepper. Toss hot pasta with the creamy sauce and cherry tomatoes."}]},
    {"branchID": 3,"messages":[{"role":"system","content":"You are a calm mindfulness coach."},{"role":"user","content":"I feel overwhelmed at work."},{"role":"assistant","content":"Pause. Close your eyes. Breathe in for 4, hold for 4, out for 6. Do this cycle three times, then choose one small task to complete. Momentum follows motion."},{"role":"user","content":"Thanks, I feel a bit lighter."}]},
    {"branchID": 4,"messages":[{"role":"system","content":"You are a friendly travel advisor."},{"role":"user","content":"I have 48 h in Lisbon—what's unmissable?"},{"role":"assistant","content":"Day 1: Alfama sunrise, tram 28, Pastéis de Belém. Day 2: LX Factory, ferry to Cacilhas for seafood sunset. Leave space to get lost."},{"role":"user","content":"Any budget tips?"},{"role":"assistant","content":"Walk instead of tuk-tuks, grab meals at Mercado da Ribeira, and many miradouros are free."}]},
    {"branchID": 5,"messages":[{"role":"system","content":"You are a concise coding mentor."},{"role":"user","content":"Explain recursion in one sentence."},{"role":"assistant","content":"A function solves a problem by solving smaller versions of itself until hitting a trivial base case."},{"role":"user","content":"Show me a tiny Python example."},{"role":"assistant","content":"def fact(n): return 1 if n<2 else n*fact(n-1)\nprint(fact(5))  # 120"}]}
]

# Find best attachment for a new user prompt
result = tree_module(user_prompt="I am stressed", conversations=conversations)

print(f"Best attachment: Branch {result['best_attachment'].branch_id}, "
      f"Index {result['best_attachment'].attachment_index}")
print(f"Reasoning: {result['best_attachment'].reasoning}")
print(f"Same path: {result['same_path']}")

Best attachment: Branch 3, Index 4
Reasoning: Same path detected, using root window
Same path: True


In [10]:
import json
import sqlite3
from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple
from dataclasses import dataclass, asdict
import uuid
from pathlib import Path
import pickle
from enum import Enum

class MessageRole(str, Enum):
    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"

@dataclass
class Message:
    """Individual message in a conversation"""
    id: str
    role: MessageRole
    content: str
    timestamp: datetime
    metadata: Optional[Dict[str, Any]] = None
    
    def to_dict(self) -> Dict:
        data = asdict(self)
        data['timestamp'] = self.timestamp.isoformat()
        data['role'] = self.role.value
        return data
    
    @classmethod
    def from_dict(cls, data: Dict) -> 'Message':
        data = data.copy()
        data['timestamp'] = datetime.fromisoformat(data['timestamp'])
        data['role'] = MessageRole(data['role'])
        return cls(**data)

@dataclass
class ConversationBranch:
    """A branch in the conversation tree"""
    id: str
    parent_id: Optional[str]  # None for root
    messages: List[Message]
    created_at: datetime
    updated_at: datetime
    metadata: Optional[Dict[str, Any]] = None
    attachment_point: Optional[int] = None  # Where this branch attaches to parent
    
    def to_dict(self) -> Dict:
        return {
            'id': self.id,
            'parent_id': self.parent_id,
            'messages': [msg.to_dict() for msg in self.messages],
            'created_at': self.created_at.isoformat(),
            'updated_at': self.updated_at.isoformat(),
            'metadata': self.metadata,
            'attachment_point': self.attachment_point
        }
    
    @classmethod
    def from_dict(cls, data: Dict) -> 'ConversationBranch':
        data = data.copy()
        data['messages'] = [Message.from_dict(msg) for msg in data['messages']]
        data['created_at'] = datetime.fromisoformat(data['created_at'])
        data['updated_at'] = datetime.fromisoformat(data['updated_at'])
        return cls(**data)

class ConversationTreeDB:
    """SQLite-based persistence for conversation trees"""
    
    def __init__(self, db_path: str = "conversation_tree.db"):
        self.db_path = db_path
        self._init_db()
    
    def _init_db(self):
        """Initialize database schema"""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute('''
                CREATE TABLE IF NOT EXISTS branches (
                    id TEXT PRIMARY KEY,
                    parent_id TEXT,
                    attachment_point INTEGER,
                    created_at TEXT NOT NULL,
                    updated_at TEXT NOT NULL,
                    metadata TEXT,
                    FOREIGN KEY (parent_id) REFERENCES branches(id)
                )
            ''')
            
            conn.execute('''
                CREATE TABLE IF NOT EXISTS messages (
                    id TEXT PRIMARY KEY,
                    branch_id TEXT NOT NULL,
                    position INTEGER NOT NULL,
                    role TEXT NOT NULL,
                    content TEXT NOT NULL,
                    timestamp TEXT NOT NULL,
                    metadata TEXT,
                    FOREIGN KEY (branch_id) REFERENCES branches(id),
                    UNIQUE(branch_id, position)
                )
            ''')
            
            conn.execute('''
                CREATE TABLE IF NOT EXISTS tree_metadata (
                    tree_id TEXT PRIMARY KEY,
                    root_branch_id TEXT,
                    created_at TEXT NOT NULL,
                    updated_at TEXT NOT NULL,
                    metadata TEXT
                )
            ''')
            
            # Indexes for performance
            conn.execute('CREATE INDEX IF NOT EXISTS idx_branch_parent ON branches(parent_id)')
            conn.execute('CREATE INDEX IF NOT EXISTS idx_message_branch ON messages(branch_id)')
    
    def save_branch(self, branch: ConversationBranch):
        """Save or update a branch"""
        with sqlite3.connect(self.db_path) as conn:
            # Save branch
            conn.execute('''
                INSERT OR REPLACE INTO branches 
                (id, parent_id, attachment_point, created_at, updated_at, metadata)
                VALUES (?, ?, ?, ?, ?, ?)
            ''', (
                branch.id,
                branch.parent_id,
                branch.attachment_point,
                branch.created_at.isoformat(),
                branch.updated_at.isoformat(),
                json.dumps(branch.metadata) if branch.metadata else None
            ))
            
            # Delete existing messages for this branch (for updates)
            conn.execute('DELETE FROM messages WHERE branch_id = ?', (branch.id,))
            
            # Save messages
            for position, msg in enumerate(branch.messages):
                conn.execute('''
                    INSERT INTO messages 
                    (id, branch_id, position, role, content, timestamp, metadata)
                    VALUES (?, ?, ?, ?, ?, ?, ?)
                ''', (
                    msg.id,
                    branch.id,
                    position,
                    msg.role.value,
                    msg.content,
                    msg.timestamp.isoformat(),
                    json.dumps(msg.metadata) if msg.metadata else None
                ))
    
    def load_branch(self, branch_id: str) -> Optional[ConversationBranch]:
        """Load a branch by ID"""
        with sqlite3.connect(self.db_path) as conn:
            # Load branch info
            cursor = conn.execute('''
                SELECT parent_id, attachment_point, created_at, updated_at, metadata
                FROM branches WHERE id = ?
            ''', (branch_id,))
            
            row = cursor.fetchone()
            if not row:
                return None
            
            parent_id, attachment_point, created_at, updated_at, metadata = row
            
            # Load messages
            cursor = conn.execute('''
                SELECT id, role, content, timestamp, metadata
                FROM messages WHERE branch_id = ?
                ORDER BY position
            ''', (branch_id,))
            
            messages = []
            for msg_row in cursor:
                msg_id, role, content, timestamp, msg_metadata = msg_row
                messages.append(Message(
                    id=msg_id,
                    role=MessageRole(role),
                    content=content,
                    timestamp=datetime.fromisoformat(timestamp),
                    metadata=json.loads(msg_metadata) if msg_metadata else None
                ))
            
            return ConversationBranch(
                id=branch_id,
                parent_id=parent_id,
                messages=messages,
                created_at=datetime.fromisoformat(created_at),
                updated_at=datetime.fromisoformat(updated_at),
                metadata=json.loads(metadata) if metadata else None,
                attachment_point=attachment_point
            )
    
    def get_children(self, branch_id: str) -> List[str]:
        """Get all child branch IDs for a given branch"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute(
                'SELECT id FROM branches WHERE parent_id = ?',
                (branch_id,)
            )
            return [row[0] for row in cursor]

class ConversationTreeManager:
    """Main manager for conversation tree operations"""
    
    def __init__(self, db_path: str = "conversation_tree.db", cache_size: int = 100):
        self.db = ConversationTreeDB(db_path)
        self.cache: Dict[str, ConversationBranch] = {}
        self.cache_size = cache_size
        self.tree_id = str(uuid.uuid4())
        self.root_branch_id: Optional[str] = None
    
    def create_root_branch(self) -> ConversationBranch:
        """Create the root branch of the conversation tree"""
        root = ConversationBranch(
            id=str(uuid.uuid4()),
            parent_id=None,
            messages=[],
            created_at=datetime.now(),
            updated_at=datetime.now(),
            metadata={"is_root": True}
        )
        self.db.save_branch(root)
        self.root_branch_id = root.id
        self._cache_branch(root)
        return root
    
    def add_message(self, branch_id: str, role: MessageRole, content: str, 
                   metadata: Optional[Dict] = None) -> Message:
        """Add a message to an existing branch"""
        branch = self.get_branch(branch_id)
        if not branch:
            raise ValueError(f"Branch {branch_id} not found")
        
        message = Message(
            id=str(uuid.uuid4()),
            role=role,
            content=content,
            timestamp=datetime.now(),
            metadata=metadata
        )
        
        branch.messages.append(message)
        branch.updated_at = datetime.now()
        
        self.db.save_branch(branch)
        self._cache_branch(branch)
        
        return message
    
    def create_branch(self, parent_id: str, attachment_point: int,
                     initial_messages: Optional[List[Dict]] = None) -> ConversationBranch:
        """Create a new branch from a parent at a specific attachment point"""
        parent = self.get_branch(parent_id)
        if not parent:
            raise ValueError(f"Parent branch {parent_id} not found")
        
        if attachment_point > len(parent.messages):
            raise ValueError(f"Invalid attachment point {attachment_point}")
        
        # Create new branch
        branch = ConversationBranch(
            id=str(uuid.uuid4()),
            parent_id=parent_id,
            messages=[],
            created_at=datetime.now(),
            updated_at=datetime.now(),
            attachment_point=attachment_point
        )
        
        # Add initial messages if provided
        if initial_messages:
            for msg_data in initial_messages:
                message = Message(
                    id=str(uuid.uuid4()),
                    role=MessageRole(msg_data['role']),
                    content=msg_data['content'],
                    timestamp=datetime.now(),
                    metadata=msg_data.get('metadata')
                )
                branch.messages.append(message)
        
        self.db.save_branch(branch)
        self._cache_branch(branch)
        
        return branch
    
    def get_branch(self, branch_id: str) -> Optional[ConversationBranch]:
        """Get a branch by ID (with caching)"""
        # Check cache first
        if branch_id in self.cache:
            return self.cache[branch_id]
        
        # Load from DB
        branch = self.db.load_branch(branch_id)
        if branch:
            self._cache_branch(branch)
        
        return branch
    
    def get_conversation_context(self, branch_id: str) -> List[Message]:
        """Get full conversation context from root to branch"""
        context = []
        current_id = branch_id
        
        # Walk up the tree collecting messages
        path = []
        while current_id:
            branch = self.get_branch(current_id)
            if not branch:
                break
            path.append(branch)
            current_id = branch.parent_id
        
        # Reverse to get root-to-leaf order
        path.reverse()
        
        # Collect messages respecting attachment points
        for i, branch in enumerate(path):
            if i == 0:  # Root branch
                context.extend(branch.messages)
            else:
                # Only include messages up to the attachment point from parent
                parent = path[i-1]
                attachment_point = branch.attachment_point or len(parent.messages)
                
                # Add messages from parent up to attachment point
                if i == 1:  # First child of root
                    context = parent.messages[:attachment_point]
                
                # Add this branch's messages
                context.extend(branch.messages)
        
        return context
    
    def reconstruct_tree(self) -> Dict[str, Any]:
        """Reconstruct the entire tree structure"""
        with sqlite3.connect(self.db.db_path) as conn:
            # Get all branches
            cursor = conn.execute('''
                SELECT id, parent_id, attachment_point, created_at, updated_at
                FROM branches
                ORDER BY created_at
            ''')
            
            tree = {"branches": {}, "root": None}
            
            for row in cursor:
                branch_id, parent_id, attachment_point, created_at, updated_at = row
                
                if parent_id is None:
                    tree["root"] = branch_id
                
                tree["branches"][branch_id] = {
                    "id": branch_id,
                    "parent_id": parent_id,
                    "attachment_point": attachment_point,
                    "children": self.db.get_children(branch_id),
                    "created_at": created_at,
                    "updated_at": updated_at
                }
            
            return tree
    
    def export_tree(self, filepath: str):
        """Export entire tree to JSON file"""
        tree_data = {"tree_id": self.tree_id, "branches": []}
        
        # Get all branch IDs
        with sqlite3.connect(self.db.db_path) as conn:
            cursor = conn.execute('SELECT id FROM branches')
            branch_ids = [row[0] for row in cursor]
        
        # Load and serialize each branch
        for branch_id in branch_ids:
            branch = self.get_branch(branch_id)
            if branch:
                tree_data["branches"].append(branch.to_dict())
        
        with open(filepath, 'w') as f:
            json.dump(tree_data, f, indent=2)
    
    def import_tree(self, filepath: str):
        """Import tree from JSON file"""
        with open(filepath, 'r') as f:
            tree_data = json.load(f)
        
        self.tree_id = tree_data.get("tree_id", str(uuid.uuid4()))
        
        for branch_data in tree_data["branches"]:
            branch = ConversationBranch.from_dict(branch_data)
            self.db.save_branch(branch)
            
            if branch.parent_id is None:
                self.root_branch_id = branch.id
    
    def _cache_branch(self, branch: ConversationBranch):
        """Add branch to cache with LRU eviction"""
        if len(self.cache) >= self.cache_size:
            # Simple eviction - remove oldest
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]
        
        self.cache[branch.id] = branch

# Integration with your DSPy module
class PersistentConversationTreeModule(ConversationTreeModule):
    """Extended module with persistence"""
    
    def __init__(self, db_path: str = "conversation_tree.db"):
        super().__init__()
        self.manager = ConversationTreeManager(db_path)
        
    def forward(self, user_prompt: str) -> Dict[str, Any]:
        """Forward with automatic persistence"""
        # Load current tree structure
        tree_structure = self.manager.reconstruct_tree()
        
        # Convert to format expected by parent class
        conversations = []
        for branch_id, branch_info in tree_structure["branches"].items():
            branch = self.manager.get_branch(branch_id)
            if branch and branch.messages:
                conversations.append({
                    "branchID": branch_id,
                    "messages": [{"role": msg.role.value, "content": msg.content} 
                                for msg in branch.messages]
                })
        
        # Get attachment decision
        result = super().forward(user_prompt, conversations)
        
        # Create new branch or extend existing one
        attachment = result["best_attachment"]
        
        if attachment.branch_id == "0" or attachment.branch_id == 0:
            # Create new branch from root
            if not self.manager.root_branch_id:
                self.manager.create_root_branch()
            
            new_branch = self.manager.create_branch(
                parent_id=self.manager.root_branch_id,
                attachment_point=0,
                initial_messages=[{"role": "user", "content": user_prompt}]
            )
            result["new_branch_id"] = new_branch.id
        else:
            # Add to existing branch or create sub-branch
            if attachment.attachment_index == len(self.manager.get_branch(str(attachment.branch_id)).messages):
                # Extend existing branch
                self.manager.add_message(
                    str(attachment.branch_id),
                    MessageRole.USER,
                    user_prompt
                )
                result["new_branch_id"] = str(attachment.branch_id)
            else:
                # Create new sub-branch
                new_branch = self.manager.create_branch(
                    parent_id=str(attachment.branch_id),
                    attachment_point=attachment.attachment_index,
                    initial_messages=[{"role": "user", "content": user_prompt}]
                )
                result["new_branch_id"] = new_branch.id
        
        return result

In [11]:
tree = PersistentConversationTreeModule("my_conversations.db")

# Process a user message
result = tree("Tell me about mindfulness techniques")

# The conversation is automatically persisted
print(f"Message attached to branch: {result['new_branch_id']}")

# Export for backup
tree.manager.export_tree("conversation_backup.json")

ValueError: max() iterable argument is empty

In [12]:
# Initialize the system
manager = ConversationTreeManager("chat_tree.db")

# Create root and add initial conversation
root = manager.create_root_branch()
manager.add_message(root.id, MessageRole.SYSTEM, "You are a helpful assistant")
manager.add_message(root.id, MessageRole.USER, "Hello!")
manager.add_message(root.id, MessageRole.ASSISTANT, "Hi! How can I help?")

# Branch off at a specific point
branch1 = manager.create_branch(
    parent_id=root.id,
    attachment_point=2,  # After "Hello!"
    initial_messages=[
        {"role": "user", "content": "Tell me about Python"}
    ]
)

# Get full context for a branch
context = manager.get_conversation_context(branch1.id)

# Export/Import
manager.export_tree("backup.json")
manager.import_tree("backup.json")

In [13]:
# Use the persistent version
tree_module = PersistentConversationTreeModule()

# Process messages - automatically persisted
result = tree_module("I need help with stress")

# Access the conversation context
branch_id = result["new_branch_id"]
context = tree_module.manager.get_conversation_context(branch_id)

# Continue the conversation
tree_module.manager.add_message(
    branch_id, 
    MessageRole.ASSISTANT,
    "I understand you're feeling stressed. Let's work through some techniques..."
)

ValueError: max() iterable argument is empty

In [16]:
"""
Complete Conversation Tree System with DSPy Integration
Single file version combining tree logic and persistence
"""

import dspy
import pydantic
import json
import sqlite3
from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple
from dataclasses import dataclass, asdict
import uuid
from collections import defaultdict
from enum import Enum

# ============= DSPy Signatures and Models =============

class ScoreItem(pydantic.BaseModel):
    branchID: int
    reasoning: str
    rank: int
    relevance: float

class Score(dspy.Signature):
    """Score each conversation turn with rank, relevance, and comments."""
    user_prompt: str = dspy.InputField(desc="user prompt for which we score conversation relevance")
    conversation_list: list[dict] = dspy.InputField(desc="List of conversation turns")
    scores_list: list[ScoreItem] = dspy.OutputField(desc="""List of ScoreItems, each with 'reasoning' (LLM's reasoning for that rank and score), 'rank' (1=highest), 'relevance' (0-1)
    conversation relevant to current user_prompt.
    return branchID to match to conversations.
    """)

class AttachmentPoint(pydantic.BaseModel):
    branch_id: int
    attachment_index: int
    reasoning: str
    score: float

class FindBestAttachment(dspy.Signature):
    """Find the best attachment point for a new conversation turn"""
    user_prompt: str = dspy.InputField(desc="The new user prompt to attach")
    candidate_points: List[Dict] = dspy.InputField(desc="List of potential attachment points with context")
    best_attachment: AttachmentPoint = dspy.OutputField(desc="The best attachment point with reasoning")

# ============= Tree Structure Classes =============

class ConversationNode:
    """Represents a node in the conversation tree"""
    def __init__(self, branch_id: int, messages: List[Dict], parent: Optional['ConversationNode'] = None):
        self.branch_id = branch_id
        self.messages = messages
        self.parent = parent
        self.children: List['ConversationNode'] = []
        
    def get_path_to_root(self) -> List['ConversationNode']:
        """Get all nodes from this node to root"""
        path = []
        current = self
        while current:
            path.append(current)
            current = current.parent
        return list(reversed(path))
    
    def get_conversation_window(self, start_idx: int = 0, window_size: int = 4) -> List[Dict]:
        """Get a window of conversation from this branch"""
        return self.messages[start_idx:start_idx + window_size]
    
    def get_last_n_turns(self, n: int = 4) -> List[Dict]:
        """Get the last n turns of this branch"""
        return self.messages[-n:] if len(self.messages) >= n else self.messages

# ============= Persistence Classes =============

class MessageRole(str, Enum):
    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"

@dataclass
class Message:
    """Individual message in a conversation"""
    id: str
    role: MessageRole
    content: str
    timestamp: datetime
    metadata: Optional[Dict[str, Any]] = None
    
    def to_dict(self) -> Dict:
        data = asdict(self)
        data['timestamp'] = self.timestamp.isoformat()
        data['role'] = self.role.value
        return data
    
    @classmethod
    def from_dict(cls, data: Dict) -> 'Message':
        data = data.copy()
        data['timestamp'] = datetime.fromisoformat(data['timestamp'])
        data['role'] = MessageRole(data['role'])
        return cls(**data)

@dataclass
class ConversationBranch:
    """A branch in the conversation tree"""
    id: str
    parent_id: Optional[str]
    messages: List[Message]
    created_at: datetime
    updated_at: datetime
    metadata: Optional[Dict[str, Any]] = None
    attachment_point: Optional[int] = None
    
    def to_dict(self) -> Dict:
        return {
            'id': self.id,
            'parent_id': self.parent_id,
            'messages': [msg.to_dict() for msg in self.messages],
            'created_at': self.created_at.isoformat(),
            'updated_at': self.updated_at.isoformat(),
            'metadata': self.metadata,
            'attachment_point': self.attachment_point
        }
    
    @classmethod
    def from_dict(cls, data: Dict) -> 'ConversationBranch':
        data = data.copy()
        data['messages'] = [Message.from_dict(msg) for msg in data['messages']]
        data['created_at'] = datetime.fromisoformat(data['created_at'])
        data['updated_at'] = datetime.fromisoformat(data['updated_at'])
        return cls(**data)

# ============= Main DSPy Module =============

class ConversationTreeModule(dspy.Module):
    """DSPy module for managing tree-based conversation branching"""
    
    def __init__(self):
        super().__init__()
        self.scorer = dspy.Predict(Score)
        self.attachment_finder = dspy.Predict(FindBestAttachment)
        self.tree_root = None
        self.branches: Dict[int, ConversationNode] = {}
        
    def build_tree(self, conversations: List[Dict]):
        """Build the conversation tree from the provided conversations"""
        self.tree_root = ConversationNode(0, [], None)
        
        for conv in conversations:
            branch_id = conv["branchID"]
            messages = conv["messages"]
            node = ConversationNode(branch_id, messages, self.tree_root)
            self.tree_root.children.append(node)
            self.branches[branch_id] = node
    
    def get_4_turn_windows_from_root(self) -> List[Dict]:
        """Get all 4-turn windows starting from root for each branch"""
        windows = []
        for branch_id, node in self.branches.items():
            if len(node.messages) >= 4:
                window = node.get_conversation_window(0, 4)
                windows.append({
                    "branchID": branch_id,
                    "messages": window,
                    "window_type": "from_root"
                })
        return windows
    
    def get_last_4_turns_from_tips(self) -> List[Dict]:
        """Get the last 4 turns from each branch tip"""
        windows = []
        for branch_id, node in self.branches.items():
            window = node.get_last_n_turns(4)
            if window:
                windows.append({
                    "branchID": branch_id,
                    "messages": window,
                    "window_type": "from_tip"
                })
        return windows
    
    def check_same_path(self, branch_id1: int, branch_id2: int) -> bool:
        """Check if two branches are on the same path"""
        if branch_id1 == branch_id2:
            return True
        
        node1 = self.branches.get(branch_id1)
        node2 = self.branches.get(branch_id2)
        
        if not node1 or not node2:
            return False
        
        path1 = set(n.branch_id for n in node1.get_path_to_root())
        path2 = set(n.branch_id for n in node2.get_path_to_root())
        
        return branch_id1 in path2 or branch_id2 in path1
    
    def explore_attachment_space(self, user_prompt: str, branch_id: int) -> AttachmentPoint:
        """Explore where in the branch to attach the new conversation"""
        node = self.branches[branch_id]
        candidate_points = []
        
        for i in range(len(node.messages) + 1):
            context_before = node.messages[:i] if i > 0 else []
            context_after = node.messages[i:i+2] if i < len(node.messages) else []
            
            candidate_points.append({
                "branch_id": branch_id,
                "attachment_index": i,
                "context_before": context_before[-2:],
                "context_after": context_after,
                "position_description": f"After message {i}" if i > 0 else "At branch start"
            })
        
        candidate_points.append({
            "branch_id": 0,
            "attachment_index": 0,
            "context_before": [],
            "context_after": [],
            "position_description": "At conversation root (new branch)"
        })
        
        result = self.attachment_finder(
            user_prompt=user_prompt,
            candidate_points=candidate_points
        )
        
        return result.best_attachment
    
    def forward(self, user_prompt: str, conversations: List[Dict]) -> Dict[str, Any]:
        """Main forward pass to find best conversation branch and attachment point"""
        
        if not self.branches:
            self.build_tree(conversations)
        
        root_windows = self.get_4_turn_windows_from_root()
        tip_windows = self.get_last_4_turns_from_tips()
        
        if not root_windows and not tip_windows:
            return {
                "best_attachment": AttachmentPoint(
                    branch_id=0,
                    attachment_index=0,
                    reasoning="No existing conversations, starting new branch from root",
                    score=1.0
                ),
                "context_messages": [],
                "root_best": None,
                "tip_best": None,
                "same_path": True
            }
        
        root_scores = []
        tip_scores = []
        
        if root_windows:
            root_scores = self.scorer(
                user_prompt=user_prompt,
                conversation_list=root_windows
            ).scores_list
        
        if tip_windows:
            tip_scores = self.scorer(
                user_prompt=user_prompt,
                conversation_list=tip_windows
            ).scores_list
        
        best_root = max(root_scores, key=lambda x: x.relevance) if root_scores else None
        best_tip = max(tip_scores, key=lambda x: x.relevance) if tip_scores else None
        
        if not best_root and best_tip:
            best_branch_id = best_tip.branchID
            attachment_point = AttachmentPoint(
                branch_id=best_branch_id,
                attachment_index=len(self.branches[best_branch_id].messages),
                reasoning="Only tip windows available, using best tip match",
                score=best_tip.relevance
            )
        elif best_root and not best_tip:
            best_branch_id = best_root.branchID
            attachment_point = AttachmentPoint(
                branch_id=best_branch_id,
                attachment_index=len(self.branches[best_branch_id].messages),
                reasoning="Only root windows available, using best root match",
                score=best_root.relevance
            )
        else:
            same_path = self.check_same_path(best_root.branchID, best_tip.branchID)
            
            if same_path:
                best_branch_id = best_root.branchID if best_root.relevance >= best_tip.relevance else best_tip.branchID
                attachment_point = AttachmentPoint(
                    branch_id=best_branch_id,
                    attachment_index=len(self.branches[best_branch_id].messages),
                    reasoning=f"Same path detected, using {'root' if best_root.relevance >= best_tip.relevance else 'tip'} window",
                    score=max(best_root.relevance, best_tip.relevance)
                )
            else:
                explore_branch_id = best_root.branchID if best_root.relevance >= best_tip.relevance else best_tip.branchID
                attachment_point = self.explore_attachment_space(user_prompt, explore_branch_id)
        
        if attachment_point.branch_id == 0:
            context_messages = []
        else:
            node = self.branches[attachment_point.branch_id]
            context_messages = node.messages[:attachment_point.attachment_index]
        
        return {
            "best_attachment": attachment_point,
            "context_messages": context_messages,
            "root_best": best_root,
            "tip_best": best_tip,
            "same_path": same_path if best_root and best_tip else True
        }

# ============= Database Layer =============

class ConversationTreeDB:
    """SQLite-based persistence for conversation trees"""
    
    def __init__(self, db_path: str = "conversation_tree.db"):
        self.db_path = db_path
        self._init_db()
    
    def _init_db(self):
        """Initialize database schema"""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute('''
                CREATE TABLE IF NOT EXISTS branches (
                    id TEXT PRIMARY KEY,
                    parent_id TEXT,
                    attachment_point INTEGER,
                    created_at TEXT NOT NULL,
                    updated_at TEXT NOT NULL,
                    metadata TEXT,
                    FOREIGN KEY (parent_id) REFERENCES branches(id)
                )
            ''')
            
            conn.execute('''
                CREATE TABLE IF NOT EXISTS messages (
                    id TEXT PRIMARY KEY,
                    branch_id TEXT NOT NULL,
                    position INTEGER NOT NULL,
                    role TEXT NOT NULL,
                    content TEXT NOT NULL,
                    timestamp TEXT NOT NULL,
                    metadata TEXT,
                    FOREIGN KEY (branch_id) REFERENCES branches(id),
                    UNIQUE(branch_id, position)
                )
            ''')
            
            conn.execute('''
                CREATE TABLE IF NOT EXISTS tree_metadata (
                    tree_id TEXT PRIMARY KEY,
                    root_branch_id TEXT,
                    created_at TEXT NOT NULL,
                    updated_at TEXT NOT NULL,
                    metadata TEXT
                )
            ''')
            
            conn.execute('CREATE INDEX IF NOT EXISTS idx_branch_parent ON branches(parent_id)')
            conn.execute('CREATE INDEX IF NOT EXISTS idx_message_branch ON messages(branch_id)')
    
    def save_branch(self, branch: ConversationBranch):
        """Save or update a branch"""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute('''
                INSERT OR REPLACE INTO branches 
                (id, parent_id, attachment_point, created_at, updated_at, metadata)
                VALUES (?, ?, ?, ?, ?, ?)
            ''', (
                branch.id,
                branch.parent_id,
                branch.attachment_point,
                branch.created_at.isoformat(),
                branch.updated_at.isoformat(),
                json.dumps(branch.metadata) if branch.metadata else None
            ))
            
            conn.execute('DELETE FROM messages WHERE branch_id = ?', (branch.id,))
            
            for position, msg in enumerate(branch.messages):
                conn.execute('''
                    INSERT INTO messages 
                    (id, branch_id, position, role, content, timestamp, metadata)
                    VALUES (?, ?, ?, ?, ?, ?, ?)
                ''', (
                    msg.id,
                    branch.id,
                    position,
                    msg.role.value,
                    msg.content,
                    msg.timestamp.isoformat(),
                    json.dumps(msg.metadata) if msg.metadata else None
                ))
    
    def load_branch(self, branch_id: str) -> Optional[ConversationBranch]:
        """Load a branch by ID"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute('''
                SELECT parent_id, attachment_point, created_at, updated_at, metadata
                FROM branches WHERE id = ?
            ''', (branch_id,))
            
            row = cursor.fetchone()
            if not row:
                return None
            
            parent_id, attachment_point, created_at, updated_at, metadata = row
            
            cursor = conn.execute('''
                SELECT id, role, content, timestamp, metadata
                FROM messages WHERE branch_id = ?
                ORDER BY position
            ''', (branch_id,))
            
            messages = []
            for msg_row in cursor:
                msg_id, role, content, timestamp, msg_metadata = msg_row
                messages.append(Message(
                    id=msg_id,
                    role=MessageRole(role),
                    content=content,
                    timestamp=datetime.fromisoformat(timestamp),
                    metadata=json.loads(msg_metadata) if msg_metadata else None
                ))
            
            return ConversationBranch(
                id=branch_id,
                parent_id=parent_id,
                messages=messages,
                created_at=datetime.fromisoformat(created_at),
                updated_at=datetime.fromisoformat(updated_at),
                metadata=json.loads(metadata) if metadata else None,
                attachment_point=attachment_point
            )
    
    def get_children(self, branch_id: str) -> List[str]:
        """Get all child branch IDs for a given branch"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute(
                'SELECT id FROM branches WHERE parent_id = ?',
                (branch_id,)
            )
            return [row[0] for row in cursor]

# ============= Manager Class =============

class ConversationTreeManager:
    """Main manager for conversation tree operations"""
    
    def __init__(self, db_path: str = "conversation_tree.db", cache_size: int = 100):
        self.db = ConversationTreeDB(db_path)
        self.cache: Dict[str, ConversationBranch] = {}
        self.cache_size = cache_size
        self.tree_id = str(uuid.uuid4())
        self.root_branch_id: Optional[str] = None
    
    def create_root_branch(self) -> ConversationBranch:
        """Create the root branch of the conversation tree"""
        root = ConversationBranch(
            id=str(uuid.uuid4()),
            parent_id=None,
            messages=[],
            created_at=datetime.now(),
            updated_at=datetime.now(),
            metadata={"is_root": True}
        )
        self.db.save_branch(root)
        self.root_branch_id = root.id
        self._cache_branch(root)
        return root
    
    def add_message(self, branch_id: str, role: MessageRole, content: str, 
                   metadata: Optional[Dict] = None) -> Message:
        """Add a message to an existing branch"""
        branch = self.get_branch(branch_id)
        if not branch:
            raise ValueError(f"Branch {branch_id} not found")
        
        message = Message(
            id=str(uuid.uuid4()),
            role=role,
            content=content,
            timestamp=datetime.now(),
            metadata=metadata
        )
        
        branch.messages.append(message)
        branch.updated_at = datetime.now()
        
        self.db.save_branch(branch)
        self._cache_branch(branch)
        
        return message
    
    def create_branch(self, parent_id: str, attachment_point: int,
                     initial_messages: Optional[List[Dict]] = None) -> ConversationBranch:
        """Create a new branch from a parent at a specific attachment point"""
        parent = self.get_branch(parent_id)
        if not parent:
            raise ValueError(f"Parent branch {parent_id} not found")
        
        if attachment_point > len(parent.messages):
            raise ValueError(f"Invalid attachment point {attachment_point}")
        
        branch = ConversationBranch(
            id=str(uuid.uuid4()),
            parent_id=parent_id,
            messages=[],
            created_at=datetime.now(),
            updated_at=datetime.now(),
            attachment_point=attachment_point
        )
        
        if initial_messages:
            for msg_data in initial_messages:
                message = Message(
                    id=str(uuid.uuid4()),
                    role=MessageRole(msg_data['role']),
                    content=msg_data['content'],
                    timestamp=datetime.now(),
                    metadata=msg_data.get('metadata')
                )
                branch.messages.append(message)
        
        self.db.save_branch(branch)
        self._cache_branch(branch)
        
        return branch
    
    def get_branch(self, branch_id: str) -> Optional[ConversationBranch]:
        """Get a branch by ID (with caching)"""
        if branch_id in self.cache:
            return self.cache[branch_id]
        
        branch = self.db.load_branch(branch_id)
        if branch:
            self._cache_branch(branch)
        
        return branch
    
    def get_conversation_context(self, branch_id: str) -> List[Message]:
        """Get full conversation context from root to branch"""
        context = []
        current_id = branch_id
        
        path = []
        while current_id:
            branch = self.get_branch(current_id)
            if not branch:
                break
            path.append(branch)
            current_id = branch.parent_id
        
        path.reverse()
        
        for i, branch in enumerate(path):
            if i == 0:
                context.extend(branch.messages)
            else:
                parent = path[i-1]
                attachment_point = branch.attachment_point or len(parent.messages)
                
                if i == 1:
                    context = parent.messages[:attachment_point]
                
                context.extend(branch.messages)
        
        return context
    
    def reconstruct_tree(self) -> Dict[str, Any]:
        """Reconstruct the entire tree structure"""
        with sqlite3.connect(self.db.db_path) as conn:
            cursor = conn.execute('''
                SELECT id, parent_id, attachment_point, created_at, updated_at
                FROM branches
                ORDER BY created_at
            ''')
            
            tree = {"branches": {}, "root": None}
            
            for row in cursor:
                branch_id, parent_id, attachment_point, created_at, updated_at = row
                
                if parent_id is None:
                    tree["root"] = branch_id
                
                tree["branches"][branch_id] = {
                    "id": branch_id,
                    "parent_id": parent_id,
                    "attachment_point": attachment_point,
                    "children": self.db.get_children(branch_id),
                    "created_at": created_at,
                    "updated_at": updated_at
                }
            
            return tree
    
    def export_tree(self, filepath: str):
        """Export entire tree to JSON file"""
        tree_data = {"tree_id": self.tree_id, "branches": []}
        
        with sqlite3.connect(self.db.db_path) as conn:
            cursor = conn.execute('SELECT id FROM branches')
            branch_ids = [row[0] for row in cursor]
        
        for branch_id in branch_ids:
            branch = self.get_branch(branch_id)
            if branch:
                tree_data["branches"].append(branch.to_dict())
        
        with open(filepath, 'w') as f:
            json.dump(tree_data, f, indent=2)
    
    def import_tree(self, filepath: str):
        """Import tree from JSON file"""
        with open(filepath, 'r') as f:
            tree_data = json.load(f)
        
        self.tree_id = tree_data.get("tree_id", str(uuid.uuid4()))
        
        for branch_data in tree_data["branches"]:
            branch = ConversationBranch.from_dict(branch_data)
            self.db.save_branch(branch)
            
            if branch.parent_id is None:
                self.root_branch_id = branch.id
    
    def _cache_branch(self, branch: ConversationBranch):
        """Add branch to cache with LRU eviction"""
        if len(self.cache) >= self.cache_size:
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]
        
        self.cache[branch.id] = branch

# ============= Persistent DSPy Module =============

class PersistentConversationTreeModule(dspy.Module):
    """Extended module with persistence"""
    
    def __init__(self, db_path: str = "conversation_tree.db"):
        super().__init__()
        self.manager = ConversationTreeManager(db_path)
        self.base_module = ConversationTreeModule()
        
        tree_structure = self.manager.reconstruct_tree()
        if not tree_structure["root"]:
            root = self.manager.create_root_branch()
            self.manager.add_message(
                root.id, 
                MessageRole.SYSTEM, 
                "You are a helpful AI assistant that maintains context across branching conversations."
            )
        
    def forward(self, user_prompt: str) -> Dict[str, Any]:
        """Forward with automatic persistence"""
        tree_structure = self.manager.reconstruct_tree()
        
        conversations = []
        for branch_id, branch_info in tree_structure["branches"].items():
            branch = self.manager.get_branch(branch_id)
            if branch and branch.messages:
                conversations.append({
                    "branchID": branch_id,
                    "messages": [{"role": msg.role.value, "content": msg.content} 
                                for msg in branch.messages]
                })
        
        if not conversations or all(len(c["messages"]) <= 1 for c in conversations):
            root_id = tree_structure["root"]
            if not root_id:
                root = self.manager.create_root_branch()
                root_id = root.id
            
            new_branch = self.manager.create_branch(
                parent_id=root_id,
                attachment_point=1,
                initial_messages=[{"role": "user", "content": user_prompt}]
            )
            
            return {
                "best_attachment": AttachmentPoint(
                    branch_id=new_branch.id,
                    attachment_index=1,
                    reasoning="First conversation in tree",
                    score=1.0
                ),
                "context_messages": self.manager.get_conversation_context(new_branch.id),
                "new_branch_id": new_branch.id,
                "root_best": None,
                "tip_best": None,
                "same_path": True
            }
        
        result = self.base_module.forward(user_prompt, conversations)
        
        attachment = result["best_attachment"]
        
        if attachment.branch_id == "0" or attachment.branch_id == 0:
            root_id = tree_structure["root"]
            new_branch = self.manager.create_branch(
                parent_id=root_id,
                attachment_point=1,
                initial_messages=[{"role": "user", "content": user_prompt}]
            )
            result["new_branch_id"] = new_branch.id
        else:
            existing_branch = self.manager.get_branch(str(attachment.branch_id))
            if attachment.attachment_index == len(existing_branch.messages):
                self.manager.add_message(
                    str(attachment.branch_id),
                    MessageRole.USER,
                    user_prompt
                )
                result["new_branch_id"] = str(attachment.branch_id)
            else:
                new_branch = self.manager.create_branch(
                    parent_id=str(attachment.branch_id),
                    attachment_point=attachment.attachment_index,
                    initial_messages=[{"role": "user", "content": user_prompt}]
                )
                result["new_branch_id"] = new_branch.id
        
        result["context_messages"] = self.manager.get_conversation_context(result["new_branch_id"])
        
        return result


In [17]:
# Configure DSPy with your LM
# lm = dspy.LM("groq/llama-70b-8192")  # Or your preferred model
# dspy.configure(lm=lm)

# Initialize the persistent conversation tree
tree = PersistentConversationTreeModule("my_conversations.db")

# Process first user message
result = tree("Tell me about mindfulness techniques")
print(f"First message attached to branch: {result['new_branch_id']}")
print(f"Reasoning: {result['best_attachment'].reasoning}")

ValidationError: 1 validation error for AttachmentPoint
branch_id
  Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='f92795e4-650a-4392-bbd9-629dd945b88b', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/int_parsing

In [None]:
# Add assistant response
tree.manager.add_message(
    result['new_branch_id'],
    MessageRole.ASSISTANT,
    "Mindfulness techniques include meditation, deep breathing, body scans, and present-moment awareness. Would you like me to explain any of these in detail?"
)

# Continue conversation
result2 = tree("Yes, tell me about deep breathing")
print(f"\nSecond message attached to branch: {result2['new_branch_id']}")

# Start a different topic
result3 = tree("I need help with Python programming")
print(f"\nNew topic attached to branch: {result3['new_branch_id']}")

# Show tree structure
tree_structure = tree.manager.reconstruct_tree()
print(f"\nTree structure:")
print(f"- Root: {tree_structure['root']}")
print(f"- Total branches: {len(tree_structure['branches'])}")

# Export tree
tree.manager.export_tree("conversation_backup.json")
print("\nTree exported to conversation_backup.json")

In [18]:
# Initialize (creates DB if needed)
tree = PersistentConversationTreeModule("chat.db")

# First message creates initial branch
result = tree("Tell me about stress management")

ValidationError: 1 validation error for AttachmentPoint
branch_id
  Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='468ec9ab-ded2-435e-9080-dcb24a4682ab', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/int_parsing

In [None]:
# Continue conversation
tree.manager.add_message(
    result['new_branch_id'],
    MessageRole.ASSISTANT,
    "I can help with stress management techniques..."
)

# New user message finds best attachment point
result2 = tree("What about breathing exercises?")

# Different topic creates new branch
result3 = tree("Help me with Python coding")

# View full context
context = tree.manager.get_conversation_context(result['new_branch_id'])

In [19]:
"""
Simple usage example for the Conversation Tree System

Save the complete_conversation_tree.py file first, then run this example.
"""

import dspy

# If you haven't saved the complete module yet, uncomment this line:
# exec(open('complete_conversation_tree.py').read())

# Or import it normally if saved as a file:
from complete_conversation_tree import PersistentConversationTreeModule, MessageRole

ModuleNotFoundError: No module named 'complete_conversation_tree'

In [23]:
"""
Complete Conversation Tree System with DSPy Integration
Single file version combining tree logic and persistence
"""

import dspy
import pydantic
import json
import sqlite3
from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple, Union
from dataclasses import dataclass, asdict
import uuid
from collections import defaultdict
from enum import Enum

# ============= DSPy Signatures and Models =============

class ScoreItem(pydantic.BaseModel):
    branchID: Union[int, str]  # Support both int and str
    reasoning: str
    rank: int
    relevance: float

class Score(dspy.Signature):
    """Score each conversation turn with rank, relevance, and comments."""
    user_prompt: str = dspy.InputField(desc="user prompt for which we score conversation relevance")
    conversation_list: list[dict] = dspy.InputField(desc="List of conversation turns")
    scores_list: list[ScoreItem] = dspy.OutputField(desc="""List of ScoreItems, each with 'reasoning' (LLM's reasoning for that rank and score), 'rank' (1=highest), 'relevance' (0-1)
    conversation relevant to current user_prompt.
    return branchID to match to conversations.
    """)

class AttachmentPoint(pydantic.BaseModel):
    branch_id: Union[int, str]  # Support both int (for tree module) and str (for persistence)
    attachment_index: int
    reasoning: str
    score: float

class FindBestAttachment(dspy.Signature):
    """Find the best attachment point for a new conversation turn"""
    user_prompt: str = dspy.InputField(desc="The new user prompt to attach")
    candidate_points: List[Dict] = dspy.InputField(desc="List of potential attachment points with context")
    best_attachment: AttachmentPoint = dspy.OutputField(desc="The best attachment point with reasoning")

# ============= Tree Structure Classes =============

class ConversationNode:
    """Represents a node in the conversation tree"""
    def __init__(self, branch_id: Union[int, str], messages: List[Dict], parent: Optional['ConversationNode'] = None):
        self.branch_id = branch_id
        self.messages = messages
        self.parent = parent
        self.children: List['ConversationNode'] = []
        
    def get_path_to_root(self) -> List['ConversationNode']:
        """Get all nodes from this node to root"""
        path = []
        current = self
        while current:
            path.append(current)
            current = current.parent
        return list(reversed(path))
    
    def get_conversation_window(self, start_idx: int = 0, window_size: int = 4) -> List[Dict]:
        """Get a window of conversation from this branch"""
        return self.messages[start_idx:start_idx + window_size]
    
    def get_last_n_turns(self, n: int = 4) -> List[Dict]:
        """Get the last n turns of this branch"""
        return self.messages[-n:] if len(self.messages) >= n else self.messages

# ============= Persistence Classes =============

class MessageRole(str, Enum):
    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"

@dataclass
class Message:
    """Individual message in a conversation"""
    id: str
    role: MessageRole
    content: str
    timestamp: datetime
    metadata: Optional[Dict[str, Any]] = None
    
    def to_dict(self) -> Dict:
        data = asdict(self)
        data['timestamp'] = self.timestamp.isoformat()
        data['role'] = self.role.value
        return data
    
    @classmethod
    def from_dict(cls, data: Dict) -> 'Message':
        data = data.copy()
        data['timestamp'] = datetime.fromisoformat(data['timestamp'])
        data['role'] = MessageRole(data['role'])
        return cls(**data)

@dataclass
class ConversationBranch:
    """A branch in the conversation tree"""
    id: str
    parent_id: Optional[str]
    messages: List[Message]
    created_at: datetime
    updated_at: datetime
    metadata: Optional[Dict[str, Any]] = None
    attachment_point: Optional[int] = None
    
    def to_dict(self) -> Dict:
        return {
            'id': self.id,
            'parent_id': self.parent_id,
            'messages': [msg.to_dict() for msg in self.messages],
            'created_at': self.created_at.isoformat(),
            'updated_at': self.updated_at.isoformat(),
            'metadata': self.metadata,
            'attachment_point': self.attachment_point
        }
    
    @classmethod
    def from_dict(cls, data: Dict) -> 'ConversationBranch':
        data = data.copy()
        data['messages'] = [Message.from_dict(msg) for msg in data['messages']]
        data['created_at'] = datetime.fromisoformat(data['created_at'])
        data['updated_at'] = datetime.fromisoformat(data['updated_at'])
        return cls(**data)

# ============= Main DSPy Module =============

class ConversationTreeModule(dspy.Module):
    """DSPy module for managing tree-based conversation branching"""
    
    def __init__(self):
        super().__init__()
        self.scorer = dspy.Predict(Score)
        self.attachment_finder = dspy.Predict(FindBestAttachment)
        self.tree_root = None
        self.branches: Dict[Union[int, str], ConversationNode] = {}
        
    def build_tree(self, conversations: List[Dict]):
        """Build the conversation tree from the provided conversations"""
        self.tree_root = ConversationNode(0, [], None)
        
        for conv in conversations:
            branch_id = conv["branchID"]
            messages = conv["messages"]
            node = ConversationNode(branch_id, messages, self.tree_root)
            self.tree_root.children.append(node)
            self.branches[branch_id] = node
    
    def get_4_turn_windows_from_root(self) -> List[Dict]:
        """Get all 4-turn windows starting from root for each branch"""
        windows = []
        for branch_id, node in self.branches.items():
            if len(node.messages) >= 4:
                window = node.get_conversation_window(0, 4)
                windows.append({
                    "branchID": branch_id,
                    "messages": window,
                    "window_type": "from_root"
                })
        return windows
    
    def get_last_4_turns_from_tips(self) -> List[Dict]:
        """Get the last 4 turns from each branch tip"""
        windows = []
        for branch_id, node in self.branches.items():
            window = node.get_last_n_turns(4)
            if window:
                windows.append({
                    "branchID": branch_id,
                    "messages": window,
                    "window_type": "from_tip"
                })
        return windows
    
    def check_same_path(self, branch_id1: Union[int, str], branch_id2: Union[int, str]) -> bool:
        """Check if two branches are on the same path"""
        if branch_id1 == branch_id2:
            return True
        
        node1 = self.branches.get(branch_id1)
        node2 = self.branches.get(branch_id2)
        
        if not node1 or not node2:
            return False
        
        path1 = set(n.branch_id for n in node1.get_path_to_root())
        path2 = set(n.branch_id for n in node2.get_path_to_root())
        
        return branch_id1 in path2 or branch_id2 in path1
    
    def explore_attachment_space(self, user_prompt: str, branch_id: Union[int, str]) -> AttachmentPoint:
        """Explore where in the branch to attach the new conversation"""
        node = self.branches[branch_id]
        candidate_points = []
        
        for i in range(len(node.messages) + 1):
            context_before = node.messages[:i] if i > 0 else []
            context_after = node.messages[i:i+2] if i < len(node.messages) else []
            
            candidate_points.append({
                "branch_id": branch_id,
                "attachment_index": i,
                "context_before": context_before[-2:],
                "context_after": context_after,
                "position_description": f"After message {i}" if i > 0 else "At branch start"
            })
        
        candidate_points.append({
            "branch_id": 0,
            "attachment_index": 0,
            "context_before": [],
            "context_after": [],
            "position_description": "At conversation root (new branch)"
        })
        
        result = self.attachment_finder(
            user_prompt=user_prompt,
            candidate_points=candidate_points
        )
        
        return result.best_attachment
    
    def forward(self, user_prompt: str, conversations: List[Dict]) -> Dict[str, Any]:
        """Main forward pass to find best conversation branch and attachment point"""
        
        if not self.branches:
            self.build_tree(conversations)
        
        root_windows = self.get_4_turn_windows_from_root()
        tip_windows = self.get_last_4_turns_from_tips()
        
        if not root_windows and not tip_windows:
            return {
                "best_attachment": AttachmentPoint(
                    branch_id=0,
                    attachment_index=0,
                    reasoning="No existing conversations, starting new branch from root",
                    score=1.0
                ),
                "context_messages": [],
                "root_best": None,
                "tip_best": None,
                "same_path": True
            }
        
        root_scores = []
        tip_scores = []
        
        if root_windows:
            root_scores = self.scorer(
                user_prompt=user_prompt,
                conversation_list=root_windows
            ).scores_list
        
        if tip_windows:
            tip_scores = self.scorer(
                user_prompt=user_prompt,
                conversation_list=tip_windows
            ).scores_list
        
        best_root = max(root_scores, key=lambda x: x.relevance) if root_scores else None
        best_tip = max(tip_scores, key=lambda x: x.relevance) if tip_scores else None
        
        if not best_root and best_tip:
            best_branch_id = best_tip.branchID
            attachment_point = AttachmentPoint(
                branch_id=best_branch_id,
                attachment_index=len(self.branches[best_branch_id].messages),
                reasoning="Only tip windows available, using best tip match",
                score=best_tip.relevance
            )
        elif best_root and not best_tip:
            best_branch_id = best_root.branchID
            attachment_point = AttachmentPoint(
                branch_id=best_branch_id,
                attachment_index=len(self.branches[best_branch_id].messages),
                reasoning="Only root windows available, using best root match",
                score=best_root.relevance
            )
        else:
            same_path = self.check_same_path(best_root.branchID, best_tip.branchID)
            
            if same_path:
                best_branch_id = best_root.branchID if best_root.relevance >= best_tip.relevance else best_tip.branchID
                attachment_point = AttachmentPoint(
                    branch_id=best_branch_id,
                    attachment_index=len(self.branches[best_branch_id].messages),
                    reasoning=f"Same path detected, using {'root' if best_root.relevance >= best_tip.relevance else 'tip'} window",
                    score=max(best_root.relevance, best_tip.relevance)
                )
            else:
                explore_branch_id = best_root.branchID if best_root.relevance >= best_tip.relevance else best_tip.branchID
                attachment_point = self.explore_attachment_space(user_prompt, explore_branch_id)
        
        if attachment_point.branch_id == 0:
            context_messages = []
        else:
            node = self.branches[attachment_point.branch_id]
            context_messages = node.messages[:attachment_point.attachment_index]
        
        return {
            "best_attachment": attachment_point,
            "context_messages": context_messages,
            "root_best": best_root,
            "tip_best": best_tip,
            "same_path": same_path if best_root and best_tip else True
        }

# ============= Database Layer =============

class ConversationTreeDB:
    """SQLite-based persistence for conversation trees"""
    
    def __init__(self, db_path: str = "conversation_tree.db"):
        self.db_path = db_path
        self._init_db()
    
    def _init_db(self):
        """Initialize database schema"""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute('''
                CREATE TABLE IF NOT EXISTS branches (
                    id TEXT PRIMARY KEY,
                    parent_id TEXT,
                    attachment_point INTEGER,
                    created_at TEXT NOT NULL,
                    updated_at TEXT NOT NULL,
                    metadata TEXT,
                    FOREIGN KEY (parent_id) REFERENCES branches(id)
                )
            ''')
            
            conn.execute('''
                CREATE TABLE IF NOT EXISTS messages (
                    id TEXT PRIMARY KEY,
                    branch_id TEXT NOT NULL,
                    position INTEGER NOT NULL,
                    role TEXT NOT NULL,
                    content TEXT NOT NULL,
                    timestamp TEXT NOT NULL,
                    metadata TEXT,
                    FOREIGN KEY (branch_id) REFERENCES branches(id),
                    UNIQUE(branch_id, position)
                )
            ''')
            
            conn.execute('''
                CREATE TABLE IF NOT EXISTS tree_metadata (
                    tree_id TEXT PRIMARY KEY,
                    root_branch_id TEXT,
                    created_at TEXT NOT NULL,
                    updated_at TEXT NOT NULL,
                    metadata TEXT
                )
            ''')
            
            conn.execute('CREATE INDEX IF NOT EXISTS idx_branch_parent ON branches(parent_id)')
            conn.execute('CREATE INDEX IF NOT EXISTS idx_message_branch ON messages(branch_id)')
    
    def save_branch(self, branch: ConversationBranch):
        """Save or update a branch"""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute('''
                INSERT OR REPLACE INTO branches 
                (id, parent_id, attachment_point, created_at, updated_at, metadata)
                VALUES (?, ?, ?, ?, ?, ?)
            ''', (
                branch.id,
                branch.parent_id,
                branch.attachment_point,
                branch.created_at.isoformat(),
                branch.updated_at.isoformat(),
                json.dumps(branch.metadata) if branch.metadata else None
            ))
            
            conn.execute('DELETE FROM messages WHERE branch_id = ?', (branch.id,))
            
            for position, msg in enumerate(branch.messages):
                conn.execute('''
                    INSERT INTO messages 
                    (id, branch_id, position, role, content, timestamp, metadata)
                    VALUES (?, ?, ?, ?, ?, ?, ?)
                ''', (
                    msg.id,
                    branch.id,
                    position,
                    msg.role.value,
                    msg.content,
                    msg.timestamp.isoformat(),
                    json.dumps(msg.metadata) if msg.metadata else None
                ))
    
    def load_branch(self, branch_id: str) -> Optional[ConversationBranch]:
        """Load a branch by ID"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute('''
                SELECT parent_id, attachment_point, created_at, updated_at, metadata
                FROM branches WHERE id = ?
            ''', (branch_id,))
            
            row = cursor.fetchone()
            if not row:
                return None
            
            parent_id, attachment_point, created_at, updated_at, metadata = row
            
            cursor = conn.execute('''
                SELECT id, role, content, timestamp, metadata
                FROM messages WHERE branch_id = ?
                ORDER BY position
            ''', (branch_id,))
            
            messages = []
            for msg_row in cursor:
                msg_id, role, content, timestamp, msg_metadata = msg_row
                messages.append(Message(
                    id=msg_id,
                    role=MessageRole(role),
                    content=content,
                    timestamp=datetime.fromisoformat(timestamp),
                    metadata=json.loads(msg_metadata) if msg_metadata else None
                ))
            
            return ConversationBranch(
                id=branch_id,
                parent_id=parent_id,
                messages=messages,
                created_at=datetime.fromisoformat(created_at),
                updated_at=datetime.fromisoformat(updated_at),
                metadata=json.loads(metadata) if metadata else None,
                attachment_point=attachment_point
            )
    
    def get_children(self, branch_id: str) -> List[str]:
        """Get all child branch IDs for a given branch"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute(
                'SELECT id FROM branches WHERE parent_id = ?',
                (branch_id,)
            )
            return [row[0] for row in cursor]

# ============= Manager Class =============

class ConversationTreeManager:
    """Main manager for conversation tree operations"""
    
    def __init__(self, db_path: str = "conversation_tree.db", cache_size: int = 100):
        self.db = ConversationTreeDB(db_path)
        self.cache: Dict[str, ConversationBranch] = {}
        self.cache_size = cache_size
        self.tree_id = str(uuid.uuid4())
        self.root_branch_id: Optional[str] = None
    
    def create_root_branch(self) -> ConversationBranch:
        """Create the root branch of the conversation tree"""
        root = ConversationBranch(
            id=str(uuid.uuid4()),
            parent_id=None,
            messages=[],
            created_at=datetime.now(),
            updated_at=datetime.now(),
            metadata={"is_root": True}
        )
        self.db.save_branch(root)
        self.root_branch_id = root.id
        self._cache_branch(root)
        return root
    
    def add_message(self, branch_id: str, role: MessageRole, content: str, 
                   metadata: Optional[Dict] = None) -> Message:
        """Add a message to an existing branch"""
        branch = self.get_branch(branch_id)
        if not branch:
            raise ValueError(f"Branch {branch_id} not found")
        
        message = Message(
            id=str(uuid.uuid4()),
            role=role,
            content=content,
            timestamp=datetime.now(),
            metadata=metadata
        )
        
        branch.messages.append(message)
        branch.updated_at = datetime.now()
        
        self.db.save_branch(branch)
        self._cache_branch(branch)
        
        return message
    
    def create_branch(self, parent_id: str, attachment_point: int,
                     initial_messages: Optional[List[Dict]] = None) -> ConversationBranch:
        """Create a new branch from a parent at a specific attachment point"""
        parent = self.get_branch(parent_id)
        if not parent:
            raise ValueError(f"Parent branch {parent_id} not found")
        
        if attachment_point > len(parent.messages):
            raise ValueError(f"Invalid attachment point {attachment_point}")
        
        branch = ConversationBranch(
            id=str(uuid.uuid4()),
            parent_id=parent_id,
            messages=[],
            created_at=datetime.now(),
            updated_at=datetime.now(),
            attachment_point=attachment_point
        )
        
        if initial_messages:
            for msg_data in initial_messages:
                message = Message(
                    id=str(uuid.uuid4()),
                    role=MessageRole(msg_data['role']),
                    content=msg_data['content'],
                    timestamp=datetime.now(),
                    metadata=msg_data.get('metadata')
                )
                branch.messages.append(message)
        
        self.db.save_branch(branch)
        self._cache_branch(branch)
        
        return branch
    
    def get_branch(self, branch_id: str) -> Optional[ConversationBranch]:
        """Get a branch by ID (with caching)"""
        if branch_id in self.cache:
            return self.cache[branch_id]
        
        branch = self.db.load_branch(branch_id)
        if branch:
            self._cache_branch(branch)
        
        return branch
    
    def get_conversation_context(self, branch_id: str) -> List[Message]:
        """Get full conversation context from root to branch"""
        context = []
        current_id = branch_id
        
        path = []
        while current_id:
            branch = self.get_branch(current_id)
            if not branch:
                break
            path.append(branch)
            current_id = branch.parent_id
        
        path.reverse()
        
        for i, branch in enumerate(path):
            if i == 0:
                context.extend(branch.messages)
            else:
                parent = path[i-1]
                attachment_point = branch.attachment_point or len(parent.messages)
                
                if i == 1:
                    context = parent.messages[:attachment_point]
                
                context.extend(branch.messages)
        
        return context
    
    def reconstruct_tree(self) -> Dict[str, Any]:
        """Reconstruct the entire tree structure"""
        with sqlite3.connect(self.db.db_path) as conn:
            cursor = conn.execute('''
                SELECT id, parent_id, attachment_point, created_at, updated_at
                FROM branches
                ORDER BY created_at
            ''')
            
            tree = {"branches": {}, "root": None}
            
            for row in cursor:
                branch_id, parent_id, attachment_point, created_at, updated_at = row
                
                if parent_id is None:
                    tree["root"] = branch_id
                
                tree["branches"][branch_id] = {
                    "id": branch_id,
                    "parent_id": parent_id,
                    "attachment_point": attachment_point,
                    "children": self.db.get_children(branch_id),
                    "created_at": created_at,
                    "updated_at": updated_at
                }
            
            return tree
    
    def export_tree(self, filepath: str):
        """Export entire tree to JSON file"""
        tree_data = {"tree_id": self.tree_id, "branches": []}
        
        with sqlite3.connect(self.db.db_path) as conn:
            cursor = conn.execute('SELECT id FROM branches')
            branch_ids = [row[0] for row in cursor]
        
        for branch_id in branch_ids:
            branch = self.get_branch(branch_id)
            if branch:
                tree_data["branches"].append(branch.to_dict())
        
        with open(filepath, 'w') as f:
            json.dump(tree_data, f, indent=2)
    
    def import_tree(self, filepath: str):
        """Import tree from JSON file"""
        with open(filepath, 'r') as f:
            tree_data = json.load(f)
        
        self.tree_id = tree_data.get("tree_id", str(uuid.uuid4()))
        
        for branch_data in tree_data["branches"]:
            branch = ConversationBranch.from_dict(branch_data)
            self.db.save_branch(branch)
            
            if branch.parent_id is None:
                self.root_branch_id = branch.id
    
    def _cache_branch(self, branch: ConversationBranch):
        """Add branch to cache with LRU eviction"""
        if len(self.cache) >= self.cache_size:
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]
        
        self.cache[branch.id] = branch

# ============= Persistent DSPy Module =============

class PersistentConversationTreeModule(dspy.Module):
    """Extended module with persistence"""
    
    def __init__(self, db_path: str = "conversation_tree.db"):
        super().__init__()
        self.manager = ConversationTreeManager(db_path)
        self.base_module = ConversationTreeModule()
        
        tree_structure = self.manager.reconstruct_tree()
        if not tree_structure["root"]:
            root = self.manager.create_root_branch()
            self.manager.add_message(
                root.id, 
                MessageRole.SYSTEM, 
                "You are a helpful AI assistant that maintains context across branching conversations."
            )
        
    def forward(self, user_prompt: str) -> Dict[str, Any]:
        """Forward with automatic persistence"""
        tree_structure = self.manager.reconstruct_tree()
        
        conversations = []
        branch_id_map = {}  # Map string IDs to integers for the base module
        
        for idx, (branch_id, branch_info) in enumerate(tree_structure["branches"].items()):
            branch = self.manager.get_branch(branch_id)
            if branch and branch.messages:
                # Use integer IDs for the base module
                int_id = idx + 1
                branch_id_map[int_id] = branch_id
                conversations.append({
                    "branchID": int_id,
                    "messages": [{"role": msg.role.value, "content": msg.content} 
                                for msg in branch.messages]
                })
        
        if not conversations or all(len(c["messages"]) <= 1 for c in conversations):
            root_id = tree_structure["root"]
            if not root_id:
                root = self.manager.create_root_branch()
                root_id = root.id
            
            new_branch = self.manager.create_branch(
                parent_id=root_id,
                attachment_point=1,
                initial_messages=[{"role": "user", "content": user_prompt}]
            )
            
            return {
                "best_attachment": AttachmentPoint(
                    branch_id=new_branch.id,
                    attachment_index=1,
                    reasoning="First conversation in tree",
                    score=1.0
                ),
                "context_messages": self.manager.get_conversation_context(new_branch.id),
                "new_branch_id": new_branch.id,
                "root_best": None,
                "tip_best": None,
                "same_path": True
            }
        
        result = self.base_module.forward(user_prompt, conversations)
        
        attachment = result["best_attachment"]
        
        # Convert integer branch_id back to string UUID
        if isinstance(attachment.branch_id, int) and attachment.branch_id != 0:
            actual_branch_id = branch_id_map.get(attachment.branch_id)
        else:
            actual_branch_id = attachment.branch_id
        
        if attachment.branch_id == "0" or attachment.branch_id == 0:
            root_id = tree_structure["root"]
            new_branch = self.manager.create_branch(
                parent_id=root_id,
                attachment_point=1,
                initial_messages=[{"role": "user", "content": user_prompt}]
            )
            result["new_branch_id"] = new_branch.id
        else:
            existing_branch = self.manager.get_branch(str(actual_branch_id))
            if attachment.attachment_index == len(existing_branch.messages):
                self.manager.add_message(
                    str(actual_branch_id),
                    MessageRole.USER,
                    user_prompt
                )
                result["new_branch_id"] = str(actual_branch_id)
            else:
                new_branch = self.manager.create_branch(
                    parent_id=str(actual_branch_id),
                    attachment_point=attachment.attachment_index,
                    initial_messages=[{"role": "user", "content": user_prompt}]
                )
                result["new_branch_id"] = new_branch.id
        
        result["context_messages"] = self.manager.get_conversation_context(result["new_branch_id"])
        
        return result

# ============= Example Usage =============

if __name__ == "__main__":
    # Configure DSPy with your LM
    # lm = dspy.LM("groq/llama-70b-8192")  # Or your preferred model
    # dspy.configure(lm=lm)
    
    # Initialize the persistent conversation tree
    tree = PersistentConversationTreeModule("my_conversations.db")
    
    # Process first user message
    result = tree("Tell me about mindfulness techniques")
    print(f"First message attached to branch: {result['new_branch_id']}")
    print(f"Reasoning: {result['best_attachment'].reasoning}")
    
    # Add assistant response
    tree.manager.add_message(
        result['new_branch_id'],
        MessageRole.ASSISTANT,
        "Mindfulness techniques include meditation, deep breathing, body scans, and present-moment awareness. Would you like me to explain any of these in detail?"
    )
    
    # Continue conversation
    result2 = tree("Yes, tell me about deep breathing")
    print(f"\nSecond message attached to branch: {result2['new_branch_id']}")
    
    # Start a different topic
    result3 = tree("I need help with Python programming")
    print(f"\nNew topic attached to branch: {result3['new_branch_id']}")
    
    # Show tree structure
    tree_structure = tree.manager.reconstruct_tree()
    print(f"\nTree structure:")
    print(f"- Root: {tree_structure['root']}")
    print(f"- Total branches: {len(tree_structure['branches'])}")
    
    # Export tree
    tree.manager.export_tree("conversation_backup.json")
    print("\nTree exported to conversation_backup.json")



First message attached to branch: 9981d485-c148-4bd8-829c-0a64c4b5dc97
Reasoning: First conversation in tree





Second message attached to branch: 9981d485-c148-4bd8-829c-0a64c4b5dc97

New topic attached to branch: 755d5c1f-c40b-414c-9813-7a1bc4bd4d43

Tree structure:
- Root: 755d5c1f-c40b-414c-9813-7a1bc4bd4d43
- Total branches: 4

Tree exported to conversation_backup.json


In [24]:
import json
import sqlite3
from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple, Union
from dataclasses import dataclass, asdict
import uuid
from pathlib import Path
import pickle
from enum import Enum
import dspy
import pydantic

class MessageRole(str, Enum):
    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"

class AttachmentPoint(pydantic.BaseModel):
    """Represents where to attach a new conversation turn"""
    branch_id: Union[int, str]  # Support both int and str
    attachment_index: int  # Where in the branch to attach (message index)
    reasoning: str
    score: float

@dataclass
class Message:
    """Individual message in a conversation"""
    id: str
    role: MessageRole
    content: str
    timestamp: datetime
    metadata: Optional[Dict[str, Any]] = None
    
    def to_dict(self) -> Dict:
        data = asdict(self)
        data['timestamp'] = self.timestamp.isoformat()
        data['role'] = self.role.value
        return data
    
    @classmethod
    def from_dict(cls, data: Dict) -> 'Message':
        data = data.copy()
        data['timestamp'] = datetime.fromisoformat(data['timestamp'])
        data['role'] = MessageRole(data['role'])
        return cls(**data)

@dataclass
class ConversationBranch:
    """A branch in the conversation tree"""
    id: str
    parent_id: Optional[str]  # None for root
    messages: List[Message]
    created_at: datetime
    updated_at: datetime
    metadata: Optional[Dict[str, Any]] = None
    attachment_point: Optional[int] = None  # Where this branch attaches to parent
    
    def to_dict(self) -> Dict:
        return {
            'id': self.id,
            'parent_id': self.parent_id,
            'messages': [msg.to_dict() for msg in self.messages],
            'created_at': self.created_at.isoformat(),
            'updated_at': self.updated_at.isoformat(),
            'metadata': self.metadata,
            'attachment_point': self.attachment_point
        }
    
    @classmethod
    def from_dict(cls, data: Dict) -> 'ConversationBranch':
        data = data.copy()
        data['messages'] = [Message.from_dict(msg) for msg in data['messages']]
        data['created_at'] = datetime.fromisoformat(data['created_at'])
        data['updated_at'] = datetime.fromisoformat(data['updated_at'])
        return cls(**data)

class ConversationTreeDB:
    """SQLite-based persistence for conversation trees"""
    
    def __init__(self, db_path: str = "conversation_tree.db"):
        self.db_path = db_path
        self._init_db()
    
    def _init_db(self):
        """Initialize database schema"""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute('''
                CREATE TABLE IF NOT EXISTS branches (
                    id TEXT PRIMARY KEY,
                    parent_id TEXT,
                    attachment_point INTEGER,
                    created_at TEXT NOT NULL,
                    updated_at TEXT NOT NULL,
                    metadata TEXT,
                    FOREIGN KEY (parent_id) REFERENCES branches(id)
                )
            ''')
            
            conn.execute('''
                CREATE TABLE IF NOT EXISTS messages (
                    id TEXT PRIMARY KEY,
                    branch_id TEXT NOT NULL,
                    position INTEGER NOT NULL,
                    role TEXT NOT NULL,
                    content TEXT NOT NULL,
                    timestamp TEXT NOT NULL,
                    metadata TEXT,
                    FOREIGN KEY (branch_id) REFERENCES branches(id),
                    UNIQUE(branch_id, position)
                )
            ''')
            
            conn.execute('''
                CREATE TABLE IF NOT EXISTS tree_metadata (
                    tree_id TEXT PRIMARY KEY,
                    root_branch_id TEXT,
                    created_at TEXT NOT NULL,
                    updated_at TEXT NOT NULL,
                    metadata TEXT
                )
            ''')
            
            # Indexes for performance
            conn.execute('CREATE INDEX IF NOT EXISTS idx_branch_parent ON branches(parent_id)')
            conn.execute('CREATE INDEX IF NOT EXISTS idx_message_branch ON messages(branch_id)')
    
    def save_branch(self, branch: ConversationBranch):
        """Save or update a branch"""
        with sqlite3.connect(self.db_path) as conn:
            # Save branch
            conn.execute('''
                INSERT OR REPLACE INTO branches 
                (id, parent_id, attachment_point, created_at, updated_at, metadata)
                VALUES (?, ?, ?, ?, ?, ?)
            ''', (
                branch.id,
                branch.parent_id,
                branch.attachment_point,
                branch.created_at.isoformat(),
                branch.updated_at.isoformat(),
                json.dumps(branch.metadata) if branch.metadata else None
            ))
            
            # Delete existing messages for this branch (for updates)
            conn.execute('DELETE FROM messages WHERE branch_id = ?', (branch.id,))
            
            # Save messages
            for position, msg in enumerate(branch.messages):
                conn.execute('''
                    INSERT INTO messages 
                    (id, branch_id, position, role, content, timestamp, metadata)
                    VALUES (?, ?, ?, ?, ?, ?, ?)
                ''', (
                    msg.id,
                    branch.id,
                    position,
                    msg.role.value,
                    msg.content,
                    msg.timestamp.isoformat(),
                    json.dumps(msg.metadata) if msg.metadata else None
                ))
    
    def load_branch(self, branch_id: str) -> Optional[ConversationBranch]:
        """Load a branch by ID"""
        with sqlite3.connect(self.db_path) as conn:
            # Load branch info
            cursor = conn.execute('''
                SELECT parent_id, attachment_point, created_at, updated_at, metadata
                FROM branches WHERE id = ?
            ''', (branch_id,))
            
            row = cursor.fetchone()
            if not row:
                return None
            
            parent_id, attachment_point, created_at, updated_at, metadata = row
            
            # Load messages
            cursor = conn.execute('''
                SELECT id, role, content, timestamp, metadata
                FROM messages WHERE branch_id = ?
                ORDER BY position
            ''', (branch_id,))
            
            messages = []
            for msg_row in cursor:
                msg_id, role, content, timestamp, msg_metadata = msg_row
                messages.append(Message(
                    id=msg_id,
                    role=MessageRole(role),
                    content=content,
                    timestamp=datetime.fromisoformat(timestamp),
                    metadata=json.loads(msg_metadata) if msg_metadata else None
                ))
            
            return ConversationBranch(
                id=branch_id,
                parent_id=parent_id,
                messages=messages,
                created_at=datetime.fromisoformat(created_at),
                updated_at=datetime.fromisoformat(updated_at),
                metadata=json.loads(metadata) if metadata else None,
                attachment_point=attachment_point
            )
    
    def get_children(self, branch_id: str) -> List[str]:
        """Get all child branch IDs for a given branch"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute(
                'SELECT id FROM branches WHERE parent_id = ?',
                (branch_id,)
            )
            return [row[0] for row in cursor]

class ConversationTreeManager:
    """Main manager for conversation tree operations"""
    
    def __init__(self, db_path: str = "conversation_tree.db", cache_size: int = 100):
        self.db = ConversationTreeDB(db_path)
        self.cache: Dict[str, ConversationBranch] = {}
        self.cache_size = cache_size
        self.tree_id = str(uuid.uuid4())
        self.root_branch_id: Optional[str] = None
    
    def create_root_branch(self) -> ConversationBranch:
        """Create the root branch of the conversation tree"""
        root = ConversationBranch(
            id=str(uuid.uuid4()),
            parent_id=None,
            messages=[],
            created_at=datetime.now(),
            updated_at=datetime.now(),
            metadata={"is_root": True}
        )
        self.db.save_branch(root)
        self.root_branch_id = root.id
        self._cache_branch(root)
        return root
    
    def add_message(self, branch_id: str, role: MessageRole, content: str, 
                   metadata: Optional[Dict] = None) -> Message:
        """Add a message to an existing branch"""
        branch = self.get_branch(branch_id)
        if not branch:
            raise ValueError(f"Branch {branch_id} not found")
        
        message = Message(
            id=str(uuid.uuid4()),
            role=role,
            content=content,
            timestamp=datetime.now(),
            metadata=metadata
        )
        
        branch.messages.append(message)
        branch.updated_at = datetime.now()
        
        self.db.save_branch(branch)
        self._cache_branch(branch)
        
        return message
    
    def create_branch(self, parent_id: str, attachment_point: int,
                     initial_messages: Optional[List[Dict]] = None) -> ConversationBranch:
        """Create a new branch from a parent at a specific attachment point"""
        parent = self.get_branch(parent_id)
        if not parent:
            raise ValueError(f"Parent branch {parent_id} not found")
        
        if attachment_point > len(parent.messages):
            raise ValueError(f"Invalid attachment point {attachment_point}")
        
        # Create new branch
        branch = ConversationBranch(
            id=str(uuid.uuid4()),
            parent_id=parent_id,
            messages=[],
            created_at=datetime.now(),
            updated_at=datetime.now(),
            attachment_point=attachment_point
        )
        
        # Add initial messages if provided
        if initial_messages:
            for msg_data in initial_messages:
                message = Message(
                    id=str(uuid.uuid4()),
                    role=MessageRole(msg_data['role']),
                    content=msg_data['content'],
                    timestamp=datetime.now(),
                    metadata=msg_data.get('metadata')
                )
                branch.messages.append(message)
        
        self.db.save_branch(branch)
        self._cache_branch(branch)
        
        return branch
    
    def get_branch(self, branch_id: str) -> Optional[ConversationBranch]:
        """Get a branch by ID (with caching)"""
        # Check cache first
        if branch_id in self.cache:
            return self.cache[branch_id]
        
        # Load from DB
        branch = self.db.load_branch(branch_id)
        if branch:
            self._cache_branch(branch)
        
        return branch
    
    def get_conversation_context(self, branch_id: str) -> List[Message]:
        """Get full conversation context from root to branch"""
        context = []
        current_id = branch_id
        
        # Walk up the tree collecting messages
        path = []
        while current_id:
            branch = self.get_branch(current_id)
            if not branch:
                break
            path.append(branch)
            current_id = branch.parent_id
        
        # Reverse to get root-to-leaf order
        path.reverse()
        
        # Collect messages respecting attachment points
        for i, branch in enumerate(path):
            if i == 0:  # Root branch
                context.extend(branch.messages)
            else:
                # Only include messages up to the attachment point from parent
                parent = path[i-1]
                attachment_point = branch.attachment_point or len(parent.messages)
                
                # Add messages from parent up to attachment point
                if i == 1:  # First child of root
                    context = parent.messages[:attachment_point]
                
                # Add this branch's messages
                context.extend(branch.messages)
        
        return context
    
    def reconstruct_tree(self) -> Dict[str, Any]:
        """Reconstruct the entire tree structure"""
        with sqlite3.connect(self.db.db_path) as conn:
            # Get all branches
            cursor = conn.execute('''
                SELECT id, parent_id, attachment_point, created_at, updated_at
                FROM branches
                ORDER BY created_at
            ''')
            
            tree = {"branches": {}, "root": None}
            
            for row in cursor:
                branch_id, parent_id, attachment_point, created_at, updated_at = row
                
                if parent_id is None:
                    tree["root"] = branch_id
                
                tree["branches"][branch_id] = {
                    "id": branch_id,
                    "parent_id": parent_id,
                    "attachment_point": attachment_point,
                    "children": self.db.get_children(branch_id),
                    "created_at": created_at,
                    "updated_at": updated_at
                }
            
            return tree
    
    def export_tree(self, filepath: str):
        """Export entire tree to JSON file"""
        tree_data = {"tree_id": self.tree_id, "branches": []}
        
        # Get all branch IDs
        with sqlite3.connect(self.db.db_path) as conn:
            cursor = conn.execute('SELECT id FROM branches')
            branch_ids = [row[0] for row in cursor]
        
        # Load and serialize each branch
        for branch_id in branch_ids:
            branch = self.get_branch(branch_id)
            if branch:
                tree_data["branches"].append(branch.to_dict())
        
        with open(filepath, 'w') as f:
            json.dump(tree_data, f, indent=2)
    
    def import_tree(self, filepath: str):
        """Import tree from JSON file"""
        with open(filepath, 'r') as f:
            tree_data = json.load(f)
        
        self.tree_id = tree_data.get("tree_id", str(uuid.uuid4()))
        
        for branch_data in tree_data["branches"]:
            branch = ConversationBranch.from_dict(branch_data)
            self.db.save_branch(branch)
            
            if branch.parent_id is None:
                self.root_branch_id = branch.id
    
    def _cache_branch(self, branch: ConversationBranch):
        """Add branch to cache with LRU eviction"""
        if len(self.cache) >= self.cache_size:
            # Simple eviction - remove oldest
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]
        
        self.cache[branch.id] = branch

# Integration with your DSPy module
# Note: This assumes ConversationTreeModule is imported from your other module
# If using as a single file, copy the ConversationTreeModule class here

class PersistentConversationTreeModule(dspy.Module):
    """Extended module with persistence - inherits from ConversationTreeModule"""
    
    def __init__(self, db_path: str = "conversation_tree.db", base_module=None):
        super().__init__()
        self.manager = ConversationTreeManager(db_path)
        
        # If ConversationTreeModule is in a separate file, pass it as base_module
        # Otherwise, it should be defined above in this file
        if base_module:
            self.base_module = base_module()
        else:
            # Assumes ConversationTreeModule is defined in this file
            from conversation_tree_module import ConversationTreeModule, Score, scorer
            self.base_module = ConversationTreeModule()
            self.scorer = self.base_module.scorer
            self.attachment_finder = self.base_module.attachment_finder
        
        # Initialize root if needed
        tree_structure = self.manager.reconstruct_tree()
        if not tree_structure["root"]:
            root = self.manager.create_root_branch()
            # Add a system message to the root
            self.manager.add_message(
                root.id, 
                MessageRole.SYSTEM, 
                "You are a helpful AI assistant that maintains context across branching conversations."
            )
        
    def forward(self, user_prompt: str) -> Dict[str, Any]:
        """Forward with automatic persistence"""
        # Load current tree structure
        tree_structure = self.manager.reconstruct_tree()
        
        # Convert to format expected by base module
        conversations = []
        branch_id_map = {}  # Map integer IDs to string UUIDs
        
        for idx, (branch_id, branch_info) in enumerate(tree_structure["branches"].items()):
            branch = self.manager.get_branch(branch_id)
            if branch and branch.messages:
                # Use integer IDs for the base module
                int_id = idx + 1
                branch_id_map[int_id] = branch_id
                conversations.append({
                    "branchID": int_id,
                    "messages": [{"role": msg.role.value, "content": msg.content} 
                                for msg in branch.messages]
                })
        
        # Handle empty tree case - create first branch
        if not conversations or all(len(c["messages"]) <= 1 for c in conversations):
            # No substantial conversations yet, create first branch
            root_id = tree_structure["root"]
            if not root_id:
                root = self.manager.create_root_branch()
                root_id = root.id
            
            new_branch = self.manager.create_branch(
                parent_id=root_id,
                attachment_point=1,  # After system message
                initial_messages=[{"role": "user", "content": user_prompt}]
            )
            
            return {
                "best_attachment": AttachmentPoint(
                    branch_id=new_branch.id,
                    attachment_index=1,
                    reasoning="First conversation in tree",
                    score=1.0
                ),
                "context_messages": self.manager.get_conversation_context(new_branch.id),
                "new_branch_id": new_branch.id,
                "root_best": None,
                "tip_best": None,
                "same_path": True
            }
        
        # Get attachment decision from base module
        result = self.base_module.forward(user_prompt, conversations)
        
        # Create new branch or extend existing one
        attachment = result["best_attachment"]
        
        # Convert integer branch_id back to string UUID
        if isinstance(attachment.branch_id, int) and attachment.branch_id != 0:
            actual_branch_id = branch_id_map.get(attachment.branch_id)
        else:
            actual_branch_id = attachment.branch_id
        
        if attachment.branch_id == "0" or attachment.branch_id == 0:
            # Create new branch from root
            root_id = tree_structure["root"]
            new_branch = self.manager.create_branch(
                parent_id=root_id,
                attachment_point=1,  # After system message
                initial_messages=[{"role": "user", "content": user_prompt}]
            )
            result["new_branch_id"] = new_branch.id
        else:
            # Add to existing branch or create sub-branch
            existing_branch = self.manager.get_branch(str(actual_branch_id))
            if attachment.attachment_index == len(existing_branch.messages):
                # Extend existing branch
                self.manager.add_message(
                    str(actual_branch_id),
                    MessageRole.USER,
                    user_prompt
                )
                result["new_branch_id"] = str(actual_branch_id)
            else:
                # Create new sub-branch
                new_branch = self.manager.create_branch(
                    parent_id=str(actual_branch_id),
                    attachment_point=attachment.attachment_index,
                    initial_messages=[{"role": "user", "content": user_prompt}]
                )
                result["new_branch_id"] = new_branch.id
        
        # Add context messages to result
        result["context_messages"] = self.manager.get_conversation_context(result["new_branch_id"])
        
        return result

# Example usage
if __name__ == "__main__":
    # Initialize persistent module
    tree = PersistentConversationTreeModule("my_conversations.db")
    
    # Process first user message - creates initial branch
    result = tree("Tell me about mindfulness techniques")
    print(f"First message attached to branch: {result['new_branch_id']}")
    
    # Add assistant response to continue the conversation
    tree.manager.add_message(
        result['new_branch_id'],
        MessageRole.ASSISTANT,
        "Mindfulness techniques include meditation, deep breathing, body scans, and present-moment awareness. Would you like me to explain any of these in detail?"
    )
    
    # Process another user message - will attach to the most relevant branch
    result2 = tree("Yes, tell me about deep breathing")
    print(f"Second message attached to branch: {result2['new_branch_id']}")
    
    # Start a different conversation thread
    result3 = tree("I need help with Python programming")
    print(f"New topic attached to branch: {result3['new_branch_id']}")
    
    # Export for backup
    tree.manager.export_tree("conversation_backup.json")
    
    # Show tree structure
    tree_structure = tree.manager.reconstruct_tree()
    print(f"\nTree has {len(tree_structure['branches'])} branches")
    print(f"Root branch: {tree_structure['root']}")

"""
Usage Notes:

1. To use both modules together, save them in your project and import:
   
   from conversation_tree_module import ConversationTreeModule, Score, ScoreItem
   from conversation_persistence import PersistentConversationTreeModule
   
2. Or combine both files into one by copying the ConversationTreeModule 
   class and its dependencies into this file.

3. The system automatically:
   - Creates a root branch on first use
   - Finds the best attachment point for new messages
   - Persists all conversations to SQLite
   - Maintains conversation context across branches
   
4. Access conversation history:
   context = tree.manager.get_conversation_context(branch_id)
   
5. Export/Import trees:
   tree.manager.export_tree("backup.json")
   tree.manager.import_tree("backup.json")
"""

ModuleNotFoundError: No module named 'conversation_tree_module'

In [20]:
# 1. Configure DSPy with your LM
lm = dspy.LM("groq/moonshotai/kimi-k2-instruct")
dspy.configure(lm=lm)

# 2. Initialize the conversation tree
tree = PersistentConversationTreeModule("chat.db")

In [21]:
# 3. Process user messages - they automatically attach to the best branch
result = tree("Tell me about stress management")
print(f"Message attached to branch: {result['new_branch_id']}")

ValidationError: 1 validation error for AttachmentPoint
branch_id
  Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='ae2808d9-2ee6-4188-bb14-f44db41c00fc', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/int_parsing

In [None]:
# 4. Add assistant responses to continue the conversation
tree.manager.add_message(
    result['new_branch_id'],
    MessageRole.ASSISTANT,
    "I can help with stress management. Some effective techniques include deep breathing, meditation, and regular exercise. What aspect would you like to explore?"
)

# 5. Continue the conversation - it will stay on the same branch
result2 = tree("Tell me more about meditation")
print(f"Continued on branch: {result2['new_branch_id']}")

# 6. Start a new topic - it will create a new branch
result3 = tree("Help me debug my Python code")
print(f"New topic on branch: {result3['new_branch_id']}")

# 7. Get the full conversation context for any branch
context = tree.manager.get_conversation_context(result['new_branch_id'])
print(f"\nConversation history ({len(context)} messages):")
for msg in context:
    print(f"[{msg.role.value}]: {msg.content[:50]}...")

# 8. Export for backup
tree.manager.export_tree("backup.json")