# CausalXray Results Visualization

This notebook visualizes experimental results including performance metrics, attribution comparisons, and cross-domain generalization analysis.

In [None]:
# Import libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
from pathlib import Path

# Set visualization style
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (10, 6)
print("Libraries imported")

## Load Experiment Results

In [None]:
# Load results from JSON files
results_dir = Path("./experiment_results")
results = {
    "baseline": json.load(open(results_dir / "baseline_results.json")),
    "causalxray": json.load(open(results_dir / "causalxray_results.json"))
}

print("Results loaded for:", list(results.keys()))

## Performance Comparison

In [None]:
# Create performance comparison dataframe
metrics = ["accuracy", "auc", "sensitivity", "specificity"]
data = []
for model_name, res in results.items():
    for metric in metrics:
        data.append({"Model": model_name, "Metric": metric, "Value": res[metric]})
        
df = pd.DataFrame(data)

# Plot performance comparison
plt.figure(figsize=(12, 6))
sns.barplot(x="Metric", y="Value", hue="Model", data=df, palette="viridis")
plt.title("Model Performance Comparison")
plt.ylabel("Score")
plt.ylim(0.7, 1.0)
plt.show()

## Cross-Domain Generalization

In [None]:
# Load cross-domain results
domain_results = json.load(open(results_dir / "cross_domain_results.json"))
domain_df = pd.DataFrame(domain_results)

# Plot domain shift impact
plt.figure(figsize=(10, 6))
sns.lineplot(
    data=domain_df, 
    x="domain_distance", 
    y="accuracy", 
    hue="model", 
    style="model", 
    markers=True
)
plt.title("Cross-Domain Performance Degradation")
plt.xlabel("Domain Distance (MMD)")
plt.ylabel("Accuracy")
plt.grid(True)
plt.show()

## Attribution Consistency Analysis

In [None]:
# Load attribution consistency metrics
attribution_data = json.load(open(results_dir / "attribution_metrics.json"))
attrib_df = pd.DataFrame(attribution_data)

# Plot attribution consistency
plt.figure(figsize=(10, 6))
sns.boxplot(
    x="method", 
    y="radiologist_agreement", 
    data=attrib_df,
    palette="coolwarm"
)
plt.title("Attribution Method Agreement with Radiologists")
plt.xlabel("Attribution Method")
plt.ylabel("Agreement Score")
plt.ylim(0.5, 1.0)
plt.show()

## Confounder Analysis

In [None]:
# Load confounder analysis results
confounder_data = json.load(open(results_dir / "confounder_analysis.json"))
conf_df = pd.DataFrame(confounder_data)

# Plot confounder impact
plt.figure(figsize=(12, 8))
sns.heatmap(
    conf_df.set_index("confounder"), 
    annot=True, 
    cmap="coolwarm", 
    fmt=".2f",
    linewidths=.5
)
plt.title("Confounder Impact on Model Performance")
plt.xlabel("Performance Metric")
plt.ylabel("Confounder")
plt.show()