# Final Report — Brain Tumor XAI (Improved)
**Author:** Ajaychary Kandukuri

This notebook reproduces evaluation and visualizes explainability outputs (Grad-CAM, Integrated Gradients) for the Brain Tumor XAI project.

**Notes**
- Run cells from top to bottom.
- Paths assume your repo root contains: `data/processed`, `checkpoints/best_model.pth`, `outputs/xai_demo_from_train` and `outputs/streamlit_xai`.
- If the notebook is inside `reports/`, the first code cell will adjust the working directory to project root.


In [1]:
# Environment & imports — robust to cwd / import problems
import sys
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms, models, datasets
from torch.utils.data import DataLoader

# Detect project root and set cwd so relative paths work whether notebook sits in /reports or repo root
NB_PATH = Path.cwd()
if (NB_PATH / "scripts").exists() and (NB_PATH / "data").exists():
    PROJECT_ROOT = NB_PATH
elif (NB_PATH / ".." / "scripts").resolve().exists():
    PROJECT_ROOT = (NB_PATH / "..").resolve()
else:
    # fallback: assume current working dir is project root
    PROJECT_ROOT = NB_PATH

os.chdir(PROJECT_ROOT)
print("Working directory set to project root:", PROJECT_ROOT)

# Paths used throughout the notebook
DATA_ROOT = PROJECT_ROOT / "data" / "processed"
CHECKPOINT = PROJECT_ROOT / "checkpoints" / "best_model.pth"
OUT_XAI_TRAIN = PROJECT_ROOT / "outputs" / "xai_demo_from_train"
OUT_XAI_STREAM = PROJECT_ROOT / "outputs" / "streamlit_xai"

print("Paths:")
print(" DATA_ROOT:", DATA_ROOT)
print(" CHECKPOINT:", CHECKPOINT)
print(" OUT_XAI_TRAIN:", OUT_XAI_TRAIN)
print(" OUT_XAI_STREAM:", OUT_XAI_STREAM)

# Ensure matplotlib inline display
%matplotlib inline


Working directory set to project root: C:\Users\ajayc\Documents\xai_brain_tumor
Paths:
 DATA_ROOT: C:\Users\ajayc\Documents\xai_brain_tumor\data\processed
 CHECKPOINT: C:\Users\ajayc\Documents\xai_brain_tumor\checkpoints\best_model.pth
 OUT_XAI_TRAIN: C:\Users\ajayc\Documents\xai_brain_tumor\outputs\xai_demo_from_train
 OUT_XAI_STREAM: C:\Users\ajayc\Documents\xai_brain_tumor\outputs\streamlit_xai


In [2]:
# Dynamic import of scripts/xai_extended.py so notebook won't fail due to sys.path issues
import importlib.util
XAI_PATH = PROJECT_ROOT / "scripts" / "xai_extended.py"
if not XAI_PATH.exists():
    raise FileNotFoundError(f"Expected xai_extended.py at {XAI_PATH}; please ensure file exists.")

spec = importlib.util.spec_from_file_location("xai_extended", str(XAI_PATH))
xai_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(xai_mod)

# Functions we'll use
load_model = getattr(xai_mod, "load_model", None)
explain_image_with_models = getattr(xai_mod, "explain_image_with_models", None)
print("Loaded explain_image_with_models:", explain_image_with_models is not None)


  from .autonotebook import tqdm as notebook_tqdm


Loaded explain_image_with_models: True


In [None]:
# Dataset counts and sample images
from collections import Counter
import matplotlib.pyplot as plt

def class_counts(folder: Path):
    if not folder.exists():
        return {}
    counts = {}
    for d in sorted([p for p in folder.iterdir() if p.is_dir()]):
        counts[d.name] = len(list(d.glob("*.*")))
    return counts

print("Train counts:", class_counts(DATA_ROOT / "train"))
print("Val counts:  ", class_counts(DATA_ROOT / "val"))
print("Test counts: ", class_counts(DATA_ROOT / "test"))

# Show up to 4 samples from each class in test
fig, axs = plt.subplots(2,4, figsize=(12,6))
for i, cls in enumerate(["yes","no"]):
    folder = DATA_ROOT / "test" / cls
    imgs = list(folder.glob("*.*"))[:4] if folder.exists() else []
    for j in range(4):
        ax = axs[i,j]
        ax.axis("off")
        if j < len(imgs):
            im = Image.open(imgs[j]).convert("RGB")
            ax.imshow(im)
            ax.set_title(f"{cls}/{imgs[j].name}", fontsize=8)
        else:
            ax.set_title("---")
plt.suptitle("Sample test images (yes / no)")
plt.tight_layout()
plt.show()


In [None]:
# Build ResNet-18 skeleton and load checkpoint (tolerant)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 2

model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)
print("Model created. Device:", device)

if CHECKPOINT.exists():
    print("Loading checkpoint:", CHECKPOINT)
    state = torch.load(str(CHECKPOINT), map_location=device)
    if isinstance(state, dict) and "state_dict" in state:
        sd = state["state_dict"]
    else:
        sd = state
    if isinstance(sd, dict):
        sd = {k.replace("module.",""): v for k, v in sd.items()}
    try:
        model.load_state_dict(sd, strict=False)
        print("Checkpoint loaded (len state_dict) =", len(sd) if isinstance(sd, dict) else "unknown")
    except Exception as e:
        print("Warning: loading checkpoint raised:", e)
        model.load_state_dict(sd, strict=False)
else:
    print("No checkpoint found at", CHECKPOINT, "-- evaluation will use untrained skeleton.")


In [None]:
# Prepare test loader and evaluate
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

test_dir = DATA_ROOT / "test"
if not test_dir.exists():
    raise FileNotFoundError(f"Expected test folder at {test_dir}")

test_ds = datasets.ImageFolder(str(test_dir), transform=transform)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=0)

model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        outs = model(imgs)
        preds = outs.argmax(dim=1).cpu().numpy()
        y_pred.extend(preds.tolist())
        y_true.extend(labels.numpy().tolist())

if len(y_true) == 0:
    print("No test samples found.")
else:
    acc = sum(int(a==b) for a,b in zip(y_true,y_pred)) / len(y_true)
    print(f"Test accuracy: {acc:.4f} ({sum(int(a==b) for a,b in zip(y_true,y_pred))}/{len(y_true)})")


In [None]:
# Confusion matrix and classification report
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import pandas as pd
labels = test_ds.classes if hasattr(test_ds, "classes") else [str(i) for i in range(num_classes)]

cm = confusion_matrix(y_true, y_pred, labels=range(len(labels)))
plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt="d", xticklabels=labels, yticklabels=labels, cmap="Blues")
plt.xlabel("Predicted"); plt.ylabel("True"); plt.title("Confusion Matrix")
plt.show()

print(classification_report(y_true, y_pred, target_names=labels, zero_division=0))


In [None]:
# Display saved XAI overlays and optionally re-run explainers on a sample
from IPython.display import display

def show_folder(folder: Path, n=8):
    if not folder.exists():
        print("Folder not found:", folder); return
    files = sorted(list(folder.glob("*.*")))
    print(f"{len(files)} files in {folder}; showing up to {n}")
    for f in files[:n]:
        print("-", f.name)
        display(Image.open(f))

print("Overlays from training run:")
show_folder(OUT_XAI_TRAIN, n=12)

print("\nOverlays from Streamlit run:")
show_folder(OUT_XAI_STREAM, n=12)

# Optionally re-run explainers on a chosen sample (first available sample)
sample = None
yes_list = list((DATA_ROOT/"test"/"yes").glob("*.*")) if (DATA_ROOT/"test"/"yes").exists() else []
no_list  = list((DATA_ROOT/"test"/"no").glob("*.*")) if (DATA_ROOT/"test"/"no").exists() else []
if yes_list:
    sample = str(yes_list[0])
elif no_list:
    sample = str(no_list[0])

if sample:
    print("Re-running explainers on:", sample)
    try:
        res = explain_image_with_models(sample, model=model, use_gradcam=True, use_ig=True)
        if res.get("original_pil") is not None:
            display(res["original_pil"])
        if res.get("gradcam_overlay") is not None:
            display(res["gradcam_overlay"])
        if res.get("ig_overlay") is not None:
            display(res["ig_overlay"])
    except Exception as e:
        print("Explainers failed:", e)
else:
    print("No sample available to re-run explainers.")


In [None]:
# If you saved a CSV training log (epoch, train_loss, val_loss, val_acc), show curves
log_path = PROJECT_ROOT / "outputs" / "training_log.csv"
if log_path.exists():
    import pandas as pd
    df = pd.read_csv(log_path)
    df.plot(x="epoch", y=["train_loss","val_loss"], marker='o', figsize=(8,4)); plt.title("Loss curves"); plt.show()
    df.plot(x="epoch", y="val_acc", marker='o', figsize=(6,3)); plt.title("Validation accuracy"); plt.show()
else:
    print("No training log found at", log_path)


In [None]:
# Quick troubleshooting outputs
import sys
print("Python executable:", sys.executable)
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Working dir:", Path.cwd())
print("scripts/xai_extended.py exists:", (PROJECT_ROOT / "scripts" / "xai_extended.py").exists())
print("checkpoint exists:", CHECKPOINT.exists())
print("data/processed/test exists:", (DATA_ROOT / "test").exists())


## Interpretation, limitations & next steps

**Interpretation**
- Test accuracy (see above) indicates the model's current performance on provided test split.
- Grad-CAM and Integrated Gradients overlays help highlight image regions contributing to the prediction.

**Limitations**
- Dataset size and class balance may bias the model.
- Model trained on 2D slices; volumetric (3D) MRIs require different handling.
- Evaluation uses a single split — cross-validation would provide more robust performance estimates.

**Next steps**
1. Log training metrics (CSV/TensorBoard) and include training curves.
2. Add k-fold cross-validation.
3. Experiment with domain pretraining or medical-image-specific backbones.
4. Add uncertainty estimation and calibration.
5. Prepare README + environment files and push to GitHub.
