# Phase 2a: Training Interpretable Transformer (Multi-Hop Reasoning Test)

This notebook trains a small transformer on **multi-hop inductive reasoning** tasks.

**Experimental Design**:
- **depth-2, depth_of_truth=0**: Model must traverse **1 hop** up the ontology tree
- **depth-3, depth_of_truth=0**: Model must traverse **2 hops** up the ontology tree

By filtering both depths to `depth_of_truth=0` (root-level targets), we get a fair comparison that directly tests whether transformers can compose multiple inference steps.

**Research Question**: 
Can transformers perform multi-hop reasoning? If 1-hop succeeds but 2-hop fails, this suggests an architectural limit on compositional reasoning.

**Training Setup**:
1. Set `TREE_DEPTH = 2` or `TREE_DEPTH = 3` in the configuration cell
2. Run all cells in order
3. Training takes ~30-60 min on Colab GPU
4. Download the trained model at the end

**Data Format (True Induction)**:
- Observations include both concept membership AND property assertions
- Model must induce the hidden rule from observations (not pattern completion)
- All samples require finding the ROOT concept (Occam's Razor / parsimony)

In [8]:
# Check GPU
!nvidia-smi

zsh:1: command not found: nvidia-smi


## Cell 1: All Imports and Class Definitions

This cell defines the Tokenizer, Dataset, and Model classes.

In [None]:
#@title Run this cell first - Defines all classes { display-mode: "form" }

import json
import math
import random
from typing import List, Dict, Optional, Tuple, Any, Set
from dataclasses import dataclass, field
from collections import deque, defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

import matplotlib.pyplot as plt
from IPython.display import clear_output
import re

# ============================================================================
# TOKENIZER (with configurable vocabulary size for deeper trees)
# ============================================================================

class SymbolicOntologyTokenizer:
    """Tokenizer for symbolic ontology notation."""
    
    PAD_TOKEN = "<PAD>"
    BOS_TOKEN = "<BOS>"
    EOS_TOKEN = "<EOS>"
    UNK_TOKEN = "<UNK>"
    WORLD_MODEL_TOKEN = "[WORLD_MODEL]"
    OBSERVATIONS_TOKEN = "[OBSERVATIONS]"
    TASK_TOKEN = "[TASK]"
    ANSWER_TOKEN = "[ANSWER]"
    FORALL_TOKEN = "‚àÄx:"
    IMPLIES_TOKEN = "->"
    OPEN_PAREN = "("
    CLOSE_PAREN = ")"
    PRED_X = "(x)"
    
    def __init__(self, max_concepts=30, max_properties=15, max_entities=30, max_seq_len=512):
        self.max_concepts = max_concepts
        self.max_properties = max_properties
        self.max_entities = max_entities
        self.max_seq_len = max_seq_len
        self._build_vocab()
        
    def _build_vocab(self):
        self.token_to_id = {}
        self.id_to_token = {}
        current_id = 0
        
        for token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
            self.token_to_id[token] = current_id
            self.id_to_token[current_id] = token
            current_id += 1
        
        for token in [self.WORLD_MODEL_TOKEN, self.OBSERVATIONS_TOKEN, 
                      self.TASK_TOKEN, self.ANSWER_TOKEN]:
            self.token_to_id[token] = current_id
            self.id_to_token[current_id] = token
            current_id += 1
        
        for token in [self.FORALL_TOKEN, self.IMPLIES_TOKEN, 
                      self.OPEN_PAREN, self.CLOSE_PAREN, self.PRED_X]:
            self.token_to_id[token] = current_id
            self.id_to_token[current_id] = token
            current_id += 1
        
        self.token_to_id["\n"] = current_id
        self.id_to_token[current_id] = "\n"
        current_id += 1
        
        for i in range(self.max_concepts):
            token = f"c{i}"
            self.token_to_id[token] = current_id
            self.id_to_token[current_id] = token
            current_id += 1
        
        for i in range(1, self.max_properties + 1):
            token = f"p{i}"
            self.token_to_id[token] = current_id
            self.id_to_token[current_id] = token
            current_id += 1
        
        for i in range(self.max_entities):
            token = f"e{i}"
            self.token_to_id[token] = current_id
            self.id_to_token[current_id] = token
            current_id += 1
        
        self.vocab_size = current_id
        self.pad_token_id = self.token_to_id[self.PAD_TOKEN]
        self.bos_token_id = self.token_to_id[self.BOS_TOKEN]
        self.eos_token_id = self.token_to_id[self.EOS_TOKEN]
        self.unk_token_id = self.token_to_id[self.UNK_TOKEN]
    
    def _tokenize_statement(self, statement):
        statement = statement.strip()
        tokens = []
        
        if statement.startswith("‚àÄx:"):
            tokens.append(self.FORALL_TOKEN)
            statement = statement[3:].strip()
        
        if "->" in statement:
            parts = statement.split("->")
            left = parts[0].strip()
            left_match = re.match(r'([cp]\d+)\(x\)', left)
            if left_match:
                tokens.append(left_match.group(1))
                tokens.append(self.PRED_X)
            tokens.append(self.IMPLIES_TOKEN)
            right = parts[1].strip()
            right_match = re.match(r'([cp]\d+)\(x\)', right)
            if right_match:
                tokens.append(right_match.group(1))
                tokens.append(self.PRED_X)
        else:
            match = re.match(r'([cp]\d+)\(([e]\d+)\)', statement)
            if match:
                tokens.append(match.group(1))
                tokens.append(self.OPEN_PAREN)
                tokens.append(match.group(2))
                tokens.append(self.CLOSE_PAREN)
        return tokens
    
    def tokenize(self, text, add_special_tokens=True):
        tokens = []
        if add_special_tokens:
            tokens.append(self.BOS_TOKEN)
        
        lines = text.strip().split('\n')
        for line in lines:
            line = line.strip()
            if not line:
                continue
            if line == "[WORLD_MODEL]":
                tokens.append(self.WORLD_MODEL_TOKEN)
                tokens.append("\n")
            elif line == "[OBSERVATIONS]":
                tokens.append(self.OBSERVATIONS_TOKEN)
                tokens.append("\n")
            elif line == "[TASK]":
                tokens.append(self.TASK_TOKEN)
                tokens.append("\n")
            elif line == "[ANSWER]":
                tokens.append(self.ANSWER_TOKEN)
                tokens.append("\n")
            elif line.startswith("Infer"):
                continue  # Skip task description text
            else:
                stmt_tokens = self._tokenize_statement(line)
                tokens.extend(stmt_tokens)
                tokens.append("\n")
        
        if add_special_tokens:
            tokens.append(self.EOS_TOKEN)
        return tokens
    
    def encode(self, text, add_special_tokens=True):
        tokens = self.tokenize(text, add_special_tokens)
        return [self.token_to_id.get(t, self.unk_token_id) for t in tokens]
    
    def decode(self, token_ids, skip_special_tokens=True):
        special_ids = {self.pad_token_id, self.bos_token_id, self.eos_token_id}
        tokens = []
        for tid in token_ids:
            if skip_special_tokens and tid in special_ids:
                continue
            tokens.append(self.id_to_token.get(tid, self.UNK_TOKEN))
        return "".join(tokens)

# ============================================================================
# DATA GENERATOR (True Induction Format - from generate_symbolic_ontology.py)
# ============================================================================

@dataclass
class OntologyNode:
    """Represents a concept node in the ontology tree."""
    concept_id: int
    depth: int
    parent_id: Optional[int] = None
    children_ids: List[int] = field(default_factory=list)
    properties: List[int] = field(default_factory=list)
    members: List[int] = field(default_factory=list)


class SymbolicOntologyGenerator:
    """
    Generates symbolic ontology trees with controlled complexity for MI research.
    
    This is the TRUE INDUCTION version that:
    1. Includes concept membership AND property observations
    2. Forces model to induce rules from scratch (not pattern completion)
    3. Tests Occam's Razor / parsimony reasoning
    """
    
    def __init__(self, depth: int = 3, branching_factor: int = 2, 
                 num_properties: int = 5, property_assignment_prob: float = 0.4,
                 members_per_leaf: int = 2, seed: Optional[int] = None):
        self.depth = depth
        self.branching_factor = branching_factor
        self.num_properties = num_properties
        self.property_assignment_prob = property_assignment_prob
        self.members_per_leaf = members_per_leaf
        
        if seed is not None:
            random.seed(seed)
        
        self.nodes: Dict[int, OntologyNode] = {}
        self.num_concepts = 0
        self.num_members = 0
        self._generate_structure()
    
    def _generate_structure(self):
        """Build the ontology tree with BFS traversal."""
        root = OntologyNode(concept_id=0, depth=0)
        self.nodes[0] = root
        self.num_concepts = 1
        
        queue = deque([0])
        
        while queue:
            curr_id = queue.popleft()
            curr_node = self.nodes[curr_id]
            
            # Assign properties randomly (distributed throughout tree)
            if random.random() < self.property_assignment_prob:
                used_props = self._get_ancestor_properties(curr_id)
                available_props = [p for p in range(1, self.num_properties + 1) 
                                   if p not in used_props]
                if available_props:
                    prop_id = random.choice(available_props)
                    curr_node.properties.append(prop_id)
            
            # Add children if not at max depth
            if curr_node.depth < self.depth - 1:
                for _ in range(self.branching_factor):
                    child_id = self.num_concepts
                    self.num_concepts += 1
                    
                    child = OntologyNode(
                        concept_id=child_id,
                        depth=curr_node.depth + 1,
                        parent_id=curr_id
                    )
                    curr_node.children_ids.append(child_id)
                    self.nodes[child_id] = child
                    queue.append(child_id)
            else:
                # Leaf node: assign members (entities)
                for _ in range(self.members_per_leaf):
                    member_id = self.num_members
                    curr_node.members.append(member_id)
                    self.num_members += 1
    
    def _get_ancestor_properties(self, concept_id: int) -> Set[int]:
        """Get all properties from ancestors to avoid conflicts."""
        props = set()
        curr_id = concept_id
        while curr_id is not None:
            node = self.nodes[curr_id]
            props.update(node.properties)
            curr_id = node.parent_id
        return props
    
    def _get_all_descendants(self, concept_id: int) -> List[int]:
        """Get all descendant concept IDs (including self)."""
        descendants = [concept_id]
        queue = deque([concept_id])
        while queue:
            curr = queue.popleft()
            for child_id in self.nodes[curr].children_ids:
                descendants.append(child_id)
                queue.append(child_id)
        return descendants
    
    def _get_all_descendant_members(self, concept_id: int) -> List[Tuple[int, int]]:
        """Get all entity members that belong to a concept (including from subtypes)."""
        members = []
        descendants = self._get_all_descendants(concept_id)
        for desc_id in descendants:
            for member_id in self.nodes[desc_id].members:
                members.append((member_id, desc_id))
        return members


def generate_inductive_sample(gen: SymbolicOntologyGenerator) -> Optional[Dict]:
    """
    Generate a TRUE INDUCTIVE reasoning task.
    
    Task: Given observations that multiple entities have a property,
    infer the most general rule (find the common ancestor).
    
    Key difference from pattern-completion:
    - Observations include BOTH concept membership AND property assertions
    - Model must induce the rule from scratch
    """
    # Find concepts with properties that have descendants with members
    candidates = []
    for concept_id, node in gen.nodes.items():
        if node.properties:
            descendant_members = gen._get_all_descendant_members(concept_id)
            if len(descendant_members) >= 2:
                for prop_id in node.properties:
                    candidates.append((concept_id, prop_id, descendant_members))
    
    if not candidates:
        return None
    
    target_c, target_p, all_members = random.choice(candidates)
    target_node = gen.nodes[target_c]
    
    # Select entities from different subtrees for parsimony testing
    num_observations = min(3, len(all_members))
    if len(all_members) <= num_observations:
        selected = all_members
    else:
        selected = random.sample(all_members, num_observations)
    
    # Generate observations (TRUE INDUCTION FORMAT)
    # Include BOTH concept membership AND property observation
    observations = []
    for entity_id, leaf_id in selected:
        observations.append(f"c{leaf_id}(e{entity_id})")  # Concept membership
        observations.append(f"p{target_p}(e{entity_id})")  # Property observation
    
    # Ground truth hypothesis (parsimonious)
    gt_hypothesis = f"‚àÄx: c{target_c}(x) -> p{target_p}(x)"
    
    # Build world model (without the hidden hypothesis)
    world_model = []
    for node in gen.nodes.values():
        # Subtype relations
        if node.parent_id is not None:
            world_model.append(f"‚àÄx: c{node.concept_id}(x) -> c{node.parent_id}(x)")
        
        # Property rules (EXCEPT the hidden target)
        for prop_id in node.properties:
            if not (node.concept_id == target_c and prop_id == target_p):
                world_model.append(f"‚àÄx: c{node.concept_id}(x) -> p{prop_id}(x)")
        
        # Membership
        for member_id in node.members:
            world_model.append(f"c{node.concept_id}(e{member_id})")
    
    random.shuffle(world_model)
    random.shuffle(observations)
    
    # Format for training
    prompt = "[WORLD_MODEL]\n" + "\n".join(world_model)
    prompt += "\n[OBSERVATIONS]\n" + "\n".join(observations)
    prompt += "\n[TASK]\nInfer the most general rule that explains all observations."
    
    return {
        "input": prompt,
        "target": gt_hypothesis,
        "task_type": "inductive",
        "metadata": {
            "target_concept": f"c{target_c}",
            "target_property": f"p{target_p}",
            "depth_of_truth": target_node.depth,
            "tree_depth": gen.depth,
            "branching_factor": gen.branching_factor,
            "num_observations": len(selected),
            "is_parsimony_test": len(set(leaf for _, leaf in selected)) > 1,
            "observed_entities": [f"e{eid}" for eid, _ in selected],
            "observed_leaf_concepts": list(set(f"c{lid}" for _, lid in selected))
        }
    }


def generate_dataset(num_samples: int, min_depth: int = 2, max_depth: int = 4,
                     min_branch: int = 2, max_branch: int = 3, seed: int = 42) -> List[Dict]:
    """Generate a dataset of inductive reasoning examples."""
    random.seed(seed)
    samples = []
    attempts = 0
    max_attempts = num_samples * 10
    
    while len(samples) < num_samples and attempts < max_attempts:
        attempts += 1
        
        depth = random.randint(min_depth, max_depth)
        branch = random.randint(min_branch, max_branch)
        
        gen = SymbolicOntologyGenerator(
            depth=depth,
            branching_factor=branch,
            num_properties=5,
            property_assignment_prob=0.4
        )
        
        sample = generate_inductive_sample(gen)
        if sample:
            sample["id"] = len(samples)
            samples.append(sample)
    
    return samples

# ============================================================================
# DATASET CLASS
# ============================================================================

class SymbolicOntologyDataset(Dataset):
    def __init__(self, samples, tokenizer, max_input_len=512, max_target_len=32):
        self.samples = samples
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len
        self.max_target_len = max_target_len
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        input_ids = self.tokenizer.encode(sample["input"], add_special_tokens=True)
        target_ids = self.tokenizer.encode(sample["target"], add_special_tokens=False)
        target_ids = target_ids + [self.tokenizer.eos_token_id]
        
        answer_token_id = self.tokenizer.token_to_id[self.tokenizer.ANSWER_TOKEN]
        newline_id = self.tokenizer.token_to_id["\n"]
        
        if input_ids[-1] == self.tokenizer.eos_token_id:
            input_ids = input_ids[:-1]
        
        max_input_tokens = self.max_input_len - 2
        if len(input_ids) > max_input_tokens:
            input_ids = input_ids[:max_input_tokens]
        if len(target_ids) > self.max_target_len:
            target_ids = target_ids[:self.max_target_len]
        
        full_ids = input_ids + [answer_token_id, newline_id] + target_ids
        max_len = self.max_input_len + self.max_target_len
        
        if len(full_ids) > max_len:
            full_ids = full_ids[:max_len]
        
        target_start = len(input_ids) + 2
        attention_mask = [1] * len(full_ids)
        padding_len = max_len - len(full_ids)
        full_ids = full_ids + [self.tokenizer.pad_token_id] * padding_len
        attention_mask = attention_mask + [0] * padding_len
        
        labels = full_ids.copy()
        labels[:target_start] = [-100] * target_start
        
        return {
            "input_ids": torch.tensor(full_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
        }

# ============================================================================
# MODEL
# ============================================================================

@dataclass
class TransformerConfig:
    vocab_size: int = 100
    max_seq_len: int = 512
    n_layers: int = 4
    n_heads: int = 1
    d_model: int = 64
    d_ff: int = 256
    dropout: float = 0.0
    pad_token_id: int = 0
    concat_pos_emb: bool = True
    pre_ln: bool = True
    causal: bool = True
    use_ff: bool = True
    
    @property
    def embedding_dim(self):
        if self.concat_pos_emb:
            return self.d_model + self.max_seq_len
        return self.d_model

class SingleHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        input_dim = config.embedding_dim
        self.proj_q = nn.Linear(input_dim, config.d_model, bias=True)
        self.proj_k = nn.Linear(input_dim, config.d_model, bias=True)
        self.proj_v = nn.Linear(input_dim, config.d_model, bias=True)
        self.proj_out = nn.Linear(config.d_model, input_dim, bias=True)
        self.dropout = nn.Dropout(config.dropout)
        self.scale = math.sqrt(config.d_model)
        
    def forward(self, x, attention_mask=None, return_attention=False):
        batch_size, seq_len, _ = x.shape
        Q = self.proj_q(x)
        K = self.proj_k(x)
        V = self.proj_v(x)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        if self.config.causal:
            causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1)
            scores = scores.masked_fill(causal_mask, float('-inf'))
        
        if attention_mask is not None:
            padding_mask = (attention_mask == 0).unsqueeze(1)
            scores = scores.masked_fill(padding_mask, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        output = torch.matmul(attention_weights, V)
        output = self.proj_out(output)
        
        if return_attention:
            return output, attention_weights
        return output, None

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        input_dim = config.embedding_dim
        self.fc1 = nn.Linear(input_dim, config.d_ff)
        self.fc2 = nn.Linear(config.d_ff, input_dim)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.ln_attn = nn.LayerNorm(config.embedding_dim)
        self.attn = SingleHeadAttention(config)
        if config.use_ff:
            self.ln_ff = nn.LayerNorm(config.embedding_dim)
            self.ff = FeedForward(config)
        else:
            self.ln_ff = None
            self.ff = None
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x, attention_mask=None, return_attention=False):
        if self.config.pre_ln:
            attn_out, attn_weights = self.attn(self.ln_attn(x), attention_mask, return_attention)
            x = x + self.dropout(attn_out)
            if self.ff is not None:
                x = x + self.dropout(self.ff(self.ln_ff(x)))
        else:
            attn_out, attn_weights = self.attn(x, attention_mask, return_attention)
            x = self.ln_attn(x + self.dropout(attn_out))
            if self.ff is not None:
                x = self.ln_ff(x + self.dropout(self.ff(x)))
        return x, attn_weights

class InterpretableTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
        
        if config.concat_pos_emb:
            self.register_buffer('position_embedding', torch.eye(config.max_seq_len))
        else:
            self.position_embedding = nn.Embedding(config.max_seq_len, config.d_model)
        
        self.dropout_embedding = nn.Dropout(config.dropout)
        self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        self.ln_final = nn.LayerNorm(config.embedding_dim)
        self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
        self._init_weights()
        
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.padding_idx is not None:
                    module.weight.data[module.padding_idx].zero_()
    
    def get_embeddings(self, input_ids):
        batch_size, seq_len = input_ids.shape
        token_emb = self.token_embedding(input_ids)
        
        if self.config.concat_pos_emb:
            pos_emb = self.position_embedding[:seq_len, :seq_len]
            pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1)
            embeddings = torch.cat([token_emb, pos_emb], dim=-1)
        else:
            positions = torch.arange(seq_len, device=input_ids.device)
            pos_emb = self.position_embedding(positions)
            embeddings = token_emb + pos_emb
        
        return self.dropout_embedding(embeddings)
    
    def forward(self, input_ids, attention_mask=None, return_hidden_states=False, return_attention=False):
        x = self.get_embeddings(input_ids)
        hidden_states = [x] if return_hidden_states else None
        attentions = [] if return_attention else None
        
        for layer in self.layers:
            x, attn_weights = layer(x, attention_mask, return_attention)
            if return_hidden_states:
                hidden_states.append(x)
            if return_attention:
                attentions.append(attn_weights)
        
        x = self.ln_final(x)
        logits = self.lm_head(x)
        
        output = {"logits": logits}
        if return_hidden_states:
            output["hidden_states"] = hidden_states
        if return_attention:
            output["attentions"] = attentions
        return output
    
    def generate(self, input_ids, max_new_tokens=32, temperature=1.0, do_sample=False):
        self.eval()
        with torch.no_grad():
            for _ in range(max_new_tokens):
                outputs = self(input_ids)
                logits = outputs["logits"][:, -1, :]
                if do_sample:
                    probs = F.softmax(logits / temperature, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                else:
                    next_token = logits.argmax(dim=-1, keepdim=True)
                input_ids = torch.cat([input_ids, next_token], dim=-1)
                if (next_token == 2).all():
                    break
        return input_ids

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# ============================================================================
# HELPER: Calculate required vocab size for a given depth
# ============================================================================

def get_vocab_requirements(depth: int, branching: int = 2, members_per_leaf: int = 2):
    """Calculate vocabulary requirements for a given tree depth."""
    # Number of concepts = sum of branching^i for i in 0 to depth-1
    num_concepts = sum(branching**i for i in range(depth))
    # Number of leaves = branching^(depth-1)
    num_leaves = branching ** (depth - 1)
    # Number of entities = leaves * members_per_leaf
    num_entities = num_leaves * members_per_leaf
    return num_concepts, num_entities

# Print requirements for various depths
print("Vocabulary requirements by depth (branching=2):")
print("-" * 45)
for d in range(2, 9):
    nc, ne = get_vocab_requirements(d, branching=2)
    print(f"  Depth {d}: {nc:3d} concepts, {ne:3d} entities")
print("-" * 45)
print("‚úì All classes defined (TRUE INDUCTION format)!")

## Cell 2: Configuration

In [20]:
#@title Training Configuration { display-mode: "form" }

#@markdown ## ‚ö†Ô∏è IMPORTANT: Set Tree Depth for Training
#@markdown Change this value to train different models:
#@markdown - `TREE_DEPTH = 2` ‚Üí depth-2 model (should achieve ~98% accuracy)
#@markdown - `TREE_DEPTH = 3` ‚Üí depth-3 model (should plateau at ~50% accuracy)

TREE_DEPTH = 3  #@param {type:"integer"}

#@markdown ---
#@markdown ### Data Settings
NUM_TRAIN_SAMPLES = 2000  #@param {type:"integer"}
NUM_VAL_SAMPLES = 400  #@param {type:"integer"}
BRANCHING_FACTOR_MIN = 2  #@param {type:"integer"}
BRANCHING_FACTOR_MAX = 3  #@param {type:"integer"}

#@markdown ### Model Architecture
N_LAYERS = 4  #@param {type:"integer"}
N_HEADS = 1  #@param {type:"integer"}
D_MODEL = 64  #@param {type:"integer"}
D_FF = 256  #@param {type:"integer"}

#@markdown ### Training Hyperparameters
BATCH_SIZE = 32  #@param {type:"integer"}
LEARNING_RATE = 1e-3  #@param {type:"number"}
NUM_EPOCHS = 1000  #@param {type:"integer"}
GRAD_CLIP = 1.0  #@param {type:"number"}

#@markdown ### Logging
EVAL_INTERVAL = 50  #@param {type:"integer"}
PLOT_INTERVAL = 100  #@param {type:"integer"}

# Derived settings
MAX_INPUT_LEN = 512
MAX_TARGET_LEN = 32
MAX_SEQ_LEN = MAX_INPUT_LEN + MAX_TARGET_LEN

# Device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")
print(f"\n{'='*50}")
print(f"TRAINING CONFIGURATION")
print(f"{'='*50}")
print(f"Tree Depth: {TREE_DEPTH} (min={TREE_DEPTH}, max={TREE_DEPTH})")
print(f"Training samples: {NUM_TRAIN_SAMPLES}")
print(f"Model: {N_LAYERS} layers, {N_HEADS} head(s), d_model={D_MODEL}")
print(f"Output file will be: depth{TREE_DEPTH}_trained.pt")
print(f"{'='*50}")

# Seed for reproducibility
torch.manual_seed(42)
random.seed(42)

Using device: cpu

TRAINING CONFIGURATION
Tree Depth: 3 (min=3, max=3)
Training samples: 2000
Model: 4 layers, 1 head(s), d_model=64
Output file will be: depth3_trained.pt


## Cell 3: Generate Data and Initialize Model

In [None]:
# FAIR COMPARISON: Filter ALL depths to depth_of_truth=0
# This tests multi-hop reasoning:
#   - depth-2, dot=0: 1 hop
#   - depth-3, dot=0: 2 hops
#   - depth-4, dot=0: 3 hops
#   - depth-5+: 4+ hops
FILTER_DEPTH_OF_TRUTH = 0  # Always filter to root-level targets

# Auto-calculate vocabulary requirements
required_concepts, required_entities = get_vocab_requirements(TREE_DEPTH, branching=2)
print(f"Depth-{TREE_DEPTH} requires: {required_concepts} concepts, {required_entities} entities")

# Set vocabulary sizes with some margin
VOCAB_CONCEPTS = max(30, required_concepts + 5)
VOCAB_ENTITIES = max(30, required_entities + 5)
print(f"Using vocabulary: {VOCAB_CONCEPTS} concepts, {VOCAB_ENTITIES} entities")

# Force branching=2 for all depths (consistent experiment)
BRANCHING_FACTOR_MIN = 2
BRANCHING_FACTOR_MAX = 2
print(f"Branching factor fixed at 2 for consistency")

# Adaptive oversampling - deeper trees have rarer depth_of_truth=0
oversample_factors = {2: 4, 3: 10, 4: 30, 5: 100, 6: 200, 7: 400, 8: 800}
oversample_factor = oversample_factors.get(TREE_DEPTH, 100)

print(f"\nGenerating training data (depth={TREE_DEPTH}, TRUE INDUCTION format)...")
print(f"Filtering to depth_of_truth={FILTER_DEPTH_OF_TRUTH} for fair comparison")
print(f"Using oversample_factor={oversample_factor}")

raw_train = generate_dataset(
    NUM_TRAIN_SAMPLES * oversample_factor,
    min_depth=TREE_DEPTH,
    max_depth=TREE_DEPTH,
    min_branch=BRANCHING_FACTOR_MIN,
    max_branch=BRANCHING_FACTOR_MAX,
    seed=42
)

# Filter to depth_of_truth=0 only
train_samples = [s for s in raw_train if s['metadata']['depth_of_truth'] == FILTER_DEPTH_OF_TRUTH]
print(f"Filtered: {len(train_samples)} samples with depth_of_truth={FILTER_DEPTH_OF_TRUTH} (from {len(raw_train)} raw)")

# Check if we have enough samples
if len(train_samples) < NUM_TRAIN_SAMPLES:
    print(f"WARNING: Only got {len(train_samples)} samples, wanted {NUM_TRAIN_SAMPLES}")
    print(f"Generating more samples with different seeds...")
    for extra_seed in [999, 1234, 5678, 9999, 11111, 22222, 33333, 44444]:
        if len(train_samples) >= NUM_TRAIN_SAMPLES:
            break
        extra_raw = generate_dataset(
            NUM_TRAIN_SAMPLES * oversample_factor,
            min_depth=TREE_DEPTH,
            max_depth=TREE_DEPTH,
            min_branch=BRANCHING_FACTOR_MIN,
            max_branch=BRANCHING_FACTOR_MAX,
            seed=extra_seed
        )
        extra_filtered = [s for s in extra_raw if s['metadata']['depth_of_truth'] == FILTER_DEPTH_OF_TRUTH]
        train_samples.extend(extra_filtered)
        print(f"  Added {len(extra_filtered)} more samples (total: {len(train_samples)})")

print(f"\nGenerating validation data (depth={TREE_DEPTH})...")
raw_val = generate_dataset(
    NUM_VAL_SAMPLES * oversample_factor,
    min_depth=TREE_DEPTH,
    max_depth=TREE_DEPTH,
    min_branch=BRANCHING_FACTOR_MIN,
    max_branch=BRANCHING_FACTOR_MAX,
    seed=123
)

# Filter to depth_of_truth=0 only
val_samples = [s for s in raw_val if s['metadata']['depth_of_truth'] == FILTER_DEPTH_OF_TRUTH]
print(f"Filtered: {len(val_samples)} samples with depth_of_truth={FILTER_DEPTH_OF_TRUTH} (from {len(raw_val)} raw)")

# Show reasoning test info
hops_required = TREE_DEPTH - 1
print(f"\n{'='*50}")
print(f"MULTI-HOP REASONING TEST")
print(f"{'='*50}")
print(f"Tree depth: {TREE_DEPTH}")
print(f"Target depth: {FILTER_DEPTH_OF_TRUTH} (root)")
print(f"Hops required: {hops_required}")
print(f"Branching factor: {BRANCHING_FACTOR_MIN}")
print(f"Training samples: {len(train_samples)}")
print(f"Validation samples: {len(val_samples)}")
print(f"{'='*50}")

# Show a sample
if train_samples:
    sample = train_samples[0]
    print(f"\n[SAMPLE INPUT] (first 400 chars):\n{sample['input'][:400]}...")
    print(f"\n[TARGET]: {sample['target']}")
else:
    print("\nERROR: No training samples generated!")
    print("This depth may be too deep to reliably generate depth_of_truth=0 samples.")

# Create tokenizer with appropriate vocabulary
tokenizer = SymbolicOntologyTokenizer(
    max_concepts=VOCAB_CONCEPTS,
    max_entities=VOCAB_ENTITIES,
    max_seq_len=MAX_SEQ_LEN
)
print(f"\nVocabulary size: {tokenizer.vocab_size}")

# Create datasets
train_dataset = SymbolicOntologyDataset(train_samples, tokenizer, MAX_INPUT_LEN, MAX_TARGET_LEN)
val_dataset = SymbolicOntologyDataset(val_samples, tokenizer, MAX_INPUT_LEN, MAX_TARGET_LEN)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Create model with matching vocabulary
model_config = TransformerConfig(
    vocab_size=tokenizer.vocab_size,
    max_seq_len=MAX_SEQ_LEN,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    d_model=D_MODEL,
    d_ff=D_FF,
    pad_token_id=tokenizer.pad_token_id,
)

model = InterpretableTransformer(model_config).to(DEVICE)
n_params = count_parameters(model)
print(f"\nModel parameters: {n_params:,} ({n_params/1e6:.2f}M)")

In [None]:
# Verify filtering worked
from collections import Counter
train_dot = Counter([s['metadata']['depth_of_truth'] for s in train_samples])
val_dot = Counter([s['metadata']['depth_of_truth'] for s in val_samples])
print(f"Training depth_of_truth distribution: {train_dot}")
print(f"Validation depth_of_truth distribution: {val_dot}")

# Verify ALL samples have depth_of_truth=0
assert all(s['metadata']['depth_of_truth'] == 0 for s in train_samples), "Filtering failed for train!"
assert all(s['metadata']['depth_of_truth'] == 0 for s in val_samples), "Filtering failed for val!"
print(f"‚úì All samples have depth_of_truth=0 (root-level targets)")

# Summary of what we're testing
print(f"\n{'='*50}")
print(f"EXPERIMENTAL SETUP SUMMARY")
print(f"{'='*50}")
print(f"Tree depth: {TREE_DEPTH}")
print(f"Hops to traverse: {TREE_DEPTH - 1}")
print(f"Training samples: {len(train_samples)}")
print(f"Validation samples: {len(val_samples)}")
print(f"\nThis tests whether the model can perform {TREE_DEPTH - 1}-hop")
print(f"upward traversal in the ontology tree to find the root concept.")
print(f"{'='*50}")

## Cell 4: Training Utilities

In [None]:
def compute_loss(model, batch, device):
    input_ids = batch["input_ids"].to(device)
    labels = batch["labels"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    
    outputs = model(input_ids, attention_mask=attention_mask)
    logits = outputs["logits"]
    
    shift_logits = logits[:, :-1, :].contiguous().view(-1, logits.size(-1))
    shift_labels = labels[:, 1:].contiguous().view(-1)
    
    return F.cross_entropy(shift_logits, shift_labels, ignore_index=-100)

def compute_accuracy(model, batch, device):
    input_ids = batch["input_ids"].to(device)
    labels = batch["labels"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        predictions = outputs["logits"].argmax(dim=-1)
    
    shift_preds = predictions[:, :-1]
    shift_labels = labels[:, 1:]
    target_mask = (shift_labels != -100)
    
    correct = ((shift_preds == shift_labels) & target_mask).sum().item()
    total = target_mask.sum().item()
    token_acc = correct / total if total > 0 else 0.0
    
    seq_correct = []
    for i in range(shift_labels.size(0)):
        mask_i = target_mask[i]
        if mask_i.sum() == 0:
            seq_correct.append(True)
        else:
            seq_correct.append((shift_preds[i][mask_i] == shift_labels[i][mask_i]).all().item())
    seq_acc = sum(seq_correct) / len(seq_correct)
    
    return {"token_acc": token_acc, "seq_acc": seq_acc}

def evaluate(model, dataloader, device, max_batches=None):
    model.eval()
    total_loss, total_token_acc, total_seq_acc, n = 0, 0, 0, 0
    
    with torch.no_grad():
        for batch in dataloader:
            loss = compute_loss(model, batch, device)
            acc = compute_accuracy(model, batch, device)
            total_loss += loss.item()
            total_token_acc += acc["token_acc"]
            total_seq_acc += acc["seq_acc"]
            n += 1
            if max_batches and n >= max_batches:
                break
    
    model.train()
    return {"loss": total_loss / n, "token_acc": total_token_acc / n, "seq_acc": total_seq_acc / n}

# Text-based logger (no inline plots to avoid crashes)
class TextLogger:
    def __init__(self, plot_dir="."):
        self.history = {'step': [], 'train_loss': [], 'val_loss': [], 'train_seq_acc': [], 'val_seq_acc': []}
        self.plot_dir = plot_dir
    
    def log(self, step, train_metrics, val_metrics=None):
        self.history['step'].append(step)
        self.history['train_loss'].append(train_metrics['loss'])
        self.history['train_seq_acc'].append(train_metrics['seq_acc'])
        self.history['val_loss'].append(val_metrics['loss'] if val_metrics else None)
        self.history['val_seq_acc'].append(val_metrics['seq_acc'] if val_metrics else None)
        
        # Print text-based stats
        msg = f"Step {step:6d} | Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['seq_acc']:.1%}"
        if val_metrics:
            msg += f" | Val Loss: {val_metrics['loss']:.4f} | Val Acc: {val_metrics['seq_acc']:.1%}"
            # Status indicator
            t, v = train_metrics['seq_acc'], val_metrics['seq_acc']
            if v > 0.9:
                msg += " ‚úì"
            elif t > 0.9 and v < 0.5:
                msg += " (grokking?)"
        print(msg)
    
    def save_plot(self, filename=None):
        """Save training curves to file instead of displaying inline."""
        if filename is None:
            filename = f"{self.plot_dir}/training_curves_depth{TREE_DEPTH}.png"
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 4))
        steps = self.history['step']
        
        # Loss
        axes[0].plot(steps, self.history['train_loss'], 'b-', label='Train', lw=2)
        val_loss = [v for v in self.history['val_loss'] if v is not None]
        if val_loss:
            val_steps = [s for s, v in zip(steps, self.history['val_loss']) if v is not None]
            axes[0].plot(val_steps, val_loss, 'r-', label='Val', lw=2)
        axes[0].set_xlabel('Step'); axes[0].set_ylabel('Loss'); axes[0].set_title('Loss')
        axes[0].legend(); axes[0].set_yscale('log'); axes[0].grid(True, alpha=0.3)
        
        # Sequence Accuracy
        axes[1].plot(steps, self.history['train_seq_acc'], 'b-', label='Train', lw=2)
        val_acc = [v for v in self.history['val_seq_acc'] if v is not None]
        if val_acc:
            val_steps = [s for s, v in zip(steps, self.history['val_seq_acc']) if v is not None]
            axes[1].plot(val_steps, val_acc, 'r-', label='Val', lw=2)
        axes[1].set_xlabel('Step'); axes[1].set_ylabel('Accuracy')
        axes[1].set_title('Sequence Accuracy')
        axes[1].legend(); axes[1].set_ylim(-0.05, 1.05); axes[1].grid(True, alpha=0.3)
        axes[1].axhline(y=0.9, color='g', ls='--', alpha=0.5, label='90%')
        axes[1].axhline(y=0.5, color='orange', ls='--', alpha=0.5, label='50%')
        
        # Gap
        if val_acc:
            gaps = [t - v for t, v in zip(self.history['train_seq_acc'], self.history['val_seq_acc']) if v is not None]
            gap_steps = [s for s, v in zip(steps, self.history['val_seq_acc']) if v is not None]
            colors = ['green' if g < 0.2 else 'orange' if g < 0.5 else 'red' for g in gaps]
            axes[2].bar(gap_steps, gaps, color=colors, alpha=0.7, width=max(1, len(gap_steps)//50))
        axes[2].set_xlabel('Step'); axes[2].set_ylabel('Train - Val')
        axes[2].set_title('Generalization Gap'); axes[2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(filename, dpi=150, bbox_inches='tight')
        plt.close(fig)  # Close to free memory
        print(f"Plot saved to {filename}")

logger = TextLogger(plot_dir=".")
print("‚úì Training utilities ready (text-based logging, plots saved to files)")

## Cell 5: Train! üöÄ

In [None]:
# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
total_steps = len(train_loader) * NUM_EPOCHS
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6)

# How often to save plot to file (less frequent to avoid overhead)
SAVE_PLOT_INTERVAL = 1000

print(f"Total steps: {total_steps}")
print(f"Training for {NUM_EPOCHS} epochs...")
print(f"Logging every {EVAL_INTERVAL} steps, saving plot every {SAVE_PLOT_INTERVAL} steps")
print("="*80)

global_step = 0
best_val_acc = 0.0

for epoch in range(NUM_EPOCHS):
    model.train()
    
    for batch in train_loader:
        loss = compute_loss(model, batch, DEVICE)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()
        scheduler.step()
        
        global_step += 1
        
        # Evaluate and log (text only)
        if global_step % EVAL_INTERVAL == 0:
            train_metrics = evaluate(model, train_loader, DEVICE, max_batches=20)
            val_metrics = evaluate(model, val_loader, DEVICE)
            logger.log(global_step, train_metrics, val_metrics)
            
            if val_metrics['seq_acc'] > best_val_acc:
                best_val_acc = val_metrics['seq_acc']
                torch.save(model.state_dict(), 'best_model.pt')
        
        # Save plot to file periodically (not displayed inline)
        if global_step % SAVE_PLOT_INTERVAL == 0:
            logger.save_plot()

# Save final plot
logger.save_plot()

print("\n" + "="*80)
print("Training complete!")
print(f"Best validation accuracy: {best_val_acc:.1%}")
print(f"Final plot saved to: training_curves_depth{TREE_DEPTH}.png")

## Cell 6: Final Results

In [None]:
# Final evaluation
train_metrics = evaluate(model, train_loader, DEVICE)
val_metrics = evaluate(model, val_loader, DEVICE)

print("="*80)
print("FINAL RESULTS")
print("="*80)
print(f"Tree Depth: {TREE_DEPTH} ({TREE_DEPTH - 1}-hop reasoning)")
print(f"Train: Loss={train_metrics['loss']:.4f}, Seq Acc={train_metrics['seq_acc']:.1%}")
print(f"Val:   Loss={val_metrics['loss']:.4f}, Seq Acc={val_metrics['seq_acc']:.1%}")
print(f"\nBest Val Seq Acc: {best_val_acc:.1%}")
print("="*80)

# Display the final saved plot
from IPython.display import Image, display
plot_path = f"training_curves_depth{TREE_DEPTH}.png"
try:
    display(Image(filename=plot_path))
except:
    print(f"Plot saved at: {plot_path}")

## Cell 7: Test Generation (Optional)

In [None]:
# Test on a sample
model.eval()
sample = val_dataset[0]
input_ids = sample['input_ids'].unsqueeze(0).to(DEVICE)

# Find [ANSWER] position
answer_token_id = tokenizer.token_to_id['[ANSWER]']
answer_pos = (input_ids[0] == answer_token_id).nonzero()

if len(answer_pos) > 0:
    prompt_end = answer_pos[0].item() + 2
    prompt = input_ids[:, :prompt_end]
    
    generated = model.generate(prompt, max_new_tokens=20)
    
    print("=== INPUT ===")
    print(tokenizer.decode(prompt[0].tolist()))
    print("\n=== GENERATED ===")
    gen_part = generated[0, prompt_end:].tolist()
    print(tokenizer.decode(gen_part))
    print("\n=== EXPECTED ===")
    labels = sample['labels'].tolist()
    target_ids = [l for l in labels if l != -100 and l != tokenizer.pad_token_id]
    print(tokenizer.decode(target_ids))

## Cell 8: Download Model

In [None]:
# Save with depth-specific filename
output_filename = f'depth{TREE_DEPTH}_trained.pt'

torch.save({
    'model_state_dict': model.state_dict(),
    'config': model_config,
    'history': logger.history,
    'best_val_acc': best_val_acc,
    'tree_depth': TREE_DEPTH,
    'hops_required': TREE_DEPTH - 1,
    'data_format': 'true_induction',
    'filter_depth_of_truth': FILTER_DEPTH_OF_TRUTH,
}, output_filename)

print(f"Model saved to {output_filename}")
print(f"  Tree depth: {TREE_DEPTH}")
print(f"  Hops required: {TREE_DEPTH - 1}")
print(f"  Best validation accuracy: {best_val_acc:.1%}")
print(f"  Data format: True Induction (depth_of_truth={FILTER_DEPTH_OF_TRUTH})")

# Also save just the state dict for easier loading
state_dict_filename = f'depth{TREE_DEPTH}_trained_state_dict.pt'
torch.save(model.state_dict(), state_dict_filename)
print(f"\nState dict also saved to {state_dict_filename}")

# Download (Colab only)
try:
    from google.colab import files
    files.download(output_filename)
    print(f"\n‚úì {output_filename} downloaded!")
except:
    print(f"\nNot running in Colab - files saved locally")