# Demo: PneumoniaMNIST CNN + Grad-CAM

This notebook guides you through loading data, training a CNN, evaluating accuracy, and visualizing Grad-CAM heatmaps.
# The dataset is a subset of the [MedMNIST(PneumoniaMNIST)](https://paperswithcode.com/dataset/medmnist) dataset from papers with codes. Check out the [MedMNIST](https://https://medmnist.com) repository for more datasets and models. also [Github MedMNIST](https://github.com/MedMNIST/MedMNIST).
This notebook runs end-to-end on your machine, using the `src/` modules:  
1. Load & inspect data  
2. Define & train a CNN  
3. Plot training metrics  
4. Visualize Grad-CAM  


# Imports and Setup
```python
# Ensure src/ is on path
import sys, os
sys.path.append(os.path.abspath('../src'))

# Core libraries
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# Workshop modules
from data import get_dataloaders
from models import SimpleCNN
from gradcam import apply_gradcam
```

## Check device
```python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Running on", device)

## 1. Load & Inspect Data  
Using `get_dataloaders()` from `src/data.py`.  
```python
batch_size = 64
train_loader, test_loader = get_dataloaders(batch_size=batch_size, data_dir='./data')
```

### Peek one batch
```python
imgs, labels = next(iter(train_loader))
imgs = imgs[:6]; labels = labels[:6].squeeze()
fig, axes = plt.subplots(2,3, figsize=(8,5))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(imgs[i].squeeze(), cmap='gray')
    ax.set_title(f"Label: {labels[i].item()}")
    ax.axis('off')
plt.suptitle("Sample Batch from PneumoniaMNIST")
plt.show()
```

## 2. Define Model
```python
model = SimpleCNN().to(device)
print(model)
```

## 3. Train for 3 epochs  
Track training loss and test accuracy.
```python
# Loss, optimizer, history
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
train_losses, test_accs = [], []

for epoch in range(1, 4):
    model.train()
    total_loss = 0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch}"):
        imgs, labels = imgs.to(device), labels.squeeze().long().to(device)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)

    # Evaluate
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.squeeze().long().to(device)
            preds = model(imgs).argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    acc = correct/total
    test_accs.append(acc)
    print(f"[Epoch {epoch}] Loss: {avg_loss:.4f}, Test Acc: {acc:.4f}")
```

### Training Loss & Test Accuracy  
```python
fig, ax1 = plt.subplots()
ax1.plot(range(1,4), train_losses, marker='o', label='Train Loss')
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.legend(loc='upper left')

ax2 = ax1.twinx()
ax2.plot(range(1,4), test_accs, marker='s', linestyle='--', label='Test Acc')
ax2.set_ylabel('Accuracy'); ax2.legend(loc='upper right')

plt.title("Training Metrics")
plt.show()
```

## 4. Define & Apply Grad-CAM  
Using `apply_gradcam()` from `src/gradcam.py`.
```python
# pick the conv layer to hook: last conv in features
target_layer = model.features[-3]

```

### 5. Display Grad-CAM overlays for first 3 test images  
```python
fig, axes = plt.subplots(2,3, figsize=(10,6))
for i in range(3):
    img, lbl = test_loader.dataset[i]
    tensor = img.unsqueeze(0).to(device)
    cam, pred = apply_gradcam(model, tensor, target_layer)

    gray = img.squeeze().cpu().numpy()
    heatmap = plt.get_cmap('jet')(cam)[...,:3]
    overlay = heatmap*0.4 + gray[...,None]*0.6

    # original
    axes[0,i].imshow(gray, cmap='gray')
    axes[0,i].set_title(f"True: {lbl.item()}")
    axes[0,i].axis('off')
    # overlay
    axes[1,i].imshow(overlay)
    axes[1,i].set_title(f"Pred: {pred}")
    axes[1,i].axis('off')

plt.suptitle("Grad-CAM Results")
plt.tight_layout()
plt.show()
```

## 6. Try It Yourself!  
- Change the learning rate or number of epochs.  
- Hook an earlier conv block (e.g. `model.features[0]`).  
- Increase image size to 128×128 in `get_dataloaders()`.  
- Swap in another MedMNIST dataset (e.g. DermaMNIST).  