# Introduction to Weight-of-Thought

This notebook introduces the Weight-of-Thought (WoT) neural reasoning framework, a novel approach extending beyond traditional Chain-of-Thought reasoning by representing reasoning as an interconnected graph of nodes rather than a linear sequence.

## Setup and Imports

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Make sure the weight-of-thought package is in the Python path
import sys
sys.path.append('..')

# Import WoT components
from wot.models import WOTReasoner

## The Weight-of-Thought Architecture

Weight-of-Thought is composed of several key components:

1. **Language Encoder**: Processes input text using a transformer-based model
2. **Node Network**: A set of specialized nodes that process and exchange information
3. **Message Passing System**: Allows nodes to share information through weighted connections
4. **Reasoning Steps**: Sequential reasoning layers that refine the reasoning representation
5. **Task-Specific Output Heads**: Specialized outputs for different reasoning tasks

Let's visualize the architecture:

In [None]:
# Create a simple visualization of the WoT architecture
plt.figure(figsize=(10, 6))

# Create a directed graph
G = nx.DiGraph()

# Add nodes
input_node = "Input Text"
encoder_node = "Language Encoder"
reasoning_nodes = [f"Node {i+1}" for i in range(8)]
reasoning_steps = [f"Reasoning Step {i+1}" for i in range(4)]
output_node = "Task-Specific Output"

G.add_node(input_node, pos=(0, 0.5))
G.add_node(encoder_node, pos=(0.2, 0.5))

# Position reasoning nodes in a circle
node_radius = 0.3
node_center = (0.5, 0.5)
for i, node in enumerate(reasoning_nodes):
    angle = 2 * np.pi * i / len(reasoning_nodes)
    x = node_center[0] + node_radius * np.cos(angle)
    y = node_center[1] + node_radius * np.sin(angle)
    G.add_node(node, pos=(x, y))

# Position reasoning steps in a sequence
for i, step in enumerate(reasoning_steps):
    G.add_node(step, pos=(0.8 + i * 0.05, 0.5))

G.add_node(output_node, pos=(1.0, 0.5))

# Add edges
G.add_edge(input_node, encoder_node)
G.add_edge(encoder_node, reasoning_nodes[0])

# Connect reasoning nodes (fully connected)
for i, source in enumerate(reasoning_nodes):
    for j, target in enumerate(reasoning_nodes):
        if i != j:  # No self-loops
            G.add_edge(source, target, weight=np.random.uniform(0.1, 1.0))

# Connect last node to first reasoning step
G.add_edge(reasoning_nodes[-1], reasoning_steps[0])

# Connect reasoning steps in sequence
for i in range(len(reasoning_steps) - 1):
    G.add_edge(reasoning_steps[i], reasoning_steps[i+1])

# Connect last reasoning step to output
G.add_edge(reasoning_steps[-1], output_node)

# Draw the network
pos = nx.get_node_attributes(G, 'pos')
plt.figure(figsize=(12, 8))

# Draw different node groups with different colors
nx.draw_networkx_nodes(G, pos, nodelist=[input_node], node_color='lightblue', node_size=2000, alpha=0.8)
nx.draw_networkx_nodes(G, pos, nodelist=[encoder_node], node_color='lightgreen', node_size=2000, alpha=0.8)
nx.draw_networkx_nodes(G, pos, nodelist=reasoning_nodes, node_color='salmon', node_size=1500, alpha=0.8)
nx.draw_networkx_nodes(G, pos, nodelist=reasoning_steps, node_color='lightgray', node_size=1500, alpha=0.8)
nx.draw_networkx_nodes(G, pos, nodelist=[output_node], node_color='gold', node_size=2000, alpha=0.8)

# Draw edges with varying width based on weight
edge_weights = [G[u][v].get('weight', 0.5) for u, v in G.edges()]
nx.draw_networkx_edges(G, pos, width=edge_weights, alpha=0.6, arrowsize=10, arrowstyle='->')

# Draw labels
nx.draw_networkx_labels(G, pos, font_size=10, font_family='sans-serif')

plt.axis('off')
plt.title('Weight-of-Thought Architecture', fontsize=16)
plt.tight_layout()
plt.show()

## Loading a Pre-trained Model

Let's load a pre-trained WoT model and examine its structure.

In [None]:
# Initialize a WOT reasoner
wot_reasoner = WOTReasoner()

# Load pre-trained model if available
model_path = '../results/models/wot_model_final.pt'
if os.path.exists(model_path):
    print(f"Loading pre-trained model from {model_path}")
    wot_reasoner.load_model(model_path)
else:
    print("No pre-trained model found. Using a new model.")

# Print model information
print(f"\nModel Configuration:")
print(f"Hidden dimension: {wot_reasoner.wot_model.hidden_dim}")
print(f"Number of nodes: {wot_reasoner.wot_model.num_nodes}")
print(f"Number of reasoning steps: {wot_reasoner.wot_model.num_reasoning_steps}")
print(f"Device: {wot_reasoner.device}")

## Reasoning Tasks

The WoT model can handle various reasoning tasks. Let's see some examples:

In [None]:
# Example reasoning tasks
examples = [
    {
        'question': 'If all Bloops are Razzies and all Razzies are Wazzies, are all Bloops definitely Wazzies? Answer with Yes or No.',
        'type': 'syllogism',
        'answer': 'Yes',
        'description': 'Syllogistic reasoning with transitive property'
    },
    {
        'question': 'What is the next number in the sequence: 2, 4, 6, 8, 10, ...?',
        'type': 'math_sequence',
        'answer': '12',
        'description': 'Pattern recognition in arithmetic sequence'
    },
    {
        'question': 'John has 3 times as many apples as Mary. Together, they have 40 apples. How many apples does John have?',
        'type': 'algebra',
        'answer': '30',
        'description': 'Algebraic word problem with ratio'
    },
    {
        'question': 'In a room of 5 people, everyone shakes hands with everyone else exactly once. How many handshakes are there in total?',
        'type': 'combinatorics',
        'answer': '10',
        'description': 'Combinatorial counting problem'
    },
    {
        'question': 'Is every square a rectangle? Answer with Yes or No.',
        'type': 'geometry',
        'answer': 'Yes',
        'description': 'Geometric properties and classification'
    }
]

# Run inference on each example
for i, example in enumerate(examples):
    question = example['question']
    task_type = example['type']
    true_answer = example['answer']
    
    print(f"\nExample {i+1}: {example['description']}")
    print(f"Task Type: {task_type}")
    print(f"Question: {question}")
    print(f"True Answer: {true_answer}")
    
    # Run inference
    predicted_answer = wot_reasoner.infer(question, task_type)
    print(f"Predicted Answer: {predicted_answer}")
    print(f"Correct: {str(true_answer) == predicted_answer}")

## Visualizing the Reasoning Process

One of the key advantages of the WoT architecture is its interpretability. Let's visualize the reasoning process for one of our examples.

In [None]:
# Select an example for visualization
vis_example = examples[0]  # Syllogism example
question = vis_example['question']
task_type = vis_example['type']

# Run inference to populate attention weights
_ = wot_reasoner.infer(question, task_type)

# Create a visualization figure
fig = plt.figure(figsize=(15, 10))
gs = plt.GridSpec(2, 2, figure=fig)

# 1. Node Attention Weights
if hasattr(wot_reasoner.wot_model, 'node_attention_weights') and wot_reasoner.wot_model.node_attention_weights is not None:
    ax1 = fig.add_subplot(gs[0, 0])
    node_attention = wot_reasoner.wot_model.node_attention_weights.numpy()
    
    # Plot as bar chart
    ax1.bar(range(wot_reasoner.wot_model.num_nodes), node_attention.squeeze())
    ax1.set_xlabel('Node Index')
    ax1.set_ylabel('Attention Weight')
    ax1.set_title('Node Attention Weights')

# 2. Reasoning Step Attention Weights
if hasattr(wot_reasoner.wot_model, 'reasoning_attention_weights') and wot_reasoner.wot_model.reasoning_attention_weights is not None:
    ax2 = fig.add_subplot(gs[0, 1])
    reasoning_attention = wot_reasoner.wot_model.reasoning_attention_weights.numpy()
    
    # Plot as bar chart
    ax2.bar(range(wot_reasoner.wot_model.num_reasoning_steps), reasoning_attention.squeeze())
    ax2.set_xlabel('Reasoning Step')
    ax2.set_ylabel('Attention Weight')
    ax2.set_title('Reasoning Step Attention Weights')

# 3. Edge Attention Matrix
if hasattr(wot_reasoner.wot_model, 'edge_matrices') and wot_reasoner.wot_model.edge_matrices is not None:
    ax3 = fig.add_subplot(gs[1, 0])
    
    # Get the edge matrix from the last message passing iteration
    edge_matrix = wot_reasoner.wot_model.edge_matrices[-1].detach().cpu().numpy()
    
    # Average across batch dimension if needed
    if len(edge_matrix.shape) > 2:
        edge_matrix = np.mean(edge_matrix, axis=0)
    
    # Plot as heatmap
    sns.heatmap(edge_matrix, annot=True, fmt='.2f', cmap='viridis', ax=ax3)
    ax3.set_xlabel('To Node')
    ax3.set_ylabel('From Node')
    ax3.set_title('Edge Attention Matrix')

# 4. Graph Visualization
ax4 = fig.add_subplot(gs[1, 1])

# Create a directed graph
G = nx.DiGraph()

# Add nodes
for i in range(wot_reasoner.wot_model.num_nodes):
    G.add_node(i)

# Add edges with weights > threshold from the edge matrix
if hasattr(wot_reasoner.wot_model, 'edge_matrices') and wot_reasoner.wot_model.edge_matrices is not None:
    edge_matrix = wot_reasoner.wot_model.edge_matrices[-1].detach().cpu().numpy()
    if len(edge_matrix.shape) > 2:
        edge_matrix = np.mean(edge_matrix, axis=0)
    
    threshold = 0.1
    for i in range(wot_reasoner.wot_model.num_nodes):
        for j in range(wot_reasoner.wot_model.num_nodes):
            if i != j and edge_matrix[i, j] > threshold:
                G.add_edge(i, j, weight=edge_matrix[i, j])

# Get node sizes based on node attention
if hasattr(wot_reasoner.wot_model, 'node_attention_weights') and wot_reasoner.wot_model.node_attention_weights is not None:
    node_attention = wot_reasoner.wot_model.node_attention_weights.numpy()
    node_sizes = node_attention.squeeze() * 1000  # Scale for visualization
else:
    node_sizes = [300] * wot_reasoner.wot_model.num_nodes

# Get edge weights
edge_weights = [G[u][v]['weight'] * 2 for u, v in G.edges()]

# Draw the network
pos = nx.spring_layout(G, seed=42)
nx.draw_networkx_nodes(G, pos, node_size=node_sizes, node_color='skyblue', alpha=0.8, ax=ax4)
nx.draw_networkx_labels(G, pos, ax=ax4)
nx.draw_networkx_edges(G, pos, width=edge_weights, alpha=0.5, edge_color='gray', 
                       connectionstyle='arc3,rad=0.1', arrowsize=15, ax=ax4)

ax4.set_title('Reasoning Network for Syllogism Task')
ax4.axis('off')

plt.tight_layout()
plt.suptitle(f'Weight-of-Thought Reasoning Analysis for: "{question}"', fontsize=16, y=1.02)
plt.show()

## Conclusion

The Weight-of-Thought architecture offers several advantages over traditional Chain-of-Thought reasoning:

1. **Beyond sequential reasoning**: Represents reasoning as an interconnected graph rather than a linear chain
2. **Parallel processing**: Different aspects of reasoning can be processed simultaneously by different nodes
3. **Adaptive information flow**: Dynamic attention mechanisms focus on the most relevant connections
4. **Interpretability**: The network structure provides insights into the reasoning process
5. **Task specialization**: Different nodes can specialize for different types of reasoning tasks

In the next notebook, we'll explore how to train a WoT model on custom reasoning tasks.