# 3D CNN Analysis & Visualization
This notebook allows you to interactively inspect the results of your trained model.
You can visualize predictions, check Grad-CAM heatmaps, and explore the dataset without running terminal scripts.

In [None]:
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# Add src to path so we can import our modules
sys.path.append('../src')

from dataset import ProcessedLunaDataset
from model import Simple3DCNN
from gradcam import GradCAM

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = '../results/model_epoch_20.pth' # Adjust if needed
DATA_DIR = '../data/processed'

print(f"Using device: {DEVICE}")

In [None]:
# Load Dataset
if os.path.exists(DATA_DIR):
    dataset = ProcessedLunaDataset(processed_dir=DATA_DIR, augment=False)
    print(f"Loaded dataset with {len(dataset)} samples")
else:
    print("Data not found! Please run preprocessing first.")

In [None]:
# Load Model
model = Simple3DCNN().to(DEVICE)

if os.path.exists(MODEL_PATH):
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()
    print("Model loaded successfully!")
else:
    print(f"Model not found at {MODEL_PATH}")

In [None]:
def visualize_sample(index=None):
    """Visualizes a sample with Grad-CAM"""
    if index is None:
        index = np.random.randint(len(dataset))
    
    # Get Data
    input_tensor, label = dataset[index]
    input_tensor = input_tensor.unsqueeze(0).to(DEVICE) # Add batch dim
    
    # Prediction
    with torch.no_grad():
        output = model(input_tensor)
        pred_prob = output.item()
    
    # Grad-CAM
    gradcam = GradCAM(model, target_layer_name='conv4')
    cams = gradcam.generate_cam(input_tensor, target_class=1)
    cam = cams[0]
    gradcam.close()
    
    # Plot
    z = input_tensor.shape[2] // 2 # Middle slice
    img_slice = input_tensor.cpu().numpy()[0, 0, z]
    cam_slice = cam[z]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(img_slice, cmap='gray')
    axes[0].set_title(f"Original (True: {int(label.item())})")
    
    axes[1].imshow(cam_slice, cmap='jet')
    axes[1].set_title("Grad-CAM Heatmap")
    
    axes[2].imshow(img_slice, cmap='gray')
    axes[2].imshow(cam_slice, cmap='jet', alpha=0.5)
    axes[2].set_title(f"Overlay (Pred: {pred_prob:.4f})")
    
    plt.show()

# Run Visualization
visualize_sample()