Template structure for making RISE MICCAI tutorials.
We suggest to follow this structure to ensure that all tutorials have a common structure.

# 📓 [Tutorial Title: e.g., "Lung Nodule Detection with PyTorch"]

**Creator:** [Your Name]  
**Date:** [Date of Creation/Last Update]  
**Video Link:** [Link to YouTube Video Tutorial] this will be added later on

---

## 🌟 1. Title & Overview

<!-- 
* **Concise title:** Make it catchy and informative.
* **Short description:** 1-2 sentences summarizing the tutorial's content and goal.
* **Learning objectives:** Use bullet points for clarity. What will the user be able to do after completing this tutorial?
-->

**Title:** `[Insert Concise Title of the Tutorial Here]`

**Description:**
`[Insert a brief (1-2 sentences) description of what this tutorial covers. e.g., "This tutorial walks you through the process of building and training a convolutional neural network (CNN) to detect lung nodules in CT scans using Python, PyTorch, and a publicly available dataset."]`

**Learning Objectives:**
By the end of this tutorial, you will be able to:
* Understand `[Concept 1, e.g., the basics of medical image analysis for nodule detection]`
* Implement `[Technique 1, e.g., a data loading pipeline for DICOM images]`
* Build `[Model/Tool, e.g., a simple CNN model for binary classification]`
* Evaluate `[Skill, e.g., the performance of your model using appropriate metrics]`
* Visualize `[Output, e.g., model predictions on sample images]`

---

## 🚀 2. Introduction

<!--
* **Background:** Provide context. What is the problem you're solving?
* **Importance:** Why is this specific topic/task significant in medical imaging?
* **Real-world use cases:** Give concrete examples of how this is applied.
* Keep this section engaging and motivate the reader.
-->

`[Provide background information on the topic. Explain the problem domain, e.g., "Lung cancer is a leading cause of cancer-related deaths worldwide. Early detection through screening programs, often involving CT scans, is crucial for improving patient outcomes. Identifying suspicious lung nodules in these scans is a key step..."]`

**Why is this important in medical imaging?**
`[Explain the significance. e.g., "Automating or assisting in the detection of lung nodules can help radiologists by reducing their workload, improving detection consistency, and potentially identifying subtle nodules that might be missed..."]`

**Real-world use cases and applications:**
* `[Use Case 1, e.g., Computer-Aided Detection (CAD) systems in hospitals]`
* `[Use Case 2, e.g., Assisting in large-scale screening programs]`
* `[Use Case 3, e.g., Quantitative analysis of nodule characteristics over time]`

---

## 🛠️ 3. Pre-requisites

<!--
* **Libraries:** List all necessary Python libraries. Provide pip install commands.
* **Dataset:** Specify the dataset.
* **Dataset Link & Description:** Provide a direct link and a brief overview of the dataset (source, type of images, annotations, etc.).
* **Download & Preparation:** Step-by-step instructions on how to download and prepare the dataset. This might include unzipping, organizing files, or running a preprocessing script.
-->

**Required Python Libraries:**
Make sure you have the following libraries installed. You can install them using pip:

In [None]:
!pip install numpy pandas matplotlib scikit-learn opencv-python pydicom torch torchvision torcheval tqdm
# !pip install [any-other-specific-library]

**Dataset to Use:**
* **Name:** `[e.g., LUNA16 (Lung Nodule Analysis 16)]`
* **Link:** `[Provide a direct URL to the dataset or its official page]`
* **Description:** `[Briefly describe the dataset: e.g., "The LUNA16 dataset contains chest CT scans from the LIDC-IDRI database, with annotations for lung nodules. It's widely used for developing and evaluating nodule detection algorithms."]`

**How to Download & Prepare the Dataset:**
1.  `[Step 1: e.g., Download the dataset from the provided link. You might need to register.]`
2.  `[Step 2: e.g., Extract the downloaded files to a specific directory, e.g., './data/LUNA16']`
3.  `[Step 3: e.g., Run any provided preprocessing scripts or detail manual steps for organizing files if needed. For example, converting .mhd/.raw to .npy or .png if preferred for easier loading, or creating a CSV file mapping image paths to labels.]`

In [None]:
# Optional: Include a small Python snippet here if there's a simple script for preparation
# import os
# def prepare_dataset(raw_data_path, processed_data_path):
#     # Your preparation code here
#     print("Dataset preparation complete.")
# prepare_dataset('./data/LUNA16_raw', './data/LUNA16_processed')

---

## 💻 4. Hands-On Code Implementation

<!--
* **Structure:** Break this section into logical sub-sections (as listed).
* **Code Blocks:** Use clear, well-commented code blocks.
* **Markdown Explanations:** Precede each code block with a Markdown cell explaining *what* the code does and *why*.
* **Step-by-step:** Guide the user through each part of the implementation.
-->

### 4.1 Importing Necessary Libraries

We'll start by importing all the Python libraries we'll need for this tutorial.

In [None]:
# Standard libraries
import os
import glob
import random
import time

# Data manipulation
import numpy as np
import pandas as pd

# Image processing and medical imaging
import cv2 # OpenCV for image manipulation
import pydicom # For reading DICOM files
# from PIL import Image # If using PIL

# Deep Learning - PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
# from torch.utils.tensorboard import SummaryWriter # For TensorBoard logging

# Scikit-learn for metrics and utilities
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
# For segmentation tasks, you might use:
# from sklearn.metrics import jaccard_score # (Dice is often custom or from other libs)

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Progress bar
from tqdm.notebook import tqdm # Use tqdm.notebook for Jupyter

# Ensure reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### 4.2 Loading the Dataset

Now, let's define how we'll load our medical images and their corresponding labels. For PyTorch, we typically create a custom `Dataset` class.

In [None]:
# Define paths (adjust as necessary)
DATASET_DIR = './data/LUNA16_processed/' # Or wherever you stored it
METADATA_CSV = os.path.join(DATASET_DIR, 'metadata.csv') # Assuming you have/create one

# Example: Custom PyTorch Dataset Class
class MedicalImageDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        """
        Args:
            dataframe (pd.DataFrame): DataFrame with image IDs and labels.
            image_dir (str): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_files = dataframe['image_id'].values
        self.labels = dataframe['label'].values # Or masks for segmentation
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.image_dir, self.image_files[idx])
        
        # Example: Loading a .npy file (adjust for .dcm, .png, etc.)
        image = np.load(img_name).astype(np.float32) 
        # If 2D slices from 3D scan, image might be (H, W) or (H, W, C)
        # If 3D scan, image might be (D, H, W)
        
        # Add channel dimension if it's grayscale and model expects channels (e.g., for ResNet)
        if image.ndim == 2: # (H, W) -> (1, H, W)
            image = np.expand_dims(image, axis=0)
        elif image.ndim == 3 and image.shape[0] != 1 and image.shape[0] != 3 : # (D,H,W) -> (1,D,H,W) for 3D CNN or select slice
             # For 3D CNN, it might be image = np.expand_dims(image, axis=0)
             # Or, if you're taking a middle slice for a 2D CNN from a 3D volume:
             # mid_slice_idx = image.shape[0] // 2
             # image = image[mid_slice_idx, :, :] 
             # image = np.expand_dims(image, axis=0) # Add channel dim
             pass # Adjust based on your data and model

        label = self.labels[idx]
        # Convert label to tensor, ensure correct dtype
        label = torch.tensor(label, dtype=torch.float32) # For BCEWithLogitsLoss
        # label = torch.tensor(label, dtype=torch.long) # For CrossEntropyLoss

        sample = {'image': image, 'label': label}

        if self.transform:
            sample['image'] = self.transform(sample['image'])
            # Note: some transforms expect PIL Image or tensor (H,W,C) or (C,H,W)
            # Ensure your loaded image format matches transform expectations.
            # If image is (1, D, H, W) for 3D, transforms might need adjustment or custom implementation.

        return sample['image'], sample['label'] # Return image and label directly for DataLoader

# Load metadata (assuming you have a CSV with image_id and label)
# This is a placeholder. You'll need to create this based on your dataset.
# Example metadata.csv structure:
# image_id,label
# image_001.npy,1
# image_002.npy,0
try:
    metadata_df = pd.read_csv(METADATA_CSV)
except FileNotFoundError:
    print(f"Error: {METADATA_CSV} not found. Please create it or adjust the path.")
    # Create a dummy dataframe for demonstration if file not found
    metadata_df = pd.DataFrame({
        'image_id': [f'dummy_image_{i:03d}.npy' for i in range(100)],
        'label': [random.randint(0,1) for _ in range(100)]
    })
    # You would need to actually save dummy .npy files for this to run
    # for fname in metadata_df['image_id']:
    #    np.save(os.path.join(DATASET_DIR, fname), np.random.rand(1, 64, 64)) # Example 2D slice

# Split dataset: train, validation, test
train_df, temp_df = train_test_split(metadata_df, test_size=0.3, random_state=42, stratify=metadata_df['label'] if 'label' in metadata_df.columns and len(metadata_df['label'].unique()) > 1 else None)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['label'] if 'label' in temp_df.columns and len(temp_df['label'].unique()) > 1 else None)

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"Test samples: {len(test_df)}")

### 4.3 Preprocessing Steps

Preprocessing is crucial for good model performance. This includes normalization, resizing, and data augmentation.

In [None]:
# Define transforms
# Adjust mean/std if you calculate them from your specific dataset
# These are often ImageNet mean/std, which might not be optimal for medical images
# For grayscale, mean/std would be single values.
# For 3D images, transforms might need to be custom.
# This example assumes 2D images that are already (C, H, W) or can be converted.

# Example for 2D images (e.g. single slices)
# If your images are (H,W) numpy arrays, ToTensor() will convert to (1,H,W) tensors
# If they are (H,W,C) numpy arrays, ToTensor() will convert to (C,H,W) tensors
# If they are PIL images, ToTensor() handles it.

# Note: If your MedicalImageDataset returns numpy arrays, ensure they are in a format
# that ToTensor() and other transforms expect. Typically (H, W, C) or (H, W).
# If your data is already (C,H,W) tensor, you might not need ToTensor().

# For this example, let's assume MedicalImageDataset provides a (1, H, W) numpy array
# and we convert it to tensor within the dataset or here.
# If it's already a tensor, skip ToTensor().

# Let's assume images are loaded as numpy arrays (e.g. (64,64) or (1,64,64))
# and need to be converted to tensors and normalized.
# If images are already (C,H,W) tensors, transforms.Normalize is fine.
# If they are numpy arrays, you'd typically include transforms.ToTensor() first.

# Let's refine the dataset's __getitem__ to output a tensor directly,
# so transforms here are simpler.
# (Revisiting MedicalImageDataset's __getitem__ for clarity with transforms)
# class MedicalImageDataset(...):
#     def __getitem__(self, idx):
#         ...
#         image = np.load(img_name).astype(np.float32) # e.g. (64,64)
#         if image.ndim == 2: # (H, W)
#             image = np.expand_dims(image, axis=0) # (1, H, W)
#         image = torch.from_numpy(image) # Convert to tensor
#         ...
#         if self.transform:
#             image = self.transform(image) # Apply transforms to tensor
#         return image, label

# Then transforms can be:
train_transforms = transforms.Compose([
    # transforms.ToTensor(), # Use if loading PIL images or HWC numpy. Not needed if data is already C,H,W tensor or loaded as such.
    transforms.Resize((128, 128), antialias=True), # Or your target size
    # transforms.RandomHorizontalFlip(), # Example augmentation
    # transforms.RandomRotation(10),   # Example augmentation
    transforms.Normalize(mean=[0.5], std=[0.5]) # For single channel. If 3 channels (e.g. RGB-like): [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
                                                # Calculate these from your dataset!
])

val_test_transforms = transforms.Compose([
    # transforms.ToTensor(),
    transforms.Resize((128, 128), antialias=True),
    transforms.Normalize(mean=[0.5], std=[0.5])
])


# Create Datasets and DataLoaders
# Ensure your MedicalImageDataset is compatible with these transforms.
# Specifically, the output of __getitem__ before transform should be what transforms expect.
# If MedicalImageDataset returns a NumPy array (e.g. (1, H, W)), transforms.ToTensor() is not needed
# if you convert to tensor inside __getitem__. If it returns PIL, ToTensor() is good.

# Let's assume MedicalImageDataset is modified to return a tensor (C,H,W)
# and image_dir points to where .npy files are.
# For dummy data, we need to ensure dummy .npy files exist.
DUMMY_IMG_DIR = './data/dummy_images/'
os.makedirs(DUMMY_IMG_DIR, exist_ok=True)
if not os.listdir(DUMMY_IMG_DIR): # Create dummy files if dir is empty
    print("Creating dummy .npy files for demonstration...")
    for fname in metadata_df['image_id']:
       # Assuming 2D slices, (1, H, W) where H=W=64
       np.save(os.path.join(DUMMY_IMG_DIR, fname), np.random.rand(1, 64, 64).astype(np.float32))

train_dataset = MedicalImageDataset(dataframe=train_df, image_dir=DUMMY_IMG_DIR, transform=train_transforms)
val_dataset = MedicalImageDataset(dataframe=val_df, image_dir=DUMMY_IMG_DIR, transform=val_test_transforms)
test_dataset = MedicalImageDataset(dataframe=test_df, image_dir=DUMMY_IMG_DIR, transform=val_test_transforms)

BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# Visualize a sample batch (optional, but good for verification)
def show_batch(data_loader, n_images=5):
    images, labels = next(iter(data_loader))
    fig, axes = plt.subplots(1, n_images, figsize=(15, 3))
    for i in range(n_images):
        img = images[i].squeeze().cpu().numpy() # Remove channel dim if 1, and move to CPU
        # Denormalize for visualization if needed: img = img * 0.5 + 0.5
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f"Label: {labels[i].item():.0f}")
        axes[i].axis('off')
    plt.show()

print("Sample batch from training loader:")
show_batch(train_loader)

### 4.4 Model Implementation

Here, we'll define a simple Convolutional Neural Network (CNN) for our medical image classification task.

In [None]:
# Example: A simple CNN for 2D image classification
# Input: (Batch, 1, ImageSize, ImageSize), e.g. (32, 1, 128, 128)
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=1): # num_classes=1 for binary classification with sigmoid
        super(SimpleCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1), # Output: (16, 128, 128)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # Output: (16, 64, 64)
            
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1), # Output: (32, 64, 64)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # Output: (32, 32, 32)
            
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), # Output: (64, 32, 32)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)  # Output: (64, 16, 16)
        )
        
        # Calculate the flattened size after conv layers
        # For input (1, 128, 128), after 3 maxpools by 2: 128 / (2*2*2) = 128 / 8 = 16
        self.flattened_size = 64 * 16 * 16 
        
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.flattened_size, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes) # Output layer
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x # Raw logits

# Initialize the model
model = SimpleCNN(num_classes=1).to(device) # num_classes=1 for binary (output will be passed to sigmoid)
print(model)

# Test with a dummy input
dummy_input = torch.randn(2, 1, 128, 128).to(device) # Batch size 2, 1 channel, 128x128
output = model(dummy_input)
print(f"Dummy input shape: {dummy_input.shape}")
print(f"Model output shape: {output.shape}") # Should be (2, 1) for binary classification

### 4.5 Training and Evaluation

We'll now set up the training loop, define our loss function (e.g., Binary Cross-Entropy with Logits for binary classification), an optimizer (e.g., Adam), and a learning rate scheduler.

In [None]:
# Loss function, optimizer, and scheduler
criterion = nn.BCEWithLogitsLoss() # Includes sigmoid, expects raw logits from model
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # Optional

NUM_EPOCHS = 10 # Adjust as needed

def train_one_epoch(model, data_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_labels = []
    all_predictions = []

    for inputs, labels in tqdm(data_loader, desc="Training"):
        inputs, labels = inputs.to(device), labels.to(device).unsqueeze(1) # Ensure labels are [B, 1] for BCEWithLogitsLoss

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        
        # Store predictions and labels for epoch metrics
        preds_proba = torch.sigmoid(outputs).detach().cpu().numpy()
        all_predictions.extend(preds_proba)
        all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(data_loader.dataset)
    all_labels = np.array(all_labels).flatten()
    all_predictions = np.array(all_predictions).flatten()
    
    # Calculate metrics
    # Threshold probabilities for binary classification metrics
    # You might want to tune this threshold based on validation set performance (e.g., using ROC curve)
    threshold = 0.5 
    binary_preds = (all_predictions >= threshold).astype(int)
    
    accuracy = accuracy_score(all_labels, binary_preds)
    precision = precision_score(all_labels, binary_preds, zero_division=0)
    recall = recall_score(all_labels, binary_preds, zero_division=0)
    f1 = f1_score(all_labels, binary_preds, zero_division=0)
    
    # AUC requires probabilities if labels are binary
    try:
        auc = roc_auc_score(all_labels, all_predictions) # Use probabilities for AUC
    except ValueError: # Handles cases with only one class present in labels
        auc = 0.0 if len(np.unique(all_labels)) < 2 else 0.5 # Or handle as appropriate

    metrics = {'loss': epoch_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}
    return metrics

def evaluate_model(model, data_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_labels = []
    all_predictions = [] # Store probabilities for AUC

    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device).unsqueeze(1)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)

            preds_proba = torch.sigmoid(outputs).cpu().numpy()
            all_predictions.extend(preds_proba)
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(data_loader.dataset)
    all_labels = np.array(all_labels).flatten()
    all_predictions = np.array(all_predictions).flatten()

    threshold = 0.5
    binary_preds = (all_predictions >= threshold).astype(int)
    
    accuracy = accuracy_score(all_labels, binary_preds)
    precision = precision_score(all_labels, binary_preds, zero_division=0)
    recall = recall_score(all_labels, binary_preds, zero_division=0)
    f1 = f1_score(all_labels, binary_preds, zero_division=0)
    try:
        auc = roc_auc_score(all_labels, all_predictions)
    except ValueError:
        auc = 0.0 if len(np.unique(all_labels)) < 2 else 0.5


    metrics = {'loss': epoch_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}
    return metrics

# Training loop
history = {'train_loss': [], 'train_acc': [], 'train_auc': [],
           'val_loss': [], 'val_acc': [], 'val_auc': []}

print("Starting training...")
start_time = time.time()

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    
    train_metrics = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_metrics = evaluate_model(model, val_loader, criterion, device)
    
    # if scheduler: scheduler.step()

    print(f"Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['accuracy']:.4f} | Train AUC: {train_metrics['auc']:.4f}")
    print(f"Val Loss: {val_metrics['loss']:.4f}   | Val Acc: {val_metrics['accuracy']:.4f}   | Val AUC: {val_metrics['auc']:.4f}")

    history['train_loss'].append(train_metrics['loss'])
    history['train_acc'].append(train_metrics['accuracy'])
    history['train_auc'].append(train_metrics['auc'])
    history['val_loss'].append(val_metrics['loss'])
    history['val_acc'].append(val_metrics['accuracy'])
    history['val_auc'].append(val_metrics['auc'])

    # Optional: Save best model
    # if val_metrics['auc'] > best_val_auc:
    #     best_val_auc = val_metrics['auc']
    #     torch.save(model.state_dict(), 'best_model.pth')

end_time = time.time()
print(f"\nTraining finished in {(end_time - start_time)/60:.2f} minutes.")

# Optional: Save the final model
# torch.save(model.state_dict(), 'final_model.pth')

### 4.6 Performance Metrics

Let's visualize the training progress and evaluate the model on the test set.

In [None]:
# Plot training history
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['train_auc'], label='Train AUC') # Or 'train_acc'
plt.plot(history['val_auc'], label='Validation AUC') # Or 'val_acc'
plt.title('AUC Over Epochs') # Or 'Accuracy Over Epochs'
plt.xlabel('Epoch')
plt.ylabel('AUC') # Or 'Accuracy'
plt.legend()

plt.tight_layout()
plt.show()

# Evaluate on the test set
print("\nEvaluating on Test Set...")
test_metrics = evaluate_model(model, test_loader, criterion, device)
print(f"Test Loss: {test_metrics['loss']:.4f}")
print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
print(f"Test Precision: {test_metrics['precision']:.4f}")
print(f"Test Recall: {test_metrics['recall']:.4f}")
print(f"Test F1-Score: {test_metrics['f1']:.4f}")
print(f"Test AUC: {test_metrics['auc']:.4f}")

# Confusion Matrix for Test Set
model.eval()
all_labels_test = []
all_preds_test_proba = []
with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Getting Test Preds"):
        inputs, labels = inputs.to(device), labels.to(device).unsqueeze(1)
        outputs = model(inputs)
        preds_proba = torch.sigmoid(outputs).cpu().numpy()
        all_preds_test_proba.extend(preds_proba)
        all_labels_test.extend(labels.cpu().numpy())

all_labels_test = np.array(all_labels_test).flatten()
all_preds_test_proba = np.array(all_preds_test_proba).flatten()
all_preds_test_binary = (all_preds_test_proba >= 0.5).astype(int)

cm = confusion_matrix(all_labels_test, all_preds_test_binary)
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Predicted Negative', 'Predicted Positive'], yticklabels=['Actual Negative', 'Actual Positive'])
plt.title('Confusion Matrix - Test Set')
plt.ylabel('Actual Label')
plt.xlabel('Predicted Label')
plt.show()

### 4.7 Visualization of Results

Visualizing the model's predictions can provide insights into its behavior.

In [None]:
def visualize_predictions(model, data_loader, device, num_images=5, threshold=0.5):
    model.eval()
    images, labels = next(iter(data_loader))
    images, labels = images.to(device), labels.to(device)

    with torch.no_grad():
        outputs = model(images)
        probs = torch.sigmoid(outputs)
        preds = (probs >= threshold).float()

    images = images.cpu()
    labels = labels.cpu()
    preds = preds.cpu()
    probs = probs.cpu()

    fig, axes = plt.subplots(1, num_images, figsize=(15, 4))
    for i in range(min(num_images, len(images))):
        img = images[i].squeeze().numpy() # Assuming single channel
        # Denormalize if needed: img = img * 0.5 + 0.5
        axes[i].imshow(img, cmap='gray')
        true_label = labels[i].item()
        pred_label = preds[i].item()
        pred_prob = probs[i].item()
        axes[i].set_title(f"True: {true_label:.0f}\nPred: {pred_label:.0f} ({pred_prob:.2f})", 
                          color=("green" if true_label == pred_label else "red"))
        axes[i].axis('off')
    plt.suptitle("Sample Model Predictions on Test Set", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
    plt.show()

print("\nVisualizing sample predictions from the test set:")
visualize_predictions(model, test_loader, device, num_images=5)

# For heatmaps/attention maps (e.g., Grad-CAM), you would need additional libraries/implementation.
# This is a placeholder for where such code would go.
# Example (conceptual):
# from pytorch_grad_cam import GradCAM
# from pytorch_grad_cam.utils.image import show_cam_on_image
# target_layers = [model.conv_layers[-1]] # Example: last conv layer
# cam = GradCAM(model=model, target_layers=target_layers)
# input_tensor = ... # A sample image tensor
# grayscale_cam = cam(input_tensor=input_tensor, targets=None) # Or specify target class
# visualization = show_cam_on_image(image_np_array, grayscale_cam[0, :], use_rgb=False)
# plt.imshow(visualization)
# plt.title("Grad-CAM")
# plt.show()

---

## ⚠️ 5. Challenges & Limitations

<!--
* **Common Pitfalls:** Discuss potential issues users might encounter (e.g., overfitting, data imbalance, slow training) and how to address them.
* **Known Limitations:** Be transparent about the limitations of the presented approach/model (e.g., "This simple CNN might not perform well on very complex cases," "Doesn't handle 3D context well").
* **Ethical Considerations:** If relevant (especially in medical AI), discuss biases in data, model fairness, interpretability, and responsible AI use.
-->

**Common Pitfalls and How to Avoid Them:**
* **Overfitting:** `[e.g., The model performs well on training data but poorly on validation/test data. Solutions: More data, data augmentation, regularization (Dropout, L2), early stopping, simpler model.]`
* **Data Imbalance:** `[e.g., If one class is much more frequent. Solutions: Weighted loss functions, oversampling minority class, undersampling majority class, using metrics like F1-score or AUC over accuracy.]`
* **Vanishing/Exploding Gradients:** `[e.g., Especially in deeper networks. Solutions: Proper weight initialization, batch normalization, gradient clipping, using activation functions like ReLU.]`
* **Incorrect Preprocessing:** `[e.g., Normalization values incorrect, images not resized consistently. Solution: Double-check preprocessing steps, visualize augmented data.]`

**Known Limitations of This Approach:**
* `[e.g., The baseline model presented is relatively simple and may not achieve state-of-the-art performance.]`
* `[e.g., This tutorial focuses on 2D slices; a full 3D approach might yield better results for volumetric data but is more computationally intensive.]`
* `[e.g., The dataset size might be limited for training very deep networks from scratch.]`
* `[e.g., Interpretability of the "black-box" CNN model can be challenging, though techniques like Grad-CAM can help.]`

**Ethical Considerations & Biases (if relevant):**
* **Dataset Bias:** `[e.g., If the training data is not diverse (e.g., lacks representation from certain demographics or scanner types), the model may perform poorly on underrepresented groups. Strive for diverse and representative datasets.]`
* **Algorithmic Bias:** `[e.g., The model might inadvertently learn spurious correlations, leading to biased predictions.]`
* **Clinical Validation:** `[e.g., Any AI model for medical use requires rigorous clinical validation before deployment. This tutorial is for educational purposes only.]`
* **Accountability & Transparency:** `[e.g., Who is responsible if the model makes an error? How can decisions be explained to clinicians and patients?]`

---

## 📚 6. Further Reading & References

<!--
* **Research Papers:** Link to seminal papers or recent advancements.
* **Books:** Suggest relevant textbooks.
* **Online Resources:** Blogs, courses, documentation.
* **Related Implementations:** Links to GitHub repos with similar or more advanced models.
-->

**Research Papers:**
* `[e.g., Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation. arXiv:1505.04597 - Link]`
* `[e.g., Setio, A. A. A., et al. (2017). Validation, comparison, and combination of algorithms for automatic detection of pulmonary nodules in CT images: The LUNA16 challenge. Medical Image Analysis. - Link]`

**Books:**
* `[e.g., "Deep Learning" by Ian Goodfellow, Yoshua Bengio, and Aaron Courville]`
* `[e.g., "Medical Imaging with Deep Learning" by [Author Name(s)] - if a specific one exists]`

**Online Resources & Courses:**
* `[e.g., PyTorch Official Tutorials: https://pytorch.org/tutorials/]`
* `[e.g., fast.ai course: https://course.fast.ai/]`
* `[e.g., Stanford CS231n: Convolutional Neural Networks for Visual Recognition: http://cs231n.stanford.edu/]`
* `[e.g., Grand Challenge (for datasets and challenges): https://grand-challenge.org/]`

**Related Open-Source Implementations:**
* `[e.g., Link to a relevant GitHub repository for U-Net implementations]`
* `[e.g., Link to MONAI (Medical Open Network for AI) library: https://monai.io/]`

---

**Congratulations on completing the tutorial!** We hope you found it informative and useful.
If you have any questions or feedback, please feel free to reach out or leave a comment on the YouTube video.