# TabPFN Causal Discovery with Attn-SCM

This notebook demonstrates:
1. Loading a pretrained TabPFN classifier model
2. Loading small tabular datasets for causal graph discovery
3. Applying the **Attn-SCM** algorithm to extract causal graphs from TabPFN attention patterns
4. Visualizing attention weights across layers and heads
5. Plotting the extracted causal graph as a DAG

**Paper**: Zero-Shot Causal Graph Extraction from Tabular Foundation Models via Attention Map Decoding

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yourusername/CircuitPFN/blob/main/notebooks/colab_tabpfn_causal_discovery.ipynb)

## Setup & Installation

First, let's install all required dependencies.

In [None]:
# Install dependencies
!pip install torch numpy scipy scikit-learn pandas matplotlib seaborn networkx
!pip install tabpfn
!pip install openml

# Clone the CircuitPFN repository (if running in Colab)
import os
if not os.path.exists('CircuitPFN'):
    !git clone https://github.com/arberzela/CircuitPFN.git
    %cd CircuitPFN
else:
    print("Repository already exists")

Cloning into 'CircuitPFN'...
remote: Repository not found.
fatal: repository 'https://github.com/yourusername/CircuitPFN.git/' not found
[Errno 2] No such file or directory: 'CircuitPFN'
/Users/zelaa/Projects/CircuitPFN


## Imports

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from tabpfn import TabPFNClassifier

# Import from the repository
import sys
sys.path.append('.')

from attn_scm.core import AttnSCM
from attn_scm.attention import AttentionExtractor
from utils.data_generation import generate_synthetic_dataset
from utils.visualization import plot_adjacency_matrix, plot_graph_network
from attn_scm.metrics import compute_graph_metrics

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

# Set random seed for reproducibility
np.random.seed(42)

print("✓ All imports successful!")

✓ All imports successful!


## 1. Load Pretrained TabPFN Classifier

TabPFN (Tabular Prior-Data Fitted Network) is a transformer-based model pre-trained on synthetic tabular data.
It performs in-context learning and can make predictions on new datasets without additional training.

In [None]:
# Initialize pretrained TabPFN classifier
tabpfn_model = TabPFNClassifier(device='cpu', N_ensemble_configurations=4)

print("✓ TabPFN model loaded successfully!")
print(f"  Device: {tabpfn_model.device}")
print(f"  Max samples: 1000")
print(f"  Max features: 100")
print("\nTabPFN uses pre-trained weights from meta-learning on synthetic datasets.")

## 2. Load Small Tabular Datasets for Causal Discovery

We'll create three datasets:
- **Dataset 1**: Synthetic data with known causal structure (Linear-Gaussian SCM)
- **Dataset 2**: Synthetic data with non-linear relationships
- **Dataset 3**: Real-world dataset from OpenML

### Dataset 1: Synthetic Linear-Gaussian SCM

In [None]:
# Generate synthetic dataset with known causal graph
X1, y1, true_adj1 = generate_synthetic_dataset(
    n_nodes=10,              # 10 features
    n_samples=500,           # 500 samples
    edge_prob=0.3,           # 30% edge probability
    scm_type='linear_gaussian',  # Linear Gaussian SCM
    seed=42
)

feature_names1 = [f'X{i}' for i in range(10)]

print("Dataset 1: Synthetic Linear-Gaussian SCM")
print(f"  Shape: {X1.shape}")
print(f"  Features: {X1.shape[1]}")
print(f"  Samples: {X1.shape[0]}")
print(f"  True edges: {int(true_adj1.sum())}")
print(f"  Graph density: {true_adj1.sum() / (10*10):.2%}")

# Visualize ground truth
fig, ax = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap
sns.heatmap(true_adj1, cmap='Blues', square=True, ax=ax[0],
            xticklabels=feature_names1, yticklabels=feature_names1,
            cbar_kws={'label': 'Edge'})
ax[0].set_title('Ground Truth Causal Graph (Adjacency Matrix)', fontweight='bold')
ax[0].set_xlabel('Target')
ax[0].set_ylabel('Source')

# Network graph
G1 = nx.DiGraph(true_adj1)
pos1 = nx.spring_layout(G1, seed=42, k=2)
nx.draw_networkx_nodes(G1, pos1, node_color='lightblue', node_size=800, ax=ax[1])
nx.draw_networkx_edges(G1, pos1, edge_color='gray', arrows=True, 
                       arrowsize=15, arrowstyle='->', ax=ax[1])
nx.draw_networkx_labels(G1, pos1, {i: feature_names1[i] for i in range(10)}, 
                       font_size=10, font_weight='bold', ax=ax[1])
ax[1].set_title('Ground Truth Causal DAG', fontweight='bold')
ax[1].axis('off')

plt.tight_layout()
plt.show()

### Dataset 2: Synthetic Non-linear SCM

In [None]:
# Generate non-linear dataset
X2, y2, true_adj2 = generate_synthetic_dataset(
    n_nodes=8,
    n_samples=400,
    edge_prob=0.35,
    scm_type='nonlinear_anm',  # Non-linear Additive Noise Model
    seed=123
)

feature_names2 = [f'F{i}' for i in range(8)]

print("Dataset 2: Synthetic Non-linear SCM")
print(f"  Shape: {X2.shape}")
print(f"  True edges: {int(true_adj2.sum())}")
print(f"  Mechanism: Polynomial + Sigmoid functions")

### Dataset 3: Real-world Dataset (OpenML)

In [None]:
import openml
from sklearn.preprocessing import LabelEncoder

# Load a small real-world dataset from OpenML
# Using diabetes dataset (37) which is small and well-known
dataset = openml.datasets.get_dataset(37)  # Pima Indians Diabetes
X3_full, y3_full, categorical, feature_names3 = dataset.get_data(
    dataset_format='array',
    target=dataset.default_target_attribute
)

# Encode labels
le = LabelEncoder()
y3_full = le.fit_transform(y3_full)

# Limit to TabPFN constraints (max 1000 samples)
X3 = X3_full[:500]
y3 = y3_full[:500]

print("Dataset 3: Real-world (Pima Indians Diabetes)")
print(f"  Shape: {X3.shape}")
print(f"  Features: {list(feature_names3)}")
print(f"  Classes: {np.unique(y3)}")
print(f"  Note: Ground truth causal graph is unknown for real-world data")

## 3. Apply Attn-SCM Algorithm to TabPFN

The **Attn-SCM** (Attention-based Structural Causal Model) algorithm works in 4 steps:

1. **Extract attention maps** from TabPFN's transformer layers
2. **Identify structural heads** using entropy-based filtering
3. **Aggregate attention** into a raw adjacency matrix
4. **Post-process** with thresholding and directionality enforcement

### 3.1 Apply to Dataset 1 (Linear-Gaussian)

In [None]:
# Initialize Attn-SCM model
attnscm_model1 = AttnSCM(
    top_k_heads=5,                    # Select top 5 structural heads
    threshold_method='otsu',          # Automatic threshold selection (Otsu's method)
    directionality_method='asymmetry', # Enforce directed edges via asymmetry
    head_scoring_method='entropy',    # Score heads by entropy
    device='cpu'
)

print("Applying Attn-SCM to Dataset 1...")
print("="*60)

# Extract causal graph (this fits TabPFN and extracts attention)
pred_adj1 = attnscm_model1.fit(X1, y1, feature_names=feature_names1)

print("\n" + "="*60)
print("✓ Causal graph extraction complete!")
print(f"\nPredicted edges: {int((pred_adj1 > 0).sum())}")
print(f"True edges: {int(true_adj1.sum())}")

### 3.2 Apply to Dataset 2 (Non-linear)

In [None]:
attnscm_model2 = AttnSCM(
    top_k_heads=5,
    threshold_method='otsu',
    directionality_method='asymmetry',
    device='cpu'
)

print("Applying Attn-SCM to Dataset 2...")
pred_adj2 = attnscm_model2.fit(X2, y2, feature_names=feature_names2)
print(f"✓ Complete! Predicted edges: {int((pred_adj2 > 0).sum())}")

### 3.3 Apply to Dataset 3 (Real-world)

In [None]:
attnscm_model3 = AttnSCM(
    top_k_heads=5,
    threshold_method='otsu',
    directionality_method='asymmetry',
    device='cpu'
)

print("Applying Attn-SCM to Dataset 3 (Diabetes)...")
pred_adj3 = attnscm_model3.fit(X3, y3, feature_names=list(feature_names3))
print(f"✓ Complete! Predicted edges: {int((pred_adj3 > 0).sum())}")

## 4. Visualize Attention Weights

Let's visualize the attention patterns extracted from TabPFN at different layers and heads.

In [None]:
# Visualize attention maps from selected structural heads
print("Selected Structural Heads:")
print(attnscm_model1.structural_heads_)

# Plot attention from top 4 structural heads
n_heads_to_plot = min(4, len(attnscm_model1.structural_heads_))
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.ravel()

for idx, (layer_idx, head_idx) in enumerate(attnscm_model1.structural_heads_[:n_heads_to_plot]):
    # Get attention for this layer and head
    layer_attn = attnscm_model1.attention_dict_[layer_idx]
    
    # layer_attn shape: (heads, d, d)
    if head_idx < layer_attn.shape[0]:
        head_attn = layer_attn[head_idx]  # (d, d)
        
        # Plot
        sns.heatmap(head_attn, cmap='viridis', square=True, ax=axes[idx],
                   xticklabels=feature_names1, yticklabels=feature_names1,
                   cbar_kws={'label': 'Attention Weight'})
        axes[idx].set_title(f'Layer {layer_idx}, Head {head_idx}\n(Structural Head #{idx+1})', 
                          fontweight='bold')
        axes[idx].set_xlabel('To Feature')
        axes[idx].set_ylabel('From Feature')

plt.suptitle('Attention Weights from Top Structural Heads (Dataset 1)', 
             fontsize=16, fontweight='bold', y=1.00)
plt.tight_layout()
plt.show()

### Raw Aggregated Attention (before post-processing)

In [None]:
# Visualize raw aggregated attention (before thresholding)
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Raw attention
sns.heatmap(attnscm_model1.adjacency_raw_, cmap='YlOrRd', square=True, ax=axes[0],
           xticklabels=feature_names1, yticklabels=feature_names1,
           cbar_kws={'label': 'Attention Score'})
axes[0].set_title('Raw Aggregated Attention\n(Before Thresholding)', fontweight='bold')
axes[0].set_xlabel('To')
axes[0].set_ylabel('From')

# After thresholding
sns.heatmap(pred_adj1, cmap='Blues', square=True, ax=axes[1],
           xticklabels=feature_names1, yticklabels=feature_names1,
           cbar_kws={'label': 'Edge Presence'})
axes[1].set_title('After Thresholding & Directionality\n(Final Prediction)', fontweight='bold')
axes[1].set_xlabel('To')
axes[1].set_ylabel('From')

# Ground truth
sns.heatmap(true_adj1, cmap='Greens', square=True, ax=axes[2],
           xticklabels=feature_names1, yticklabels=feature_names1,
           cbar_kws={'label': 'Edge Presence'})
axes[2].set_title('Ground Truth', fontweight='bold')
axes[2].set_xlabel('To')
axes[2].set_ylabel('From')

plt.tight_layout()
plt.show()

### Layer-wise Attention Comparison

In [None]:
# Compare attention patterns across different layers
layer_indices = sorted(attnscm_model1.attention_dict_.keys())[:3]

fig, axes = plt.subplots(1, len(layer_indices), figsize=(6*len(layer_indices), 5))

for idx, layer_idx in enumerate(layer_indices):
    # Average attention across all heads in this layer
    layer_attn = attnscm_model1.attention_dict_[layer_idx]
    avg_attn = layer_attn.mean(axis=0)  # Average over heads
    
    ax = axes[idx] if len(layer_indices) > 1 else axes
    sns.heatmap(avg_attn, cmap='RdYlBu_r', square=True, ax=ax,
               xticklabels=feature_names1, yticklabels=feature_names1,
               cbar_kws={'label': 'Avg Attention'})
    ax.set_title(f'Layer {layer_idx}\n(Averaged over all heads)', fontweight='bold')
    ax.set_xlabel('To Feature')
    ax.set_ylabel('From Feature')

plt.suptitle('Layer-wise Attention Patterns', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## 5. Plot Extracted Causal Graph DAGs

Now let's visualize the extracted causal graphs as network diagrams.

### 5.1 Dataset 1: Predicted vs Ground Truth

In [None]:
# Compare predicted and true graphs
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Ground Truth
G_true = nx.DiGraph((true_adj1 > 0).astype(int))
pos_true = nx.spring_layout(G_true, seed=42, k=2, iterations=50)

nx.draw_networkx_nodes(G_true, pos_true, node_color='lightgreen', 
                      node_size=1200, ax=axes[0])
nx.draw_networkx_edges(G_true, pos_true, edge_color='darkgreen', 
                      arrows=True, arrowsize=20, arrowstyle='->', 
                      connectionstyle='arc3,rad=0.1', width=2, ax=axes[0])
nx.draw_networkx_labels(G_true, pos_true, 
                       {i: feature_names1[i] for i in range(10)},
                       font_size=11, font_weight='bold', ax=axes[0])
axes[0].set_title(f'Ground Truth Causal Graph\n({int(true_adj1.sum())} edges)', 
                 fontsize=14, fontweight='bold')
axes[0].axis('off')

# Predicted
G_pred = nx.DiGraph((pred_adj1 > 0).astype(int))
pos_pred = nx.spring_layout(G_pred, seed=42, k=2, iterations=50)

nx.draw_networkx_nodes(G_pred, pos_pred, node_color='lightcoral', 
                      node_size=1200, ax=axes[1])
nx.draw_networkx_edges(G_pred, pos_pred, edge_color='darkred', 
                      arrows=True, arrowsize=20, arrowstyle='->', 
                      connectionstyle='arc3,rad=0.1', width=2, ax=axes[1])
nx.draw_networkx_labels(G_pred, pos_pred, 
                       {i: feature_names1[i] for i in range(10)},
                       font_size=11, font_weight='bold', ax=axes[1])
axes[1].set_title(f'Predicted Causal Graph (Attn-SCM)\n({int((pred_adj1 > 0).sum())} edges)', 
                 fontsize=14, fontweight='bold')
axes[1].axis('off')

plt.suptitle('Dataset 1: Linear-Gaussian SCM', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

### Evaluation Metrics

In [None]:
# Compute evaluation metrics
metrics1 = compute_graph_metrics(pred_adj1, true_adj1)

print("="*60)
print("Dataset 1 - Evaluation Metrics")
print("="*60)
print(f"Structural Hamming Distance (SHD): {metrics1['shd']}")
print(f"F1 Score (directed):               {metrics1['f1_directed']:.3f}")
print(f"Precision:                         {metrics1['precision']:.3f}")
print(f"Recall:                            {metrics1['recall']:.3f}")
print(f"True Positive Edges:               {metrics1['tp']}")
print(f"False Positive Edges:              {metrics1['fp']}")
print(f"False Negative Edges:              {metrics1['fn']}")
print("="*60)

# Visualize metrics
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Confusion matrix style visualization
confusion_data = np.array([[metrics1['tp'], metrics1['fp']], 
                           [metrics1['fn'], 0]])
sns.heatmap(confusion_data, annot=True, fmt='d', cmap='Blues', ax=axes[0],
           xticklabels=['Predicted Edge', 'Predicted No Edge'],
           yticklabels=['True Edge', 'True No Edge'])
axes[0].set_title('Edge Prediction Confusion', fontweight='bold')

# Bar chart of metrics
metrics_data = [metrics1['precision'], metrics1['recall'], metrics1['f1_directed']]
axes[1].bar(['Precision', 'Recall', 'F1 Score'], metrics_data, 
           color=['steelblue', 'coral', 'seagreen'])
axes[1].set_ylim([0, 1])
axes[1].set_ylabel('Score')
axes[1].set_title('Performance Metrics', fontweight='bold')
axes[1].grid(axis='y', alpha=0.3)

for i, v in enumerate(metrics_data):
    axes[1].text(i, v + 0.02, f'{v:.3f}', ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

### 5.2 Dataset 2: Non-linear SCM

In [None]:
# Plot Dataset 2 graphs
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Ground Truth
G_true2 = nx.DiGraph((true_adj2 > 0).astype(int))
pos_true2 = nx.spring_layout(G_true2, seed=123, k=2)
nx.draw(G_true2, pos_true2, with_labels=True, 
        labels={i: feature_names2[i] for i in range(8)},
        node_color='lightgreen', node_size=1200, 
        edge_color='darkgreen', width=2, arrows=True, 
        arrowsize=20, font_weight='bold', ax=axes[0])
axes[0].set_title(f'Ground Truth\n({int(true_adj2.sum())} edges)', 
                 fontsize=14, fontweight='bold')

# Predicted
G_pred2 = nx.DiGraph((pred_adj2 > 0).astype(int))
pos_pred2 = nx.spring_layout(G_pred2, seed=123, k=2)
nx.draw(G_pred2, pos_pred2, with_labels=True,
        labels={i: feature_names2[i] for i in range(8)},
        node_color='lightcoral', node_size=1200,
        edge_color='darkred', width=2, arrows=True,
        arrowsize=20, font_weight='bold', ax=axes[1])
axes[1].set_title(f'Predicted (Attn-SCM)\n({int((pred_adj2 > 0).sum())} edges)', 
                 fontsize=14, fontweight='bold')

plt.suptitle('Dataset 2: Non-linear SCM', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# Metrics
metrics2 = compute_graph_metrics(pred_adj2, true_adj2)
print(f"\nDataset 2 Metrics: SHD={metrics2['shd']}, F1={metrics2['f1_directed']:.3f}, "
      f"Precision={metrics2['precision']:.3f}, Recall={metrics2['recall']:.3f}")

### 5.3 Dataset 3: Real-world (Diabetes)

In [None]:
# Plot real-world dataset graph
fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# Network graph
G_pred3 = nx.DiGraph((pred_adj3 > 0).astype(int))
pos_pred3 = nx.spring_layout(G_pred3, seed=42, k=3, iterations=50)

nx.draw_networkx_nodes(G_pred3, pos_pred3, node_color='skyblue', 
                      node_size=1500, ax=axes[0])
nx.draw_networkx_edges(G_pred3, pos_pred3, edge_color='gray', 
                      arrows=True, arrowsize=15, arrowstyle='->', 
                      connectionstyle='arc3,rad=0.1', width=1.5, ax=axes[0])
nx.draw_networkx_labels(G_pred3, pos_pred3, 
                       {i: feature_names3[i] for i in range(len(feature_names3))},
                       font_size=9, font_weight='bold', ax=axes[0])
axes[0].set_title(f'Discovered Causal Graph\n({int((pred_adj3 > 0).sum())} edges)', 
                 fontsize=14, fontweight='bold')
axes[0].axis('off')

# Adjacency heatmap
sns.heatmap(pred_adj3, cmap='Blues', square=True, ax=axes[1],
           xticklabels=feature_names3, yticklabels=feature_names3,
           cbar_kws={'label': 'Edge Weight'})
axes[1].set_title('Adjacency Matrix', fontsize=14, fontweight='bold')
axes[1].set_xlabel('To Feature')
axes[1].set_ylabel('From Feature')

plt.suptitle('Dataset 3: Pima Indians Diabetes (Real-world)', 
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"\nGraph statistics:")
print(f"  Nodes: {len(feature_names3)}")
print(f"  Edges: {int((pred_adj3 > 0).sum())}")
print(f"  Sparsity: {1 - (pred_adj3 > 0).sum() / (len(feature_names3)**2):.2%}")
print(f"  Avg in-degree: {(pred_adj3 > 0).sum(axis=0).mean():.2f}")
print(f"  Avg out-degree: {(pred_adj3 > 0).sum(axis=1).mean():.2f}")

### Markov Blanket Analysis

The Markov Blanket of a variable contains its parents, children, and children's other parents.
This is useful for feature selection.

In [None]:
# Get Markov Blanket for each feature in Dataset 3
print("Markov Blankets for Real-world Dataset Features:")
print("="*60)

for i, feature_name in enumerate(feature_names3):
    mb = attnscm_model3.get_markov_blanket(i)
    mb_names = [feature_names3[j] for j in mb]
    print(f"{feature_name:20s} -> MB size: {len(mb):2d} | {mb_names}")

print("="*60)

# Visualize Markov Blanket for 'glucose' feature
glucose_idx = list(feature_names3).index('glucose')
mb_glucose = attnscm_model3.get_markov_blanket(glucose_idx)

# Create subgraph
nodes_to_show = [glucose_idx] + mb_glucose
subgraph_adj = pred_adj3[np.ix_(nodes_to_show, nodes_to_show)]
subgraph_names = [feature_names3[i] for i in nodes_to_show]

fig, ax = plt.subplots(figsize=(10, 8))
G_sub = nx.DiGraph((subgraph_adj > 0).astype(int))
pos_sub = nx.spring_layout(G_sub, seed=42, k=2)

# Color nodes: target in orange, MB in lightblue
node_colors = ['orange'] + ['lightblue'] * len(mb_glucose)

nx.draw_networkx_nodes(G_sub, pos_sub, node_color=node_colors, 
                      node_size=1500, ax=ax)
nx.draw_networkx_edges(G_sub, pos_sub, edge_color='gray', 
                      arrows=True, arrowsize=20, arrowstyle='->', 
                      width=2, ax=ax)
nx.draw_networkx_labels(G_sub, pos_sub, 
                       {i: subgraph_names[i] for i in range(len(subgraph_names))},
                       font_size=11, font_weight='bold', ax=ax)

ax.set_title(f"Markov Blanket of 'glucose'\n(Orange = Target, Blue = MB)", 
            fontsize=14, fontweight='bold')
ax.axis('off')
plt.tight_layout()
plt.show()

## Summary

In this notebook, we demonstrated:

1. ✅ **Loaded pretrained TabPFN classifier** - A transformer-based foundation model for tabular data

2. ✅ **Loaded three tabular datasets**:
   - Synthetic Linear-Gaussian SCM with known ground truth
   - Synthetic Non-linear SCM with complex mechanisms
   - Real-world Diabetes dataset from OpenML

3. ✅ **Applied Attn-SCM algorithm** - Extracted causal graphs by:
   - Extracting attention maps from TabPFN's transformer layers
   - Identifying structural heads using entropy-based scoring
   - Aggregating attention into adjacency matrices
   - Post-processing with thresholding and directionality

4. ✅ **Visualized attention weights** - Showed:
   - Individual attention heads (layer + head combinations)
   - Raw aggregated attention before post-processing
   - Layer-wise attention patterns

5. ✅ **Plotted causal graph DAGs** - Including:
   - Network visualizations comparing predicted vs ground truth
   - Adjacency matrix heatmaps
   - Evaluation metrics (SHD, F1, Precision, Recall)
   - Markov Blanket analysis for feature selection

### Key Insights

- **Zero-shot causal discovery**: Attn-SCM extracts causal graphs without training on the target dataset
- **Interpretable attention**: Structural heads capture meaningful causal relationships
- **Fast inference**: Orders of magnitude faster than traditional constraint-based methods
- **Real-world applicability**: Can discover plausible causal structures in real datasets

### References

- **TabPFN**: Hollmann et al. (2023) - "TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second"
- **Attn-SCM**: Zero-Shot Causal Graph Extraction from Tabular Foundation Models via Attention Map Decoding

---

**Repository**: https://github.com/arberzela/CircuitPFN  
**Paper**: [Link to paper]
