# CLEANN: Causal Lens for Explaining Attention Networks

This notebook demonstrates:
1. Loading attention matrices from transformer models
2. Applying the **CLEANN** algorithm to extract causal graphs from attention patterns
3. Generating minimal explanations for target tokens/features
4. Visualizing learned causal structures and explanation sets

**Paper**: [CLEANN](https://arxiv.org/abs/2310.20307)

**Source**: Intel Labs Causality Lab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/arberzela/CircuitPFN/blob/main/notebooks/colab_cleann_causal_explanation.ipynb)

## Setup & Installation

First, let's install all required dependencies including the Intel Labs causality-lab library.

In [None]:
# Install dependencies
!pip install torch numpy scipy scikit-learn pandas matplotlib seaborn networkx
!pip install tabpfn
!pip install openml

# Install Intel Labs causality-lab
!pip install git+https://github.com/IntelLabs/causality-lab.git

# Clone the CircuitPFN repository (if running in Colab)
import os
if not os.path.exists('CircuitPFN'):
    !git clone https://github.com/arberzela/CircuitPFN.git
    %cd CircuitPFN
else:
    print("Repository already exists")

## Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from tabpfn import TabPFNClassifier

# Import from causality-lab
from causal_discovery_algs import LearnStructICD
from causal_discovery_algs.icd import create_pds_tree
from causal_discovery_utils.cond_indep_tests import CondIndepParCorr
from causal_discovery_utils.stat_utils import cov_to_corr

# Import from the CircuitPFN repository
import sys
sys.path.append('.')

from attn_scm.attention import AttentionExtractor
from utils.data_generation import generate_synthetic_dataset

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

# Set random seed for reproducibility
np.random.seed(42)

print("✓ All imports successful!")

## CLEANN Implementation

The CLEANN algorithm extracts causal explanations from attention patterns:
1. Converts attention matrices to correlation matrices
2. Learns causal graph structure using conditional independence tests
3. Generates minimal explanation sets for target nodes

In [None]:
class CLEANN:
    """CLEANN: Causal Lens for Explaining Attention Networks
    
    Adapted from Intel Labs causality-lab:
    https://github.com/IntelLabs/causality-lab/blob/main/causal_reasoning/cleann_explainer.py
    """
    
    def __init__(self, attention_matrix: np.ndarray, num_samples: int, 
                 p_val_th: float = 0.05, explanation_tester=None, 
                 nodes_set=None, search_minimal=True, 
                 structure_learning_class=LearnStructICD):
        """
        Initialize CLEANN explainer.
        
        Parameters:
        -----------
        attention_matrix : np.ndarray
            Attention weights matrix (n_features x n_features)
        num_samples : int
            Number of samples used to generate attention (for statistical tests)
        p_val_th : float
            P-value threshold for conditional independence tests
        explanation_tester : callable, optional
            Function to validate if a set is a valid explanation
        nodes_set : set, optional
            Set of node indices to consider (default: all nodes)
        search_minimal : bool
            If True, search for minimal explanation sets only
        structure_learning_class : class
            Structure learning algorithm class (default: ICD)
        """
        # Calculate correlation matrix from attention matrix
        cov_matrix = np.matmul(attention_matrix, attention_matrix.transpose())
        corr_mat = cov_to_corr(cov_matrix)
        
        # Prepare for learning a graph
        num_vars, _ = corr_mat.shape
        if nodes_set is None:
            nodes_set = set(range(num_vars))
        self.nodes_set = nodes_set
        
        # Setup conditional independence test
        self.ci_test = CondIndepParCorr(
            threshold=p_val_th, 
            dataset=None, 
            num_records=num_samples, 
            num_vars=num_vars, 
            count_tests=True, 
            use_cache=True
        )
        self.ci_test.correlation_matrix = corr_mat
        
        self.StructureLearning = structure_learning_class
        self.graph = None
        
        # Initialize for evaluating explanations
        self.results = dict()
        self.is_explanation = explanation_tester
        self._search_minimal = search_minimal
        
        # Store correlation matrix for visualization
        self.correlation_matrix = corr_mat
    
    def learn_graph(self):
        """Learn causal graph structure using ICD algorithm."""
        icd_alg = self.StructureLearning(
            nodes_set=self.nodes_set, 
            ci_test=self.ci_test
        )
        icd_alg.learn_structure()
        return icd_alg.graph
    
    def explain(self, target_node_idx: int, max_set_size=None, max_range=None):
        """
        Identify minimal explanation set for the target node.
        
        Parameters:
        -----------
        target_node_idx : int
            Index of the node to explain
        max_set_size : int, optional
            Maximum size of explanation sets to consider
        max_range : int, optional
            Maximum distance in PDS tree to search
            
        Returns:
        --------
        list
            List of explanation sets (tuples of node indices and depths)
        """
        # Learn a Graph if one hasn't been learned already
        if self.graph is None:
            self.graph = self.learn_graph()
        
        # Create a PDS-tree rooted at the target node
        pds_tree, full_explain_set = create_pds_tree(
            self.graph, target_node_idx, max_depth=max_range
        )
        max_pds_tree_depth = pds_tree.get_max_depth()
        
        results = dict()
        results['pds_tree'] = pds_tree
        results['full_explanation_set'] = full_explain_set
        results['max_pds_tree_depth'] = max_pds_tree_depth
        
        if max_set_size is None:
            max_size = len(full_explain_set)
        else:
            max_size = max_set_size
        
        explanations_list = []
        if self.is_explanation is None:
            if len(full_explain_set) <= max_size:
                explanations_list.append([full_explain_set, max_size])
        else:
            found_explanation = False
            for set_size in range(1, max_size+1):
                sets_list = pds_tree.get_subsets_list(
                    set_nodes=full_explain_set, subset_size=set_size
                )
                sets_list.sort(key=lambda x: x[1])
                for possible_explanation_set in sets_list:
                    if self.is_explanation(list(possible_explanation_set[0]), 
                                         target_node_idx):
                        explanations_list.append(possible_explanation_set)
                        found_explanation = True
                if found_explanation and self._search_minimal:
                    break
        
        results['explanations'] = explanations_list
        self.results[target_node_idx] = results
        return explanations_list
    
    def get_adjacency_matrix(self):
        """Convert learned graph to adjacency matrix."""
        if self.graph is None:
            self.graph = self.learn_graph()
        
        num_nodes = len(self.nodes_set)
        adj_matrix = np.zeros((num_nodes, num_nodes))
        
        for node in self.graph.nodes:
            parents = self.graph.parents[node]
            for parent in parents:
                adj_matrix[parent, node] = 1
        
        return adj_matrix

print("✓ CLEANN class defined successfully!")

## 1. Load Pretrained TabPFN Classifier

We'll use TabPFN to extract attention patterns that CLEANN will analyze.

In [None]:
# Initialize pretrained TabPFN classifier
tabpfn_model = TabPFNClassifier(device='cpu', N_ensemble_configurations=4)

print("✓ TabPFN model loaded successfully!")
print(f"  Device: {tabpfn_model.device}")
print(f"  Max samples: 1000")
print(f"  Max features: 100")
print("\nTabPFN uses pre-trained weights from meta-learning on synthetic datasets.")

## 2. Load Tabular Datasets

We'll use three datasets to demonstrate CLEANN:
- **Dataset 1**: Synthetic data with known causal structure (Linear-Gaussian SCM)
- **Dataset 2**: Synthetic data with non-linear relationships
- **Dataset 3**: Real-world dataset from OpenML

### Dataset 1: Synthetic Linear-Gaussian SCM

In [None]:
# Generate synthetic dataset with known causal graph
X1, y1, true_adj1 = generate_synthetic_dataset(
    n_nodes=10,              # 10 features
    n_samples=500,           # 500 samples
    edge_prob=0.3,           # 30% edge probability
    scm_type='linear_gaussian',  # Linear Gaussian SCM
    seed=42
)

feature_names1 = [f'X{i}' for i in range(10)]

print("Dataset 1: Synthetic Linear-Gaussian SCM")
print(f"  Shape: {X1.shape}")
print(f"  Features: {X1.shape[1]}")
print(f"  Samples: {X1.shape[0]}")
print(f"  True edges: {int(true_adj1.sum())}")
print(f"  Graph density: {true_adj1.sum() / (10*10):.2%}")

# Visualize ground truth
fig, ax = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap
sns.heatmap(true_adj1, cmap='Blues', square=True, ax=ax[0],
            xticklabels=feature_names1, yticklabels=feature_names1,
            cbar_kws={'label': 'Edge'})
ax[0].set_title('Ground Truth Causal Graph (Adjacency Matrix)', fontweight='bold')
ax[0].set_xlabel('Target')
ax[0].set_ylabel('Source')

# Network graph
G1 = nx.DiGraph(true_adj1)
pos1 = nx.spring_layout(G1, seed=42, k=2)
nx.draw_networkx_nodes(G1, pos1, node_color='lightblue', node_size=800, ax=ax[1])
nx.draw_networkx_edges(G1, pos1, edge_color='gray', arrows=True, 
                       arrowsize=15, arrowstyle='->', ax=ax[1])
nx.draw_networkx_labels(G1, pos1, {i: feature_names1[i] for i in range(10)}, 
                       font_size=10, font_weight='bold', ax=ax[1])
ax[1].set_title('Ground Truth Causal DAG', fontweight='bold')
ax[1].axis('off')

plt.tight_layout()
plt.show()

### Dataset 2: Synthetic Non-linear SCM

In [None]:
# Generate non-linear dataset
X2, y2, true_adj2 = generate_synthetic_dataset(
    n_nodes=8,
    n_samples=400,
    edge_prob=0.35,
    scm_type='nonlinear_anm',  # Non-linear Additive Noise Model
    seed=123
)

feature_names2 = [f'F{i}' for i in range(8)]

print("Dataset 2: Synthetic Non-linear SCM")
print(f"  Shape: {X2.shape}")
print(f"  True edges: {int(true_adj2.sum())}")
print(f"  Mechanism: Polynomial + Sigmoid functions")

### Dataset 3: Real-world Dataset (OpenML)

In [None]:
import openml
from sklearn.preprocessing import LabelEncoder

# Load a small real-world dataset from OpenML
# Using diabetes dataset (37) which is small and well-known
dataset = openml.datasets.get_dataset(37)  # Pima Indians Diabetes
X3_full, y3_full, categorical, feature_names3 = dataset.get_data(
    dataset_format='array',
    target=dataset.default_target_attribute
)

# Encode labels
le = LabelEncoder()
y3_full = le.fit_transform(y3_full)

# Limit to TabPFN constraints (max 1000 samples)
X3 = X3_full[:500]
y3 = y3_full[:500]

print("Dataset 3: Real-world (Pima Indians Diabetes)")
print(f"  Shape: {X3.shape}")
print(f"  Features: {list(feature_names3)}")
print(f"  Classes: {np.unique(y3)}")
print(f"  Note: Ground truth causal graph is unknown for real-world data")

## 3. Extract Attention Patterns from TabPFN

First, we need to extract attention matrices from TabPFN's transformer layers.

In [None]:
# Initialize attention extractor
attention_extractor1 = AttentionExtractor(tabpfn_model)

# Fit TabPFN and extract attention patterns
print("Extracting attention patterns from TabPFN...")
print("="*60)

attention_dict1 = attention_extractor1.extract_attention(X1, y1)

print("\n" + "="*60)
print("✓ Attention extraction complete!")
print(f"\nExtracted attention from {len(attention_dict1)} layers")
print(f"Attention shape per layer: {attention_dict1[0].shape}")
print(f"  (num_heads, num_features, num_features)")

### Aggregate Attention Across Layers and Heads

CLEANN works with a single attention matrix, so we'll aggregate across layers and heads.

In [None]:
def aggregate_attention(attention_dict, method='mean'):
    """
    Aggregate attention matrices across layers and heads.
    
    Parameters:
    -----------
    attention_dict : dict
        Dictionary mapping layer indices to attention tensors
    method : str
        Aggregation method: 'mean', 'max', or 'sum'
    
    Returns:
    --------
    np.ndarray
        Aggregated attention matrix (n_features x n_features)
    """
    all_attention = []
    
    for layer_idx in attention_dict:
        layer_attn = attention_dict[layer_idx]  # (heads, d, d)
        # Average across heads
        avg_attn = layer_attn.mean(axis=0)  # (d, d)
        all_attention.append(avg_attn)
    
    # Stack all layers
    stacked = np.stack(all_attention, axis=0)  # (layers, d, d)
    
    # Aggregate across layers
    if method == 'mean':
        aggregated = stacked.mean(axis=0)
    elif method == 'max':
        aggregated = stacked.max(axis=0)
    elif method == 'sum':
        aggregated = stacked.sum(axis=0)
    else:
        raise ValueError(f"Unknown aggregation method: {method}")
    
    return aggregated

# Aggregate attention for Dataset 1
aggregated_attention1 = aggregate_attention(attention_dict1, method='mean')

print(f"Aggregated attention matrix shape: {aggregated_attention1.shape}")
print(f"Attention values range: [{aggregated_attention1.min():.4f}, {aggregated_attention1.max():.4f}]")

# Visualize aggregated attention
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(aggregated_attention1, cmap='viridis', square=True, ax=ax,
            xticklabels=feature_names1, yticklabels=feature_names1,
            cbar_kws={'label': 'Attention Weight'})
ax.set_title('Aggregated Attention Matrix (Dataset 1)\nAveraged across layers and heads', 
             fontweight='bold')
ax.set_xlabel('To Feature')
ax.set_ylabel('From Feature')
plt.tight_layout()
plt.show()

## 4. Apply CLEANN Algorithm

Now we'll apply CLEANN to learn the causal graph structure from attention patterns.

### 4.1 Apply CLEANN to Dataset 1 (Linear-Gaussian)

In [None]:
# Initialize CLEANN with aggregated attention
cleann_model1 = CLEANN(
    attention_matrix=aggregated_attention1,
    num_samples=X1.shape[0],
    p_val_th=0.05,  # p-value threshold for conditional independence
    search_minimal=True
)

print("Applying CLEANN to Dataset 1...")
print("="*60)

# Learn causal graph structure
learned_graph1 = cleann_model1.learn_graph()
learned_adj1 = cleann_model1.get_adjacency_matrix()

print("\n" + "="*60)
print("✓ Causal graph learning complete!")
print(f"\nPredicted edges: {int(learned_adj1.sum())}")
print(f"True edges: {int(true_adj1.sum())}")

# Visualize correlation matrix (intermediate step)
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Original attention
sns.heatmap(aggregated_attention1, cmap='viridis', square=True, ax=axes[0],
            xticklabels=feature_names1, yticklabels=feature_names1,
            cbar_kws={'label': 'Attention'})
axes[0].set_title('Aggregated Attention Matrix', fontweight='bold')

# Correlation matrix
sns.heatmap(cleann_model1.correlation_matrix, cmap='RdBu_r', square=True, ax=axes[1],
            xticklabels=feature_names1, yticklabels=feature_names1,
            vmin=-1, vmax=1, center=0,
            cbar_kws={'label': 'Correlation'})
axes[1].set_title('Correlation Matrix\n(Computed from Attention)', fontweight='bold')

# Learned causal graph
sns.heatmap(learned_adj1, cmap='Blues', square=True, ax=axes[2],
            xticklabels=feature_names1, yticklabels=feature_names1,
            cbar_kws={'label': 'Edge'})
axes[2].set_title('Learned Causal Graph\n(via ICD Algorithm)', fontweight='bold')

for ax in axes:
    ax.set_xlabel('To Feature')
    ax.set_ylabel('From Feature')

plt.tight_layout()
plt.show()

### 4.2 Apply CLEANN to Dataset 2 (Non-linear)

In [None]:
# Extract attention for Dataset 2
attention_extractor2 = AttentionExtractor(tabpfn_model)
attention_dict2 = attention_extractor2.extract_attention(X2, y2)
aggregated_attention2 = aggregate_attention(attention_dict2, method='mean')

# Apply CLEANN
cleann_model2 = CLEANN(
    attention_matrix=aggregated_attention2,
    num_samples=X2.shape[0],
    p_val_th=0.05,
    search_minimal=True
)

print("Applying CLEANN to Dataset 2...")
learned_graph2 = cleann_model2.learn_graph()
learned_adj2 = cleann_model2.get_adjacency_matrix()
print(f"✓ Complete! Predicted edges: {int(learned_adj2.sum())}")

### 4.3 Apply CLEANN to Dataset 3 (Real-world)

In [None]:
# Extract attention for Dataset 3
attention_extractor3 = AttentionExtractor(tabpfn_model)
attention_dict3 = attention_extractor3.extract_attention(X3, y3)
aggregated_attention3 = aggregate_attention(attention_dict3, method='mean')

# Apply CLEANN
cleann_model3 = CLEANN(
    attention_matrix=aggregated_attention3,
    num_samples=X3.shape[0],
    p_val_th=0.05,
    search_minimal=True
)

print("Applying CLEANN to Dataset 3 (Diabetes)...")
learned_graph3 = cleann_model3.learn_graph()
learned_adj3 = cleann_model3.get_adjacency_matrix()
print(f"✓ Complete! Predicted edges: {int(learned_adj3.sum())}")

## 5. Generate Explanations for Target Features

CLEANN can identify minimal explanation sets for specific target features.

### 5.1 Explain Target Features in Dataset 1

In [None]:
# Select target features to explain
target_features = [0, 5, 9]  # X0, X5, X9

print("Generating explanations for target features...")
print("="*60)

for target_idx in target_features:
    print(f"\nTarget: {feature_names1[target_idx]} (index {target_idx})")
    print("-" * 60)
    
    # Get explanation set
    explanations = cleann_model1.explain(target_idx, max_set_size=5)
    
    if explanations:
        for i, (explain_set, depth) in enumerate(explanations):
            explain_names = [feature_names1[idx] for idx in explain_set]
            print(f"  Explanation {i+1}:")
            print(f"    Set: {explain_names}")
            print(f"    Size: {len(explain_set)}")
            print(f"    Max depth in PDS tree: {depth}")
    else:
        print("  No explanation found (isolated node or root)")

print("\n" + "="*60)

### Visualize Explanation Set

In [None]:
# Visualize explanation for a specific target
target_to_visualize = 5  # X5

if target_to_visualize in cleann_model1.results:
    result = cleann_model1.results[target_to_visualize]
    full_explain_set = result['full_explanation_set']
    
    # Create subgraph including target and its explanation set
    nodes_to_show = sorted([target_to_visualize] + list(full_explain_set))
    subgraph_adj = learned_adj1[np.ix_(nodes_to_show, nodes_to_show)]
    subgraph_names = [feature_names1[i] for i in nodes_to_show]
    
    fig, ax = plt.subplots(figsize=(12, 9))
    
    G_sub = nx.DiGraph(subgraph_adj)
    pos_sub = nx.spring_layout(G_sub, seed=42, k=2)
    
    # Color nodes: target in red, explanatory features in lightblue
    node_colors = ['red' if nodes_to_show[i] == target_to_visualize 
                   else 'lightblue' for i in range(len(nodes_to_show))]
    
    nx.draw_networkx_nodes(G_sub, pos_sub, node_color=node_colors, 
                          node_size=1500, ax=ax)
    nx.draw_networkx_edges(G_sub, pos_sub, edge_color='gray', 
                          arrows=True, arrowsize=20, arrowstyle='->', 
                          width=2, connectionstyle='arc3,rad=0.1', ax=ax)
    nx.draw_networkx_labels(G_sub, pos_sub, 
                           {i: subgraph_names[i] for i in range(len(subgraph_names))},
                           font_size=11, font_weight='bold', ax=ax)
    
    ax.set_title(f"Explanation Set for '{feature_names1[target_to_visualize]}'\n" +
                f"(Red = Target, Blue = Explanatory Features)",
                fontsize=14, fontweight='bold')
    ax.axis('off')
    plt.tight_layout()
    plt.show()
    
    print(f"\nExplanation for {feature_names1[target_to_visualize]}:")
    print(f"  Full explanation set: {[feature_names1[i] for i in full_explain_set]}")
    print(f"  Set size: {len(full_explain_set)}")
else:
    print(f"No explanation generated for {feature_names1[target_to_visualize]}")

### 5.2 Explain Features in Real-world Dataset

In [None]:
# Explain all features in the diabetes dataset
print("Generating explanations for Diabetes dataset features:")
print("="*60)

for target_idx in range(len(feature_names3)):
    feature_name = feature_names3[target_idx]
    print(f"\n{feature_name}:")
    
    explanations = cleann_model3.explain(target_idx, max_set_size=5)
    
    if explanations:
        for i, (explain_set, depth) in enumerate(explanations[:1]):  # Show first explanation
            explain_names = [feature_names3[idx] for idx in explain_set]
            print(f"  Explained by: {explain_names} (size: {len(explain_set)})")
    else:
        print("  No explanation (isolated or root node)")

print("\n" + "="*60)

## 6. Visualize Learned Causal Graphs

Compare the learned graphs with ground truth (where available).

### 6.1 Dataset 1: Predicted vs Ground Truth

In [None]:
# Compare predicted and true graphs
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Ground Truth
G_true = nx.DiGraph(true_adj1)
pos_true = nx.spring_layout(G_true, seed=42, k=2, iterations=50)

nx.draw_networkx_nodes(G_true, pos_true, node_color='lightgreen', 
                      node_size=1200, ax=axes[0])
nx.draw_networkx_edges(G_true, pos_true, edge_color='darkgreen', 
                      arrows=True, arrowsize=20, arrowstyle='->', 
                      connectionstyle='arc3,rad=0.1', width=2, ax=axes[0])
nx.draw_networkx_labels(G_true, pos_true, 
                       {i: feature_names1[i] for i in range(10)},
                       font_size=11, font_weight='bold', ax=axes[0])
axes[0].set_title(f'Ground Truth Causal Graph\n({int(true_adj1.sum())} edges)', 
                 fontsize=14, fontweight='bold')
axes[0].axis('off')

# Predicted (CLEANN)
G_pred = nx.DiGraph(learned_adj1)
pos_pred = nx.spring_layout(G_pred, seed=42, k=2, iterations=50)

nx.draw_networkx_nodes(G_pred, pos_pred, node_color='lightcoral', 
                      node_size=1200, ax=axes[1])
nx.draw_networkx_edges(G_pred, pos_pred, edge_color='darkred', 
                      arrows=True, arrowsize=20, arrowstyle='->', 
                      connectionstyle='arc3,rad=0.1', width=2, ax=axes[1])
nx.draw_networkx_labels(G_pred, pos_pred, 
                       {i: feature_names1[i] for i in range(10)},
                       font_size=11, font_weight='bold', ax=axes[1])
axes[1].set_title(f'Learned Causal Graph (CLEANN)\n({int(learned_adj1.sum())} edges)', 
                 fontsize=14, fontweight='bold')
axes[1].axis('off')

plt.suptitle('Dataset 1: Linear-Gaussian SCM', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

### Evaluation Metrics

In [None]:
# Import metrics calculation from CircuitPFN
from attn_scm.metrics import compute_graph_metrics

# Compute evaluation metrics
metrics1 = compute_graph_metrics(learned_adj1, true_adj1)

print("="*60)
print("Dataset 1 - Evaluation Metrics")
print("="*60)
print(f"Structural Hamming Distance (SHD): {metrics1['shd']}")
print(f"F1 Score (directed):               {metrics1['f1_directed']:.3f}")
print(f"Precision:                         {metrics1['precision']:.3f}")
print(f"Recall:                            {metrics1['recall']:.3f}")
print(f"True Positive Edges:               {metrics1['tp']}")
print(f"False Positive Edges:              {metrics1['fp']}")
print(f"False Negative Edges:              {metrics1['fn']}")
print("="*60)

# Visualize metrics
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart of metrics
metrics_data = [metrics1['precision'], metrics1['recall'], metrics1['f1_directed']]
axes[0].bar(['Precision', 'Recall', 'F1 Score'], metrics_data, 
           color=['steelblue', 'coral', 'seagreen'])
axes[0].set_ylim([0, 1])
axes[0].set_ylabel('Score')
axes[0].set_title('Performance Metrics', fontweight='bold')
axes[0].grid(axis='y', alpha=0.3)

for i, v in enumerate(metrics_data):
    axes[0].text(i, v + 0.02, f'{v:.3f}', ha='center', fontweight='bold')

# Confusion-style visualization
edge_data = [metrics1['tp'], metrics1['fp'], metrics1['fn']]
axes[1].bar(['True\nPositives', 'False\nPositives', 'False\nNegatives'], 
           edge_data, color=['green', 'orange', 'red'])
axes[1].set_ylabel('Count')
axes[1].set_title('Edge Prediction Breakdown', fontweight='bold')
axes[1].grid(axis='y', alpha=0.3)

for i, v in enumerate(edge_data):
    axes[1].text(i, v + 0.5, str(v), ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

### 6.2 Dataset 2: Non-linear SCM

In [None]:
# Plot Dataset 2 graphs
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Ground Truth
G_true2 = nx.DiGraph(true_adj2)
pos_true2 = nx.spring_layout(G_true2, seed=123, k=2)
nx.draw(G_true2, pos_true2, with_labels=True, 
        labels={i: feature_names2[i] for i in range(8)},
        node_color='lightgreen', node_size=1200, 
        edge_color='darkgreen', width=2, arrows=True, 
        arrowsize=20, font_weight='bold', ax=axes[0])
axes[0].set_title(f'Ground Truth\n({int(true_adj2.sum())} edges)', 
                 fontsize=14, fontweight='bold')

# Predicted (CLEANN)
G_pred2 = nx.DiGraph(learned_adj2)
pos_pred2 = nx.spring_layout(G_pred2, seed=123, k=2)
nx.draw(G_pred2, pos_pred2, with_labels=True,
        labels={i: feature_names2[i] for i in range(8)},
        node_color='lightcoral', node_size=1200,
        edge_color='darkred', width=2, arrows=True,
        arrowsize=20, font_weight='bold', ax=axes[1])
axes[1].set_title(f'Learned (CLEANN)\n({int(learned_adj2.sum())} edges)', 
                 fontsize=14, fontweight='bold')

plt.suptitle('Dataset 2: Non-linear SCM', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# Metrics
metrics2 = compute_graph_metrics(learned_adj2, true_adj2)
print(f"\nDataset 2 Metrics: SHD={metrics2['shd']}, F1={metrics2['f1_directed']:.3f}, "
      f"Precision={metrics2['precision']:.3f}, Recall={metrics2['recall']:.3f}")

### 6.3 Dataset 3: Real-world (Diabetes)

In [None]:
# Plot real-world dataset graph
fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# Network graph
G_pred3 = nx.DiGraph(learned_adj3)
pos_pred3 = nx.spring_layout(G_pred3, seed=42, k=3, iterations=50)

nx.draw_networkx_nodes(G_pred3, pos_pred3, node_color='skyblue', 
                      node_size=1500, ax=axes[0])
nx.draw_networkx_edges(G_pred3, pos_pred3, edge_color='gray', 
                      arrows=True, arrowsize=15, arrowstyle='->', 
                      connectionstyle='arc3,rad=0.1', width=1.5, ax=axes[0])
nx.draw_networkx_labels(G_pred3, pos_pred3, 
                       {i: feature_names3[i] for i in range(len(feature_names3))},
                       font_size=9, font_weight='bold', ax=axes[0])
axes[0].set_title(f'Discovered Causal Graph (CLEANN)\n({int(learned_adj3.sum())} edges)', 
                 fontsize=14, fontweight='bold')
axes[0].axis('off')

# Adjacency heatmap
sns.heatmap(learned_adj3, cmap='Blues', square=True, ax=axes[1],
           xticklabels=feature_names3, yticklabels=feature_names3,
           cbar_kws={'label': 'Edge Presence'})
axes[1].set_title('Adjacency Matrix', fontsize=14, fontweight='bold')
axes[1].set_xlabel('To Feature')
axes[1].set_ylabel('From Feature')

plt.suptitle('Dataset 3: Pima Indians Diabetes (Real-world)', 
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"\nGraph statistics:")
print(f"  Nodes: {len(feature_names3)}")
print(f"  Edges: {int(learned_adj3.sum())}")
print(f"  Sparsity: {1 - learned_adj3.sum() / (len(feature_names3)**2):.2%}")
print(f"  Avg in-degree: {learned_adj3.sum(axis=0).mean():.2f}")
print(f"  Avg out-degree: {learned_adj3.sum(axis=1).mean():.2f}")

## 7. Conditional Independence Test Statistics

CLEANN uses conditional independence tests to determine graph structure. Let's examine the test statistics.

In [None]:
# Get CI test statistics from CLEANN
print("Conditional Independence Test Statistics (Dataset 1):")
print("="*60)
print(f"Total CI tests performed: {cleann_model1.ci_test.count}")
print(f"P-value threshold: {cleann_model1.ci_test.threshold}")
print(f"Cache enabled: {cleann_model1.ci_test.use_cache}")
print("="*60)

# Visualize p-value threshold impact
p_vals = [0.001, 0.01, 0.05, 0.1, 0.2]
edge_counts = []

print("\nExploring different p-value thresholds:")
for p_val in p_vals:
    cleann_temp = CLEANN(
        attention_matrix=aggregated_attention1,
        num_samples=X1.shape[0],
        p_val_th=p_val,
        search_minimal=True
    )
    temp_graph = cleann_temp.learn_graph()
    temp_adj = cleann_temp.get_adjacency_matrix()
    edge_count = int(temp_adj.sum())
    edge_counts.append(edge_count)
    print(f"  p-value = {p_val:.3f} -> {edge_count} edges")

# Plot p-value sensitivity
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(p_vals, edge_counts, marker='o', linewidth=2, markersize=10, 
        color='steelblue', label='Learned edges')
ax.axhline(y=int(true_adj1.sum()), color='green', linestyle='--', 
          linewidth=2, label='Ground truth edges')
ax.set_xlabel('P-value Threshold', fontsize=12, fontweight='bold')
ax.set_ylabel('Number of Edges', fontsize=12, fontweight='bold')
ax.set_title('Impact of P-value Threshold on Graph Sparsity', 
            fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## Summary

In this notebook, we demonstrated:

1. ✅ **Loaded pretrained TabPFN classifier** - A transformer-based foundation model for tabular data

2. ✅ **Loaded three tabular datasets**:
   - Synthetic Linear-Gaussian SCM with known ground truth
   - Synthetic Non-linear SCM with complex mechanisms
   - Real-world Diabetes dataset from OpenML

3. ✅ **Extracted attention patterns** - Aggregated attention matrices from TabPFN's transformer layers

4. ✅ **Applied CLEANN algorithm** - Learned causal graphs by:
   - Converting attention to correlation matrices
   - Using conditional independence tests (partial correlation)
   - Learning graph structure via ICD (Iterative Conditional independence Discovery) algorithm

5. ✅ **Generated minimal explanations** - Identified explanation sets for target features using PDS trees

6. ✅ **Visualized learned causal graphs** - Including:
   - Network visualizations comparing predicted vs ground truth
   - Adjacency matrix heatmaps
   - Evaluation metrics (SHD, F1, Precision, Recall)
   - Explanation set subgraphs

7. ✅ **Analyzed CI test statistics** - Explored impact of p-value threshold on graph sparsity

### Key Insights

- **Attention-based causal discovery**: CLEANN interprets attention patterns as causal relationships
- **Minimal explanations**: PDS trees enable efficient search for minimal explanation sets
- **Statistical rigor**: Conditional independence testing provides formal guarantees
- **Real-world applicability**: Can discover plausible causal structures in real datasets

### Comparison: CLEANN vs Attn-SCM

| Aspect | CLEANN | Attn-SCM |
|--------|--------|----------|
| **Approach** | Correlation + CI tests | Direct attention aggregation |
| **Algorithm** | ICD structure learning | Entropy-based head selection |
| **Output** | Causal graph + explanations | Causal graph only |
| **Statistical foundation** | Partial correlation tests | Thresholding + directionality |
| **Interpretability** | High (minimal explanations) | Medium (attention patterns) |

### References

- **CLEANN**: [Accelerated Trojan Shield for Embedded Neural Networks](https://arxiv.org/abs/2310.20307)
- **Intel Labs Causality Lab**: https://github.com/IntelLabs/causality-lab
- **TabPFN**: Hollmann et al. (2023) - "TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second"

---

**Repository**: https://github.com/arberzela/CircuitPFN
