# Causality Lab Integration - Quick Start

This notebook demonstrates how to use Intel Labs causality-lab algorithms with TabPFN on tabular data.

We'll cover:
1. Basic causal discovery with causality-lab algorithms
2. Comparing different methods
3. TabPFN-enhanced causal discovery
4. Applying to real-world datasets

In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from attn_scm.core import AttnSCM
from attn_scm.metrics import compute_graph_metrics
from baselines import CAUSALITY_LAB_AVAILABLE, run_rai, run_fci, run_icd
from utils import generate_synthetic_dataset, visualize_graph

print(f"Causality Lab available: {CAUSALITY_LAB_AVAILABLE}")

# Set random seed
np.random.seed(42)

## 1. Generate Synthetic Data with Known Causal Structure

In [None]:
# Generate synthetic dataset with linear Gaussian SCM
X, y, true_adj = generate_synthetic_dataset(
    n_nodes=10,
    n_samples=500,
    edge_prob=0.3,
    scm_type='linear_gaussian',
    seed=42
)

print(f"Data shape: {X.shape}")
print(f"Number of true edges: {true_adj.sum()}")

# Visualize ground truth
plt.figure(figsize=(8, 6))
sns.heatmap(true_adj, cmap='RdBu_r', center=0, square=True, 
            xticklabels=[f'X{i}' for i in range(10)],
            yticklabels=[f'X{i}' for i in range(10)],
            cbar_kws={'label': 'Edge'})
plt.title('Ground Truth Causal Graph')
plt.xlabel('Target')
plt.ylabel('Source')
plt.tight_layout()
plt.show()

## 2. Causal Discovery with Different Methods

### 2.1 AttnSCM (Attention-based)

In [None]:
# Run AttnSCM
model_attnscm = AttnSCM(top_k_heads=5, threshold_method='otsu', device='cpu')
adj_attnscm = model_attnscm.fit(X, y)

# Compute metrics
metrics_attnscm = compute_graph_metrics(adj_attnscm, true_adj)

print("AttnSCM Results:")
print(f"  SHD: {metrics_attnscm['shd']}")
print(f"  F1 (directed): {metrics_attnscm['f1_directed']:.3f}")
print(f"  Precision: {metrics_attnscm['precision']:.3f}")
print(f"  Recall: {metrics_attnscm['recall']:.3f}")
print(f"  Predicted edges: {adj_attnscm.sum()}")

### 2.2 RAI (Recursive Autonomy Identification)

In [None]:
if CAUSALITY_LAB_AVAILABLE:
    # Run RAI
    feature_names = [f'X{i}' for i in range(10)]
    adj_rai = run_rai(X, alpha=0.05, feature_names=feature_names)
    
    # Compute metrics
    metrics_rai = compute_graph_metrics(adj_rai, true_adj)
    
    print("RAI Results:")
    print(f"  SHD: {metrics_rai['shd']}")
    print(f"  F1 (directed): {metrics_rai['f1_directed']:.3f}")
    print(f"  Precision: {metrics_rai['precision']:.3f}")
    print(f"  Recall: {metrics_rai['recall']:.3f}")
    print(f"  Predicted edges: {adj_rai.sum()}")
else:
    print("Causality Lab not available. Skipping RAI.")

### 2.3 FCI (Fast Causal Inference)

In [None]:
if CAUSALITY_LAB_AVAILABLE:
    # Run FCI (handles latent confounders)
    adj_fci = run_fci(X, alpha=0.05, feature_names=feature_names)
    
    # Compute metrics
    metrics_fci = compute_graph_metrics(adj_fci, true_adj)
    
    print("FCI Results:")
    print(f"  SHD: {metrics_fci['shd']}")
    print(f"  F1 (directed): {metrics_fci['f1_directed']:.3f}")
    print(f"  Precision: {metrics_fci['precision']:.3f}")
    print(f"  Recall: {metrics_fci['recall']:.3f}")
    print(f"  Predicted edges: {adj_fci.sum()}")
else:
    print("Causality Lab not available. Skipping FCI.")

### 2.4 ICD (Iterative Causal Discovery)

In [None]:
if CAUSALITY_LAB_AVAILABLE:
    # Run ICD
    adj_icd = run_icd(X, alpha=0.05, feature_names=feature_names)
    
    # Compute metrics
    metrics_icd = compute_graph_metrics(adj_icd, true_adj)
    
    print("ICD Results:")
    print(f"  SHD: {metrics_icd['shd']}")
    print(f"  F1 (directed): {metrics_icd['f1_directed']:.3f}")
    print(f"  Precision: {metrics_icd['precision']:.3f}")
    print(f"  Recall: {metrics_icd['recall']:.3f}")
    print(f"  Predicted edges: {adj_icd.sum()}")
else:
    print("Causality Lab not available. Skipping ICD.")

## 3. Visual Comparison

In [None]:
# Visualize all methods
methods = ['Ground Truth', 'AttnSCM']
adjacencies = [true_adj, adj_attnscm]

if CAUSALITY_LAB_AVAILABLE:
    methods.extend(['RAI', 'FCI', 'ICD'])
    adjacencies.extend([adj_rai, adj_fci, adj_icd])

fig, axes = plt.subplots(1, len(methods), figsize=(5*len(methods), 4))

for ax, method, adj in zip(axes, methods, adjacencies):
    sns.heatmap(adj, cmap='RdBu_r', center=0, square=True, ax=ax,
                xticklabels=[f'X{i}' for i in range(10)],
                yticklabels=[f'X{i}' for i in range(10)],
                cbar_kws={'label': 'Edge'})
    ax.set_title(method)
    ax.set_xlabel('Target')
    ax.set_ylabel('Source')

plt.tight_layout()
plt.show()

## 4. Performance Comparison

In [None]:
# Create comparison DataFrame
import pandas as pd

results = []
results.append({'Method': 'AttnSCM', **metrics_attnscm})

if CAUSALITY_LAB_AVAILABLE:
    results.append({'Method': 'RAI', **metrics_rai})
    results.append({'Method': 'FCI', **metrics_fci})
    results.append({'Method': 'ICD', **metrics_icd})

df_results = pd.DataFrame(results)
df_results = df_results[['Method', 'shd', 'f1_directed', 'precision', 'recall', 
                         'n_predicted_edges', 'n_true_edges']]

print("\nPerformance Comparison:")
print(df_results.to_string(index=False))

# Visualize F1 scores
plt.figure(figsize=(8, 5))
plt.bar(df_results['Method'], df_results['f1_directed'])
plt.ylabel('F1 Score (Directed)')
plt.title('Causal Discovery Performance Comparison')
plt.ylim(0, 1)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

## 5. TabPFN-Enhanced Causal Discovery

Now let's see how TabPFN's attention patterns can enhance traditional causal discovery.

In [None]:
if CAUSALITY_LAB_AVAILABLE:
    from baselines.tabpfn_causality_adapter import CondIndepAttentionWeighted
    from causality_lab.learn_structure import LearnStructRAI
    from causality_lab.data import Dataset
    
    # Extract attention matrix from TabPFN
    attention_matrix = model_attnscm.get_adjacency_matrix(binarize=False)
    
    # Create attention-weighted conditional independence test
    dataset = Dataset(X, var_names=feature_names)
    cond_indep_test = CondIndepAttentionWeighted(
        dataset,
        threshold=0.05,
        attention_weights=attention_matrix
    )
    
    # Run RAI with attention weighting
    nodes_set = set(feature_names)
    rai_learner = LearnStructRAI(nodes_set, cond_indep_test)
    rai_learner.learn_structure()
    
    # Convert to adjacency
    n_features = len(feature_names)
    adj_rai_attention = np.zeros((n_features, n_features), dtype=int)
    name_to_idx = {name: idx for idx, name in enumerate(feature_names)}
    
    for edge in rai_learner.graph.edges:
        if hasattr(edge, 'source') and hasattr(edge, 'target'):
            source_idx = name_to_idx.get(edge.source.name)
            target_idx = name_to_idx.get(edge.target.name)
            if source_idx is not None and target_idx is not None:
                adj_rai_attention[source_idx, target_idx] = 1
    
    # Compute metrics
    metrics_rai_attention = compute_graph_metrics(adj_rai_attention, true_adj)
    
    print("RAI with Attention Weighting Results:")
    print(f"  SHD: {metrics_rai_attention['shd']}")
    print(f"  F1 (directed): {metrics_rai_attention['f1_directed']:.3f}")
    print(f"  Precision: {metrics_rai_attention['precision']:.3f}")
    print(f"  Recall: {metrics_rai_attention['recall']:.3f}")
    
    # Compare with standard RAI
    improvement = metrics_rai_attention['f1_directed'] - metrics_rai['f1_directed']
    print(f"\nImprovement over standard RAI: {improvement:+.3f}")
else:
    print("Causality Lab not available. Skipping TabPFN-enhanced discovery.")

## 6. Real-World Dataset Example

Let's apply these methods to a real-world dataset from OpenML.

In [None]:
import openml
from sklearn.preprocessing import LabelEncoder

# Load diabetes dataset
dataset = openml.datasets.get_dataset(37)  # Pima Indians Diabetes
X_real, y_real, categorical, feature_names_real = dataset.get_data(
    dataset_format='array',
    target=dataset.default_target_attribute
)

# Encode labels
le = LabelEncoder()
y_real = le.fit_transform(y_real)

# Limit to TabPFN constraints
X_real = X_real[:500]
y_real = y_real[:500]

print(f"Dataset shape: {X_real.shape}")
print(f"Features: {list(feature_names_real)}")

In [None]:
# Run AttnSCM on real data
model_real = AttnSCM(top_k_heads=5, threshold_method='otsu', device='cpu')
adj_real_attnscm = model_real.fit(X_real, y_real)

print(f"Discovered edges: {adj_real_attnscm.sum()}")
print(f"Graph sparsity: {1 - adj_real_attnscm.sum() / (X_real.shape[1]**2):.3f}")

# Visualize discovered graph
plt.figure(figsize=(10, 8))
sns.heatmap(adj_real_attnscm, cmap='RdBu_r', center=0, square=True,
            xticklabels=feature_names_real,
            yticklabels=feature_names_real,
            cbar_kws={'label': 'Edge'})
plt.title('Discovered Causal Graph: Diabetes Dataset (AttnSCM)')
plt.xlabel('Target')
plt.ylabel('Source')
plt.tight_layout()
plt.show()

In [None]:
if CAUSALITY_LAB_AVAILABLE:
    # Compare with RAI
    adj_real_rai = run_rai(X_real, alpha=0.05, feature_names=list(feature_names_real))
    
    print(f"\nRAI discovered edges: {adj_real_rai.sum()}")
    
    # Compare edge overlap
    common_edges = np.sum((adj_real_attnscm == 1) & (adj_real_rai == 1))
    print(f"Common edges between AttnSCM and RAI: {common_edges}")
    print(f"Edge agreement: {np.mean(adj_real_attnscm == adj_real_rai):.3f}")

## Summary

This notebook demonstrated:

1. **Multiple causal discovery methods**: AttnSCM, RAI, FCI, ICD
2. **Performance comparison** on synthetic data with known ground truth
3. **TabPFN-enhanced discovery** using attention-weighted conditional independence tests
4. **Real-world application** to OpenML datasets

Key takeaways:
- Different methods have different assumptions and strengths
- TabPFN's attention patterns can enhance traditional causal discovery
- Real-world causal discovery requires validation through domain knowledge or predictive utility

For more examples, see:
- `experiments/exp_a_causality_lab.py` - Full synthetic benchmark
- `experiments/exp_e_tabpfn_enhanced_causality.py` - TabPFN enhancement experiments
- `experiments/exp_f_real_datasets.py` - Real-world dataset evaluation