# Graph Classes and Operations

This notebook demonstrates how to use the graph classes and operations in the causal meta-learning library.

In [None]:
# Import necessary modules
import sys
import os

# Add the root directory to the path to make imports work
root_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
if root_dir not in sys.path:
    sys.path.append(root_dir)

# Import common libraries
import numpy as np
import matplotlib.pyplot as plt

# Import the causal meta-learning library
from causal_meta.graph import Graph, DirectedGraph, CausalGraph
import causal_meta.graph.visualization as viz

## 1. Base Graph Class

The `Graph` class is the foundation for all graph implementations in the library.

In [None]:
# Create a simple graph
graph = Graph()

# Add nodes
graph.add_node('A', label='Node A', color='red')
graph.add_node('B', label='Node B', color='blue')
graph.add_node('C', label='Node C', color='green')

# Add edges
graph.add_edge('A', 'B', weight=0.5)
graph.add_edge('B', 'C', weight=1.0)

# Print graph info
print(f"Nodes: {graph.get_nodes()}")
print(f"Edges: {graph.get_edges()}")

## 2. Directed Graph Class

The `DirectedGraph` class extends the base `Graph` class to provide functionality for directed edges. It includes methods for working with predecessors, successors, paths, cycles, and topological sorting.

In [None]:
# Create a directed graph
digraph = DirectedGraph()

# Add nodes
for node in ['A', 'B', 'C', 'D', 'E']:
    digraph.add_node(node)

# Add directed edges
digraph.add_edge('A', 'B')
digraph.add_edge('A', 'C')
digraph.add_edge('B', 'D')
digraph.add_edge('C', 'D')
digraph.add_edge('D', 'E')

# Visualize the directed graph
plt.figure(figsize=(10, 6))
ax = plt.gca()
viz.plot_graph(digraph, ax=ax, title="Directed Graph Example")
plt.show()

### Working with Directed Graph Operations

In [None]:
# Examine predecessors and successors
print(f"Predecessors of 'D': {digraph.get_predecessors('D')}")
print(f"Successors of 'A': {digraph.get_successors('A')}")

# Check for paths
print(f"\nPath from 'A' to 'E': {digraph.has_path('A', 'E')}")
print(f"Path from 'E' to 'A': {digraph.has_path('E', 'A')}")

# Find all paths between nodes
paths = digraph.find_all_paths('A', 'D')
print(f"\nAll paths from 'A' to 'D':")
for i, path in enumerate(paths, 1):
    print(f"  Path {i}: {path}")

# Check for cycles
print(f"\nDoes the graph have cycles? {digraph.has_cycle()}")

# Create a graph with a cycle
cycle_graph = DirectedGraph()
for node in ['A', 'B', 'C']:
    cycle_graph.add_node(node)
cycle_graph.add_edge('A', 'B')
cycle_graph.add_edge('B', 'C')
cycle_graph.add_edge('C', 'A')
print(f"Does the cycle graph have cycles? {cycle_graph.has_cycle()}")

# Topological sort (works only on DAGs)
try:
    print(f"\nTopological sort of the digraph: {digraph.topological_sort()}")
    print(f"Topological sort of the cycle graph: {cycle_graph.topological_sort()}")
except ValueError as e:
    print(f"Topological sort error: {e}")

### Adjacency Matrix Representation

In [None]:
# Get the adjacency matrix representation
adjacency_matrix, node_indices = digraph.get_adjacency_matrix()

print("Node indices:")
print(node_indices)
print("\nAdjacency Matrix:")
print(adjacency_matrix)

# Visualize the adjacency matrix
plt.figure(figsize=(8, 6))
plt.imshow(adjacency_matrix, cmap='Blues')
plt.colorbar(label='Edge presence')
plt.title('Adjacency Matrix of Directed Graph')
plt.xticks(range(len(node_indices)), list(node_indices.keys()), rotation=45)
plt.yticks(range(len(node_indices)), list(node_indices.keys()))
plt.grid(False)
plt.show()

## 3. Causal Graph Class

The `CausalGraph` class extends the `DirectedGraph` class with causal semantics, including methods for identifying parents, children, Markov blankets, d-separation, and interventions.

In [None]:
# Create a causal graph with a fork structure: X <- Z -> Y
fork_graph = CausalGraph()
for node in ['X', 'Y', 'Z']:
    fork_graph.add_node(node)
fork_graph.add_edge('Z', 'X')
fork_graph.add_edge('Z', 'Y')

# Visualize the fork graph
plt.figure(figsize=(8, 6))
ax = plt.gca()
viz.plot_causal_graph(fork_graph, ax=ax, title="Fork Causal Structure (X ← Z → Y)")
plt.show()

# Create a causal graph with a collider structure: X -> Z <- Y
collider_graph = CausalGraph()
for node in ['X', 'Y', 'Z']:
    collider_graph.add_node(node)
collider_graph.add_edge('X', 'Z')
collider_graph.add_edge('Y', 'Z')

# Visualize the collider graph
plt.figure(figsize=(8, 6))
ax = plt.gca()
viz.plot_causal_graph(collider_graph, ax=ax, title="Collider Causal Structure (X → Z ← Y)")
plt.show()

### Working with Causal Relations

In [None]:
# Examining causal relationships in the fork structure
print("=== Fork Structure (X ← Z → Y) ===")
print(f"Parents of X: {fork_graph.get_parents('X')}")
print(f"Children of Z: {fork_graph.get_children('Z')}")
print(f"Markov blanket of Z: {fork_graph.get_markov_blanket('Z')}")
print(f"Is Z a confounder for X and Y? {fork_graph.is_confounder('X', 'Y')}")

# D-separation in the fork structure
print(f"\nAre X and Y d-separated? {fork_graph.is_d_separated('X', 'Y')}")
print(f"Are X and Y d-separated given Z? {fork_graph.is_d_separated('X', 'Y', {'Z'})}")

# Examining causal relationships in the collider structure
print("\n=== Collider Structure (X → Z ← Y) ===")
print(f"Parents of Z: {collider_graph.get_parents('Z')}")
print(f"Is Z a collider? {collider_graph.is_collider('Z')}")

# D-separation in the collider structure
print(f"\nAre X and Y d-separated? {collider_graph.is_d_separated('X', 'Y')}")
print(f"Are X and Y d-separated given Z? {collider_graph.is_d_separated('X', 'Y', {'Z'})}")

### Intervening on Causal Graphs

In [None]:
# Create a more complex causal graph
complex_graph = CausalGraph()
for node in ['A', 'B', 'C', 'D', 'E']:
    complex_graph.add_node(node)
complex_graph.add_edge('A', 'B')
complex_graph.add_edge('A', 'C')
complex_graph.add_edge('B', 'D')
complex_graph.add_edge('C', 'D')
complex_graph.add_edge('C', 'E')

# Visualize the original graph
plt.figure(figsize=(10, 6))
ax = plt.gca()
viz.plot_causal_graph(complex_graph, ax=ax, title="Original Causal Graph")
plt.show()

# Perform a do-intervention on node 'A'
intervened_graph = complex_graph.do_intervention('A')

# Visualize the intervened graph
plt.figure(figsize=(10, 6))
ax = plt.gca()
viz.plot_causal_graph(intervened_graph, ax=ax, title="Causal Graph after do(A)")
plt.show()

# Check how the intervention affected the graph structure
print(f"Original graph has edge A→B: {complex_graph.has_edge('A', 'B')}")
print(f"Intervened graph has edge A→B: {intervened_graph.has_edge('A', 'B')}")
print(f"\nParents of 'A' in original graph: {complex_graph.get_parents('A')}")
print(f"Parents of 'A' in intervened graph: {intervened_graph.get_parents('A')}")

### Backdoor Paths and Adjustment Sets

In [None]:
# Create a causal graph with backdoor paths
backdoor_graph = CausalGraph()
for node in ['X', 'Y', 'Z1', 'Z2', 'Z3']:
    backdoor_graph.add_node(node)
backdoor_graph.add_edge('X', 'Y')  # Direct causal effect
backdoor_graph.add_edge('Z1', 'X')  # Backdoor path through Z1
backdoor_graph.add_edge('Z1', 'Y')
backdoor_graph.add_edge('Z2', 'Z1')  # Longer backdoor path
backdoor_graph.add_edge('Z2', 'Z3')
backdoor_graph.add_edge('Z3', 'Y')

# Visualize the backdoor graph
plt.figure(figsize=(10, 6))
ax = plt.gca()
viz.plot_causal_graph(backdoor_graph, ax=ax, title="Causal Graph with Backdoor Paths")
plt.show()

# Find backdoor paths
backdoor_paths = backdoor_graph.get_backdoor_paths('X', 'Y')
print("Backdoor paths from X to Y:")
for i, path in enumerate(backdoor_paths, 1):
    print(f"  Path {i}: {path}")

# Check if certain sets are valid adjustment sets
print("\nAdjustment set validity:")
print(f"Is {{Z1}} a valid adjustment set? {backdoor_graph.is_valid_adjustment_set('X', 'Y', {'Z1'})}")
print(f"Is {{Z2}} a valid adjustment set? {backdoor_graph.is_valid_adjustment_set('X', 'Y', {'Z2'})}")
print(f"Is {{Z1, Z3}} a valid adjustment set? {backdoor_graph.is_valid_adjustment_set('X', 'Y', {'Z1', 'Z3'})}")

## 4. Advanced Graph Operations

Now let's explore some more advanced operations on graphs, including graph traversal, component analysis, and structural properties.

In [None]:
# Create a disconnected directed graph
disconnected_graph = DirectedGraph()
for node in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']:
    disconnected_graph.add_node(node)

# Add edges to create two separate components
disconnected_graph.add_edge('A', 'B')
disconnected_graph.add_edge('B', 'C')
disconnected_graph.add_edge('C', 'A')

disconnected_graph.add_edge('E', 'F')
disconnected_graph.add_edge('F', 'G')
disconnected_graph.add_edge('G', 'H')
disconnected_graph.add_edge('H', 'E')

# Visualize the disconnected graph
plt.figure(figsize=(12, 8))
ax = plt.gca()
viz.plot_graph(disconnected_graph, ax=ax, title="Disconnected Directed Graph")
plt.show()

# Check if the graph is connected (it shouldn't be)
print(f"Is the graph connected? {disconnected_graph.is_connected()}")

# Find connected components
components = disconnected_graph.get_connected_components()
print("\nConnected components:")
for i, component in enumerate(components, 1):
    print(f"  Component {i}: {component}")

# Check for strongly connected components (cycles)
strongly_connected = disconnected_graph.get_strongly_connected_components()
print("\nStrongly connected components (cycles):")
for i, component in enumerate(strongly_connected, 1):
    print(f"  Component {i}: {component}")

## 5. Path Finding and Analysis

Let's explore path finding in more complex graphs and analyze different types of paths.

In [None]:
# Create a more complex causal graph for path analysis
path_graph = CausalGraph()
nodes = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
for node in nodes:
    path_graph.add_node(node)

# Add edges to create different path options
path_graph.add_edge('A', 'B')
path_graph.add_edge('A', 'C')
path_graph.add_edge('B', 'D')
path_graph.add_edge('C', 'D')
path_graph.add_edge('C', 'E')
path_graph.add_edge('D', 'F')
path_graph.add_edge('E', 'F')
path_graph.add_edge('F', 'G')
path_graph.add_edge('B', 'G')  # Direct but longer hop count

# Visualize the path graph
plt.figure(figsize=(12, 8))
ax = plt.gca()
viz.plot_causal_graph(path_graph, ax=ax, title="Complex Path Graph")
plt.show()

# Find all paths between distant nodes
all_paths = path_graph.find_all_paths('A', 'G')
print(f"Found {len(all_paths)} different paths from A to G:")
for i, path in enumerate(all_paths, 1):
    print(f"  Path {i}: {path} (length: {len(path)-1})")

# Find shortest path
shortest_path = path_graph.find_shortest_path('A', 'G')
print(f"\nShortest path from A to G: {shortest_path}")

# Highlight a specific path in the visualization
highlighted_path = ['A', 'B', 'G']
path_edges = [(highlighted_path[i], highlighted_path[i+1]) for i in range(len(highlighted_path)-1)]

plt.figure(figsize=(12, 8))
ax = plt.gca()
viz.plot_causal_graph(path_graph, ax=ax, title="Graph with Highlighted Path", 
                     highlight_edges=path_edges, highlight_nodes=highlighted_path,
                     highlight_edge_color='red', highlight_node_color='yellow')
plt.show()

## Summary

In this notebook, we explored the graph classes provided by the causal meta-learning library:

1. The base `Graph` class provides fundamental operations for working with nodes, edges, and their attributes.
2. The `DirectedGraph` class extends this with directed edge semantics, path finding, cycle detection, and topological sorting.
3. The `CausalGraph` class adds causal reasoning capabilities including d-separation, Markov blanket identification, and intervention operations.

These graph implementations provide a solid foundation for causal modeling, inference, and optimization tasks in the library. In the next notebooks, we'll explore graph visualization, generation, and causal environments in more detail.