# 9. Visualization

This notebook showcases various visualization techniques for exploring and understanding the mathematical concepts and theorems discovered by our system.

## 9.1 Importing Required Modules

In [None]:
import sys
import os

# Add the src directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', 'src')))

from probabilistic_model import mathematical_concept_model
from symbolic_reasoning import SymbolicReasoning
from structure_learning import learn_concept_structure

import torch
import sympy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

print("Imports complete!")

## 9.2 Generating Data for Visualization

In [None]:
# Generate concepts
input_data = torch.randn(1000)
concepts, observations = mathematical_concept_model(input_data)

# Flatten the list of concepts
flat_concepts = [c for level in concepts for c in level]

print(f"Generated {len(flat_concepts)} concepts.")

## 9.3 Visualizing Concept Spaces

In [None]:
def visualize_concept_space_2d(concepts, n_points=1000):
    x = np.linspace(-5, 5, n_points)
    
    plt.figure(figsize=(12, 8))
    for i, concept in enumerate(concepts[:5]):  # Plot first 5 concepts
        y = concept(torch.tensor(x)).detach().numpy()
        plt.plot(x, y, label=f'Concept {i}')
    
    plt.title('2D Visualization of Concept Space')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.legend()
    plt.grid(True)
    plt.show()

visualize_concept_space_2d(flat_concepts)

def visualize_concept_space_3d(concepts, n_points=50):
    x = np.linspace(-5, 5, n_points)
    y = np.linspace(-5, 5, n_points)
    X, Y = np.meshgrid(x, y)
    
    fig = plt.figure(figsize=(15, 5))
    
    for i, concept in enumerate(concepts[:3]):  # Plot first 3 concepts
        ax = fig.add_subplot(1, 3, i+1, projection='3d')
        Z = concept(torch.tensor(np.stack([X.flatten(), Y.flatten()])).T).reshape(X.shape).detach().numpy()
        surf = ax.plot_surface(X, Y, Z, cmap='viridis')
        ax.set_title(f'Concept {i}')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_zlabel('z')
        fig.colorbar(surf, ax=ax, shrink=0.5, aspect=5)
    
    plt.tight_layout()
    plt.show()

visualize_concept_space_3d(flat_concepts)

## 9.4 Visualizing Concept Hierarchy

In [None]:
def visualize_concept_hierarchy(concepts):
    G = nx.DiGraph()
    
    for level, level_concepts in enumerate(concepts):
        for i, concept in enumerate(level_concepts):
            node_id = f"L{level}C{i}"
            G.add_node(node_id, label=f"L{level}C{i}\n{concept}")
            
            if level > 0:
                # Connect to all concepts in the previous level
                for j in range(len(concepts[level-1])):
                    G.add_edge(f"L{level-1}C{j}", node_id)
    
    pos = nx.spring_layout(G)
    plt.figure(figsize=(15, 10))
    nx.draw(G, pos, with_labels=True, node_color='lightblue', 
            node_size=3000, font_size=8, arrows=True)
    
    labels = nx.get_node_attributes(G, 'label')
    nx.draw_networkx_labels(G, pos, labels, font_size=6)
    
    plt.title("Concept Hierarchy")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

visualize_concept_hierarchy(concepts)

## 9.5 Visualizing Theorem Networks

In [None]:
def generate_theorems(concepts, n_theorems=10):
    sr = SymbolicReasoning()
    theorems = []
    for _ in range(n_theorems):
        theorem = sr.generate_theorem(np.random.choice(flat_concepts, size=2, replace=False))
        theorems.append(theorem)
    return theorems

def visualize_theorem_network(concepts, theorems):
    G = nx.Graph()
    
    # Add concepts as nodes
    for i, concept in enumerate(concepts):
        G.add_node(f"C{i}", label=f"C{i}\n{concept}", type='concept')
    
    # Add theorems as nodes and create edges to related concepts
    for i, theorem in enumerate(theorems):
        G.add_node(f"T{i}", label=f"T{i}\n{theorem}", type='theorem')
        for j, concept in enumerate(concepts):
            if str(concept) in str(theorem):
                G.add_edge(f"C{j}", f"T{i}")
    
    pos = nx.spring_layout(G)
    plt.figure(figsize=(15, 10))
    
    # Draw concept nodes
    nx.draw_networkx_nodes(G, pos, nodelist=[n for n in G.nodes() if G.nodes[n]['type']=='concept'],
                           node_color='lightblue', node_size=3000)
    
    # Draw theorem nodes
    nx.draw_networkx_nodes(G, pos, nodelist=[n for n in G.nodes() if G.nodes[n]['type']=='theorem'],
                           node_color='lightgreen', node_size=4000)
    
    nx.draw_networkx_edges(G, pos)
    
    labels = nx.get_node_attributes(G, 'label')
    nx.draw_networkx_labels(G, pos, labels, font_size=6)
    
    plt.title("Theorem Network")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

theorems = generate_theorems(flat_concepts)
visualize_theorem_network(flat_concepts, theorems)

## 9.6 Interactive Concept Explorer

In [None]:
from ipywidgets import interact, interactive, fixed
import ipywidgets as widgets

def plot_concept(concept_index, x_min, x_max):
    concept = flat_concepts[concept_index]
    x = torch.linspace(x_min, x_max, 1000)
    y = concept(x).detach().numpy()
    
    plt.figure(figsize=(10, 6))
    plt.plot(x, y)
    plt.title(f'Concept {concept_index}: {concept}')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.grid(True)
    plt.show()

interact(plot_concept, 
         concept_index=widgets.IntSlider(min=0, max=len(flat_concepts)-1, step=1, value=0),
         x_min=widgets.FloatSlider(min=-10, max=0, step=0.1, value=-5),
         x_max=widgets.FloatSlider(min=0, max=10, step=0.1, value=5))

## 9.7 Concept Similarity Heatmap

In [None]:
def compute_concept_similarity(concepts, n_points=1000):
    x = torch.linspace(-5, 5, n_points)
    outputs = torch.stack([concept(x) for concept in concepts])
    similarity = torch.corrcoef(outputs)
    return similarity.detach().numpy()

def plot_concept_similarity(concepts):
    similarity = compute_concept_similarity(concepts)
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(similarity, annot=False, cmap='viridis')
    plt.title('Concept Similarity Heatmap')
    plt.xlabel('Concept Index')
    plt.ylabel('Concept Index')
    plt.show()

plot_concept_similarity(flat_concepts[:50])  # Plot for first 50 concepts for better visibility

This notebook demonstrates various visualization techniques for exploring and understanding the mathematical concepts and theorems discovered by the system. I've shown how to:

1. Visualize concept spaces in 2D and 3D, allowing us to see how different concepts behave over their input domains.
2. Visualize the concept hierarchy, showing how more complex concepts are built from simpler ones.
3. Create a theorem network, illustrating the relationships between concepts and the theorems that connect them.
4. Provide an interactive concept explorer, allowing for dynamic exploration of individual concepts.
5. Generate a concept similarity heatmap, helping to identify groups of related concepts.

These visualizations serve several important purposes:
- They help us understand the structure and relationships within our discovered mathematical knowledge.
- They can reveal patterns or clusters of similar concepts, potentially leading to new insights.
- They make our system's outputs more interpretable and accessible to human mathematicians.
- They can guide further exploration by highlighting interesting or unusual concepts or relationships.

By leveraging these visualization techniques, we can gain deeper insights into the mathematical structures discovered by the system, potentially leading to new mathematical understanding or avenues for further research.