# CausalXray Demo Notebook

This notebook demonstrates the end-to-end pipeline of the CausalXray framework, including data download, model loading, inference, and visualization of causal attributions.

In [1]:
# Import necessary libraries
import torch
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import json

from causalxray import CausalXray
from causalxray.data import CausalTransforms
from causalxray.utils import AttributionVisualizer

print('Libraries imported successfully')

ModuleNotFoundError: No module named 'causalxray'

## Download Sample Data

In [None]:
# Download small sample subset of NIH CXR14 dataset
!kaggle datasets download -d nih-chest-xrays/sample -p ./data/sample_nih --unzip

data_dir = Path('./data/sample_nih')
print(f'Data directory set to: {data_dir}')

## Load Pretrained Model

In [None]:
# Load pretrained CausalXray model
model_path = './models/causalxray_pretrained.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, _ = CausalXray.load_checkpoint(model_path, device)
model.to(device)
model.eval()
print('Model loaded and set to evaluation mode')

## Prepare Sample Image and Transform

In [None]:
# Load and preprocess sample image
sample_image_path = list(data_dir.glob('*.png'))[0]
image = Image.open(sample_image_path)
transforms = CausalTransforms(mode='test')
image_tensor = transforms(image).unsqueeze(0).to(device)
print(f'Sample image loaded: {sample_image_path.name}')

## Run Inference and Generate Causal Attribution

In [None]:
# Run inference with causal attribution
with torch.no_grad():
    outputs = model.predict(image_tensor, return_probabilities=True, return_attributions=True)

predicted_class = outputs['predicted_class'].item()
probabilities = outputs['probabilities'].cpu().numpy()
attributions = outputs.get('attributions', {})

print(f'Predicted class: {predicted_class}')
print(f'Probabilities: Normal={probabilities[0][0]:.4f}, Pneumonia={probabilities[0][1]:.4f}')

## Visualize Attribution Maps

In [None]:
# Visualize causal attribution heatmaps
visualizer = AttributionVisualizer()
fig = visualizer.visualize_attribution_comparison(
    np.array(image),
    {k: v.cpu().numpy()[0] for k, v in attributions.items()},
    prediction={
        'class': predicted_class,
        'probability': probabilities[0][predicted_class]
    }
)
plt.show()