## 1. Setup and Imports

In [None]:
import json
import os
import sys
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

# Add parent directory to path
sys.path.insert(0, '..')

from src.models.basic_cbm import ConceptBottleneckModel

print("✓ Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Load Sample Data

We have 3 sample cases:
- Case 1: Basal Cell Carcinoma (benign)
- Case 577: Melanoma (malignant)
- Case 578: Melanoma (malignant)

Each case includes:
- Clinical photograph
- Dermoscopic image
- 7-point checklist annotations

In [None]:
# Load metadata
with open('sample_data_derm7pt/cases_metadata.json', 'r') as f:
    cases = json.load(f)

print(f"Loaded {len(cases)} sample cases:\n")

for case in cases:
    print(f"Case {case['case_num']}: {case['diagnosis']}")
    print(f"  Location: {case['metadata']['location']}")
    print(f"  7-point score: {case['clinical_features']['seven_point']}")
    print()

## 3. Visualize Sample Cases

Let's look at the dermoscopic images and their concept annotations.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, case in enumerate(cases):
    # Load dermoscopic image
    img_path = f"sample_data_derm7pt/{case['dermoscopic_img']}"
    img = Image.open(img_path)
    
    axes[idx].imshow(img)
    axes[idx].set_title(f"Case {case['case_num']}\n{case['diagnosis']}", fontsize=10)
    axes[idx].axis('off')

plt.tight_layout()
plt.savefig('outputs/sample_cases.png', dpi=150, bbox_inches='tight')
plt.show()

print("Sample cases visualized")

## 4. Understanding the 7-Point Checklist

The 7-point checklist is a standardized method for melanoma diagnosis:

1. **Atypical Pigment Network** - Irregular brown lines
2. **Blue-Whitish Veil** - Blue-white coloration
3. **Atypical Vascular Pattern** - Irregular blood vessels
4. **Irregular Streaks** - Linear structures at edges
5. **Irregular Pigmentation** - Uneven coloring
6. **Irregular Dots/Globules** - Round structures
7. **Regression Structures** - White/blue areas from healing

Each feature can be present (1) or absent (0), creating interpretable concepts!

In [None]:
# Extract concepts from case 578 (the melanoma with highest score)
case_578 = [c for c in cases if c['case_num'] == 578][0]

concept_names = [
    "Atypical Pigment Network",
    "Blue-Whitish Veil",
    "Atypical Vascular",
    "Irregular Streaks",
    "Irregular Pigmentation",
    "Irregular Dots/Globules",
    "Regression Structures"
]

features = case_578['clinical_features']
concepts = [
    1 if features['pigment_network'] == 'atypical' else 0,
    1 if features['blue_whitish_veil'] == 'present' else 0,
    1 if features['vascular_structures'] in ['atypical', 'arborizing'] else 0,
    1 if features['streaks'] == 'irregular' else 0,
    1 if features['pigmentation'] in ['irregular', 'localized in'] else 0,
    1 if features['dots_and_globules'] == 'irregular' else 0,
    1 if features['regression'] in ['combination', 'white areas'] else 0,
]

# Visualize
plt.figure(figsize=(10, 6))
colors = ['red' if c == 1 else 'lightgray' for c in concepts]
plt.barh(range(7), concepts, color=colors)
plt.yticks(range(7), concept_names)
plt.xlabel('Present (1) or Absent (0)')
plt.title(f'Case {case_578["case_num"]} - {case_578["diagnosis"]}\n7-Point Checklist Annotations')
plt.xlim([0, 1.2])
for i, v in enumerate(concepts):
    plt.text(v + 0.05, i, 'Present' if v == 1 else 'Absent', va='center')
plt.tight_layout()
plt.savefig('outputs/case_578_concepts.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Case 578 has {sum(concepts)}/7 features present → High risk for melanoma")

## 5. Initialize a Concept Bottleneck Model

Let's create a CBM with:
- **ResNet50** backbone for feature extraction
- **7 concept predictors** (one per checklist item)
- **Linear task predictor** for diagnosis (interpretable!)

**Note**: This model is untrained, so predictions will be random. 
For a trained model, run `python examples/train_basic_cbm.py --data_path /home/xrai/datasets/derm7pt/release_v0`

In [None]:
# Initialize model
model = ConceptBottleneckModel(
    num_concepts=7,
    num_classes=2,
    backbone='resnet50',
    pretrained=True  # Use pretrained ImageNet weights
)

model.eval()

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
concept_params = sum(p.numel() for p in model.concept_encoder.parameters())
task_params = sum(p.numel() for p in model.task_predictor.parameters())

print("Model Architecture:")
print(f"  Total parameters: {total_params:,}")
print(f"  Concept encoder: {concept_params:,}")
print(f"  Task predictor: {task_params:,}")
print(f"\nModel structure:")
print(f"  Input: [batch, 3, 224, 224] image")
print(f"  → Concept Encoder → [batch, 7] concepts")
print(f"  → Task Predictor → [batch, 2] logits (benign vs malignant)")

## 6. Run Inference on Sample Cases

Let's see what the model predicts (even though it's untrained).

In [None]:
# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

# Process all cases
results = []

for case in cases:
    # Load image
    img_path = f"sample_data_derm7pt/{case['dermoscopic_img']}"
    img = Image.open(img_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0)
    
    # Run inference
    with torch.no_grad():
        pred_concepts, pred_logits = model(img_tensor)
    
    pred_probs = torch.softmax(pred_logits, dim=1)
    pred_label = pred_probs[0, 1].item()  # Probability of malignant
    
    results.append({
        'case': case,
        'concepts': pred_concepts[0].tolist(),
        'diagnosis_prob': pred_label
    })
    
    print(f"Case {case['case_num']}: {case['diagnosis']}")
    print(f"  Model prediction: {'Malignant' if pred_label > 0.5 else 'Benign'} ({pred_label:.3f})")
    print(f"  Concept predictions: {[f'{c:.3f}' for c in pred_concepts[0].tolist()][:3]}...")
    print()

print("✓ Inference complete")

## 7. Visualize Concept Predictions

Let's compare the model's concept predictions with the ground truth.

In [None]:
# Visualize one case in detail
case_idx = 2  # Case 578 (melanoma with all features)
result = results[case_idx]
case = result['case']

# Extract ground truth concepts
features = case['clinical_features']
gt_concepts = [
    1 if features['pigment_network'] == 'atypical' else 0,
    1 if features['blue_whitish_veil'] == 'present' else 0,
    1 if features['vascular_structures'] in ['atypical', 'arborizing'] else 0,
    1 if features['streaks'] == 'irregular' else 0,
    1 if features['pigmentation'] in ['irregular', 'localized in'] else 0,
    1 if features['dots_and_globules'] == 'irregular' else 0,
    1 if features['regression'] in ['combination', 'white areas'] else 0,
]

pred_concepts = result['concepts']

# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Ground truth
colors_gt = ['red' if c == 1 else 'lightgray' for c in gt_concepts]
axes[0].barh(range(7), gt_concepts, color=colors_gt)
axes[0].set_yticks(range(7))
axes[0].set_yticklabels(concept_names, fontsize=10)
axes[0].set_xlabel('Value')
axes[0].set_title(f'Ground Truth Concepts\nCase {case["case_num"]}: {case["diagnosis"]}')
axes[0].set_xlim([0, 1.2])

# Predictions
colors_pred = ['green' if c > 0.5 else 'lightgray' for c in pred_concepts]
axes[1].barh(range(7), pred_concepts, color=colors_pred)
axes[1].set_yticks(range(7))
axes[1].set_yticklabels(concept_names, fontsize=10)
axes[1].set_xlabel('Probability')
axes[1].set_title(f'Model Predictions (Untrained)\nDiagnosis: {"Malignant" if result["diagnosis_prob"] > 0.5 else "Benign"} ({result["diagnosis_prob"]:.3f})')
axes[1].set_xlim([0, 1.2])

plt.tight_layout()
plt.savefig(f'outputs/case_{case["case_num"]}_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Comparison saved for case {case['case_num']}")

## 8. Understanding Concept Intervention

The key advantage of CBMs: we can **intervene on concepts**!

If the model gets a concept wrong, we can:
1. Correct it manually
2. Re-run the task predictor with corrected concepts
3. Get a better diagnosis

Let's demonstrate this:

In [None]:
# Take case 578 (melanoma)
case = cases[2]
img_path = f"sample_data_derm7pt/{case['dermoscopic_img']}"
img = Image.open(img_path).convert('RGB')
img_tensor = transform(img).unsqueeze(0)

# Original prediction
with torch.no_grad():
    original_concepts, original_logits = model(img_tensor)
    original_prob = torch.softmax(original_logits, dim=1)[0, 1].item()

print("BEFORE INTERVENTION:")
print(f"  Predicted concepts: {[f'{c:.3f}' for c in original_concepts[0].tolist()]}")
print(f"  Diagnosis probability (malignant): {original_prob:.3f}")
print(f"  Prediction: {'Malignant' if original_prob > 0.5 else 'Benign'}")
print()

# Simulate intervention: set all concepts to ground truth
corrected_concepts = original_concepts.clone()
gt_concepts_tensor = torch.tensor([[1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0]], dtype=torch.float32)
corrected_concepts = gt_concepts_tensor

# Re-predict with corrected concepts
with torch.no_grad():
    corrected_logits = model.predict_from_concepts(corrected_concepts)
    corrected_prob = torch.softmax(corrected_logits, dim=1)[0, 1].item()

print("AFTER INTERVENTION (corrected all concepts):")
print(f"  Corrected concepts: {[f'{c:.3f}' for c in corrected_concepts[0].tolist()]}")
print(f"  Diagnosis probability (malignant): {corrected_prob:.3f}")
print(f"  Prediction: {'Malignant' if corrected_prob > 0.5 else 'Benign'}")
print()

print(f"Change in diagnosis probability: {corrected_prob - original_prob:+.3f}")
print("\n✓ This demonstrates how human expertise can improve predictions!")

## 9. Inspect Task Predictor Weights

Since we use a linear task predictor, we can see exactly how much each concept contributes to the diagnosis!

In [None]:
# Get the linear layer weights
task_weights = model.task_predictor.weight.detach().numpy()

# Weight for malignant class (index 1)
malignant_weights = task_weights[1, :]

# Visualize
plt.figure(figsize=(10, 6))
colors = ['red' if w > 0 else 'blue' for w in malignant_weights]
plt.barh(range(7), malignant_weights, color=colors)
plt.yticks(range(7), concept_names)
plt.xlabel('Weight')
plt.title('Task Predictor Weights for Malignant Class\n(Positive = increases malignant probability)')
plt.axvline(x=0, color='black', linestyle='--', linewidth=0.5)
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig('outputs/task_predictor_weights.png', dpi=150, bbox_inches='tight')
plt.show()

print("Task Predictor Analysis:")
print(f"  Most important concept: {concept_names[np.argmax(np.abs(malignant_weights))]}")
print(f"  Weight: {malignant_weights[np.argmax(np.abs(malignant_weights))]:.3f}")
print("\n  All weights:")
for name, weight in zip(concept_names, malignant_weights):
    print(f"    {name:.<40} {weight:+.3f}")

## 10. Summary and Next Steps

**What we learned:**

1. ✅ **CBM Structure**: Image → Concepts → Diagnosis
2. ✅ **7-Point Checklist**: Interpretable clinical concepts
3. ✅ **Concept Predictions**: Binary features that doctors understand
4. ✅ **Intervention**: Can correct wrong concepts to improve diagnosis
5. ✅ **Interpretable Weights**: Linear predictor shows exact contribution

**Next Steps:**

1. **Train a full model**: Run `python examples/train_basic_cbm.py --data_path /home/xrai/datasets/derm7pt/release_v0`
2. **Evaluate on test set**: See how well concepts align with expert annotations
3. **Information theory analysis**: Measure concept completeness and synergy
4. **Compare architectures**: Try different backbones (EfficientNet, Vision Transformers)

**Key Insight**: CBMs trade ~5% accuracy for full interpretability and intervention capability!

In [None]:
# Create output directory summary
os.makedirs('outputs', exist_ok=True)
print("✅ Demo complete!")
print("\nGenerated files:")
print("  - outputs/sample_cases.png")
print("  - outputs/case_578_concepts.png")
print(f"  - outputs/case_{cases[2]['case_num']}_comparison.png")
print("  - outputs/task_predictor_weights.png")
print("\nTo train a full model, run:")
print("  python examples/train_basic_cbm.py --data_path /home/xrai/datasets/derm7pt/release_v0 --epochs 50")