# üñ•Ô∏è Local GPU (24 GB) + 5.5 GB Dataset Optimized Version

---

## ‚ö° This notebook is optimized for:
- **Local GPU:** 24 GB VRAM (RTX 3090/4090, A5000, etc.)
- **Dataset Size:** ~5.5 GB (streaming-safe loading)
- **Memory Safety:** No full dataset caching, mini-batch processing

## üîß Key Optimizations:
| Feature | Setting | Reason |
|---------|---------|--------|
| RBM Batch Size | 32 | Larger batches for CDBN, uses ~8-12 GB |
| FC-RBM Batch Size | 64 | Lighter memory footprint |
| Classifier Batch Size | 128 | Most memory-efficient stage |
| Feature Caching | Disk-backed | Avoids RAM overflow for 5.5GB dataset |
| NUM_WORKERS | 4 | Parallel data loading |
| PERSISTENT_WORKERS | True | Reduces DataLoader overhead |

## ‚ö†Ô∏è Critical Differences from Kaggle Version:
1. **No full latent caching** ‚Äî Features extracted on-the-fly or saved to disk
2. **Periodic GPU cache clearing** ‚Äî Prevents fragmentation
3. **Stage-specific batch sizes** ‚Äî Optimized per training phase
4. **Resume capability** ‚Äî Checkpoints saved after each epoch

---

# Convolutional Deep Belief Network (CDBN) for Eye OCT Image Classification

This notebook implements a CDBN pipeline for multi-class classification of Eye OCT (Optical Coherence Tomography) images using PyTorch.

**Notebook Structure:**
1. Environment Setup & Imports
2. Central Configuration
3. Dataset Loading
4. Dataset Sanity Checks

---

# üîß Local GPU Configuration (24 GB + 5.5 GB Dataset)

This cell configures the notebook for optimal performance on a local high-end GPU.

**Optimizations for large dataset:**
- GPU memory management (95% allocation cap)
- CUDA optimizations for RTX/Quadro GPUs
- Streaming-friendly data loading
- Stage-specific batch sizes

In [None]:
# =============================================================================
# LOCAL GPU ENVIRONMENT CONFIGURATION (24 GB + 5.5 GB Dataset)
# =============================================================================
# Optimized for local high-end GPUs with large datasets

import torch
import os

# =============================================================================
# STEP A: GPU Detection and Validation
# =============================================================================

print("=" * 70)
print("LOCAL GPU CONFIGURATION (24 GB + 5.5 GB Dataset)")
print("=" * 70)

# Check CUDA availability
cuda_available = torch.cuda.is_available()
print(f"CUDA Available: {cuda_available}")

if cuda_available:
    gpu_name = torch.cuda.get_device_name(0)
    gpu_props = torch.cuda.get_device_properties(0)
    gpu_memory = gpu_props.total_memory / 1e9
    print(f"GPU Name: {gpu_name}")
    print(f"GPU Memory: {gpu_memory:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Multi-Processor Count: {gpu_props.multi_processor_count}")
    print(f"Compute Capability: {gpu_props.major}.{gpu_props.minor}")
else:
    print("‚ö†Ô∏è  WARNING: No GPU detected!")
    print("   This notebook requires a GPU for efficient training.")

# Assert CUDA is available (will stop execution if not)
assert cuda_available, "‚ùå CRITICAL: CUDA is required but not available."
print("‚úì GPU assertion passed")

# =============================================================================
# STEP C: CUDA Performance Optimizations (Local GPU)
# =============================================================================

# Allow cuDNN to auto-tune for best algorithm (CRITICAL for large batches)
torch.backends.cudnn.benchmark = True

# Non-deterministic for performance (set to True only for debugging)
torch.backends.cudnn.deterministic = False

# Reserve 95% of GPU memory to prevent allocation failures during spikes
# This leaves ~1.2 GB buffer on a 24 GB GPU
torch.cuda.set_per_process_memory_fraction(0.95, device=0)

print("\nCUDA Optimizations:")
print(f"  cudnn.deterministic: {torch.backends.cudnn.deterministic}")
print(f"  cudnn.benchmark: {torch.backends.cudnn.benchmark}")
print(f"  Memory fraction: 95% (leaves ~{gpu_memory * 0.05:.1f} GB buffer)")

# =============================================================================
# GPU Info Cell (STEP C)
# =============================================================================

def print_gpu_memory():
    """Print current GPU memory usage."""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated(0) / 1e9
        reserved = torch.cuda.memory_reserved(0) / 1e9
        total = torch.cuda.get_device_properties(0).total_memory / 1e9
        free = total - reserved
        print(f"GPU Memory ‚Äî Allocated: {allocated:.2f} GB | Reserved: {reserved:.2f} GB | Free: {free:.2f} GB")
        return allocated, reserved, free
    return 0, 0, 0

print("\nInitial GPU Memory Status:")
print_gpu_memory()

# =============================================================================
# Global Debug Flag (STEP H: Large Dataset Visualization Policy)
# =============================================================================
# DEBUG=True enables visualizations BUT limited to:
# - Only ONE batch (not full dataset)
# - Only FIRST and LAST epochs
# This prevents memory spikes from large visualizations

DEBUG = True  # Enabled for local development, but with safeguards

print(f"\nDebug Mode: {DEBUG}")
print("  ‚Üí Visualizations enabled but LIMITED:")
print("  ‚Üí Only 1 batch visualized per stage")
print("  ‚Üí Only first & last epoch plots")

# =============================================================================
# Local Output Directory (STEP I)
# =============================================================================
# All outputs saved locally with checkpoint support

SAVE_DIR = "./outputs_local_5gb"

# Create output subdirectories
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(os.path.join(SAVE_DIR, "models"), exist_ok=True)
os.makedirs(os.path.join(SAVE_DIR, "plots"), exist_ok=True)
os.makedirs(os.path.join(SAVE_DIR, "logs"), exist_ok=True)
os.makedirs(os.path.join(SAVE_DIR, "latent_cache"), exist_ok=True)  # For disk-backed features

print(f"\nOutput Directory: {SAVE_DIR}")
print("  ‚Üí models/       - Model checkpoints")
print("  ‚Üí plots/        - Training plots")
print("  ‚Üí logs/         - Training logs")
print("  ‚Üí latent_cache/ - Disk-backed latent features (5.5 GB dataset safety)")

print("\n" + "=" * 70)
print("‚úì LOCAL GPU CONFIGURATION COMPLETE")
print("=" * 70)

In [None]:
# =============================================================================
# PRE-FLIGHT VALIDATION: Estimate Memory Requirements
# =============================================================================
# This cell estimates GPU memory requirements BEFORE training starts
# and warns if memory is likely to exceed available capacity.

def estimate_memory_requirements():
    """
    Estimate GPU memory requirements for CDBN training.
    
    Returns estimated memory in GB and prints warnings if too high.
    """
    print("=" * 70)
    print("PRE-FLIGHT VALIDATION: Memory Estimation")
    print("=" * 70)
    
    # Constants for estimation
    IMAGE_SIZE = (128, 128)
    BYTES_PER_FLOAT = 4  # float32
    
    # Batch sizes
    rbm_batch = 32      # RBM_BATCH_SIZE
    fc_batch = 64       # FC_RBM_BATCH_SIZE
    clf_batch = 128     # CLASSIFIER_BATCH_SIZE
    
    # Architecture sizes
    conv1_filters = 32
    conv1_kernel = 7
    conv2_filters = 64
    conv2_kernel = 5
    fc_hidden = 256
    
    # Calculate output sizes
    h1_size = (IMAGE_SIZE[0] - conv1_kernel + 1, IMAGE_SIZE[1] - conv1_kernel + 1)  # 122x122
    pool_size = (h1_size[0] // 2, h1_size[1] // 2)  # 61x61
    h2_size = (pool_size[0] - conv2_kernel + 1, pool_size[1] - conv2_kernel + 1)  # 57x57
    flat_size = conv2_filters * h2_size[0] * h2_size[1]  # 64 * 57 * 57
    
    # Memory per batch (in bytes)
    # Conv-RBM-1 training batch
    input_mem = rbm_batch * 1 * IMAGE_SIZE[0] * IMAGE_SIZE[1] * BYTES_PER_FLOAT
    h1_mem = rbm_batch * conv1_filters * h1_size[0] * h1_size[1] * BYTES_PER_FLOAT
    conv1_weights = conv1_filters * 1 * conv1_kernel * conv1_kernel * BYTES_PER_FLOAT
    conv1_velocity = conv1_weights  # Momentum velocity tensor
    
    # Conv-RBM-2 training batch
    h1_pooled_mem = rbm_batch * conv1_filters * pool_size[0] * pool_size[1] * BYTES_PER_FLOAT
    h2_mem = rbm_batch * conv2_filters * h2_size[0] * h2_size[1] * BYTES_PER_FLOAT
    conv2_weights = conv2_filters * conv1_filters * conv2_kernel * conv2_kernel * BYTES_PER_FLOAT
    conv2_velocity = conv2_weights
    
    # FC-RBM training
    flat_mem = fc_batch * flat_size * BYTES_PER_FLOAT
    fc_weights = flat_size * fc_hidden * BYTES_PER_FLOAT  # This is BIG!
    fc_velocity = fc_weights
    
    # Peak memory estimates (in GB)
    conv1_peak = (input_mem + h1_mem + conv1_weights + conv1_velocity) / (1024**3)
    conv2_peak = (h1_pooled_mem + h2_mem + conv2_weights + conv2_velocity) / (1024**3)
    fc_peak = (flat_mem + fc_weights + fc_velocity) / (1024**3)
    
    # PyTorch overhead (fragmentation, cudnn workspace, etc.)
    overhead_factor = 1.5
    total_estimated = max(conv1_peak, conv2_peak, fc_peak) * overhead_factor
    
    print(f"\nEstimated Peak GPU Memory per Stage:")
    print(f"  Conv-RBM-1: ~{conv1_peak:.2f} GB (batch={rbm_batch})")
    print(f"  Conv-RBM-2: ~{conv2_peak:.2f} GB (batch={rbm_batch})")
    print(f"  FC-RBM:     ~{fc_peak:.2f} GB (batch={fc_batch})")
    print(f"\nFC-RBM Weights: {flat_size:,} √ó {fc_hidden} = {flat_size * fc_hidden:,} parameters")
    print(f"FC-RBM Weight Size: {fc_weights / (1024**3):.2f} GB")
    
    # Check against available GPU memory
    if torch.cuda.is_available():
        gpu_total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        gpu_available = (torch.cuda.get_device_properties(0).total_memory - 
                        torch.cuda.memory_allocated()) / (1024**3)
        
        print(f"\nGPU Memory Status:")
        print(f"  Total GPU Memory:     {gpu_total:.2f} GB")
        print(f"  Available (current):  {gpu_available:.2f} GB")
        print(f"  Estimated Peak (√ó{overhead_factor}): {total_estimated:.2f} GB")
        
        # Safety check
        if total_estimated > gpu_total * 0.95:
            print(f"\n‚ö†Ô∏è WARNING: Estimated memory ({total_estimated:.1f} GB) may exceed ")
            print(f"   available GPU memory ({gpu_total:.1f} GB)!")
            print(f"   Consider reducing batch sizes or using gradient checkpointing.")
        elif total_estimated > gpu_total * 0.75:
            print(f"\n‚ö° CAUTION: Memory usage will be high ({total_estimated:.1f} GB / {gpu_total:.1f} GB)")
            print(f"   Monitor with print_gpu_memory() during training.")
        else:
            print(f"\n‚úì Memory estimates look safe for 24 GB GPU")
    else:
        print("\n‚ö†Ô∏è No CUDA GPU detected - running on CPU")
    
    print("=" * 70)
    return total_estimated

# Run pre-flight validation
estimated_peak_memory = estimate_memory_requirements()

## STEP 1 ‚Äî Environment Setup & Imports

Import only the necessary libraries and configure the computing environment.

In [None]:
# =============================================================================
# STEP 1: Environment Setup & Imports
# =============================================================================

import os
import random

# Core deep learning libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Vision utilities
import torchvision
from torchvision import datasets, transforms

# Numerical and visualization
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------------------------------------------------------
# Set Global Random Seeds for Reproducibility
# -----------------------------------------------------------------------------
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Set CUDA seeds if GPU is available
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)  # For multi-GPU setups
    # NOTE: cudnn settings are configured in the Kaggle Configuration cell above
    # deterministic=False and benchmark=True for Kaggle performance
    # These settings override the default reproducibility settings for speed

# -----------------------------------------------------------------------------
# Device Configuration
# -----------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("=" * 60)
print("ENVIRONMENT CONFIGURATION")
print("=" * 60)
print(f"PyTorch Version    : {torch.__version__}")
print(f"Torchvision Version: {torchvision.__version__}")
print(f"NumPy Version      : {np.__version__}")
print(f"Random Seed        : {SEED}")
print(f"Device             : {device}")

if torch.cuda.is_available():
    print(f"GPU Name           : {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version       : {torch.version.cuda}")
    print(f"GPU Memory         : {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("GPU                : Not available, using CPU")

print("=" * 60)

## STEP 2 ‚Äî Central Configuration

All hyperparameters and paths are centralized here for easy modification.

In [None]:
# =============================================================================
# STEP 2: Central Configuration (Local GPU + 5.5 GB Dataset)
# =============================================================================

class Config:
    """
    Central configuration class for the CDBN OCT Classification pipeline.
    
    LOCAL GPU (24 GB) + 5.5 GB DATASET OPTIMIZATIONS:
    -------------------------------------------------
    - Stage-specific batch sizes (STEP D)
    - Streaming-friendly DataLoader settings (STEP B)
    - No full-dataset caching (STEP F)
    - Disk-backed feature storage
    """
    
    # -------------------------------------------------------------------------
    # Dataset Paths (STEP B: Local Dataset Path)
    # -------------------------------------------------------------------------
    # TODO: Replace with your actual local dataset path
    # Example: "C:/Datasets/OCT2017" or "/home/user/datasets/OCT"
    DATASET_ROOT = "C:/path/to/your/OCT/dataset"  # <-- UPDATE THIS PATH
    
    TRAIN_DIR = os.path.join(DATASET_ROOT, "train")
    VAL_DIR = os.path.join(DATASET_ROOT, "val")
    TEST_DIR = os.path.join(DATASET_ROOT, "test")
    
    # -------------------------------------------------------------------------
    # Output Directory (Local saving with checkpoints)
    # -------------------------------------------------------------------------
    OUTPUT_DIR = SAVE_DIR  # Uses ./outputs_local_5gb from config cell
    LATENT_CACHE_DIR = os.path.join(SAVE_DIR, "latent_cache")  # Disk-backed features
    
    # -------------------------------------------------------------------------
    # Image Parameters
    # -------------------------------------------------------------------------
    IMAGE_SIZE = (128, 128)      # (height, width) - resize all images to this
    NUM_CHANNELS = 1              # Grayscale images
    
    # -------------------------------------------------------------------------
    # STEP D: Stage-Specific Batch Sizes (VERY IMPORTANT)
    # -------------------------------------------------------------------------
    # RBMs need smaller batches due to:
    # 1. unfold() operation creates large intermediate tensors
    # 2. Gibbs sampling stores multiple states in memory
    # 3. Momentum buffers consume GPU memory
    #
    # Classifier can use larger batches (no Gibbs sampling)
    
    RBM_BATCH_SIZE = 32        # Conv-RBM-1 & Conv-RBM-2 (uses ~10-12 GB)
    FC_RBM_BATCH_SIZE = 64     # FC-RBM (lighter memory footprint)
    CLASSIFIER_BATCH_SIZE = 128  # Classifier training (most efficient)
    
    # Default batch size (for dataset loading)
    BATCH_SIZE = RBM_BATCH_SIZE
    
    # -------------------------------------------------------------------------
    # DataLoader Parameters (STEP B: Streaming-Friendly Settings)
    # -------------------------------------------------------------------------
    # NUM_WORKERS=4: Local machine has more CPU cores than Kaggle
    # Parallel data loading prevents GPU idle time
    NUM_WORKERS = 4
    
    # PIN_MEMORY=True: Speeds up host-to-GPU transfers
    # Pre-pins memory on CPU for faster CUDA copies
    PIN_MEMORY = True if torch.cuda.is_available() else False
    
    # PERSISTENT_WORKERS=True: Keeps DataLoader workers alive between epochs
    # Reduces worker spawn overhead for large datasets (~5.5 GB)
    PERSISTENT_WORKERS = True
    
    # PREFETCH_FACTOR=2: Pre-loads 2 batches per worker
    # Ensures GPU never waits for data (streaming safety)
    PREFETCH_FACTOR = 2
    
    # -------------------------------------------------------------------------
    # Device Configuration
    # -------------------------------------------------------------------------
    DEVICE = device
    
    # -------------------------------------------------------------------------
    # Number of Classes (will be inferred from dataset)
    # -------------------------------------------------------------------------
    NUM_CLASSES = None  # To be set after loading dataset
    CLASS_NAMES = None  # To be set after loading dataset


# Create global config instance
config = Config()

# Display configuration
print("=" * 60)
print("CENTRAL CONFIGURATION (Local GPU + 5.5 GB Dataset)")
print("=" * 60)
print(f"Dataset Root       : {config.DATASET_ROOT}")
print(f"Output Directory   : {config.OUTPUT_DIR}")
print(f"Latent Cache       : {config.LATENT_CACHE_DIR}")
print(f"Image Size         : {config.IMAGE_SIZE}")
print(f"Number of Channels : {config.NUM_CHANNELS}")
print(f"\nStage-Specific Batch Sizes (STEP D):")
print(f"  RBM Batch Size       : {config.RBM_BATCH_SIZE}")
print(f"  FC-RBM Batch Size    : {config.FC_RBM_BATCH_SIZE}")
print(f"  Classifier Batch Size: {config.CLASSIFIER_BATCH_SIZE}")
print(f"\nDataLoader Settings (Streaming-Safe):")
print(f"  Num Workers        : {config.NUM_WORKERS}")
print(f"  Pin Memory         : {config.PIN_MEMORY}")
print(f"  Persistent Workers : {config.PERSISTENT_WORKERS}")
print(f"  Prefetch Factor    : {config.PREFETCH_FACTOR}")
print(f"\nDevice             : {config.DEVICE}")
print("=" * 60)

# =============================================================================
# STEP B: Dataset Path Sanity Check
# =============================================================================
print("\n" + "=" * 60)
print("LOCAL DATASET SANITY CHECK")
print("=" * 60)

print(f"\nDataset Root: {config.DATASET_ROOT}")
if os.path.exists(config.DATASET_ROOT):
    # Calculate dataset size
    total_size = 0
    file_count = 0
    for dirpath, dirnames, filenames in os.walk(config.DATASET_ROOT):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            total_size += os.path.getsize(fp)
            file_count += 1
    
    print(f"‚úì Dataset found!")
    print(f"  Total files: {file_count:,}")
    print(f"  Total size: {total_size / 1e9:.2f} GB")
    
    # List subdirectories
    for item in os.listdir(config.DATASET_ROOT):
        item_path = os.path.join(config.DATASET_ROOT, item)
        if os.path.isdir(item_path):
            subdir_files = sum([len(files) for _, _, files in os.walk(item_path)])
            print(f"  üìÅ {item}/ ({subdir_files:,} files)")
else:
    print("‚ö†Ô∏è  WARNING: Dataset path does not exist!")
    print("   Please update Config.DATASET_ROOT with your local path")
    print(f"   Expected: {config.DATASET_ROOT}")
    
print("=" * 60)

## STEP 3 ‚Äî Dataset Loading

Load the OCT dataset using `torchvision.datasets.ImageFolder`:
- Convert images to grayscale (1 channel)
- Resize to configured dimensions
- Normalize pixel values to [0, 1] range
- Create DataLoaders for all splits

In [None]:
# =============================================================================
# STEP 3: Dataset Loading
# =============================================================================

# -----------------------------------------------------------------------------
# Define Image Transformations
# -----------------------------------------------------------------------------
# Transform pipeline:
#   1. Convert to grayscale (handles both RGB and grayscale inputs)
#   2. Resize to target dimensions
#   3. Convert to tensor (scales to [0, 1] automatically for uint8 images)

data_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=config.NUM_CHANNELS),  # Ensure grayscale
    transforms.Resize(config.IMAGE_SIZE),                           # Resize to (128, 128)
    transforms.ToTensor(),                                          # Convert to tensor, scale to [0, 1]
])

print("Data Transforms Pipeline:")
print(data_transforms)
print()

In [None]:
# -----------------------------------------------------------------------------
# Load Datasets using ImageFolder
# -----------------------------------------------------------------------------
# ImageFolder expects: root/class_name/image.ext

print("Loading datasets...")

# Training dataset
train_dataset = datasets.ImageFolder(
    root=config.TRAIN_DIR,
    transform=data_transforms
)

# Validation dataset
val_dataset = datasets.ImageFolder(
    root=config.VAL_DIR,
    transform=data_transforms
)

# Test dataset
test_dataset = datasets.ImageFolder(
    root=config.TEST_DIR,
    transform=data_transforms
)

print("‚úì All datasets loaded successfully!")
print()

In [None]:
# -----------------------------------------------------------------------------
# Create DataLoaders (Streaming-Safe for 5.5 GB Dataset)
# -----------------------------------------------------------------------------
# Using persistent_workers and prefetch_factor for efficient streaming
# This avoids loading the full dataset into RAM at once

train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,                    # Shuffle training data
    num_workers=config.NUM_WORKERS,
    pin_memory=config.PIN_MEMORY,
    drop_last=False,
    persistent_workers=config.PERSISTENT_WORKERS if config.NUM_WORKERS > 0 else False,
    prefetch_factor=config.PREFETCH_FACTOR if config.NUM_WORKERS > 0 else None
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,                   # No shuffle for validation
    num_workers=config.NUM_WORKERS,
    pin_memory=config.PIN_MEMORY,
    drop_last=False,
    persistent_workers=config.PERSISTENT_WORKERS if config.NUM_WORKERS > 0 else False,
    prefetch_factor=config.PREFETCH_FACTOR if config.NUM_WORKERS > 0 else None
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,                   # No shuffle for test
    num_workers=config.NUM_WORKERS,
    pin_memory=config.PIN_MEMORY,
    drop_last=False,
    persistent_workers=config.PERSISTENT_WORKERS if config.NUM_WORKERS > 0 else False,
    prefetch_factor=config.PREFETCH_FACTOR if config.NUM_WORKERS > 0 else None
)

print("‚úì DataLoaders created successfully!")
print(f"  - Persistent workers: {config.PERSISTENT_WORKERS}")
print(f"  - Prefetch factor: {config.PREFETCH_FACTOR}")
print(f"  - Pin memory: {config.PIN_MEMORY}")
print()

In [None]:
# -----------------------------------------------------------------------------
# Update Configuration with Inferred Values
# -----------------------------------------------------------------------------

# Infer number of classes from the training dataset
config.NUM_CLASSES = len(train_dataset.classes)
config.CLASS_NAMES = train_dataset.classes

print("=" * 60)
print("DATASET CONFIGURATION UPDATED")
print("=" * 60)
print(f"Number of Classes  : {config.NUM_CLASSES}")
print(f"Class Names        : {config.CLASS_NAMES}")
print("=" * 60)

## STEP 4 ‚Äî Dataset Sanity Checks

Verify the dataset loading was successful:
- Print sample counts per split
- Display class mappings
- Inspect batch tensor shapes
- Visualize sample images

In [None]:
# =============================================================================
# STEP 4: Dataset Sanity Checks
# =============================================================================

# -----------------------------------------------------------------------------
# 4.1 Print Number of Samples per Split
# -----------------------------------------------------------------------------

print("=" * 60)
print("DATASET STATISTICS")
print("=" * 60)
print(f"{'Split':<12} {'Samples':>10} {'Batches':>10}")
print("-" * 35)
print(f"{'Train':<12} {len(train_dataset):>10} {len(train_loader):>10}")
print(f"{'Validation':<12} {len(val_dataset):>10} {len(val_loader):>10}")
print(f"{'Test':<12} {len(test_dataset):>10} {len(test_loader):>10}")
print("-" * 35)
print(f"{'Total':<12} {len(train_dataset) + len(val_dataset) + len(test_dataset):>10}")
print("=" * 60)

In [None]:
# -----------------------------------------------------------------------------
# 4.2 Print Class Names and Class-to-Index Mapping
# -----------------------------------------------------------------------------

print("=" * 60)
print("CLASS INFORMATION")
print("=" * 60)
print(f"\nClass Names: {train_dataset.classes}")
print(f"\nClass-to-Index Mapping:")
for class_name, class_idx in train_dataset.class_to_idx.items():
    print(f"  {class_name:<20} -> {class_idx}")
print("=" * 60)

# Count samples per class in training set
print("\nSamples per Class (Training Set):")
print("-" * 35)
class_counts = {}
for _, label in train_dataset.samples:
    class_name = train_dataset.classes[label]
    class_counts[class_name] = class_counts.get(class_name, 0) + 1

for class_name, count in sorted(class_counts.items()):
    print(f"  {class_name:<20} : {count:>6} samples")
print("-" * 35)

In [None]:
# -----------------------------------------------------------------------------
# 4.3 Fetch One Batch and Print Tensor Shape
# -----------------------------------------------------------------------------

print("=" * 60)
print("BATCH INSPECTION")
print("=" * 60)

# Get one batch from training loader
sample_batch_images, sample_batch_labels = next(iter(train_loader))

print(f"\nBatch Images Tensor:")
print(f"  Shape      : {sample_batch_images.shape}")
print(f"  Data Type  : {sample_batch_images.dtype}")
print(f"  Min Value  : {sample_batch_images.min().item():.4f}")
print(f"  Max Value  : {sample_batch_images.max().item():.4f}")
print(f"  Mean Value : {sample_batch_images.mean().item():.4f}")

print(f"\nBatch Labels Tensor:")
print(f"  Shape      : {sample_batch_labels.shape}")
print(f"  Data Type  : {sample_batch_labels.dtype}")
print(f"  Labels     : {sample_batch_labels.tolist()}")

# Verify expected shape
expected_shape = (config.BATCH_SIZE, config.NUM_CHANNELS, config.IMAGE_SIZE[0], config.IMAGE_SIZE[1])
actual_shape = tuple(sample_batch_images.shape)

print(f"\nShape Verification:")
print(f"  Expected   : {expected_shape}")
print(f"  Actual     : {actual_shape}")

# Note: Last batch might be smaller if dataset size is not divisible by batch size
if actual_shape[1:] == expected_shape[1:]:
    print("  Status     : ‚úì Shape matches expected dimensions!")
else:
    print("  Status     : ‚úó Shape mismatch detected!")

print("=" * 60)

In [None]:
# -----------------------------------------------------------------------------
# 4.4 Visualize One Batch of Images with Labels
# -----------------------------------------------------------------------------

def visualize_batch(images, labels, class_names, num_images=16, figsize=(12, 12)):
    """
    Visualize a batch of images in a grid layout.
    
    Args:
        images: Tensor of shape (batch_size, channels, height, width)
        labels: Tensor of class indices
        class_names: List of class name strings
        num_images: Number of images to display (default: 16)
        figsize: Figure size tuple
    """
    # Limit to available images and maximum display count
    num_images = min(num_images, len(images))
    
    # Calculate grid dimensions
    grid_size = int(np.ceil(np.sqrt(num_images)))
    
    # Create figure
    fig, axes = plt.subplots(grid_size, grid_size, figsize=figsize)
    axes = axes.flatten() if grid_size > 1 else [axes]
    
    for idx in range(len(axes)):
        ax = axes[idx]
        
        if idx < num_images:
            # Get image and label
            img = images[idx].squeeze().numpy()  # Remove channel dim for grayscale
            label = labels[idx].item()
            class_name = class_names[label]
            
            # Display image
            ax.imshow(img, cmap='gray')
            ax.set_title(f'{class_name}\n(class {label})', fontsize=10)
        
        ax.axis('off')
    
    plt.suptitle('Sample Batch from Training Set\n(OCT Grayscale Images)', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


# Visualize the sample batch (only if DEBUG mode is enabled)
if DEBUG:
    print("Visualizing sample batch from training set...")
    visualize_batch(
        images=sample_batch_images,
        labels=sample_batch_labels,
        class_names=config.CLASS_NAMES,
        num_images=16,
        figsize=(12, 12)
    )
else:
    print("‚è≠Ô∏è  Skipping batch visualization (DEBUG=False)")
    print("   Set DEBUG=True in Kaggle Configuration cell to enable visualizations")

In [None]:
# -----------------------------------------------------------------------------
# Summary: Data Loaders and Configuration
# -----------------------------------------------------------------------------

print("\n" + "=" * 60)
print("SUMMARY: READY FOR NEXT STEPS")
print("=" * 60)
print("\nData Loaders Available:")
print(f"  ‚Ä¢ train_loader : {len(train_loader)} batches")
print(f"  ‚Ä¢ val_loader   : {len(val_loader)} batches")
print(f"  ‚Ä¢ test_loader  : {len(test_loader)} batches")

print("\nConfiguration Object (config):")
print(f"  ‚Ä¢ NUM_CLASSES  : {config.NUM_CLASSES}")
print(f"  ‚Ä¢ CLASS_NAMES  : {config.CLASS_NAMES}")
print(f"  ‚Ä¢ IMAGE_SIZE   : {config.IMAGE_SIZE}")
print(f"  ‚Ä¢ BATCH_SIZE   : {config.BATCH_SIZE}")
print(f"  ‚Ä¢ DEVICE       : {config.DEVICE}")

print("\nData Format:")
print(f"  ‚Ä¢ Input Shape  : ({config.BATCH_SIZE}, {config.NUM_CHANNELS}, {config.IMAGE_SIZE[0]}, {config.IMAGE_SIZE[1]})")
print(f"  ‚Ä¢ Pixel Range  : [0.0, 1.0]")
print(f"  ‚Ä¢ Data Type    : torch.float32")
print("=" * 60)

In [None]:
# =============================================================================
# STEP G: Final Sanity Summary Cell (Pre-Training Validation)
# =============================================================================
# This cell validates all critical conditions before training begins.
# Execution will STOP if any condition fails.

print("\n" + "=" * 70)
print("üîç KAGGLE PRE-TRAINING SANITY CHECK")
print("=" * 70)

# Collect all critical info
sanity_checks = []

# 1. Device check
print(f"\n1. Device: {config.DEVICE}")
if config.DEVICE.type == 'cuda':
    print(f"   GPU Name: {torch.cuda.get_device_name(0)}")
    sanity_checks.append(("GPU Available", True))
else:
    print("   ‚ö†Ô∏è  WARNING: Running on CPU - training will be very slow!")
    sanity_checks.append(("GPU Available", False))

# 2. Batch size check
print(f"\n2. Batch Size: {config.BATCH_SIZE}")
if config.BATCH_SIZE <= 16:
    print("   ‚úì Kaggle-safe batch size")
    sanity_checks.append(("Batch Size Safe", True))
else:
    print("   ‚ö†Ô∏è  WARNING: Batch size > 16 may cause OOM on Kaggle")
    sanity_checks.append(("Batch Size Safe", False))

# 3. Image size check
print(f"\n3. Image Size: {config.IMAGE_SIZE}")
sanity_checks.append(("Image Size Valid", config.IMAGE_SIZE[0] > 0 and config.IMAGE_SIZE[1] > 0))

# 4. Training samples check
print(f"\n4. Training Samples: {len(train_dataset)}")
if len(train_dataset) > 0:
    sanity_checks.append(("Training Data", True))
else:
    sanity_checks.append(("Training Data", False))

# 5. Class names check
print(f"\n5. Classes ({config.NUM_CLASSES}):")
for i, name in enumerate(config.CLASS_NAMES):
    print(f"   [{i}] {name}")
sanity_checks.append(("Classes Detected", config.NUM_CLASSES > 0))

# 6. Output directory check
print(f"\n6. Output Directory: {config.OUTPUT_DIR}")
output_writable = os.access(config.OUTPUT_DIR, os.W_OK)
if output_writable:
    print("   ‚úì Directory is writable")
    sanity_checks.append(("Output Writable", True))
else:
    print("   ‚úó Directory is NOT writable!")
    sanity_checks.append(("Output Writable", False))

# 7. Debug mode status
print(f"\n7. Debug Mode: {DEBUG}")
if not DEBUG:
    print("   ‚Üí Visualizations minimized for production run")

# Summary
print("\n" + "=" * 70)
print("SANITY CHECK SUMMARY")
print("=" * 70)

all_passed = True
for check_name, passed in sanity_checks:
    status = "‚úì PASS" if passed else "‚úó FAIL"
    print(f"  {check_name:<20} : {status}")
    if not passed and check_name in ["Training Data", "Classes Detected", "Output Writable"]:
        all_passed = False

print("=" * 70)

# Critical assertion
if not all_passed:
    raise RuntimeError("‚ùå CRITICAL: Sanity check failed! Please fix the issues above before training.")

print("\n‚úÖ ALL CRITICAL CHECKS PASSED - Ready for training!")
print("=" * 70)

---

# Dataset Loaded Successfully ‚Äî Ready for CDBN Pretraining

‚úÖ **Environment configured** with PyTorch and GPU support (if available)

‚úÖ **Central configuration** established with all hyperparameters

‚úÖ **Datasets loaded** for train, validation, and test splits

‚úÖ **DataLoaders created** with proper batching and shuffling

‚úÖ **Sanity checks passed** ‚Äî tensor shapes and pixel ranges verified

---

**Next Steps:**
- STEP 5: Implement Restricted Boltzmann Machine (RBM) layer
- STEP 6: Implement Convolutional RBM (CRBM) layer  
- STEP 7: Stack layers to form the Convolutional Deep Belief Network (CDBN)
- STEP 8: Greedy layer-wise pretraining
- STEP 9: Fine-tuning with supervised classifier

## STEP 5 ‚Äî Convolutional Restricted Boltzmann Machine (Conv-RBM) Module

Implement a Conv-RBM as a building block for the CDBN:

**Architecture:**
- **Visible units:** Gaussian (real-valued OCT pixels in [0, 1])
- **Hidden units:** Bernoulli (binary activations)
- **Weight sharing:** Via convolutional kernels
- **No pooling** in this layer

**Energy Function (Gaussian-Bernoulli RBM):**

$$E(v, h) = \sum_i \frac{(v_i - a_i)^2}{2\sigma^2} - \sum_j b_j h_j - \sum_{i,j} \frac{v_i}{\sigma^2} W_{ij} h_j$$

where $\sigma = 1$ (unit variance) for normalized inputs.

In [None]:
# =============================================================================
# STEP 5: Convolutional Restricted Boltzmann Machine (Conv-RBM) Module
# =============================================================================

class ConvRBM(nn.Module):
    """
    Convolutional Restricted Boltzmann Machine (Conv-RBM) with:
        - Gaussian visible units (real-valued, for normalized image pixels)
        - Bernoulli hidden units (binary activations)
        - Weight sharing via 2D convolution
    
    This is a generative model that learns to reconstruct input images
    while capturing hierarchical features in the hidden layer.
    
    Architecture:
        visible (v) <---> hidden (h)
        
        v: [batch, visible_channels, H, W]  - Input image
        h: [batch, hidden_channels, H', W'] - Hidden feature maps
        
        where H' = H - kernel_size + 1, W' = W - kernel_size + 1 (valid convolution)
    
    Energy Function (Gaussian-Bernoulli):
        E(v, h) = sum_i (v_i - a_i)^2 / (2œÉ¬≤) - sum_j b_j * h_j - sum_ij (v_i / œÉ¬≤) * W_ij * h_j
        
        With œÉ = 1 (unit variance) for [0, 1] normalized inputs.
    
    Conditional Distributions:
        P(h_j = 1 | v) = sigmoid(b_j + sum_i W_ij * v_i)  [Bernoulli]
        P(v_i | h)     = N(a_i + sum_j W_ij * h_j, œÉ¬≤)    [Gaussian with œÉ=1]
    
    Args:
        visible_channels (int): Number of input channels (1 for grayscale)
        hidden_channels (int): Number of hidden feature maps to learn
        kernel_size (int): Size of the convolutional kernel (square)
        
    Attributes:
        W: Convolutional weights [hidden_channels, visible_channels, kernel_size, kernel_size]
        h_bias: Hidden bias [hidden_channels] - one bias per feature map
        v_bias: Visible bias [visible_channels] - one bias per input channel
    """
    
    def __init__(
        self,
        visible_channels: int = 1,
        hidden_channels: int = 32,
        kernel_size: int = 7
    ):
        super(ConvRBM, self).__init__()
        
        # Store configuration
        self.visible_channels = visible_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        
        # ---------------------------------------------------------------------
        # Learnable Parameters
        # ---------------------------------------------------------------------
        
        # Convolutional weights: W
        # Shape: [hidden_channels, visible_channels, kernel_size, kernel_size]
        # Initialized with small random values (Xavier/Glorot-like scaling)
        self.W = nn.Parameter(
            torch.randn(hidden_channels, visible_channels, kernel_size, kernel_size) * 0.01
        )
        
        # Hidden bias: b_j (one per feature map)
        # Shape: [hidden_channels]
        # Initialized to zero
        self.h_bias = nn.Parameter(
            torch.zeros(hidden_channels)
        )
        
        # Visible bias: a_i (one per input channel)
        # Shape: [visible_channels]
        # Initialized to zero
        self.v_bias = nn.Parameter(
            torch.zeros(visible_channels)
        )
        
    def hidden_probabilities(self, v: torch.Tensor) -> torch.Tensor:
        """
        Compute hidden unit activation probabilities given visible units.
        
        P(h_j = 1 | v) = sigmoid(b_j + conv(v, W))
        
        The convolution implements the sum: sum_i W_ij * v_i
        with weight sharing across spatial locations.
        
        Args:
            v: Visible units tensor
               Shape: [batch_size, visible_channels, H, W]
               
        Returns:
            h_prob: Hidden activation probabilities
                    Shape: [batch_size, hidden_channels, H', W']
                    where H' = H - kernel_size + 1 (valid convolution)
        """
        # Convolve visible with weights (valid convolution, no padding)
        # Input:  [batch, visible_channels, H, W]
        # Weight: [hidden_channels, visible_channels, kernel_size, kernel_size]
        # Output: [batch, hidden_channels, H - kernel_size + 1, W - kernel_size + 1]
        conv_out = nn.functional.conv2d(v, self.W, bias=None, padding=0)
        
        # Add hidden bias (broadcast across spatial dimensions)
        # h_bias shape: [hidden_channels] -> [1, hidden_channels, 1, 1]
        h_bias_expanded = self.h_bias.view(1, -1, 1, 1)
        
        # Pre-activation: b_j + conv(v, W)
        pre_activation = conv_out + h_bias_expanded
        
        # Apply sigmoid to get probabilities (Bernoulli hidden units)
        h_prob = torch.sigmoid(pre_activation)
        
        return h_prob
    
    def sample_hidden(self, v: torch.Tensor) -> tuple:
        """
        Sample binary hidden states from visible units.
        
        h ~ Bernoulli(P(h=1|v))
        
        Args:
            v: Visible units tensor
               Shape: [batch_size, visible_channels, H, W]
               
        Returns:
            h_prob: Hidden activation probabilities
                    Shape: [batch_size, hidden_channels, H', W']
            h_sample: Binary hidden samples (0 or 1)
                      Shape: [batch_size, hidden_channels, H', W']
        """
        # Get hidden probabilities
        h_prob = self.hidden_probabilities(v)
        
        # Sample from Bernoulli distribution
        # Each unit is independently set to 1 with probability h_prob
        h_sample = torch.bernoulli(h_prob)
        
        return h_prob, h_sample
    
    def visible_probabilities(self, h: torch.Tensor) -> torch.Tensor:
        """
        Reconstruct visible units (mean) given hidden units.
        
        For Gaussian visible units with unit variance:
        E[v_i | h] = a_i + conv_transpose(h, W)
        
        The transposed convolution implements: sum_j W_ij * h_j
        
        Args:
            h: Hidden units tensor (probabilities or samples)
               Shape: [batch_size, hidden_channels, H', W']
               
        Returns:
            v_mean: Reconstructed visible mean
                    Shape: [batch_size, visible_channels, H, W]
                    where H = H' + kernel_size - 1
        """
        # Transposed convolution to upsample hidden to visible space
        # Input:  [batch, hidden_channels, H', W']
        # Weight: [hidden_channels, visible_channels, kernel_size, kernel_size]
        # Output: [batch, visible_channels, H' + kernel_size - 1, W' + kernel_size - 1]
        conv_transpose_out = nn.functional.conv_transpose2d(h, self.W, bias=None, padding=0)
        
        # Add visible bias (broadcast across spatial dimensions)
        # v_bias shape: [visible_channels] -> [1, visible_channels, 1, 1]
        v_bias_expanded = self.v_bias.view(1, -1, 1, 1)
        
        # Reconstructed visible mean: a_i + conv_transpose(h, W)
        v_mean = conv_transpose_out + v_bias_expanded
        
        return v_mean
    
    def sample_visible(self, h: torch.Tensor) -> tuple:
        """
        Sample visible units from hidden units (Gaussian sampling).
        
        v ~ N(E[v|h], œÉ¬≤) where œÉ = 1 (unit variance)
        
        For normalized inputs [0, 1], we use unit variance Gaussian.
        The mean is computed from visible_probabilities().
        
        Args:
            h: Hidden units tensor (probabilities or samples)
               Shape: [batch_size, hidden_channels, H', W']
               
        Returns:
            v_mean: Reconstructed visible mean
                    Shape: [batch_size, visible_channels, H, W]
            v_sample: Gaussian visible samples
                      Shape: [batch_size, visible_channels, H, W]
        """
        # Get visible mean (reconstruction)
        v_mean = self.visible_probabilities(h)
        
        # Sample from Gaussian: v = mean + noise (œÉ = 1)
        # Add standard Gaussian noise
        v_sample = v_mean + torch.randn_like(v_mean)
        
        return v_mean, v_sample
    
    def forward(self, v: torch.Tensor) -> tuple:
        """
        Forward pass: compute hidden probabilities and samples.
        
        This is the inference direction: visible -> hidden
        
        Args:
            v: Input visible units (batch of images)
               Shape: [batch_size, visible_channels, H, W]
               
        Returns:
            h_prob: Hidden activation probabilities
                    Shape: [batch_size, hidden_channels, H', W']
            h_sample: Binary hidden samples
                      Shape: [batch_size, hidden_channels, H', W']
        """
        h_prob, h_sample = self.sample_hidden(v)
        return h_prob, h_sample
    
    def get_output_size(self, input_size: tuple) -> tuple:
        """
        Calculate hidden layer spatial dimensions given input size.
        
        For valid convolution: output_size = input_size - kernel_size + 1
        
        Args:
            input_size: (H, W) of input images
            
        Returns:
            output_size: (H', W') of hidden feature maps
        """
        H, W = input_size
        H_out = H - self.kernel_size + 1
        W_out = W - self.kernel_size + 1
        return (H_out, W_out)
    
    def __repr__(self):
        return (
            f"ConvRBM(\n"
            f"  visible_channels={self.visible_channels},\n"
            f"  hidden_channels={self.hidden_channels},\n"
            f"  kernel_size={self.kernel_size},\n"
            f"  W shape={tuple(self.W.shape)},\n"
            f"  h_bias shape={tuple(self.h_bias.shape)},\n"
            f"  v_bias shape={tuple(self.v_bias.shape)}\n"
            f")"
        )


print("‚úì ConvRBM class defined successfully!")

### Conv-RBM Verification

Test the ConvRBM module to ensure correct tensor shapes and functionality.

In [None]:
# =============================================================================
# Conv-RBM Verification
# =============================================================================

print("=" * 70)
print("CONV-RBM MODULE VERIFICATION")
print("=" * 70)

# -----------------------------------------------------------------------------
# 5.1 Instantiate ConvRBM
# -----------------------------------------------------------------------------

# Create Conv-RBM instance with specified parameters
conv_rbm = ConvRBM(
    visible_channels=config.NUM_CHANNELS,  # 1 (grayscale)
    hidden_channels=32,                     # 32 feature maps
    kernel_size=7                           # 7√ó7 kernels
)

# Move to device
conv_rbm = conv_rbm.to(config.DEVICE)

print("\nConv-RBM Instance:")
print(conv_rbm)

print(f"\nDevice: {next(conv_rbm.parameters()).device}")

# -----------------------------------------------------------------------------
# 5.2 Parameter Summary
# -----------------------------------------------------------------------------

print("\n" + "-" * 50)
print("PARAMETER SUMMARY")
print("-" * 50)

total_params = 0
for name, param in conv_rbm.named_parameters():
    num_params = param.numel()
    total_params += num_params
    print(f"  {name:<10} : shape={tuple(param.shape):<25} params={num_params:,}")

print(f"\n  Total trainable parameters: {total_params:,}")

# -----------------------------------------------------------------------------
# 5.3 Test with One Batch from DataLoader
# -----------------------------------------------------------------------------

print("\n" + "-" * 50)
print("FORWARD PASS TEST")
print("-" * 50)

# Get one batch from training loader
test_batch_v, test_batch_labels = next(iter(train_loader))
test_batch_v = test_batch_v.to(config.DEVICE)

print(f"\nInput (visible):")
print(f"  Shape       : {test_batch_v.shape}")
print(f"  Expected    : [batch_size, {config.NUM_CHANNELS}, {config.IMAGE_SIZE[0]}, {config.IMAGE_SIZE[1]}]")
print(f"  Device      : {test_batch_v.device}")
print(f"  Dtype       : {test_batch_v.dtype}")
print(f"  Value range : [{test_batch_v.min().item():.4f}, {test_batch_v.max().item():.4f}]")

# Compute expected output size
expected_h_size = conv_rbm.get_output_size(config.IMAGE_SIZE)
print(f"\nExpected hidden spatial size: {expected_h_size}")

# Forward pass (inference: v -> h)
with torch.no_grad():
    h_prob, h_sample = conv_rbm(test_batch_v)

print(f"\nHidden Probabilities (h_prob):")
print(f"  Shape       : {h_prob.shape}")
print(f"  Expected    : [{test_batch_v.shape[0]}, 32, {expected_h_size[0]}, {expected_h_size[1]}]")
print(f"  Value range : [{h_prob.min().item():.4f}, {h_prob.max().item():.4f}]")
print(f"  Mean        : {h_prob.mean().item():.4f}")

print(f"\nHidden Samples (h_sample):")
print(f"  Shape       : {h_sample.shape}")
print(f"  Unique vals : {torch.unique(h_sample).tolist()}")  # Should be [0, 1]
print(f"  Sparsity    : {(h_sample == 0).float().mean().item():.2%} zeros")

# -----------------------------------------------------------------------------
# 5.4 Test Reconstruction (h -> v)
# -----------------------------------------------------------------------------

print("\n" + "-" * 50)
print("RECONSTRUCTION TEST (h -> v)")
print("-" * 50)

with torch.no_grad():
    # Reconstruct visible from hidden probabilities
    v_mean, v_sample = conv_rbm.sample_visible(h_prob)

print(f"\nReconstructed Visible Mean (v_mean):")
print(f"  Shape       : {v_mean.shape}")
print(f"  Expected    : {test_batch_v.shape}")
print(f"  Value range : [{v_mean.min().item():.4f}, {v_mean.max().item():.4f}]")

print(f"\nReconstructed Visible Sample (v_sample):")
print(f"  Shape       : {v_sample.shape}")
print(f"  Value range : [{v_sample.min().item():.4f}, {v_sample.max().item():.4f}]")

# -----------------------------------------------------------------------------
# 5.5 Shape Assertions
# -----------------------------------------------------------------------------

print("\n" + "-" * 50)
print("SHAPE ASSERTIONS")
print("-" * 50)

# Check hidden shape
expected_h_shape = (test_batch_v.shape[0], 32, expected_h_size[0], expected_h_size[1])
assert h_prob.shape == expected_h_shape, f"Hidden prob shape mismatch: {h_prob.shape} vs {expected_h_shape}"
assert h_sample.shape == expected_h_shape, f"Hidden sample shape mismatch: {h_sample.shape} vs {expected_h_shape}"
print(f"  ‚úì Hidden shapes correct: {h_prob.shape}")

# Check reconstruction shape matches input
assert v_mean.shape == test_batch_v.shape, f"Reconstruction shape mismatch: {v_mean.shape} vs {test_batch_v.shape}"
assert v_sample.shape == test_batch_v.shape, f"Reconstruction sample shape mismatch"
print(f"  ‚úì Reconstruction shapes correct: {v_mean.shape}")

# Check hidden values are probabilities (between 0 and 1)
assert h_prob.min() >= 0 and h_prob.max() <= 1, "Hidden probabilities out of [0, 1] range"
print(f"  ‚úì Hidden probabilities in valid range [0, 1]")

# Check hidden samples are binary
assert set(torch.unique(h_sample).tolist()).issubset({0.0, 1.0}), "Hidden samples not binary"
print(f"  ‚úì Hidden samples are binary {{0, 1}}")

print("\n" + "=" * 70)
print("ALL VERIFICATIONS PASSED!")
print("=" * 70)

In [None]:
# -----------------------------------------------------------------------------
# 5.6 Visualize Input vs Reconstruction
# -----------------------------------------------------------------------------

def visualize_reconstruction(original, reconstructed, num_images=4, figsize=(12, 5)):
    """
    Visualize original images alongside their reconstructions.
    
    Args:
        original: Original input tensor [batch, 1, H, W]
        reconstructed: Reconstructed tensor [batch, 1, H, W]
        num_images: Number of image pairs to show
        figsize: Figure size
    """
    num_images = min(num_images, len(original))
    
    fig, axes = plt.subplots(2, num_images, figsize=figsize)
    
    for i in range(num_images):
        # Original
        orig_img = original[i].squeeze().cpu().numpy()
        axes[0, i].imshow(orig_img, cmap='gray')
        axes[0, i].set_title(f'Original {i+1}')
        axes[0, i].axis('off')
        
        # Reconstruction
        recon_img = reconstructed[i].squeeze().cpu().numpy()
        # Clip to valid range for visualization
        recon_img = np.clip(recon_img, 0, 1)
        axes[1, i].imshow(recon_img, cmap='gray')
        axes[1, i].set_title(f'Recon {i+1}')
        axes[1, i].axis('off')
    
    axes[0, 0].set_ylabel('Original', fontsize=12)
    axes[1, 0].set_ylabel('Reconstructed', fontsize=12)
    
    plt.suptitle('Conv-RBM: Input vs Reconstruction (Untrained)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


# Visualize (reconstruction will be poor since RBM is untrained)
print("Visualizing input vs reconstruction (untrained RBM - expect poor reconstruction):")
visualize_reconstruction(test_batch_v, v_mean, num_images=4)

---

# Conv-RBM Module Implemented ‚Äî Ready for Contrastive Divergence

‚úÖ **ConvRBM class** implemented with Gaussian visible / Bernoulli hidden units

‚úÖ **Core methods verified:**
- `hidden_probabilities(v)` ‚Äî Computes P(h|v) via convolution + sigmoid
- `sample_hidden(v)` ‚Äî Bernoulli sampling of hidden states
- `visible_probabilities(h)` ‚Äî Reconstructs visible mean via transposed convolution
- `sample_visible(h)` ‚Äî Gaussian sampling with unit variance
- `forward(v)` ‚Äî Returns hidden probabilities and samples

‚úÖ **Shape verification passed** ‚Äî Input ‚Üí Hidden ‚Üí Reconstruction cycle works correctly

‚úÖ **Weight sharing** implemented via `nn.functional.conv2d` and `conv_transpose2d`

---

**Key Dimensions:**
- Input: `[batch, 1, 128, 128]`
- Hidden: `[batch, 32, 122, 122]` (after 7√ó7 valid convolution)
- Reconstruction: `[batch, 1, 128, 128]` (matches input)

**Next Steps:**
- STEP 6: Implement Contrastive Divergence (CD-k) training algorithm
- STEP 7: Train the Conv-RBM layer on OCT images
- STEP 8: Stack layers to form the full CDBN

## STEP 6 ‚Äî Contrastive Divergence (CD-k) Trainer

Implement the CD-k algorithm for unsupervised training of the Conv-RBM.

**Contrastive Divergence Algorithm:**

The gradient of the log-likelihood is approximated by:

$$\frac{\partial \log P(v)}{\partial W} \approx \langle v \cdot h^T \rangle_{\text{data}} - \langle v \cdot h^T \rangle_{\text{model}}$$

**CD-k Steps:**
1. **Positive Phase:** Clamp visible units to data, sample hidden states
2. **Gibbs Sampling:** Run k steps of alternating sampling (h ‚Üí v ‚Üí h)
3. **Negative Phase:** Use reconstructed visible and hidden for model statistics
4. **Update:** $\Delta W = \eta \cdot (\langle v_0 h_0 \rangle - \langle v_k h_k \rangle)$

In [None]:
# =============================================================================
# STEP 6: Contrastive Divergence (CD-k) Trainer
# =============================================================================

class CDTrainer:
    """
    Contrastive Divergence (CD-k) Trainer for Convolutional RBM.
    
    CD-k is an approximate method for training RBMs without computing the 
    intractable partition function. It estimates the gradient of the 
    log-likelihood using k steps of Gibbs sampling.
    
    Algorithm:
        1. POSITIVE PHASE (data-driven):
           - Clamp v0 = input data
           - Compute h0_prob = P(h|v0)
           - Sample h0 ~ Bernoulli(h0_prob)
           
        2. GIBBS SAMPLING (k steps):
           For i in range(k):
               - v_i = sample_visible(h_{i-1})
               - h_i = sample_hidden(v_i)
           
        3. NEGATIVE PHASE (model-driven):
           - Use vk and hk from Gibbs chain
           
        4. PARAMETER UPDATES:
           ŒîW = lr * (‚ü®v0 ‚äó h0‚ü© - ‚ü®vk ‚äó hk‚ü©)
           Œîh_bias = lr * (‚ü®h0‚ü© - ‚ü®hk‚ü©)  
           Œîv_bias = lr * (‚ü®v0‚ü© - ‚ü®vk‚ü©)
           
           where ‚äó denotes the correlation computed via convolution
    
    Args:
        rbm: ConvRBM instance to train
        learning_rate: Learning rate for parameter updates
        k: Number of Gibbs sampling steps (default: 1)
        device: Device for computation
        momentum: Momentum coefficient for updates (default: 0.0)
        weight_decay: L2 regularization coefficient (default: 0.0)
    """
    
    def __init__(
        self,
        rbm: ConvRBM,
        learning_rate: float = 0.01,
        k: int = 1,
        device: torch.device = None,
        momentum: float = 0.0,
        weight_decay: float = 0.0001
    ):
        self.rbm = rbm
        self.lr = learning_rate
        self.k = k
        self.device = device if device else torch.device('cpu')
        self.momentum = momentum
        self.weight_decay = weight_decay
        
        # Initialize velocity terms for momentum-based updates
        # These accumulate the running average of gradients
        self.W_velocity = torch.zeros_like(rbm.W.data)
        self.h_bias_velocity = torch.zeros_like(rbm.h_bias.data)
        self.v_bias_velocity = torch.zeros_like(rbm.v_bias.data)
        
    def train_batch(self, v0: torch.Tensor) -> float:
        """
        Train the RBM on a single batch using CD-k.
        
        This method performs one complete CD-k update:
        1. Positive phase with clamped data
        2. k steps of Gibbs sampling  
        3. Negative phase with fantasy particles
        4. Parameter updates using the difference
        
        Args:
            v0: Input visible batch (real OCT images)
                Shape: [batch_size, channels, height, width]
                
        Returns:
            reconstruction_loss: MSE between v0 and vk (reconstruction error)
        """
        batch_size = v0.shape[0]
        
        # Move to device
        v0 = v0.to(self.device)
        
        # =====================================================================
        # POSITIVE PHASE (Data-Driven Statistics)
        # =====================================================================
        # The positive phase captures the correlation between data and hidden
        # units when the visible units are clamped to the training data.
        #
        # Math: ‚ü®v_i h_j‚ü©_data = E_data[v_i * P(h_j=1|v)]
        
        # Compute hidden probabilities given visible data
        # h0_prob: P(h=1|v0), shape: [batch, hidden_channels, H', W']
        h0_prob = self.rbm.hidden_probabilities(v0)
        
        # NUMERICAL STABILITY: Clamp probabilities to valid range
        h0_prob = torch.clamp(h0_prob, min=1e-7, max=1.0 - 1e-7)
        
        # Sample hidden states from probabilities (Bernoulli sampling)
        # h0_sample: binary hidden states {0, 1}
        h0_sample = torch.bernoulli(h0_prob)
        
        # Compute positive phase correlation (data statistics)
        # For convolutional RBM, this is computed via convolution:
        # positive_W = conv2d(v0, h0_prob) averaged over batch
        #
        # This computes: sum over all (i,j) spatial positions of v0_i * h0_j
        # with weight sharing across spatial locations
        positive_W = self._compute_weight_gradient(v0, h0_prob)
        
        # Positive phase for biases (spatial average)
        # h_bias gradient: mean activation per feature map
        positive_h_bias = h0_prob.mean(dim=(0, 2, 3))  # [hidden_channels]
        
        # v_bias gradient: mean activation per visible channel
        positive_v_bias = v0.mean(dim=(0, 2, 3))  # [visible_channels]
        
        # =====================================================================
        # GIBBS SAMPLING (k steps)
        # =====================================================================
        # Run k steps of alternating Gibbs sampling to get "fantasy particles"
        # that represent the model's equilibrium distribution.
        #
        # Each step: h -> v -> h (block Gibbs sampling)
        
        # Start Gibbs chain from the sampled hidden states
        hk_sample = h0_sample
        
        for step in range(self.k):
            # -------------------------------------------------------------
            # Step 1: Sample visible given hidden (v ~ P(v|h))
            # -------------------------------------------------------------
            # For Gaussian visible units: v = mean + noise
            # Mean: E[v|h] = v_bias + conv_transpose(h, W)
            vk_mean = self.rbm.visible_probabilities(hk_sample)
            
            # For Gaussian-Bernoulli RBM, we can use the mean directly
            # (sampling adds noise which can hurt learning)
            # vk_sample = vk_mean + torch.randn_like(vk_mean)  # Noisy version
            vk_sample = vk_mean  # Use mean (deterministic, more stable)
            
            # NUMERICAL STABILITY: Clamp reconstructed visible to valid range
            vk_sample = torch.clamp(vk_sample, min=-5.0, max=5.0)
            
            # -------------------------------------------------------------
            # Step 2: Sample hidden given visible (h ~ P(h|v))  
            # -------------------------------------------------------------
            # For Bernoulli hidden units: h ~ Bernoulli(sigmoid(h_bias + conv(v, W)))
            hk_prob = self.rbm.hidden_probabilities(vk_sample)
            
            # NUMERICAL STABILITY: Clamp probabilities to valid range
            hk_prob = torch.clamp(hk_prob, min=1e-7, max=1.0 - 1e-7)
            
            # Sample hidden (except on last step where we use probabilities)
            if step < self.k - 1:
                hk_sample = torch.bernoulli(hk_prob)
            # On last step, keep hk_prob for gradient computation
        
        # =====================================================================
        # NEGATIVE PHASE (Model-Driven Statistics)
        # =====================================================================
        # The negative phase captures the correlation when the network runs
        # freely (fantasy particles from Gibbs sampling).
        #
        # Math: ‚ü®v_i h_j‚ü©_model ‚âà E_model[v_i * P(h_j=1|v)]
        
        # Use probabilities (not samples) for more stable gradients
        # This is the "probability-based" CD variant
        negative_W = self._compute_weight_gradient(vk_sample, hk_prob)
        
        # Negative phase for biases
        negative_h_bias = hk_prob.mean(dim=(0, 2, 3))  # [hidden_channels]
        negative_v_bias = vk_sample.mean(dim=(0, 2, 3))  # [visible_channels]
        
        # =====================================================================
        # PARAMETER UPDATES
        # =====================================================================
        # Update rule: Œ∏_new = Œ∏_old + lr * (positive - negative) - weight_decay * Œ∏
        #
        # With momentum: v_new = momentum * v_old + gradient
        #                Œ∏_new = Œ∏_old + lr * v_new
        
        # Compute gradients (positive - negative)
        W_grad = positive_W - negative_W
        h_bias_grad = positive_h_bias - negative_h_bias
        v_bias_grad = positive_v_bias - negative_v_bias
        
        # NUMERICAL STABILITY: Clip gradients to prevent explosion
        max_grad = 1.0
        W_grad = torch.clamp(W_grad, min=-max_grad, max=max_grad)
        h_bias_grad = torch.clamp(h_bias_grad, min=-max_grad, max=max_grad)
        v_bias_grad = torch.clamp(v_bias_grad, min=-max_grad, max=max_grad)
        
        # Apply weight decay (L2 regularization)
        # This penalizes large weights to prevent overfitting
        W_grad -= self.weight_decay * self.rbm.W.data
        
        # Update velocity with momentum
        self.W_velocity = self.momentum * self.W_velocity + W_grad
        self.h_bias_velocity = self.momentum * self.h_bias_velocity + h_bias_grad
        self.v_bias_velocity = self.momentum * self.v_bias_velocity + v_bias_grad
        
        # Apply updates to parameters (no autograd, manual update)
        with torch.no_grad():
            self.rbm.W.data += self.lr * self.W_velocity
            self.rbm.h_bias.data += self.lr * self.h_bias_velocity
            self.rbm.v_bias.data += self.lr * self.v_bias_velocity
        
        # =====================================================================
        # RECONSTRUCTION LOSS
        # =====================================================================
        # Compute MSE between original input and reconstruction
        # This is a proxy for how well the RBM can reconstruct the input
        reconstruction_loss = nn.functional.mse_loss(vk_sample, v0).item()
        
        return reconstruction_loss
    
    def _compute_weight_gradient(
        self, 
        v: torch.Tensor, 
        h: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute the weight gradient via cross-correlation (memory-efficient).
        
        For convolutional RBM, the weight gradient is:
            ‚àÇE/‚àÇW[hc,vc,i,j] = (1/B) * (1/H'W') * Œ£_b Œ£_x,y h[b,hc,x,y] * v[b,vc,x+i,y+j]
            
        Uses chunked processing to reduce peak memory usage.
        
        Args:
            v: Visible units [batch, v_channels, H, W]
            h: Hidden units/probabilities [batch, h_channels, H', W']
            
        Returns:
            W_grad: Gradient for weights [h_channels, v_channels, kernel, kernel]
        """
        batch_size = v.shape[0]
        K = self.rbm.kernel_size
        h_channels = self.rbm.hidden_channels
        v_channels = self.rbm.visible_channels
        
        # Get spatial dimensions of hidden layer
        h_height, h_width = h.shape[2], h.shape[3]
        spatial_size = h_height * h_width  # For normalization
        
        # Initialize gradient accumulator
        W_grad = torch.zeros(
            h_channels, v_channels, K, K,
            device=v.device, dtype=v.dtype
        )
        
        # Process one sample at a time to save memory
        for b in range(batch_size):
            # Extract patches for single sample: [v_ch, H', W', K, K]
            v_unfold_b = v[b].unfold(1, K, 1).unfold(2, K, 1)  # [v_ch, H', W', K, K]
            h_b = h[b]  # [h_ch, H', W']
            
            # Accumulate: W_grad[o,v,i,j] += sum_hw h[o,h,w] * v_patch[v,h,w,i,j]
            W_grad += torch.einsum('ohw,vhwij->ovij', h_b, v_unfold_b)
        
        # Normalize by batch size AND spatial size for numerical stability
        W_grad /= (batch_size * spatial_size)
        
        return W_grad
    
    def get_reconstruction(self, v: torch.Tensor) -> torch.Tensor:
        """
        Get reconstruction of visible input (for visualization).
        
        Args:
            v: Input visible units
            
        Returns:
            v_recon: Reconstructed visible units
        """
        with torch.no_grad():
            v = v.to(self.device)
            h_prob = self.rbm.hidden_probabilities(v)
            v_recon = self.rbm.visible_probabilities(h_prob)
        return v_recon


print("‚úì CDTrainer class defined successfully!")

## STEP 7 ‚Äî Unsupervised Pretraining Loop

Train the Conv-RBM-1 layer using Contrastive Divergence on OCT images.

**Training Configuration:**
- Unsupervised learning (labels ignored)
- CD-1 (single Gibbs step)
- Track reconstruction loss per epoch

In [None]:
# =============================================================================
# STEP 7: Unsupervised Pretraining Loop for Conv-RBM-1
# =============================================================================

# -----------------------------------------------------------------------------
# Training Configuration (STEP C: Kaggle-Safe Values)
# -----------------------------------------------------------------------------

# Conv-RBM-1 hyperparameters
CONVRBM1_CONFIG = {
    'visible_channels': config.NUM_CHANNELS,  # 1 (grayscale)
    'hidden_channels': 32,                     # Number of learned filters
    'kernel_size': 7,                          # 7√ó7 kernels
}

# CD Training hyperparameters (Kaggle-optimized)
# EPOCHS: Reduced from 10 to 5-7 for Kaggle time limits (~9 hour sessions)
# CD-k=1: Single Gibbs step is fastest and usually sufficient
# LR: Reduced for numerical stability with large spatial dimensions
CD_CONFIG = {
    'learning_rate': 0.001,    # Reduced from 0.01 for stability
    'k': 1,                    # CD-1 (single Gibbs step - fastest)
    'momentum': 0.5,           # Momentum coefficient
    'weight_decay': 0.0001,    # L2 regularization
    'num_epochs': 5,           # Reduced for Kaggle (was 10)
}

print("=" * 70)
print("CONV-RBM-1 PRETRAINING CONFIGURATION (Kaggle-Optimized)")
print("=" * 70)
print("\nConv-RBM Architecture:")
for key, value in CONVRBM1_CONFIG.items():
    print(f"  {key:<20}: {value}")
print("\nCD Training Parameters:")
for key, value in CD_CONFIG.items():
    print(f"  {key:<20}: {value}")
print("\n‚ö° Note: Epochs reduced to 5 for Kaggle time limits")
print("=" * 70)

In [None]:
# -----------------------------------------------------------------------------
# Initialize Conv-RBM-1 and CD Trainer
# -----------------------------------------------------------------------------

# Create fresh Conv-RBM-1 instance
conv_rbm_1 = ConvRBM(
    visible_channels=CONVRBM1_CONFIG['visible_channels'],
    hidden_channels=CONVRBM1_CONFIG['hidden_channels'],
    kernel_size=CONVRBM1_CONFIG['kernel_size']
).to(config.DEVICE)

print("Conv-RBM-1 Architecture:")
print(conv_rbm_1)

# Initialize CD Trainer
cd_trainer = CDTrainer(
    rbm=conv_rbm_1,
    learning_rate=CD_CONFIG['learning_rate'],
    k=CD_CONFIG['k'],
    device=config.DEVICE,
    momentum=CD_CONFIG['momentum'],
    weight_decay=CD_CONFIG['weight_decay']
)

print(f"\nCD Trainer initialized with k={CD_CONFIG['k']} Gibbs steps")
print(f"Device: {config.DEVICE}")

In [None]:
# -----------------------------------------------------------------------------
# Unsupervised Pretraining Loop (Memory-Optimized for Local 24GB GPU)
# -----------------------------------------------------------------------------

def train_convrbm(
    rbm: ConvRBM,
    trainer: CDTrainer,
    train_loader: DataLoader,
    num_epochs: int,
    device: torch.device,
    memory_cleanup_freq: int = 100  # Clear cache every N batches
) -> dict:
    """
    Train a Conv-RBM using Contrastive Divergence (Memory-Optimized).
    
    This is UNSUPERVISED learning - we only use the images, not the labels.
    The RBM learns to reconstruct the input distribution.
    
    Optimized for large datasets on local GPU with aggressive memory cleanup.
    
    Args:
        rbm: ConvRBM instance to train
        trainer: CDTrainer instance
        train_loader: DataLoader for training data
        num_epochs: Number of training epochs
        device: Device for computation
        memory_cleanup_freq: Frequency of memory cleanup (batches)
        
    Returns:
        history: Dictionary containing training history
    """
    history = {
        'epoch_loss': [],           # Average loss per epoch
        'batch_losses': [],         # All batch losses (for detailed analysis)
    }
    
    print("\n" + "=" * 70)
    print("STARTING UNSUPERVISED PRETRAINING (Memory-Optimized)")
    print("=" * 70)
    print(f"Training on {len(train_loader.dataset)} samples")
    print(f"Batch size: {train_loader.batch_size}")
    print(f"Batches per epoch: {len(train_loader)}")
    print(f"Total epochs: {num_epochs}")
    print(f"Memory cleanup frequency: every {memory_cleanup_freq} batches")
    if torch.cuda.is_available():
        print_gpu_memory()
    print("-" * 70)
    
    # Training loop
    for epoch in range(num_epochs):
        epoch_losses = []
        
        # Progress tracking
        epoch_start = time.time()
        
        # Iterate over batches (UNSUPERVISED - ignore labels)
        for batch_idx, (images, _) in enumerate(train_loader):
            # Train on batch using CD-k
            # images: [batch_size, 1, 128, 128]
            # labels are ignored (unsupervised)
            batch_loss = trainer.train_batch(images)
            epoch_losses.append(batch_loss)
            
            # Explicit cleanup of batch data
            del images
            
            # Print progress every 100 batches
            if (batch_idx + 1) % 100 == 0:
                print(f"  Epoch {epoch+1}/{num_epochs} | "
                      f"Batch {batch_idx+1}/{len(train_loader)} | "
                      f"Loss: {batch_loss:.6f}")
            
            # Aggressive memory cleanup for large datasets
            if torch.cuda.is_available() and (batch_idx + 1) % memory_cleanup_freq == 0:
                torch.cuda.empty_cache()
        
        epoch_time = time.time() - epoch_start
        
        # Compute epoch statistics
        avg_epoch_loss = np.mean(epoch_losses)
        history['epoch_loss'].append(avg_epoch_loss)
        history['batch_losses'].extend(epoch_losses)
        
        # Clear batch losses list to free memory
        epoch_losses.clear()
        
        # Print epoch summary with GPU memory status
        print(f"Epoch {epoch+1:2d}/{num_epochs} | "
              f"Avg Loss: {avg_epoch_loss:.6f} | "
              f"Time: {epoch_time:.1f}s", end="")
        
        if torch.cuda.is_available():
            mem_alloc = torch.cuda.memory_allocated() / 1024**3
            print(f" | GPU: {mem_alloc:.2f} GB")
        else:
            print()
        
        # Clear cache after each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    print("-" * 70)
    print("PRETRAINING COMPLETED!")
    print(f"Final reconstruction loss: {history['epoch_loss'][-1]:.6f}")
    if torch.cuda.is_available():
        print_gpu_memory()
    print("=" * 70)
    
    return history


# Run pretraining
print("Starting Conv-RBM-1 pretraining...")
training_history = train_convrbm(
    rbm=conv_rbm_1,
    trainer=cd_trainer,
    train_loader=train_loader,
    num_epochs=CD_CONFIG['num_epochs'],
    device=config.DEVICE,
    memory_cleanup_freq=100  # Clean every 100 batches
)

In [None]:
# -----------------------------------------------------------------------------
# Plot Training Progress
# -----------------------------------------------------------------------------

def plot_training_history(history: dict, figsize=(14, 5)):
    """
    Plot the training history of the Conv-RBM.
    
    Args:
        history: Dictionary containing 'epoch_loss' and 'batch_losses'
        figsize: Figure size
    """
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    
    # Plot 1: Loss per epoch
    ax1 = axes[0]
    epochs = range(1, len(history['epoch_loss']) + 1)
    ax1.plot(epochs, history['epoch_loss'], 'b-o', linewidth=2, markersize=8)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Reconstruction Loss (MSE)', fontsize=12)
    ax1.set_title('Average Loss per Epoch', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.set_xticks(epochs)
    
    # Plot 2: Loss per batch (smoothed)
    ax2 = axes[1]
    batch_losses = history['batch_losses']
    ax2.plot(batch_losses, 'b-', alpha=0.3, linewidth=0.5, label='Raw')
    
    # Add smoothed line (moving average)
    window = min(50, len(batch_losses) // 10)
    if window > 1:
        smoothed = np.convolve(batch_losses, np.ones(window)/window, mode='valid')
        ax2.plot(range(window-1, len(batch_losses)), smoothed, 'r-', 
                 linewidth=2, label=f'Smoothed (window={window})')
    
    ax2.set_xlabel('Batch', fontsize=12)
    ax2.set_ylabel('Reconstruction Loss (MSE)', fontsize=12)
    ax2.set_title('Loss per Batch', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    
    plt.suptitle('Conv-RBM-1 Training Progress', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print("\nTraining Summary:")
    print(f"  Initial loss (epoch 1): {history['epoch_loss'][0]:.6f}")
    print(f"  Final loss (epoch {len(history['epoch_loss'])}): {history['epoch_loss'][-1]:.6f}")
    print(f"  Improvement: {(1 - history['epoch_loss'][-1]/history['epoch_loss'][0])*100:.1f}%")


# Plot training history (STEP D: Conditional visualization)
if DEBUG:
    plot_training_history(training_history)
else:
    # Print summary only without plots
    print("\nüìä Training Summary (DEBUG=False, plots disabled):")
    print(f"  Initial loss (epoch 1): {training_history['epoch_loss'][0]:.6f}")
    print(f"  Final loss (epoch {len(training_history['epoch_loss'])}): {training_history['epoch_loss'][-1]:.6f}")
    print(f"  Improvement: {(1 - training_history['epoch_loss'][-1]/training_history['epoch_loss'][0])*100:.1f}%")

## STEP 8 ‚Äî Visualization

### 8.1 Reconstruction Quality
Visualize how well the trained Conv-RBM-1 reconstructs OCT images.

In [None]:
# =============================================================================
# STEP 8.1: Reconstruction Quality Visualization
# =============================================================================

def visualize_reconstructions(
    rbm: ConvRBM,
    data_loader: DataLoader,
    trainer: CDTrainer,
    num_samples: int = 8,
    figsize: tuple = (16, 6)
):
    """
    Visualize original OCT images alongside their reconstructions.
    
    Args:
        rbm: Trained ConvRBM
        data_loader: DataLoader to get samples from
        trainer: CDTrainer (for reconstruction method)
        num_samples: Number of samples to display
        figsize: Figure size
    """
    # Get a batch of images
    images, labels = next(iter(data_loader))
    images = images[:num_samples].to(config.DEVICE)
    labels = labels[:num_samples]
    
    # Get reconstructions
    with torch.no_grad():
        reconstructions = trainer.get_reconstruction(images)
    
    # Create visualization
    fig, axes = plt.subplots(3, num_samples, figsize=figsize)
    
    for i in range(num_samples):
        # Original image
        orig = images[i].squeeze().cpu().numpy()
        axes[0, i].imshow(orig, cmap='gray', vmin=0, vmax=1)
        axes[0, i].set_title(f'{config.CLASS_NAMES[labels[i]]}', fontsize=9)
        axes[0, i].axis('off')
        
        # Reconstruction
        recon = reconstructions[i].squeeze().cpu().numpy()
        recon_clipped = np.clip(recon, 0, 1)
        axes[1, i].imshow(recon_clipped, cmap='gray', vmin=0, vmax=1)
        axes[1, i].axis('off')
        
        # Difference (error) map
        diff = np.abs(orig - recon_clipped)
        axes[2, i].imshow(diff, cmap='hot', vmin=0, vmax=0.5)
        axes[2, i].axis('off')
    
    # Row labels
    axes[0, 0].set_ylabel('Original', fontsize=12, fontweight='bold')
    axes[1, 0].set_ylabel('Reconstructed', fontsize=12, fontweight='bold')
    axes[2, 0].set_ylabel('|Difference|', fontsize=12, fontweight='bold')
    
    plt.suptitle('Conv-RBM-1: Original vs Reconstruction\n(After Unsupervised Pretraining)', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    # Compute overall reconstruction statistics
    with torch.no_grad():
        all_images, _ = next(iter(data_loader))
        all_images = all_images.to(config.DEVICE)
        all_recons = trainer.get_reconstruction(all_images)
        mse = nn.functional.mse_loss(all_recons, all_images).item()
        
    print(f"\nReconstruction Quality Metrics:")
    print(f"  MSE on sample batch: {mse:.6f}")
    print(f"  RMSE: {np.sqrt(mse):.6f}")


# Visualize reconstructions from training set
print("Visualizing reconstructions from training set:")
visualize_reconstructions(conv_rbm_1, train_loader, cd_trainer, num_samples=8)

In [None]:
# Visualize reconstructions from validation set (unseen during training)
print("\nVisualizing reconstructions from validation set (unseen data):")
visualize_reconstructions(conv_rbm_1, val_loader, cd_trainer, num_samples=8)

### 8.2 Learned Convolutional Filters
Visualize the 32 learned filters (7√ó7 kernels) from Conv-RBM-1.

These filters represent the features the RBM has learned to detect in OCT images.

In [None]:
# =============================================================================
# STEP 8.2: Learned Convolutional Filters Visualization
# =============================================================================

def visualize_filters(
    rbm: ConvRBM,
    num_filters: int = 16,
    figsize: tuple = (12, 12),
    normalize: bool = True
):
    """
    Visualize learned convolutional filters from the Conv-RBM.
    
    Each filter represents a pattern the RBM has learned to detect.
    For OCT images, these might include edges, textures, layer boundaries, etc.
    
    Args:
        rbm: Trained ConvRBM
        num_filters: Number of filters to display (default: 16)
        figsize: Figure size
        normalize: Whether to normalize each filter for visualization
    """
    # Get weights: [hidden_channels, visible_channels, kernel_h, kernel_w]
    weights = rbm.W.data.cpu().numpy()
    
    # For grayscale (1 channel), squeeze the visible channel dimension
    # weights shape: [32, 1, 7, 7] -> [32, 7, 7]
    if weights.shape[1] == 1:
        weights = weights.squeeze(1)
    
    num_filters = min(num_filters, weights.shape[0])
    grid_size = int(np.ceil(np.sqrt(num_filters)))
    
    fig, axes = plt.subplots(grid_size, grid_size, figsize=figsize)
    axes = axes.flatten()
    
    for i in range(len(axes)):
        ax = axes[i]
        
        if i < num_filters:
            # Get filter
            filt = weights[i]
            
            if normalize:
                # Normalize to [0, 1] for visualization
                filt_min = filt.min()
                filt_max = filt.max()
                if filt_max - filt_min > 1e-8:
                    filt = (filt - filt_min) / (filt_max - filt_min)
                else:
                    filt = np.zeros_like(filt) + 0.5
            
            ax.imshow(filt, cmap='gray')
            ax.set_title(f'Filter {i+1}', fontsize=9)
        
        ax.axis('off')
    
    plt.suptitle(f'Conv-RBM-1 Learned Filters ({num_filters} of {weights.shape[0]})\n'
                 f'Kernel Size: {rbm.kernel_size}√ó{rbm.kernel_size}',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    # Print filter statistics
    print("\nFilter Statistics:")
    print(f"  Total filters: {weights.shape[0]}")
    print(f"  Kernel size: {rbm.kernel_size}√ó{rbm.kernel_size}")
    print(f"  Weight range: [{weights.min():.4f}, {weights.max():.4f}]")
    print(f"  Weight mean: {weights.mean():.4f}")
    print(f"  Weight std: {weights.std():.4f}")


# Visualize first 16 filters (STEP D: Conditional visualization)
if DEBUG:
    print("Visualizing first 16 learned filters:")
    visualize_filters(conv_rbm_1, num_filters=16, figsize=(10, 10))
else:
    print("‚è≠Ô∏è  Skipping filter visualization (DEBUG=False)")

In [None]:
# Visualize all 32 filters (STEP D: Conditional visualization)
if DEBUG:
    print("\nVisualizing all 32 learned filters:")
    visualize_filters(conv_rbm_1, num_filters=32, figsize=(12, 12))
else:
    # Print filter statistics without visualization
    weights = conv_rbm_1.W.data.cpu().numpy()
    print("\nüìä Filter Statistics (DEBUG=False, plots disabled):")
    print(f"  Total filters: {weights.shape[0]}")
    print(f"  Kernel size: {conv_rbm_1.kernel_size}√ó{conv_rbm_1.kernel_size}")
    print(f"  Weight range: [{weights.min():.4f}, {weights.max():.4f}]")
    print(f"  Weight mean: {weights.mean():.4f}")
    print(f"  Weight std: {weights.std():.4f}")

In [None]:
# =============================================================================
# Visualize Hidden Activations for Sample Images
# =============================================================================

def visualize_hidden_activations(
    rbm: ConvRBM,
    images: torch.Tensor,
    num_images: int = 2,
    num_feature_maps: int = 8,
    figsize: tuple = (16, 8)
):
    """
    Visualize hidden layer activations for sample images.
    
    This shows what features the RBM detects in each input image.
    
    Args:
        rbm: Trained ConvRBM
        images: Input images tensor
        num_images: Number of images to show
        num_feature_maps: Number of feature maps to display per image
        figsize: Figure size
    """
    num_images = min(num_images, images.shape[0])
    images = images[:num_images].to(config.DEVICE)
    
    # Get hidden activations
    with torch.no_grad():
        h_prob = rbm.hidden_probabilities(images)
    
    # Create visualization
    fig, axes = plt.subplots(num_images, num_feature_maps + 1, figsize=figsize)
    
    for img_idx in range(num_images):
        # Show original image
        orig = images[img_idx].squeeze().cpu().numpy()
        axes[img_idx, 0].imshow(orig, cmap='gray')
        axes[img_idx, 0].set_title('Input' if img_idx == 0 else '', fontsize=10)
        axes[img_idx, 0].axis('off')
        
        # Show feature map activations
        for fm_idx in range(num_feature_maps):
            activation = h_prob[img_idx, fm_idx].cpu().numpy()
            axes[img_idx, fm_idx + 1].imshow(activation, cmap='viridis')
            if img_idx == 0:
                axes[img_idx, fm_idx + 1].set_title(f'FM {fm_idx+1}', fontsize=9)
            axes[img_idx, fm_idx + 1].axis('off')
    
    plt.suptitle('Hidden Layer Activations (Feature Maps)\n'
                 'Brighter regions = stronger activations',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


# Get sample images and visualize activations
sample_images, _ = next(iter(train_loader))
print("Visualizing hidden activations for sample images:")
visualize_hidden_activations(conv_rbm_1, sample_images, num_images=3, num_feature_maps=8)

In [None]:
# =============================================================================
# Summary: Conv-RBM-1 Training Results
# =============================================================================

print("\n" + "=" * 70)
print("CONV-RBM-1 PRETRAINING SUMMARY")
print("=" * 70)

print("\nArchitecture:")
print(f"  ‚Ä¢ Input shape: [{config.BATCH_SIZE}, {config.NUM_CHANNELS}, {config.IMAGE_SIZE[0]}, {config.IMAGE_SIZE[1]}]")
print(f"  ‚Ä¢ Hidden shape: [{config.BATCH_SIZE}, {CONVRBM1_CONFIG['hidden_channels']}, "
      f"{config.IMAGE_SIZE[0] - CONVRBM1_CONFIG['kernel_size'] + 1}, "
      f"{config.IMAGE_SIZE[1] - CONVRBM1_CONFIG['kernel_size'] + 1}]")
print(f"  ‚Ä¢ Kernel size: {CONVRBM1_CONFIG['kernel_size']}√ó{CONVRBM1_CONFIG['kernel_size']}")
print(f"  ‚Ä¢ Number of filters: {CONVRBM1_CONFIG['hidden_channels']}")

print("\nTraining Configuration:")
print(f"  ‚Ä¢ Learning rate: {CD_CONFIG['learning_rate']}")
print(f"  ‚Ä¢ Gibbs steps (k): {CD_CONFIG['k']}")
print(f"  ‚Ä¢ Momentum: {CD_CONFIG['momentum']}")
print(f"  ‚Ä¢ Weight decay: {CD_CONFIG['weight_decay']}")
print(f"  ‚Ä¢ Epochs: {CD_CONFIG['num_epochs']}")

print("\nTraining Results:")
print(f"  ‚Ä¢ Initial loss: {training_history['epoch_loss'][0]:.6f}")
print(f"  ‚Ä¢ Final loss: {training_history['epoch_loss'][-1]:.6f}")
improvement = (1 - training_history['epoch_loss'][-1]/training_history['epoch_loss'][0]) * 100
print(f"  ‚Ä¢ Improvement: {improvement:.1f}%")

print("\nLearned Weights Statistics:")
weights = conv_rbm_1.W.data.cpu().numpy()
print(f"  ‚Ä¢ Weight range: [{weights.min():.4f}, {weights.max():.4f}]")
print(f"  ‚Ä¢ Weight mean: {weights.mean():.6f}")
print(f"  ‚Ä¢ Weight std: {weights.std():.6f}")

print("\n" + "=" * 70)
print("Conv-RBM-1 is ready for use in the CDBN stack!")
print("=" * 70)

---

# Conv-RBM-1 Unsupervised Pretraining Completed

‚úÖ **CDTrainer class** implemented with Contrastive Divergence (CD-k) algorithm

‚úÖ **Unsupervised pretraining** completed on OCT training images
- Labels were NOT used (purely unsupervised)
- 10 epochs of CD-1 training
- Manual parameter updates (no PyTorch optimizers)

‚úÖ **Reconstruction quality** verified
- Training set reconstructions visualized
- Validation set reconstructions visualized  
- MSE reconstruction loss tracked

‚úÖ **Learned filters** visualized
- 32 convolutional filters (7√ó7)
- Filters show learned edge detectors and texture patterns

‚úÖ **Hidden activations** visualized
- Feature maps show what patterns the RBM detects
- Spatially localized responses to input features

---

**Key Components Created:**
- `CDTrainer` class: Implements CD-k with momentum and weight decay
- `conv_rbm_1`: Trained Conv-RBM with 32 filters
- `training_history`: Loss curves for analysis

**Next Steps:**
- STEP 9: Add probabilistic max-pooling layer
- STEP 10: Stack second Conv-RBM layer
- STEP 11: Build complete CDBN architecture
- STEP 12: Fine-tune with supervised classifier for OCT classification

In [None]:
# =============================================================================
# STEP E: Memory Cleanup After Conv-RBM-1 Training
# =============================================================================
# IMPORTANT: On Kaggle GPUs (T4/P100 with 16GB), we need to free memory
# between training stages to prevent OOM errors.
#
# We delete the CD trainer (which holds velocity buffers) but KEEP:
# - conv_rbm_1: The trained model (needed for inference)
# - training_history: The loss history (for analysis)

print("üßπ Memory Cleanup: Conv-RBM-1 Training Complete")
print("-" * 50)

# Delete the trainer (it holds large velocity tensors)
del cd_trainer

# Clear CUDA cache to free GPU memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    # Report memory status
    allocated = torch.cuda.memory_allocated() / 1e9
    reserved = torch.cuda.memory_reserved() / 1e9
    print(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")

print("‚úì CD Trainer deleted, CUDA cache cleared")
print("‚úì conv_rbm_1 retained for next layers")
print("-" * 50)

## STEP 9 ‚Äî Probabilistic Pooling Layer

Implement probabilistic pooling to reduce spatial dimensions while preserving probabilistic information.

**Key Differences from Max Pooling:**
- **Max pooling:** Takes the maximum value in each region (deterministic, loses probability information)
- **Probabilistic pooling:** Computes the probability that at least one unit is active in the region

**Probabilistic Pooling Rule:**
For a 2√ó2 pooling region with hidden probabilities $p_1, p_2, p_3, p_4$:

$$P(\text{pool active}) = 1 - \prod_{i=1}^{4}(1 - p_i) \approx \min(1, \sum_{i=1}^{4} p_i)$$

The sum approximation is used for computational efficiency, clipped to [0, 1].

In [None]:
# =============================================================================
# STEP 9: Probabilistic Pooling Layer
# =============================================================================

class ProbabilisticPooling(nn.Module):
    """
    Probabilistic Pooling Layer for Convolutional Deep Belief Networks.
    
    This layer performs spatial down-sampling of hidden activation probabilities
    while preserving probabilistic semantics, unlike traditional max pooling.
    
    WHY PROBABILISTIC POOLING INSTEAD OF MAX POOLING?
    -------------------------------------------------
    
    1. MAX POOLING (Traditional CNNs):
       - Takes the maximum activation in each pooling region
       - Loses information about other activations in the region
       - Not probabilistically meaningful (max of probabilities ‚â† probability)
       - Example: [0.9, 0.8, 0.7, 0.6] ‚Üí 0.9 (ignores other high activations)
    
    2. PROBABILISTIC POOLING (CDBNs):
       - Computes the probability that at least one unit in the region is active
       - Preserves the probabilistic interpretation of hidden units
       - Exact formula: P(pool) = 1 - ‚àè(1 - p_i) for all p_i in region
       - Approximation: P(pool) ‚âà min(1, Œ£p_i) (sum of probabilities, clipped)
       - Example: [0.9, 0.8, 0.7, 0.6] ‚Üí min(1.0, 3.0) = 1.0 (high confidence)
    
    3. KEY DIFFERENCES:
       - Max pooling: sparse, winner-take-all
       - Probabilistic pooling: aggregative, preserves uncertainty
       - Probabilistic pooling is more suitable for generative models (RBMs)
       - Better gradient flow during fine-tuning
    
    Args:
        pool_size (int): Size of the pooling window (default: 2 for 2√ó2 pooling)
        mode (str): Pooling mode - 'sum' (default) or 'prob' (exact probabilistic)
        
    Input:
        x: Hidden probabilities [batch, channels, H, W]
        
    Output:
        pooled: Pooled probabilities [batch, channels, H//pool_size, W//pool_size]
    """
    
    def __init__(self, pool_size: int = 2, mode: str = 'sum'):
        super(ProbabilisticPooling, self).__init__()
        
        self.pool_size = pool_size
        self.mode = mode
        
        # No learnable parameters in this layer
        # Pooling is a fixed operation that aggregates probabilities
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply probabilistic pooling to input probabilities.
        
        Args:
            x: Input tensor of probabilities [batch, channels, H, W]
               Values should be in range [0, 1]
               
        Returns:
            pooled: Pooled probabilities [batch, channels, H', W']
                    where H' = H // pool_size, W' = W // pool_size
        """
        if self.mode == 'sum':
            # SUM-BASED APPROXIMATION (fast and effective)
            # -----------------------------------------
            # Sum the probabilities in each pooling region and clip to [0, 1]
            # This approximates: P(at least one active) ‚âà sum(p_i)
            # Valid when probabilities are small (sum << 1)
            
            # Use average pooling and multiply by pool_size^2 to get sum
            # Then clip to [0, 1] range
            pooled = nn.functional.avg_pool2d(x, kernel_size=self.pool_size)
            pooled = pooled * (self.pool_size ** 2)  # Convert avg to sum
            pooled = torch.clamp(pooled, 0.0, 1.0)   # Clip to valid probability range
            
        elif self.mode == 'prob':
            # EXACT PROBABILISTIC POOLING
            # ---------------------------
            # P(pool active) = 1 - ‚àè(1 - p_i) for all p_i in the pooling region
            # This is the exact probability that at least one unit is active
            
            batch, channels, H, W = x.shape
            H_out = H // self.pool_size
            W_out = W // self.pool_size
            
            # Reshape to extract pooling regions
            # [batch, channels, H, W] -> [batch, channels, H_out, pool_size, W_out, pool_size]
            x_reshaped = x.view(batch, channels, H_out, self.pool_size, W_out, self.pool_size)
            
            # Compute (1 - p_i) for each element
            one_minus_p = 1.0 - x_reshaped
            
            # Product over the pooling region dimensions (dim 3 and 5)
            # ‚àè(1 - p_i) over all pool_size √ó pool_size elements
            prod_term = one_minus_p.prod(dim=3).prod(dim=-1)  # [batch, channels, H_out, W_out]
            
            # P(at least one active) = 1 - ‚àè(1 - p_i)
            pooled = 1.0 - prod_term
            
        else:
            raise ValueError(f"Unknown pooling mode: {self.mode}. Use 'sum' or 'prob'.")
        
        return pooled
    
    def get_output_size(self, input_size: tuple) -> tuple:
        """
        Calculate output spatial dimensions after pooling.
        
        Args:
            input_size: (H, W) of input feature maps
            
        Returns:
            output_size: (H // pool_size, W // pool_size)
        """
        H, W = input_size
        return (H // self.pool_size, W // self.pool_size)
    
    def __repr__(self):
        return f"ProbabilisticPooling(pool_size={self.pool_size}, mode='{self.mode}')"


print("‚úì ProbabilisticPooling class defined successfully!")

In [None]:
# -----------------------------------------------------------------------------
# Test Probabilistic Pooling
# -----------------------------------------------------------------------------

print("=" * 70)
print("PROBABILISTIC POOLING VERIFICATION")
print("=" * 70)

# Create pooling layer
prob_pool = ProbabilisticPooling(pool_size=2, mode='sum')
print(f"\nPooling Layer: {prob_pool}")

# Test with sample hidden activations from Conv-RBM-1
with torch.no_grad():
    # Get a batch of images
    test_images, _ = next(iter(train_loader))
    test_images = test_images.to(config.DEVICE)
    
    # Get hidden probabilities from Conv-RBM-1
    h1_prob = conv_rbm_1.hidden_probabilities(test_images)
    
    # Apply probabilistic pooling
    h1_pooled = prob_pool(h1_prob)

print(f"\nShape Flow:")
print(f"  Input images        : {test_images.shape}")
print(f"  Conv-RBM-1 hidden   : {h1_prob.shape}")
print(f"  After pooling (2√ó2) : {h1_pooled.shape}")

# Verify dimensions
expected_pooled_h = h1_prob.shape[2] // 2
expected_pooled_w = h1_prob.shape[3] // 2
assert h1_pooled.shape == (test_images.shape[0], 32, expected_pooled_h, expected_pooled_w), \
    "Pooling output shape mismatch!"
print(f"\n‚úì Pooling shape verification passed!")

# Check value ranges
print(f"\nValue Statistics:")
print(f"  Hidden probs range  : [{h1_prob.min().item():.4f}, {h1_prob.max().item():.4f}]")
print(f"  Pooled probs range  : [{h1_pooled.min().item():.4f}, {h1_pooled.max().item():.4f}]")
print(f"  Pooled values in [0,1]: {(h1_pooled >= 0).all() and (h1_pooled <= 1).all()}")

print("=" * 70)

## STEP 10 ‚Äî Freeze Conv-RBM-1 and Prepare Pooled Features

Freeze Conv-RBM-1 parameters and compute pooled hidden features for training Conv-RBM-2.

In [None]:
# =============================================================================
# STEP 10: Freeze Conv-RBM-1 and Prepare Pooled Features
# =============================================================================

# -----------------------------------------------------------------------------
# 10.1 Freeze Conv-RBM-1 Parameters
# -----------------------------------------------------------------------------
# After unsupervised pretraining, we freeze the first layer to preserve
# the learned features while training subsequent layers.

print("=" * 70)
print("FREEZING CONV-RBM-1 AND PREPARING FEATURES FOR CONV-RBM-2")
print("=" * 70)

# Freeze all parameters in Conv-RBM-1
for param in conv_rbm_1.parameters():
    param.requires_grad = False

print("\n‚úì Conv-RBM-1 parameters frozen:")
for name, param in conv_rbm_1.named_parameters():
    print(f"  {name}: requires_grad = {param.requires_grad}")

# Set to evaluation mode (affects dropout/batchnorm if any, not in our case)
conv_rbm_1.eval()
print("\n‚úì Conv-RBM-1 set to evaluation mode")

In [None]:
# -----------------------------------------------------------------------------
# 10.2 Create Feature Extraction Pipeline
# -----------------------------------------------------------------------------

class FeatureExtractor:
    """
    Extracts pooled hidden features from Conv-RBM-1 for training Conv-RBM-2.
    
    Pipeline: Input Image ‚Üí Conv-RBM-1 ‚Üí Hidden Probabilities ‚Üí Pooling ‚Üí Features
    
    This class can either:
    1. Compute features on-the-fly (memory efficient, slower)
    2. Precompute all features (fast training, higher memory)
    """
    
    def __init__(
        self,
        conv_rbm: ConvRBM,
        pooling: ProbabilisticPooling,
        device: torch.device
    ):
        self.conv_rbm = conv_rbm
        self.pooling = pooling
        self.device = device
        
        # Ensure RBM is in eval mode
        self.conv_rbm.eval()
        
    @torch.no_grad()
    def extract_features(self, images: torch.Tensor) -> torch.Tensor:
        """
        Extract pooled features from a batch of images.
        
        Args:
            images: Input images [batch, 1, H, W]
            
        Returns:
            pooled_features: Pooled hidden probabilities [batch, hidden_channels, H', W']
        """
        images = images.to(self.device)
        
        # Step 1: Get hidden probabilities from Conv-RBM-1
        h_prob = self.conv_rbm.hidden_probabilities(images)
        
        # Step 2: Apply probabilistic pooling
        h_pooled = self.pooling(h_prob)
        
        return h_pooled
    
    @torch.no_grad()
    def precompute_all_features(self, data_loader: DataLoader) -> torch.Tensor:
        """
        Precompute pooled features for the entire dataset.
        
        This is faster for training but requires more memory.
        
        Args:
            data_loader: DataLoader for the dataset
            
        Returns:
            all_features: Tensor of all pooled features [N, channels, H', W']
        """
        all_features = []
        
        print("Precomputing pooled features...")
        for batch_idx, (images, _) in enumerate(data_loader):
            features = self.extract_features(images)
            all_features.append(features.cpu())
            
            if (batch_idx + 1) % 50 == 0:
                print(f"  Processed {batch_idx + 1}/{len(data_loader)} batches")
        
        all_features = torch.cat(all_features, dim=0)
        print(f"‚úì Precomputed {all_features.shape[0]} feature maps")
        
        return all_features


# Create feature extractor
feature_extractor = FeatureExtractor(
    conv_rbm=conv_rbm_1,
    pooling=prob_pool,
    device=config.DEVICE
)

print("\n‚úì FeatureExtractor created")

In [None]:
# -----------------------------------------------------------------------------
# 10.3 Compute Pooled Feature Dimensions
# -----------------------------------------------------------------------------

# Calculate the dimensions for Conv-RBM-2 input
with torch.no_grad():
    sample_images, _ = next(iter(train_loader))
    sample_images = sample_images.to(config.DEVICE)
    
    # Pass through Conv-RBM-1
    sample_h1 = conv_rbm_1.hidden_probabilities(sample_images)
    
    # Pass through pooling
    sample_pooled = prob_pool(sample_h1)

print("\n" + "-" * 50)
print("FEATURE DIMENSIONS FOR CONV-RBM-2")
print("-" * 50)
print(f"\nInput to Conv-RBM-1:")
print(f"  Shape: {sample_images.shape}")
print(f"  [batch={sample_images.shape[0]}, channels={sample_images.shape[1]}, "
      f"H={sample_images.shape[2]}, W={sample_images.shape[3]}]")

print(f"\nConv-RBM-1 Hidden (before pooling):")
print(f"  Shape: {sample_h1.shape}")
print(f"  [batch={sample_h1.shape[0]}, feature_maps={sample_h1.shape[1]}, "
      f"H={sample_h1.shape[2]}, W={sample_h1.shape[3]}]")

print(f"\nPooled Features (input to Conv-RBM-2):")
print(f"  Shape: {sample_pooled.shape}")
print(f"  [batch={sample_pooled.shape[0]}, feature_maps={sample_pooled.shape[1]}, "
      f"H={sample_pooled.shape[2]}, W={sample_pooled.shape[3]}]")

# Store dimensions for Conv-RBM-2 configuration
POOLED_CHANNELS = sample_pooled.shape[1]
POOLED_HEIGHT = sample_pooled.shape[2]
POOLED_WIDTH = sample_pooled.shape[3]

print(f"\n‚úì Conv-RBM-2 will receive {POOLED_CHANNELS} input channels")
print(f"‚úì Spatial dimensions: {POOLED_HEIGHT}√ó{POOLED_WIDTH}")
print("=" * 70)

## STEP 11 ‚Äî Conv-RBM-2 Module

Create the second Conv-RBM layer that learns higher-level features from the pooled representations.

In [None]:
# =============================================================================
# STEP 11: Conv-RBM-2 Module
# =============================================================================

# -----------------------------------------------------------------------------
# Conv-RBM-2 Configuration
# -----------------------------------------------------------------------------

CONVRBM2_CONFIG = {
    'visible_channels': POOLED_CHANNELS,  # 32 (from Conv-RBM-1 pooled output)
    'hidden_channels': 64,                 # More feature maps for higher-level features
    'kernel_size': 5,                      # 5√ó5 kernels (smaller than layer 1)
}

print("=" * 70)
print("CONV-RBM-2 CONFIGURATION")
print("=" * 70)
print("\nArchitecture:")
for key, value in CONVRBM2_CONFIG.items():
    print(f"  {key:<20}: {value}")

# Calculate output dimensions
conv_rbm2_output_h = POOLED_HEIGHT - CONVRBM2_CONFIG['kernel_size'] + 1
conv_rbm2_output_w = POOLED_WIDTH - CONVRBM2_CONFIG['kernel_size'] + 1

print(f"\nExpected output dimensions:")
print(f"  Input  : [batch, {CONVRBM2_CONFIG['visible_channels']}, {POOLED_HEIGHT}, {POOLED_WIDTH}]")
print(f"  Output : [batch, {CONVRBM2_CONFIG['hidden_channels']}, {conv_rbm2_output_h}, {conv_rbm2_output_w}]")
print("=" * 70)

In [None]:
# -----------------------------------------------------------------------------
# Instantiate Conv-RBM-2
# -----------------------------------------------------------------------------

# Reuse the same ConvRBM class defined earlier
conv_rbm_2 = ConvRBM(
    visible_channels=CONVRBM2_CONFIG['visible_channels'],
    hidden_channels=CONVRBM2_CONFIG['hidden_channels'],
    kernel_size=CONVRBM2_CONFIG['kernel_size']
).to(config.DEVICE)

print("Conv-RBM-2 Architecture:")
print(conv_rbm_2)

# Print parameter count
total_params_rbm2 = sum(p.numel() for p in conv_rbm_2.parameters())
print(f"\nTotal parameters in Conv-RBM-2: {total_params_rbm2:,}")

# Verify with a forward pass
with torch.no_grad():
    test_h2_prob, test_h2_sample = conv_rbm_2(sample_pooled)

print(f"\nVerification forward pass:")
print(f"  Input  : {sample_pooled.shape}")
print(f"  Output : {test_h2_prob.shape}")

expected_shape = (sample_pooled.shape[0], CONVRBM2_CONFIG['hidden_channels'], 
                  conv_rbm2_output_h, conv_rbm2_output_w)
assert test_h2_prob.shape == expected_shape, f"Shape mismatch: {test_h2_prob.shape} vs {expected_shape}"
print(f"‚úì Forward pass shape verification passed!")

## STEP 12 ‚Äî Unsupervised Pretraining of Conv-RBM-2

Train Conv-RBM-2 using CD-k on the pooled features from Conv-RBM-1.

**Training Pipeline:**
```
Image ‚Üí Conv-RBM-1 (frozen) ‚Üí Pool ‚Üí Conv-RBM-2 (train) ‚Üí Hidden Features
```

In [None]:
# =============================================================================
# STEP 12: Unsupervised Pretraining of Conv-RBM-2
# =============================================================================

# -----------------------------------------------------------------------------
# CD Training Configuration for Conv-RBM-2
# -----------------------------------------------------------------------------

CD2_CONFIG = {
    'learning_rate': 0.005,    # Slightly lower LR for second layer
    'k': 1,                    # CD-1
    'momentum': 0.5,           # Momentum
    'weight_decay': 0.0001,    # L2 regularization
    'num_epochs': 5,           # Kaggle: reduced from 10 for time limits
}

print("=" * 70)
print("CONV-RBM-2 TRAINING CONFIGURATION")
print("=" * 70)
for key, value in CD2_CONFIG.items():
    print(f"  {key:<20}: {value}")
print("=" * 70)

# Initialize CD Trainer for Conv-RBM-2
cd_trainer_2 = CDTrainer(
    rbm=conv_rbm_2,
    learning_rate=CD2_CONFIG['learning_rate'],
    k=CD2_CONFIG['k'],
    device=config.DEVICE,
    momentum=CD2_CONFIG['momentum'],
    weight_decay=CD2_CONFIG['weight_decay']
)

print("\n‚úì CD Trainer for Conv-RBM-2 initialized")

In [None]:
# -----------------------------------------------------------------------------
# Training Loop for Conv-RBM-2 (Memory-Optimized for Local 24GB GPU)
# -----------------------------------------------------------------------------

def train_convrbm2(
    conv_rbm_1: ConvRBM,
    conv_rbm_2: ConvRBM,
    pooling: ProbabilisticPooling,
    trainer: CDTrainer,
    train_loader: DataLoader,
    num_epochs: int,
    device: torch.device,
    memory_cleanup_freq: int = 100
) -> dict:
    """
    Train Conv-RBM-2 on pooled features from Conv-RBM-1 (Memory-Optimized).
    
    Pipeline per batch:
        1. Get images from dataloader (ignore labels - unsupervised)
        2. Extract hidden probabilities from frozen Conv-RBM-1
        3. Apply probabilistic pooling
        4. Train Conv-RBM-2 on pooled features using CD-k
    
    Args:
        conv_rbm_1: Pretrained and frozen Conv-RBM-1
        conv_rbm_2: Conv-RBM-2 to train
        pooling: ProbabilisticPooling layer
        trainer: CDTrainer for Conv-RBM-2
        train_loader: DataLoader for training images
        num_epochs: Number of training epochs
        device: Computation device
        memory_cleanup_freq: Frequency of memory cleanup (batches)
        
    Returns:
        history: Training history dictionary
    """
    history = {
        'epoch_loss': [],
        'batch_losses': [],
    }
    
    # Ensure Conv-RBM-1 is frozen and in eval mode
    conv_rbm_1.eval()
    for param in conv_rbm_1.parameters():
        param.requires_grad = False
    
    print("\n" + "=" * 70)
    print("STARTING CONV-RBM-2 UNSUPERVISED PRETRAINING (Memory-Optimized)")
    print("=" * 70)
    print(f"Training on pooled features from {len(train_loader.dataset)} images")
    print(f"Batch size: {train_loader.batch_size}")
    print(f"Batches per epoch: {len(train_loader)}")
    print(f"Memory cleanup frequency: every {memory_cleanup_freq} batches")
    if torch.cuda.is_available():
        print_gpu_memory()
    print("-" * 70)
    
    for epoch in range(num_epochs):
        epoch_losses = []
        epoch_start = time.time()
        
        for batch_idx, (images, _) in enumerate(train_loader):
            images = images.to(device)
            
            # -------------------------------------------------------------
            # Step 1: Extract pooled features from Conv-RBM-1 (frozen)
            # -------------------------------------------------------------
            with torch.no_grad():
                # Get hidden probabilities from Conv-RBM-1
                h1_prob = conv_rbm_1.hidden_probabilities(images)
                
                # Apply probabilistic pooling
                h1_pooled = pooling(h1_prob)
            
            # -------------------------------------------------------------
            # Step 2: Train Conv-RBM-2 on pooled features
            # -------------------------------------------------------------
            # h1_pooled is now the "visible" input to Conv-RBM-2
            batch_loss = trainer.train_batch(h1_pooled)
            epoch_losses.append(batch_loss)
            
            # Explicit memory cleanup
            del images, h1_prob, h1_pooled
            
            # Progress update
            if (batch_idx + 1) % 50 == 0:
                print(f"  Epoch {epoch+1}/{num_epochs} | "
                      f"Batch {batch_idx+1}/{len(train_loader)} | "
                      f"Loss: {batch_loss:.6f}")
            
            # Aggressive memory cleanup for large datasets
            if torch.cuda.is_available() and (batch_idx + 1) % memory_cleanup_freq == 0:
                torch.cuda.empty_cache()
        
        epoch_time = time.time() - epoch_start
        
        # Epoch statistics
        avg_epoch_loss = np.mean(epoch_losses)
        history['epoch_loss'].append(avg_epoch_loss)
        history['batch_losses'].extend(epoch_losses)
        
        # Clear batch losses to free memory
        epoch_losses.clear()
        
        # Print epoch summary with GPU memory status
        print(f"Epoch {epoch+1:2d}/{num_epochs} | "
              f"Avg Loss: {avg_epoch_loss:.6f} | "
              f"Time: {epoch_time:.1f}s", end="")
        
        if torch.cuda.is_available():
            mem_alloc = torch.cuda.memory_allocated() / 1024**3
            print(f" | GPU: {mem_alloc:.2f} GB")
        else:
            print()
        
        # Clear cache after each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    print("-" * 70)
    print("CONV-RBM-2 PRETRAINING COMPLETED!")
    print(f"Final reconstruction loss: {history['epoch_loss'][-1]:.6f}")
    if torch.cuda.is_available():
        print_gpu_memory()
    print("=" * 70)
    
    return history


# Train Conv-RBM-2
print("Starting Conv-RBM-2 pretraining on pooled features...")
training_history_2 = train_convrbm2(
    conv_rbm_1=conv_rbm_1,
    conv_rbm_2=conv_rbm_2,
    pooling=prob_pool,
    trainer=cd_trainer_2,
    train_loader=train_loader,
    num_epochs=CD2_CONFIG['num_epochs'],
    device=config.DEVICE,
    memory_cleanup_freq=100
)

In [None]:
# -----------------------------------------------------------------------------
# Plot Conv-RBM-2 Training Progress
# -----------------------------------------------------------------------------

def plot_training_comparison(history1: dict, history2: dict, figsize=(14, 5)):
    """
    Compare training progress of Conv-RBM-1 and Conv-RBM-2.
    """
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    
    # Plot 1: Conv-RBM-2 training loss
    ax1 = axes[0]
    epochs = range(1, len(history2['epoch_loss']) + 1)
    ax1.plot(epochs, history2['epoch_loss'], 'g-o', linewidth=2, markersize=8)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Reconstruction Loss (MSE)', fontsize=12)
    ax1.set_title('Conv-RBM-2 Training Loss', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.set_xticks(epochs)
    
    # Plot 2: Comparison of both layers
    ax2 = axes[1]
    epochs1 = range(1, len(history1['epoch_loss']) + 1)
    epochs2 = range(1, len(history2['epoch_loss']) + 1)
    ax2.plot(epochs1, history1['epoch_loss'], 'b-o', linewidth=2, markersize=6, label='Conv-RBM-1')
    ax2.plot(epochs2, history2['epoch_loss'], 'g-s', linewidth=2, markersize=6, label='Conv-RBM-2')
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Reconstruction Loss (MSE)', fontsize=12)
    ax2.set_title('Training Comparison', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    
    plt.suptitle('CDBN Layer-wise Pretraining Progress', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()
    
    # Print comparison
    print("\nTraining Comparison:")
    print(f"  Conv-RBM-1: {history1['epoch_loss'][0]:.6f} ‚Üí {history1['epoch_loss'][-1]:.6f} "
          f"({(1 - history1['epoch_loss'][-1]/history1['epoch_loss'][0])*100:.1f}% improvement)")
    print(f"  Conv-RBM-2: {history2['epoch_loss'][0]:.6f} ‚Üí {history2['epoch_loss'][-1]:.6f} "
          f"({(1 - history2['epoch_loss'][-1]/history2['epoch_loss'][0])*100:.1f}% improvement)")


# Plot training comparison
plot_training_comparison(training_history, training_history_2)

In [None]:
# =============================================================================
# STEP E: Memory Cleanup After Conv-RBM-2 Pretraining (Kaggle)
# =============================================================================
# Free GPU memory by deleting the CD trainer (holds large velocity tensors)
# This is critical for Kaggle's 16GB GPU memory limit

# Delete the Conv-RBM-2 trainer
del cd_trainer_2

# Clear CUDA cache to release fragmented memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    allocated = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    print(f"‚úì Conv-RBM-2 trainer deleted, CUDA cache cleared")
    print(f"  GPU Memory ‚Äî Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
else:
    print("‚úì Conv-RBM-2 trainer deleted (CPU mode)")

print("  Conv-RBM-2 weights are preserved in conv_rbm_2 module")
print("=" * 70)

## STEP 13 ‚Äî Visualization

### 13.1 Conv-RBM-2 Learned Filters
Visualize the 64 learned filters (5√ó5 kernels) from Conv-RBM-2.

Note: Each filter now has 32 input channels (from Conv-RBM-1 pooled features), so we visualize the filter norms or individual channel slices.

In [None]:
# =============================================================================
# STEP 13.1: Conv-RBM-2 Learned Filters Visualization
# =============================================================================

def visualize_multilayer_filters(
    rbm: ConvRBM,
    num_filters: int = 16,
    num_channels_to_show: int = 4,
    figsize: tuple = (14, 10),
    layer_name: str = "Conv-RBM-2"
):
    """
    Visualize filters from a multi-channel Conv-RBM.
    
    For Conv-RBM-2, each filter has shape [visible_channels, kernel_h, kernel_w].
    We show multiple visualizations:
    1. Filter energy (L2 norm across input channels)
    2. Individual channel slices for selected filters
    
    Args:
        rbm: Trained ConvRBM
        num_filters: Number of filters to display
        num_channels_to_show: Number of input channel slices to show per filter
        figsize: Figure size
        layer_name: Name for the title
    """
    # Get weights: [hidden_channels, visible_channels, kernel_h, kernel_w]
    weights = rbm.W.data.cpu().numpy()
    
    num_filters = min(num_filters, weights.shape[0])
    num_channels = weights.shape[1]
    
    # -------------------------------------------------------------------------
    # Part 1: Filter Energy Visualization (L2 norm across input channels)
    # -------------------------------------------------------------------------
    grid_size = int(np.ceil(np.sqrt(num_filters)))
    
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10))
    axes = axes.flatten()
    
    for i in range(len(axes)):
        ax = axes[i]
        
        if i < num_filters:
            # Compute L2 norm across input channels: sqrt(sum over channels of W^2)
            # This gives a single 2D visualization per filter
            filter_energy = np.sqrt(np.sum(weights[i] ** 2, axis=0))
            
            # Normalize for visualization
            fe_min, fe_max = filter_energy.min(), filter_energy.max()
            if fe_max - fe_min > 1e-8:
                filter_energy = (filter_energy - fe_min) / (fe_max - fe_min)
            
            ax.imshow(filter_energy, cmap='viridis')
            ax.set_title(f'F{i+1}', fontsize=9)
        
        ax.axis('off')
    
    plt.suptitle(f'{layer_name} Filter Energies (L2 norm across {num_channels} input channels)\n'
                 f'Kernel Size: {rbm.kernel_size}√ó{rbm.kernel_size}',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    # -------------------------------------------------------------------------
    # Part 2: Individual Channel Slices for Top Filters
    # -------------------------------------------------------------------------
    num_filters_detail = min(8, num_filters)
    num_ch = min(num_channels_to_show, num_channels)
    
    fig, axes = plt.subplots(num_filters_detail, num_ch + 1, figsize=figsize)
    
    for f_idx in range(num_filters_detail):
        # First column: filter energy
        filter_energy = np.sqrt(np.sum(weights[f_idx] ** 2, axis=0))
        fe_norm = (filter_energy - filter_energy.min()) / (filter_energy.max() - filter_energy.min() + 1e-8)
        axes[f_idx, 0].imshow(fe_norm, cmap='viridis')
        axes[f_idx, 0].set_title('Energy' if f_idx == 0 else '', fontsize=9)
        axes[f_idx, 0].set_ylabel(f'F{f_idx+1}', fontsize=10)
        axes[f_idx, 0].axis('off')
        
        # Remaining columns: individual input channel slices
        for ch_idx in range(num_ch):
            channel_slice = weights[f_idx, ch_idx]
            cs_norm = (channel_slice - channel_slice.min()) / (channel_slice.max() - channel_slice.min() + 1e-8)
            axes[f_idx, ch_idx + 1].imshow(cs_norm, cmap='gray')
            if f_idx == 0:
                axes[f_idx, ch_idx + 1].set_title(f'Ch{ch_idx+1}', fontsize=9)
            axes[f_idx, ch_idx + 1].axis('off')
    
    plt.suptitle(f'{layer_name} Filters: Energy + Individual Channel Slices\n'
                 f'(Showing first {num_ch} of {num_channels} input channels)',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"\n{layer_name} Filter Statistics:")
    print(f"  Total filters      : {weights.shape[0]}")
    print(f"  Input channels     : {weights.shape[1]}")
    print(f"  Kernel size        : {rbm.kernel_size}√ó{rbm.kernel_size}")
    print(f"  Weight range       : [{weights.min():.4f}, {weights.max():.4f}]")
    print(f"  Weight mean        : {weights.mean():.6f}")
    print(f"  Weight std         : {weights.std():.6f}")


# Visualize Conv-RBM-2 filters
if DEBUG:
    print("Visualizing Conv-RBM-2 learned filters:")
    visualize_multilayer_filters(conv_rbm_2, num_filters=16, num_channels_to_show=6)
else:
    print("‚úì Conv-RBM-2 visualization skipped (DEBUG=False for Kaggle)")

### 13.2 Conv-RBM-2 Hidden Feature Maps
Visualize the hierarchical features learned by the 2-layer CDBN stack.

In [None]:
# =============================================================================
# STEP 13.2: Hierarchical Feature Visualization
# =============================================================================

def visualize_hierarchical_features(
    images: torch.Tensor,
    labels: torch.Tensor,
    conv_rbm_1: ConvRBM,
    conv_rbm_2: ConvRBM,
    pooling: ProbabilisticPooling,
    class_names: list,
    num_images: int = 4,
    num_feature_maps: int = 6,
    figsize: tuple = (18, 16)
):
    """
    Visualize the hierarchical feature extraction through the CDBN.
    
    Shows: Input ‚Üí Conv-RBM-1 Features ‚Üí Pooled ‚Üí Conv-RBM-2 Features
    
    Args:
        images: Input images tensor
        labels: Class labels
        conv_rbm_1: First Conv-RBM layer
        conv_rbm_2: Second Conv-RBM layer
        pooling: Probabilistic pooling layer
        class_names: List of class names
        num_images: Number of images to display
        num_feature_maps: Number of feature maps per layer
        figsize: Figure size
    """
    num_images = min(num_images, images.shape[0])
    images = images[:num_images].to(config.DEVICE)
    labels = labels[:num_images]
    
    # Forward pass through the CDBN stack
    with torch.no_grad():
        # Layer 1: Input ‚Üí Hidden
        h1_prob = conv_rbm_1.hidden_probabilities(images)
        
        # Pooling
        h1_pooled = pooling(h1_prob)
        
        # Layer 2: Pooled ‚Üí Hidden
        h2_prob = conv_rbm_2.hidden_probabilities(h1_pooled)
    
    # Create visualization
    num_cols = 1 + num_feature_maps + num_feature_maps + num_feature_maps  # Input + L1 + Pooled + L2
    fig, axes = plt.subplots(num_images, 4, figsize=figsize)
    
    for img_idx in range(num_images):
        # Column 1: Original image
        orig = images[img_idx].squeeze().cpu().numpy()
        axes[img_idx, 0].imshow(orig, cmap='gray')
        axes[img_idx, 0].set_title(f'{class_names[labels[img_idx]]}' if img_idx == 0 else '', fontsize=10)
        if img_idx == 0:
            axes[img_idx, 0].set_title('Input\n' + class_names[labels[img_idx]], fontsize=10)
        axes[img_idx, 0].set_ylabel(f'Img {img_idx+1}', fontsize=10)
        axes[img_idx, 0].axis('off')
        
        # Column 2: Conv-RBM-1 feature map montage
        h1_montage = create_feature_montage(h1_prob[img_idx].cpu().numpy(), num_feature_maps)
        axes[img_idx, 1].imshow(h1_montage, cmap='viridis')
        if img_idx == 0:
            axes[img_idx, 1].set_title(f'Conv-RBM-1\n({h1_prob.shape[2]}√ó{h1_prob.shape[3]})', fontsize=10)
        axes[img_idx, 1].axis('off')
        
        # Column 3: Pooled feature map montage
        pooled_montage = create_feature_montage(h1_pooled[img_idx].cpu().numpy(), num_feature_maps)
        axes[img_idx, 2].imshow(pooled_montage, cmap='viridis')
        if img_idx == 0:
            axes[img_idx, 2].set_title(f'Pooled\n({h1_pooled.shape[2]}√ó{h1_pooled.shape[3]})', fontsize=10)
        axes[img_idx, 2].axis('off')
        
        # Column 4: Conv-RBM-2 feature map montage
        h2_montage = create_feature_montage(h2_prob[img_idx].cpu().numpy(), num_feature_maps)
        axes[img_idx, 3].imshow(h2_montage, cmap='viridis')
        if img_idx == 0:
            axes[img_idx, 3].set_title(f'Conv-RBM-2\n({h2_prob.shape[2]}√ó{h2_prob.shape[3]})', fontsize=10)
        axes[img_idx, 3].axis('off')
    
    plt.suptitle('Hierarchical Feature Extraction through 2-Layer CDBN\n'
                 'Input ‚Üí Conv-RBM-1 ‚Üí Pooling ‚Üí Conv-RBM-2',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


def create_feature_montage(feature_maps: np.ndarray, num_maps: int = 6) -> np.ndarray:
    """
    Create a montage of feature maps for visualization.
    
    Args:
        feature_maps: [channels, H, W] array
        num_maps: Number of maps to include
        
    Returns:
        montage: Combined image array
    """
    num_maps = min(num_maps, feature_maps.shape[0])
    
    # Arrange in a 2-row grid
    rows = 2
    cols = (num_maps + 1) // 2
    
    h, w = feature_maps.shape[1], feature_maps.shape[2]
    montage = np.zeros((rows * h, cols * w))
    
    for i in range(num_maps):
        row = i // cols
        col = i % cols
        fm = feature_maps[i]
        # Normalize
        fm = (fm - fm.min()) / (fm.max() - fm.min() + 1e-8)
        montage[row*h:(row+1)*h, col*w:(col+1)*w] = fm
    
    return montage


# Get sample images and visualize
sample_images, sample_labels = next(iter(train_loader))
if DEBUG:
    print("Visualizing hierarchical feature extraction:")
    visualize_hierarchical_features(
        images=sample_images,
        labels=sample_labels,
        conv_rbm_1=conv_rbm_1,
        conv_rbm_2=conv_rbm_2,
        pooling=prob_pool,
        class_names=config.CLASS_NAMES,
        num_images=4,
        num_feature_maps=6
    )
else:
    print("‚úì Hierarchical feature visualization skipped (DEBUG=False for Kaggle)")

In [None]:
# -----------------------------------------------------------------------------
# Detailed Feature Map Visualization for Conv-RBM-2
# -----------------------------------------------------------------------------

def visualize_layer2_activations(
    images: torch.Tensor,
    labels: torch.Tensor,
    conv_rbm_1: ConvRBM,
    conv_rbm_2: ConvRBM,
    pooling: ProbabilisticPooling,
    class_names: list,
    num_images: int = 2,
    num_feature_maps: int = 16,
    figsize: tuple = (20, 8)
):
    """
    Detailed visualization of Conv-RBM-2 feature maps.
    """
    num_images = min(num_images, images.shape[0])
    images = images[:num_images].to(config.DEVICE)
    labels = labels[:num_images]
    
    # Forward pass
    with torch.no_grad():
        h1_prob = conv_rbm_1.hidden_probabilities(images)
        h1_pooled = pooling(h1_prob)
        h2_prob = conv_rbm_2.hidden_probabilities(h1_pooled)
    
    num_feature_maps = min(num_feature_maps, h2_prob.shape[1])
    grid_cols = int(np.ceil(np.sqrt(num_feature_maps)))
    grid_rows = int(np.ceil(num_feature_maps / grid_cols))
    
    for img_idx in range(num_images):
        fig, axes = plt.subplots(grid_rows, grid_cols + 1, figsize=figsize)
        axes = axes.flatten()
        
        # First subplot: original image
        orig = images[img_idx].squeeze().cpu().numpy()
        axes[0].imshow(orig, cmap='gray')
        axes[0].set_title(f'Input\n{class_names[labels[img_idx]]}', fontsize=10)
        axes[0].axis('off')
        
        # Feature maps
        for fm_idx in range(num_feature_maps):
            ax = axes[fm_idx + 1]
            fm = h2_prob[img_idx, fm_idx].cpu().numpy()
            ax.imshow(fm, cmap='viridis')
            ax.set_title(f'FM {fm_idx+1}', fontsize=8)
            ax.axis('off')
        
        # Hide unused axes
        for ax in axes[num_feature_maps + 1:]:
            ax.axis('off')
        
        plt.suptitle(f'Conv-RBM-2 Feature Maps for Sample {img_idx+1}\n'
                     f'(64 feature maps of size {h2_prob.shape[2]}√ó{h2_prob.shape[3]})',
                     fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()


# Visualize Conv-RBM-2 activations
if DEBUG:
    print("\nDetailed Conv-RBM-2 feature map visualization:")
    visualize_layer2_activations(
        images=sample_images,
        labels=sample_labels,
        conv_rbm_1=conv_rbm_1,
        conv_rbm_2=conv_rbm_2,
        pooling=prob_pool,
        class_names=config.CLASS_NAMES,
        num_images=2,
        num_feature_maps=16
    )
else:
    print("‚úì Conv-RBM-2 activations visualization skipped (DEBUG=False for Kaggle)")

In [None]:
# =============================================================================
# Summary: 2-Layer CDBN Stack
# =============================================================================

print("\n" + "=" * 70)
print("2-LAYER CDBN STACK SUMMARY")
print("=" * 70)

print("\n" + "-" * 50)
print("ARCHITECTURE")
print("-" * 50)

print("\nLayer-by-Layer Dimensions:")
print(f"  Input               : [{config.BATCH_SIZE}, {config.NUM_CHANNELS}, {config.IMAGE_SIZE[0]}, {config.IMAGE_SIZE[1]}]")

# Conv-RBM-1
h1_h = config.IMAGE_SIZE[0] - CONVRBM1_CONFIG['kernel_size'] + 1
h1_w = config.IMAGE_SIZE[1] - CONVRBM1_CONFIG['kernel_size'] + 1
print(f"  Conv-RBM-1 Hidden   : [{config.BATCH_SIZE}, {CONVRBM1_CONFIG['hidden_channels']}, {h1_h}, {h1_w}]")

# Pooling
pool_h = h1_h // 2
pool_w = h1_w // 2
print(f"  After Pooling (2√ó2) : [{config.BATCH_SIZE}, {CONVRBM1_CONFIG['hidden_channels']}, {pool_h}, {pool_w}]")

# Conv-RBM-2
h2_h = pool_h - CONVRBM2_CONFIG['kernel_size'] + 1
h2_w = pool_w - CONVRBM2_CONFIG['kernel_size'] + 1
print(f"  Conv-RBM-2 Hidden   : [{config.BATCH_SIZE}, {CONVRBM2_CONFIG['hidden_channels']}, {h2_h}, {h2_w}]")

print("\n" + "-" * 50)
print("TRAINABLE PARAMETERS")
print("-" * 50)

params_rbm1 = sum(p.numel() for p in conv_rbm_1.parameters())
params_rbm2 = sum(p.numel() for p in conv_rbm_2.parameters())

print(f"  Conv-RBM-1          : {params_rbm1:,} parameters")
print(f"  Conv-RBM-2          : {params_rbm2:,} parameters")
print(f"  Total               : {params_rbm1 + params_rbm2:,} parameters")

print("\n" + "-" * 50)
print("TRAINING RESULTS")
print("-" * 50)

print(f"\n  Conv-RBM-1 (10 epochs):")
print(f"    Initial loss: {training_history['epoch_loss'][0]:.6f}")
print(f"    Final loss  : {training_history['epoch_loss'][-1]:.6f}")

print(f"\n  Conv-RBM-2 (10 epochs):")
print(f"    Initial loss: {training_history_2['epoch_loss'][0]:.6f}")
print(f"    Final loss  : {training_history_2['epoch_loss'][-1]:.6f}")

print("\n" + "-" * 50)
print("COMPONENTS CREATED")
print("-" * 50)

print("\n  Models:")
print(f"    ‚Ä¢ conv_rbm_1  : Pretrained, frozen ({CONVRBM1_CONFIG['hidden_channels']} filters, {CONVRBM1_CONFIG['kernel_size']}√ó{CONVRBM1_CONFIG['kernel_size']})")
print(f"    ‚Ä¢ conv_rbm_2  : Pretrained ({CONVRBM2_CONFIG['hidden_channels']} filters, {CONVRBM2_CONFIG['kernel_size']}√ó{CONVRBM2_CONFIG['kernel_size']})")
print(f"    ‚Ä¢ prob_pool   : 2√ó2 probabilistic pooling")

print("\n  Utilities:")
print(f"    ‚Ä¢ cd_trainer    : CD-1 trainer for Conv-RBM-1")
print(f"    ‚Ä¢ cd_trainer_2  : CD-1 trainer for Conv-RBM-2")
print(f"    ‚Ä¢ feature_extractor : Pipeline for extracting pooled features")

print("\n" + "=" * 70)
print("2-LAYER CDBN READY FOR CLASSIFICATION!")
print("=" * 70)

---

# Conv-RBM-2 Stacked and Pretrained ‚Äî Hierarchical Features Learned

‚úÖ **STEP 9: Probabilistic Pooling Layer** implemented
- 2√ó2 pooling with sum-based probability aggregation
- Differs from max pooling by preserving probabilistic semantics
- No learnable parameters

‚úÖ **STEP 10: Conv-RBM-1 Frozen** and feature extraction pipeline created
- All Conv-RBM-1 parameters set to `requires_grad = False`
- `FeatureExtractor` class for computing pooled features

‚úÖ **STEP 11: Conv-RBM-2** instantiated
- Input: 32 channels (pooled features from Conv-RBM-1)
- Output: 64 feature maps with 5√ó5 kernels
- Reuses the same `ConvRBM` class architecture

‚úÖ **STEP 12: Conv-RBM-2 Pretrained** with Contrastive Divergence
- 10 epochs of CD-1 training on pooled features
- Unsupervised (no labels used)
- Reconstruction loss tracked

‚úÖ **STEP 13: Visualizations** completed
- Conv-RBM-2 filter energies and channel slices
- Hierarchical feature extraction through both layers
- Detailed Conv-RBM-2 activation maps

---

**CDBN Architecture Summary:**

```
Input [B, 1, 128, 128]
    ‚Üì
Conv-RBM-1 (7√ó7, 32 filters) ‚Üí [B, 32, 122, 122]
    ‚Üì
Probabilistic Pooling (2√ó2) ‚Üí [B, 32, 61, 61]
    ‚Üì
Conv-RBM-2 (5√ó5, 64 filters) ‚Üí [B, 64, 57, 57]
    ‚Üì
Ready for Classification
```

**Next Steps:**
- STEP 14: Add second pooling layer
- STEP 15: Flatten features and add supervised classifier
- STEP 16: Fine-tune the entire CDBN end-to-end
- STEP 17: Evaluate on test set

## STEP 14 ‚Äî Flattening & Feature Vector Preparation

Before feeding Conv-RBM-2 hidden activations to the Fully Connected RBM (FC-RBM), we must flatten the 4D tensor into a 2D matrix.

**Why Flattening is Required:**
1. **Conv-RBM output:** `[B, C, H, W]` ‚Äî 4D tensor with spatial structure
2. **FC-RBM input:** `[B, D]` ‚Äî 2D tensor where D = C √ó H √ó W
3. FC-RBMs use dense weight matrices without weight sharing
4. Flattening preserves all information but loses spatial locality
5. This transition moves from local feature detection to global pattern learning

In [None]:
# =============================================================================
# STEP 14: Flattening & Feature Vector Preparation
# =============================================================================

class FeatureFlattener:
    """
    Flattens Conv-RBM-2 hidden probabilities for input to FC-RBM.
    
    WHY FLATTENING IS REQUIRED BEFORE FC-RBM:
    -----------------------------------------
    
    1. CONVOLUTIONAL RBM OUTPUT:
       - Shape: [batch_size, num_feature_maps, height, width]
       - Preserves spatial structure of learned features
       - Each position represents a local feature detection
       - Example: [32, 64, 57, 57] = 32 samples √ó 64 feature maps √ó 57√ó57 spatial
    
    2. FULLY CONNECTED RBM INPUT:
       - Shape: [batch_size, num_features]
       - No spatial structure - each unit connects to all hidden units
       - Treats all features equally regardless of original position
       - Example: [32, 207936] = 32 samples √ó (64 √ó 57 √ó 57) features
    
    3. MATHEMATICAL JUSTIFICATION:
       - Conv-RBMs: W has shape [K, C, k, k] with weight sharing
       - FC-RBMs: W has shape [D, H] where D = C √ó H √ó W
       - The transition from local (convolutional) to global (fully connected)
         allows the model to learn arbitrary combinations of features
    
    4. INFORMATION PRESERVATION:
       - Flattening is a bijective (one-to-one) transformation
       - No information is lost, but spatial relationships become implicit
       - The FC-RBM must learn spatial relationships from data
    
    This layer computes the feature dimension D and provides utilities
    for flattening and unflattening operations.
    """
    
    def __init__(self, feature_shape: tuple):
        """
        Initialize the flattener with the expected feature shape.
        
        Args:
            feature_shape: Shape of Conv-RBM-2 hidden output [C, H, W]
                          (without batch dimension)
        """
        self.feature_shape = feature_shape  # [C, H, W]
        self.num_channels = feature_shape[0]
        self.height = feature_shape[1]
        self.width = feature_shape[2]
        
        # Compute flattened dimension: D = C √ó H √ó W
        self.flat_dim = self.num_channels * self.height * self.width
        
    def flatten(self, x: torch.Tensor) -> torch.Tensor:
        """
        Flatten 4D feature tensor to 2D.
        
        Args:
            x: Input tensor [batch, channels, height, width]
            
        Returns:
            Flattened tensor [batch, D] where D = channels √ó height √ó width
        """
        batch_size = x.shape[0]
        # Reshape: [B, C, H, W] -> [B, C*H*W]
        return x.view(batch_size, -1)
    
    def unflatten(self, x: torch.Tensor) -> torch.Tensor:
        """
        Unflatten 2D tensor back to 4D.
        
        Args:
            x: Flattened tensor [batch, D]
            
        Returns:
            Unflattened tensor [batch, channels, height, width]
        """
        batch_size = x.shape[0]
        return x.view(batch_size, self.num_channels, self.height, self.width)
    
    def __repr__(self):
        return (f"FeatureFlattener(\n"
                f"  feature_shape={self.feature_shape},\n"
                f"  flat_dim={self.flat_dim:,}\n"
                f")")


print("‚úì FeatureFlattener class defined successfully!")

In [None]:
# -----------------------------------------------------------------------------
# Compute Feature Dimensions and Create Flattener
# -----------------------------------------------------------------------------

print("=" * 70)
print("FEATURE FLATTENING SETUP")
print("=" * 70)

# Freeze Conv-RBM-2 for feature extraction
for param in conv_rbm_2.parameters():
    param.requires_grad = False
conv_rbm_2.eval()

print("\n‚úì Conv-RBM-2 parameters frozen")

# Compute feature dimensions by passing a sample through the pipeline
with torch.no_grad():
    # Get a sample batch
    sample_images, _ = next(iter(train_loader))
    sample_images = sample_images.to(config.DEVICE)
    
    # Forward through Conv-RBM-1
    h1_prob = conv_rbm_1.hidden_probabilities(sample_images)
    
    # Pooling
    h1_pooled = prob_pool(h1_prob)
    
    # Forward through Conv-RBM-2
    h2_prob = conv_rbm_2.hidden_probabilities(h1_pooled)

# Get feature shape (without batch dimension)
feature_shape = h2_prob.shape[1:]  # [C, H, W]

print(f"\nConv-RBM-2 Hidden Shape: {h2_prob.shape}")
print(f"  Batch size        : {h2_prob.shape[0]}")
print(f"  Feature maps (C)  : {h2_prob.shape[1]}")
print(f"  Height (H)        : {h2_prob.shape[2]}")
print(f"  Width (W)         : {h2_prob.shape[3]}")

# Create flattener
flattener = FeatureFlattener(feature_shape)
print(f"\n{flattener}")

# Test flattening
h2_flat = flattener.flatten(h2_prob)
print(f"\nFlattened shape: {h2_flat.shape}")
print(f"  Feature dimension D = {flattener.flat_dim:,}")

# Verify unflatten recovers original shape
h2_unflat = flattener.unflatten(h2_flat)
assert h2_unflat.shape == h2_prob.shape, "Unflatten shape mismatch!"
print(f"‚úì Flatten/Unflatten verified")

# Store the flat dimension for FC-RBM
FLAT_DIM = flattener.flat_dim

print(f"\n" + "=" * 70)
print(f"FC-RBM will have {FLAT_DIM:,} visible units")
print("=" * 70)

## STEP 15 ‚Äî Fully Connected RBM (FC-RBM)

Implement a standard Fully Connected Restricted Boltzmann Machine with Bernoulli visible and hidden units.

**FC-RBM Energy Function:**
$$E(v, h) = -\sum_i a_i v_i - \sum_j b_j h_j - \sum_{i,j} v_i W_{ij} h_j$$

**Conditional Distributions:**
$$P(h_j = 1 | v) = \sigma(b_j + \sum_i W_{ij} v_i)$$
$$P(v_i = 1 | h) = \sigma(a_i + \sum_j W_{ij} h_j)$$

where $\sigma(x) = 1/(1 + e^{-x})$ is the sigmoid function.

In [None]:
# =============================================================================
# STEP 15: Fully Connected RBM (FC-RBM)
# =============================================================================

class FCRBM(nn.Module):
    """
    Fully Connected Restricted Boltzmann Machine (FC-RBM).
    
    This is a standard RBM with dense connections between visible and hidden layers.
    Unlike Conv-RBM, there is no weight sharing - each connection has its own weight.
    
    Architecture:
        visible (v) <---> hidden (h)
        
        v: [batch_size, n_visible] - Visible units (flattened features)
        h: [batch_size, n_hidden]  - Hidden units (learned representations)
        W: [n_visible, n_hidden]   - Weight matrix (dense connections)
    
    Unit Types:
        - Visible units: Bernoulli (binary probabilities from Conv-RBM-2)
        - Hidden units: Bernoulli (binary latent representations)
    
    Energy Function (Bernoulli-Bernoulli RBM):
        E(v, h) = -sum_i(a_i * v_i) - sum_j(b_j * h_j) - sum_ij(v_i * W_ij * h_j)
        
        where:
            a_i = visible bias for unit i
            b_j = hidden bias for unit j
            W_ij = weight connecting visible unit i to hidden unit j
    
    Conditional Distributions:
        P(h_j = 1 | v) = sigmoid(b_j + sum_i(W_ij * v_i))
        P(v_i = 1 | h) = sigmoid(a_i + sum_j(W_ij * h_j))
        
        These follow from the energy function via the Boltzmann distribution.
    
    Args:
        n_visible (int): Number of visible units (flattened feature dimension)
        n_hidden (int): Number of hidden units (latent representation size)
        
    Attributes:
        W: Weight matrix [n_visible, n_hidden]
        v_bias: Visible bias [n_visible]
        h_bias: Hidden bias [n_hidden]
    """
    
    def __init__(self, n_visible: int, n_hidden: int):
        super(FCRBM, self).__init__()
        
        self.n_visible = n_visible
        self.n_hidden = n_hidden
        
        # ---------------------------------------------------------------------
        # Learnable Parameters
        # ---------------------------------------------------------------------
        
        # Weight matrix: W[i,j] connects visible unit i to hidden unit j
        # Shape: [n_visible, n_hidden]
        # Initialization: Small random values (helps with initial learning)
        # Using Xavier/Glorot-like initialization scaled down
        std = 0.01  # Small initial weights
        self.W = nn.Parameter(
            torch.randn(n_visible, n_hidden) * std
        )
        
        # Visible bias: a_i for each visible unit
        # Shape: [n_visible]
        # Initialization: Zero (or can initialize to log(p/(1-p)) for mean activation p)
        self.v_bias = nn.Parameter(
            torch.zeros(n_visible)
        )
        
        # Hidden bias: b_j for each hidden unit
        # Shape: [n_hidden]
        # Initialization: Zero
        self.h_bias = nn.Parameter(
            torch.zeros(n_hidden)
        )
    
    def hidden_probabilities(self, v: torch.Tensor) -> torch.Tensor:
        """
        Compute hidden unit probabilities given visible units.
        
        MATH:
            P(h_j = 1 | v) = sigmoid(b_j + sum_i(W_ij * v_i))
                           = sigmoid(b_j + v @ W[:, j])
            
            In matrix form for all hidden units:
            P(h | v) = sigmoid(h_bias + v @ W)
        
        Args:
            v: Visible units [batch_size, n_visible]
               Values should be probabilities in [0, 1]
               
        Returns:
            h_prob: Hidden probabilities [batch_size, n_hidden]
                    P(h_j = 1 | v) for each hidden unit j
        """
        # Linear transformation: v @ W + h_bias
        # v: [batch, n_visible]
        # W: [n_visible, n_hidden]
        # Result: [batch, n_hidden]
        pre_activation = torch.mm(v, self.W) + self.h_bias
        
        # Apply sigmoid to get probabilities
        # sigmoid(x) = 1 / (1 + exp(-x))
        h_prob = torch.sigmoid(pre_activation)
        
        return h_prob
    
    def sample_hidden(self, v: torch.Tensor) -> tuple:
        """
        Sample hidden states given visible units.
        
        MATH:
            h_j ~ Bernoulli(P(h_j = 1 | v))
            
            Each hidden unit is independently sampled.
        
        Args:
            v: Visible units [batch_size, n_visible]
            
        Returns:
            h_prob: Hidden probabilities [batch_size, n_hidden]
            h_sample: Binary hidden samples [batch_size, n_hidden]
        """
        h_prob = self.hidden_probabilities(v)
        
        # Sample from Bernoulli distribution
        # Each unit is 1 with probability h_prob, 0 otherwise
        h_sample = torch.bernoulli(h_prob)
        
        return h_prob, h_sample
    
    def visible_probabilities(self, h: torch.Tensor) -> torch.Tensor:
        """
        Compute visible unit probabilities given hidden units.
        
        MATH:
            P(v_i = 1 | h) = sigmoid(a_i + sum_j(W_ij * h_j))
                           = sigmoid(a_i + W[i, :] @ h)
            
            In matrix form for all visible units:
            P(v | h) = sigmoid(v_bias + h @ W^T)
        
        Args:
            h: Hidden units [batch_size, n_hidden]
               Can be probabilities or binary samples
               
        Returns:
            v_prob: Visible probabilities [batch_size, n_visible]
                    P(v_i = 1 | h) for each visible unit i
        """
        # Linear transformation: h @ W^T + v_bias
        # h: [batch, n_hidden]
        # W^T: [n_hidden, n_visible]
        # Result: [batch, n_visible]
        pre_activation = torch.mm(h, self.W.t()) + self.v_bias
        
        # Apply sigmoid for Bernoulli visible units
        v_prob = torch.sigmoid(pre_activation)
        
        return v_prob
    
    def sample_visible(self, h: torch.Tensor) -> tuple:
        """
        Sample visible states given hidden units.
        
        MATH:
            v_i ~ Bernoulli(P(v_i = 1 | h))
            
            Each visible unit is independently sampled.
        
        Args:
            h: Hidden units [batch_size, n_hidden]
            
        Returns:
            v_prob: Visible probabilities [batch_size, n_visible]
            v_sample: Binary visible samples [batch_size, n_visible]
        """
        v_prob = self.visible_probabilities(h)
        
        # Sample from Bernoulli distribution
        v_sample = torch.bernoulli(v_prob)
        
        return v_prob, v_sample
    
    def forward(self, v: torch.Tensor) -> tuple:
        """
        Forward pass: compute hidden probabilities and samples.
        
        This is the inference direction: visible -> hidden
        
        Args:
            v: Input visible units [batch_size, n_visible]
            
        Returns:
            h_prob: Hidden probabilities [batch_size, n_hidden]
            h_sample: Binary hidden samples [batch_size, n_hidden]
        """
        return self.sample_hidden(v)
    
    def free_energy(self, v: torch.Tensor) -> torch.Tensor:
        """
        Compute the free energy of visible configurations.
        
        MATH:
            F(v) = -sum_i(a_i * v_i) - sum_j(log(1 + exp(b_j + sum_i(W_ij * v_i))))
            
            The free energy is useful for computing the log-likelihood gradient
            and for model comparison.
        
        Args:
            v: Visible units [batch_size, n_visible]
            
        Returns:
            free_energy: Free energy for each sample [batch_size]
        """
        # First term: -v @ v_bias (visible bias contribution)
        # Shape: [batch_size]
        v_term = torch.mv(v, self.v_bias)
        
        # Second term: hidden contribution
        # pre_activation: b_j + sum_i(W_ij * v_i)
        # Shape: [batch_size, n_hidden]
        pre_activation = torch.mm(v, self.W) + self.h_bias
        
        # log(1 + exp(x)) = softplus(x)
        # Sum over hidden units
        # Shape: [batch_size]
        h_term = torch.sum(nn.functional.softplus(pre_activation), dim=1)
        
        # Free energy: F(v) = -v_term - h_term
        return -v_term - h_term
    
    def __repr__(self):
        return (f"FCRBM(\n"
                f"  n_visible={self.n_visible:,},\n"
                f"  n_hidden={self.n_hidden:,},\n"
                f"  W shape={tuple(self.W.shape)},\n"
                f"  total_params={self.n_visible * self.n_hidden + self.n_visible + self.n_hidden:,}\n"
                f")")


print("‚úì FCRBM class defined successfully!")

In [None]:
# -----------------------------------------------------------------------------
# Instantiate FC-RBM
# -----------------------------------------------------------------------------

# FC-RBM Configuration
FCRBM_CONFIG = {
    'n_visible': FLAT_DIM,      # From Conv-RBM-2 flattened output
    'n_hidden': 256,             # Latent representation size (tune as needed)
}

print("=" * 70)
print("FC-RBM CONFIGURATION")
print("=" * 70)
print(f"\nArchitecture:")
print(f"  n_visible (D)     : {FCRBM_CONFIG['n_visible']:,}")
print(f"  n_hidden          : {FCRBM_CONFIG['n_hidden']:,}")
print(f"  Weight matrix     : [{FCRBM_CONFIG['n_visible']:,}, {FCRBM_CONFIG['n_hidden']}]")

# Create FC-RBM instance
fc_rbm = FCRBM(
    n_visible=FCRBM_CONFIG['n_visible'],
    n_hidden=FCRBM_CONFIG['n_hidden']
).to(config.DEVICE)

print(f"\n{fc_rbm}")

# Parameter count
total_params_fc = sum(p.numel() for p in fc_rbm.parameters())
print(f"\nTotal parameters: {total_params_fc:,}")

# Test forward pass
with torch.no_grad():
    test_h_prob, test_h_sample = fc_rbm(h2_flat.to(config.DEVICE))

print(f"\nVerification forward pass:")
print(f"  Input shape  : {h2_flat.shape}")
print(f"  Output shape : {test_h_prob.shape}")

assert test_h_prob.shape == (h2_flat.shape[0], FCRBM_CONFIG['n_hidden']), "Shape mismatch!"
print(f"‚úì Forward pass verification passed!")
print("=" * 70)

## STEP 16 ‚Äî Contrastive Divergence for FC-RBM

Implement CD-k training for the Fully Connected RBM.

The algorithm is similar to Conv-RBM CD but uses matrix operations instead of convolutions.

In [None]:
# =============================================================================
# STEP 16: Contrastive Divergence Trainer for FC-RBM
# =============================================================================

class FCRBMTrainer:
    """
    Contrastive Divergence (CD-k) Trainer for Fully Connected RBM.
    
    CD-k ALGORITHM FOR FC-RBM:
    --------------------------
    
    1. POSITIVE PHASE (Data Statistics):
       - Clamp v0 = input data (flattened features from Conv-RBM-2)
       - Compute h0_prob = P(h|v0) = sigmoid(h_bias + v0 @ W)
       - Sample h0 ~ Bernoulli(h0_prob)
       - Positive gradient: ‚ü®v0^T @ h0_prob‚ü©_data
    
    2. GIBBS SAMPLING (k steps):
       For each step i:
           - Compute v_i_prob = sigmoid(v_bias + h_{i-1} @ W^T)
           - Sample v_i ~ Bernoulli(v_i_prob) [or use mean]
           - Compute h_i_prob = sigmoid(h_bias + v_i @ W)
           - Sample h_i ~ Bernoulli(h_i_prob)
    
    3. NEGATIVE PHASE (Model Statistics):
       - Use vk and hk_prob from Gibbs chain
       - Negative gradient: ‚ü®vk^T @ hk_prob‚ü©_model
    
    4. PARAMETER UPDATES:
       ŒîW = lr * (positive_W - negative_W) / batch_size
       Œîv_bias = lr * mean(v0 - vk)
       Œîh_bias = lr * mean(h0_prob - hk_prob)
    
    Args:
        rbm: FCRBM instance to train
        learning_rate: Learning rate for updates
        k: Number of Gibbs sampling steps
        device: Computation device
        momentum: Momentum coefficient
        weight_decay: L2 regularization coefficient
    """
    
    def __init__(
        self,
        rbm: FCRBM,
        learning_rate: float = 0.01,
        k: int = 1,
        device: torch.device = None,
        momentum: float = 0.0,
        weight_decay: float = 0.0001
    ):
        self.rbm = rbm
        self.lr = learning_rate
        self.k = k
        self.device = device if device else torch.device('cpu')
        self.momentum = momentum
        self.weight_decay = weight_decay
        
        # Initialize velocity terms for momentum
        self.W_velocity = torch.zeros_like(rbm.W.data)
        self.v_bias_velocity = torch.zeros_like(rbm.v_bias.data)
        self.h_bias_velocity = torch.zeros_like(rbm.h_bias.data)
    
    def train_batch(self, v0: torch.Tensor) -> float:
        """
        Train the FC-RBM on a single batch using CD-k.
        
        Args:
            v0: Input visible batch (flattened features)
                Shape: [batch_size, n_visible]
                Values should be probabilities in [0, 1]
                
        Returns:
            reconstruction_loss: MSE between v0 and vk
        """
        batch_size = v0.shape[0]
        v0 = v0.to(self.device)
        
        # =====================================================================
        # POSITIVE PHASE
        # =====================================================================
        # Compute hidden probabilities from clamped visible data
        # h0_prob = sigmoid(h_bias + v0 @ W)
        
        h0_prob = self.rbm.hidden_probabilities(v0)
        
        # Sample hidden states for Gibbs chain initialization
        h0_sample = torch.bernoulli(h0_prob)
        
        # Positive phase statistics (outer product averaged over batch)
        # positive_W = (v0^T @ h0_prob) / batch_size
        # This computes: sum over batch of v0_i * h0_j for all (i,j)
        positive_W = torch.mm(v0.t(), h0_prob) / batch_size
        
        # Bias gradients
        positive_v_bias = v0.mean(dim=0)
        positive_h_bias = h0_prob.mean(dim=0)
        
        # =====================================================================
        # GIBBS SAMPLING (k steps)
        # =====================================================================
        hk = h0_sample
        
        for step in range(self.k):
            # -----------------------------------------------------------------
            # Reconstruct visible from hidden
            # -----------------------------------------------------------------
            # vk_prob = sigmoid(v_bias + hk @ W^T)
            vk_prob = self.rbm.visible_probabilities(hk)
            
            # Use probabilities (more stable than sampling for visible)
            vk = vk_prob
            
            # -----------------------------------------------------------------
            # Compute hidden from reconstructed visible
            # -----------------------------------------------------------------
            # hk_prob = sigmoid(h_bias + vk @ W)
            hk_prob = self.rbm.hidden_probabilities(vk)
            
            # Sample for next Gibbs step (except last step)
            if step < self.k - 1:
                hk = torch.bernoulli(hk_prob)
        
        # =====================================================================
        # NEGATIVE PHASE
        # =====================================================================
        # Negative phase statistics from the k-step reconstructions
        # negative_W = (vk^T @ hk_prob) / batch_size
        
        negative_W = torch.mm(vk.t(), hk_prob) / batch_size
        
        # Bias gradients
        negative_v_bias = vk.mean(dim=0)
        negative_h_bias = hk_prob.mean(dim=0)
        
        # =====================================================================
        # PARAMETER UPDATES
        # =====================================================================
        # Gradient: positive - negative
        # Update: Œ∏ += lr * gradient
        
        W_grad = positive_W - negative_W
        v_bias_grad = positive_v_bias - negative_v_bias
        h_bias_grad = positive_h_bias - negative_h_bias
        
        # Apply weight decay (L2 regularization)
        W_grad -= self.weight_decay * self.rbm.W.data
        
        # Update with momentum
        self.W_velocity = self.momentum * self.W_velocity + W_grad
        self.v_bias_velocity = self.momentum * self.v_bias_velocity + v_bias_grad
        self.h_bias_velocity = self.momentum * self.h_bias_velocity + h_bias_grad
        
        # Apply updates
        with torch.no_grad():
            self.rbm.W.data += self.lr * self.W_velocity
            self.rbm.v_bias.data += self.lr * self.v_bias_velocity
            self.rbm.h_bias.data += self.lr * self.h_bias_velocity
        
        # =====================================================================
        # RECONSTRUCTION LOSS
        # =====================================================================
        reconstruction_loss = nn.functional.mse_loss(vk, v0).item()
        
        return reconstruction_loss
    
    def get_reconstruction(self, v: torch.Tensor) -> torch.Tensor:
        """
        Get reconstruction of visible input.
        
        Args:
            v: Input visible units
            
        Returns:
            v_recon: Reconstructed visible units
        """
        with torch.no_grad():
            v = v.to(self.device)
            h_prob = self.rbm.hidden_probabilities(v)
            v_recon = self.rbm.visible_probabilities(h_prob)
        return v_recon


print("‚úì FCRBMTrainer class defined successfully!")

## STEP 17 ‚Äî Train FC-RBM on Hierarchical Features

Now we train the FC-RBM using the complete feature extraction pipeline:

**Pipeline:** Image ‚Üí Conv-RBM-1 ‚Üí Pool ‚Üí Conv-RBM-2 ‚Üí Flatten ‚Üí FC-RBM

The FC-RBM receives the flattened probabilistic features from Conv-RBM-2 and learns a compressed latent representation of 256 hidden units.

In [None]:
# =============================================================================
# STEP 17: Define Complete Feature Extraction Pipeline
# =============================================================================

class HierarchicalFeatureExtractor:
    """
    Complete pipeline to extract features from OCT images through the CDBN.
    
    PIPELINE ARCHITECTURE:
    ----------------------
    Input Image [B, 1, 128, 128]
          ‚Üì
    Conv-RBM-1: 7√ó7 conv ‚Üí 32 channels
          ‚Üì [B, 32, 122, 122]
    Prob-Pool: 2√ó2 ‚Üí [B, 32, 61, 61]
          ‚Üì
    Conv-RBM-2: 5√ó5 conv ‚Üí 64 channels
          ‚Üì [B, 64, 57, 57]
    Flatten
          ‚Üì [B, 207936]
    FC-RBM (after training)
          ‚Üì [B, 256]
    
    All Conv-RBMs are FROZEN (pretrained, no further updates).
    """
    
    def __init__(
        self,
        conv_rbm_1: ConvRBM,
        pooling: ProbabilisticPooling,
        conv_rbm_2: ConvRBM,
        flattener: FeatureFlattener,
        device: torch.device
    ):
        self.conv_rbm_1 = conv_rbm_1
        self.pooling = pooling
        self.conv_rbm_2 = conv_rbm_2
        self.flattener = flattener
        self.device = device
        
        # Ensure all Conv-RBMs are frozen
        self.conv_rbm_1.eval()
        self.conv_rbm_2.eval()
        for param in self.conv_rbm_1.parameters():
            param.requires_grad = False
        for param in self.conv_rbm_2.parameters():
            param.requires_grad = False
    
    def extract_flat_features(self, images: torch.Tensor) -> torch.Tensor:
        """
        Extract flattened features ready for FC-RBM.
        
        Args:
            images: Input images [B, 1, 128, 128]
            
        Returns:
            flat_features: Flattened features [B, n_visible]
        """
        with torch.no_grad():
            images = images.to(self.device)
            
            # Layer 1: Conv-RBM-1 ‚Üí probabilities
            h1_prob = self.conv_rbm_1.hidden_probabilities(images)
            
            # Layer 2: Probabilistic pooling
            h1_pooled = self.pooling(h1_prob)
            
            # Layer 3: Conv-RBM-2 ‚Üí probabilities
            h2_prob = self.conv_rbm_2.hidden_probabilities(h1_pooled)
            
            # Layer 4: Flatten
            flat_features = self.flattener(h2_prob)
        
        return flat_features


# Create the hierarchical feature extractor
hierarchical_extractor = HierarchicalFeatureExtractor(
    conv_rbm_1=conv_rbm_1,
    pooling=prob_pool,
    conv_rbm_2=conv_rbm_2,
    flattener=flattener,
    device=Config.DEVICE
)

# Verify with a test batch
test_batch = next(iter(train_loader))[0][:4]  # Get 4 images
test_flat = hierarchical_extractor.extract_flat_features(test_batch)
print(f"‚úì HierarchicalFeatureExtractor created!")
print(f"  Input shape: {test_batch.shape}")
print(f"  Output shape: {test_flat.shape}")
print(f"  Output range: [{test_flat.min():.4f}, {test_flat.max():.4f}]")

In [None]:
# =============================================================================
# STEP 17: Train FC-RBM on Flattened Hierarchical Features (Memory-Optimized)
# =============================================================================

# Training configuration for FC-RBM
FC_RBM_EPOCHS = 5        # Local GPU: can use more epochs if needed
FC_RBM_LR = 0.001        # Lower learning rate for high-dimensional input
FC_RBM_K = 1             # CD-1
FC_RBM_MOMENTUM = 0.5
FC_RBM_WEIGHT_DECAY = 0.0001
MEMORY_CLEANUP_FREQ = 50  # Clean GPU cache every N batches

# Create trainer
fc_trainer = FCRBMTrainer(
    rbm=fc_rbm,
    learning_rate=FC_RBM_LR,
    k=FC_RBM_K,
    device=config.DEVICE,
    momentum=FC_RBM_MOMENTUM,
    weight_decay=FC_RBM_WEIGHT_DECAY
)

print("=" * 70)
print("FC-RBM TRAINING (Memory-Optimized for Local GPU)")
print("=" * 70)
print(f"Epochs: {FC_RBM_EPOCHS}")
print(f"Learning Rate: {FC_RBM_LR}")
print(f"CD-k: {FC_RBM_K}")
print(f"Momentum: {FC_RBM_MOMENTUM}")
print(f"Weight Decay: {FC_RBM_WEIGHT_DECAY}")
print(f"Visible Units: {fc_rbm.n_visible:,}")
print(f"Hidden Units: {fc_rbm.n_hidden}")
print(f"Memory cleanup frequency: every {MEMORY_CLEANUP_FREQ} batches")
if torch.cuda.is_available():
    print_gpu_memory()
print("=" * 70)

# Training loop with memory optimization
fc_training_history = []

for epoch in range(FC_RBM_EPOCHS):
    epoch_losses = []
    epoch_start = time.time()
    
    for batch_idx, (images, _) in enumerate(train_loader):
        # Extract flattened features through the hierarchical pipeline
        flat_features = hierarchical_extractor.extract_flat_features(images)
        
        # Train FC-RBM on flattened features
        loss = fc_trainer.train_batch(flat_features)
        epoch_losses.append(loss)
        
        # Explicit memory cleanup
        del images, flat_features
        
        # Progress update
        if (batch_idx + 1) % 20 == 0:
            print(f"  Epoch [{epoch+1}/{FC_RBM_EPOCHS}] "
                  f"Batch [{batch_idx+1}/{len(train_loader)}] "
                  f"Recon Loss: {loss:.6f}")
        
        # Aggressive memory cleanup for large datasets
        if torch.cuda.is_available() and (batch_idx + 1) % MEMORY_CLEANUP_FREQ == 0:
            torch.cuda.empty_cache()
    
    epoch_time = time.time() - epoch_start
    
    # Epoch summary
    avg_loss = np.mean(epoch_losses)
    fc_training_history.append(avg_loss)
    
    # Clear batch losses to free memory
    epoch_losses.clear()
    
    # Print epoch summary with GPU memory
    print(f"‚ñ∂ Epoch [{epoch+1}/{FC_RBM_EPOCHS}] Average Loss: {avg_loss:.6f} | "
          f"Time: {epoch_time:.1f}s", end="")
    if torch.cuda.is_available():
        mem_alloc = torch.cuda.memory_allocated() / 1024**3
        print(f" | GPU: {mem_alloc:.2f} GB")
    else:
        print()
    
    # Clear cache after each epoch
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print("-" * 70)

print("\n" + "=" * 70)
print("FC-RBM TRAINING COMPLETE!")
print(f"Final Loss: {fc_training_history[-1]:.6f}")
if torch.cuda.is_available():
    print_gpu_memory()
print("=" * 70)

In [None]:
# =============================================================================
# Visualize FC-RBM Training Progress
# =============================================================================

if DEBUG:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Plot 1: FC-RBM training loss
    ax1 = axes[0]
    ax1.plot(range(1, len(fc_training_history) + 1), fc_training_history, 
             'b-o', linewidth=2, markersize=8, label='FC-RBM')
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Reconstruction Loss (MSE)', fontsize=12)
    ax1.set_title('FC-RBM Training Loss', fontsize=14)
    ax1.grid(True, alpha=0.3)
    ax1.legend()

    # Plot 2: All three RBMs training comparison (if histories available)
    ax2 = axes[1]
    if 'training_history' in dir() and len(training_history) > 0:
        ax2.plot(range(1, len(training_history) + 1), training_history, 
                 'r-s', linewidth=2, markersize=6, label='Conv-RBM-1', alpha=0.7)
    if 'training_history_2' in dir() and len(training_history_2) > 0:
        ax2.plot(range(1, len(training_history_2) + 1), training_history_2, 
                 'g-^', linewidth=2, markersize=6, label='Conv-RBM-2', alpha=0.7)
    ax2.plot(range(1, len(fc_training_history) + 1), fc_training_history, 
             'b-o', linewidth=2, markersize=6, label='FC-RBM', alpha=0.7)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Reconstruction Loss (MSE)', fontsize=12)
    ax2.set_title('All RBMs Training Comparison', fontsize=14)
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(Config.OUTPUT_DIR, 'fc_rbm_training.png'), dpi=150, bbox_inches='tight')
    plt.show()

    print("‚úì FC-RBM training plot saved to output directory")
else:
    print("‚úì FC-RBM visualization skipped (DEBUG=False for Kaggle)")
    print(f"  Final FC-RBM Loss: {fc_training_history[-1]:.6f}")

In [None]:
# =============================================================================
# STEP E: Memory Cleanup After FC-RBM Pretraining (Kaggle)
# =============================================================================
# Free GPU memory by deleting the FC-RBM trainer (holds large velocity tensors)
# This is critical for Kaggle's 16GB GPU memory limit

# Delete the FC-RBM trainer
del fc_trainer

# Clear CUDA cache to release fragmented memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    allocated = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    print(f"‚úì FC-RBM trainer deleted, CUDA cache cleared")
    print(f"  GPU Memory ‚Äî Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
else:
    print("‚úì FC-RBM trainer deleted (CPU mode)")

print("  FC-RBM weights are preserved in fc_rbm module")
print("=" * 70)
print("\n‚úì ALL RBM PRETRAINING COMPLETE ‚Äî Memory optimized for supervised phase")

## STEP 18 ‚Äî Extract Latent Representations from Complete CDBN

With all three RBM layers pretrained, we now define a complete feature extraction function that takes OCT images and outputs the top-level latent representation from the FC-RBM.

**Complete CDBN Pipeline:**
$$
\mathbf{z} = h^{(3)} = \sigma\left(\mathbf{b}_h^{(3)} + \text{flatten}\left(h^{(2)}\right) \mathbf{W}^{(3)}\right)
$$

where $h^{(2)}$ is the pooled output from Conv-RBM-2.

In [None]:
# =============================================================================
# STEP 18: Complete CDBN Latent Representation Extractor
# =============================================================================

class CDBNFeatureExtractor:
    """
    Complete Convolutional Deep Belief Network Feature Extractor.
    
    FULL CDBN ARCHITECTURE:
    -----------------------
    
    Input: OCT Image [B, 1, 128, 128]
           ‚Üì
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  Conv-RBM-1: Gaussian-Bernoulli      ‚îÇ
    ‚îÇ  32 filters, 7√ó7 kernels             ‚îÇ
    ‚îÇ  Output: [B, 32, 122, 122]           ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
           ‚Üì
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  Probabilistic Pooling               ‚îÇ
    ‚îÇ  2√ó2 pool, sum-based                 ‚îÇ
    ‚îÇ  Output: [B, 32, 61, 61]             ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
           ‚Üì
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  Conv-RBM-2: Bernoulli-Bernoulli     ‚îÇ
    ‚îÇ  64 filters, 5√ó5 kernels             ‚îÇ
    ‚îÇ  Output: [B, 64, 57, 57]             ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
           ‚Üì
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  Flatten                             ‚îÇ
    ‚îÇ  Output: [B, 207936]                 ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
           ‚Üì
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  FC-RBM: Bernoulli-Bernoulli         ‚îÇ
    ‚îÇ  256 hidden units                    ‚îÇ
    ‚îÇ  Output: [B, 256]                    ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
           ‚Üì
    Latent Representation z ‚àà [0, 1]^256
    
    This latent representation can be used for:
    - Supervised classification (add softmax layer)
    - Fine-tuning with backpropagation
    - Clustering
    - Visualization
    
    All parameters are FROZEN after unsupervised pretraining.
    """
    
    def __init__(
        self,
        conv_rbm_1: ConvRBM,
        pooling: ProbabilisticPooling,
        conv_rbm_2: ConvRBM,
        flattener: FeatureFlattener,
        fc_rbm: FCRBM,
        device: torch.device
    ):
        self.conv_rbm_1 = conv_rbm_1
        self.pooling = pooling
        self.conv_rbm_2 = conv_rbm_2
        self.flattener = flattener
        self.fc_rbm = fc_rbm
        self.device = device
        
        # Freeze all layers
        self._freeze_all()
    
    def _freeze_all(self):
        """Freeze all parameters in the CDBN."""
        for module in [self.conv_rbm_1, self.conv_rbm_2, self.fc_rbm]:
            module.eval()
            for param in module.parameters():
                param.requires_grad = False
    
    def extract_latent(self, images: torch.Tensor) -> torch.Tensor:
        """
        Extract top-level latent representation from OCT images.
        
        Args:
            images: Batch of OCT images
                    Shape: [B, 1, 128, 128]
                    Values: normalized to [0, 1]
                    
        Returns:
            latent: Top-level hidden probabilities from FC-RBM
                    Shape: [B, n_hidden] = [B, 256]
                    Values: probabilities in [0, 1]
        """
        with torch.no_grad():
            images = images.to(self.device)
            
            # Layer 1: Conv-RBM-1
            # h1 = sigmoid(conv(v, W1) + b1)
            h1_prob = self.conv_rbm_1.hidden_probabilities(images)
            
            # Layer 2: Probabilistic Pooling (2√ó2)
            h1_pooled = self.pooling(h1_prob)
            
            # Layer 3: Conv-RBM-2
            # h2 = sigmoid(conv(h1_pooled, W2) + b2)
            h2_prob = self.conv_rbm_2.hidden_probabilities(h1_pooled)
            
            # Layer 4: Flatten
            flat = self.flattener(h2_prob)
            
            # Layer 5: FC-RBM
            # z = sigmoid(flat @ W3 + b3)
            latent = self.fc_rbm.hidden_probabilities(flat)
        
        return latent
    
    def extract_all_representations(self, images: torch.Tensor) -> dict:
        """
        Extract representations at all layers (for visualization/analysis).
        
        Args:
            images: Batch of OCT images [B, 1, 128, 128]
            
        Returns:
            Dictionary containing representations at each layer
        """
        with torch.no_grad():
            images = images.to(self.device)
            
            h1 = self.conv_rbm_1.hidden_probabilities(images)
            h1_pooled = self.pooling(h1)
            h2 = self.conv_rbm_2.hidden_probabilities(h1_pooled)
            flat = self.flattener(h2)
            latent = self.fc_rbm.hidden_probabilities(flat)
        
        return {
            'input': images,
            'conv_rbm_1': h1,
            'pooled_1': h1_pooled,
            'conv_rbm_2': h2,
            'flattened': flat,
            'latent': latent
        }


# Create the complete CDBN feature extractor
cdbn_extractor = CDBNFeatureExtractor(
    conv_rbm_1=conv_rbm_1,
    pooling=prob_pool,
    conv_rbm_2=conv_rbm_2,
    flattener=flattener,
    fc_rbm=fc_rbm,
    device=Config.DEVICE
)

print("=" * 70)
print("CDBN FEATURE EXTRACTOR CREATED")
print("=" * 70)

In [None]:
# =============================================================================
# Verify CDBN Feature Extraction
# =============================================================================

# Test with a batch of images
test_images, test_labels = next(iter(train_loader))
test_images = test_images[:8]  # Use 8 images for testing
test_labels = test_labels[:8]

# Extract latent representations
latent_representations = cdbn_extractor.extract_latent(test_images)

print("=" * 70)
print("CDBN LATENT EXTRACTION VERIFICATION")
print("=" * 70)
print(f"Input shape:  {test_images.shape}")
print(f"Output shape: {latent_representations.shape}")
print(f"Output dtype: {latent_representations.dtype}")
print(f"Output device: {latent_representations.device}")
print(f"Output range: [{latent_representations.min():.4f}, {latent_representations.max():.4f}]")
print(f"Output mean:  {latent_representations.mean():.4f}")
print(f"Output std:   {latent_representations.std():.4f}")
print("=" * 70)

# Verify expected shape
expected_shape = (8, 256)
assert latent_representations.shape == expected_shape, \
    f"Expected {expected_shape}, got {latent_representations.shape}"
print(f"‚úì Shape verification passed: [B, n_hidden] = {latent_representations.shape}")

In [None]:
# =============================================================================
# Visualize All Layer Representations
# =============================================================================

# Get all representations for visualization
all_reps = cdbn_extractor.extract_all_representations(test_images[:1])

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# 1. Input image
ax = axes[0, 0]
ax.imshow(test_images[0, 0].cpu().numpy(), cmap='gray')
ax.set_title(f'Input Image\n{test_images[0].shape}', fontsize=11)
ax.axis('off')

# 2. Conv-RBM-1 features (show first 16 channels as grid)
ax = axes[0, 1]
h1 = all_reps['conv_rbm_1'][0].cpu().numpy()
# Create a grid of first 16 feature maps
grid = np.zeros((4*h1.shape[1]//4, 4*h1.shape[2]//4))
for i in range(min(16, h1.shape[0])):
    r, c = i // 4, i % 4
    grid[r*h1.shape[1]//4:(r+1)*h1.shape[1]//4, 
         c*h1.shape[2]//4:(c+1)*h1.shape[2]//4] = \
        h1[i, ::4, ::4]  # Downsample for display
ax.imshow(grid, cmap='hot')
ax.set_title(f'Conv-RBM-1 Output\n{all_reps["conv_rbm_1"].shape}', fontsize=11)
ax.axis('off')

# 3. Pooled features
ax = axes[0, 2]
h1_pool = all_reps['pooled_1'][0].cpu().numpy()
grid2 = np.zeros((4*h1_pool.shape[1]//4, 4*h1_pool.shape[2]//4))
for i in range(min(16, h1_pool.shape[0])):
    r, c = i // 4, i % 4
    grid2[r*h1_pool.shape[1]//4:(r+1)*h1_pool.shape[1]//4, 
          c*h1_pool.shape[2]//4:(c+1)*h1_pool.shape[2]//4] = \
        h1_pool[i, ::4, ::4]
ax.imshow(grid2, cmap='hot')
ax.set_title(f'After Pooling\n{all_reps["pooled_1"].shape}', fontsize=11)
ax.axis('off')

# 4. Conv-RBM-2 features
ax = axes[1, 0]
h2 = all_reps['conv_rbm_2'][0].cpu().numpy()
grid3 = np.zeros((4*h2.shape[1]//4, 4*h2.shape[2]//4))
for i in range(min(16, h2.shape[0])):
    r, c = i // 4, i % 4
    grid3[r*h2.shape[1]//4:(r+1)*h2.shape[1]//4, 
          c*h2.shape[2]//4:(c+1)*h2.shape[2]//4] = \
        h2[i, ::4, ::4]
ax.imshow(grid3, cmap='hot')
ax.set_title(f'Conv-RBM-2 Output\n{all_reps["conv_rbm_2"].shape}', fontsize=11)
ax.axis('off')

# 5. Flattened representation (show as 1D bar)
ax = axes[1, 1]
flat = all_reps['flattened'][0].cpu().numpy()
# Show a subset (every 1000th value)
flat_subset = flat[::1000]
ax.bar(range(len(flat_subset)), flat_subset, color='steelblue', alpha=0.7)
ax.set_title(f'Flattened (sampled)\n{all_reps["flattened"].shape}', fontsize=11)
ax.set_xlabel('Feature Index (√ó1000)')
ax.set_ylabel('Activation')

# 6. Final latent representation
ax = axes[1, 2]
latent = all_reps['latent'][0].cpu().numpy()
ax.bar(range(len(latent)), latent, color='darkgreen', alpha=0.7)
ax.set_title(f'FC-RBM Latent\n{all_reps["latent"].shape}', fontsize=11)
ax.set_xlabel('Hidden Unit')
ax.set_ylabel('Probability')
ax.set_ylim([0, 1])

plt.suptitle('CDBN Layer-by-Layer Representations', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(Config.OUTPUT_DIR, 'cdbn_layer_representations.png'), 
            dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Layer-by-layer visualization saved to output directory")

---

# ‚úÖ FC-RBM Pretrained ‚Äî Unsupervised CDBN Complete

## Summary

The complete CDBN has been built and pretrained using **unsupervised learning** with Contrastive Divergence:

### Architecture Summary

| Layer | Type | Configuration | Output Shape |
|-------|------|--------------|--------------|
| Input | OCT Image | 128√ó128 grayscale | [B, 1, 128, 128] |
| Layer 1 | Conv-RBM-1 | 32 filters, 7√ó7, Gaussian-Bernoulli | [B, 32, 122, 122] |
| Pool | Probabilistic | 2√ó2 sum-based | [B, 32, 61, 61] |
| Layer 2 | Conv-RBM-2 | 64 filters, 5√ó5, Bernoulli-Bernoulli | [B, 64, 57, 57] |
| Flatten | - | - | [B, 207,936] |
| Layer 3 | FC-RBM | 256 hidden, Bernoulli-Bernoulli | [B, 256] |

### Training Summary

- ‚úÖ **Conv-RBM-1**: Trained to extract low-level edge/texture features
- ‚úÖ **Conv-RBM-2**: Trained to extract higher-level compositional features  
- ‚úÖ **FC-RBM**: Trained to learn a compact 256-dimensional latent representation

### Key Points

1. **No labels used** ‚Äî All training was unsupervised
2. **No backpropagation** ‚Äî Used Contrastive Divergence throughout
3. **Layer-wise greedy training** ‚Äî Each layer trained independently
4. **Probabilistic representations** ‚Äî All outputs are probabilities in [0, 1]

### Next Steps

The pretrained CDBN can now be used for:
- **Supervised fine-tuning**: Add a classifier and fine-tune with labeled data
- **Feature extraction**: Use latent representations for downstream tasks
- **Transfer learning**: Apply to related medical imaging tasks

---

# PART A ‚Äî Classifier on Frozen CDBN (Feature-Based)

Now that the CDBN is pretrained using unsupervised learning, we add a supervised classification head to perform multi-class Eye OCT classification.

**Strategy:**
1. Freeze all CDBN layers (Conv-RBM-1, Pool, Conv-RBM-2, FC-RBM)
2. Add a single linear layer (Softmax Classifier)
3. Train ONLY the classifier head using labeled data
4. Use CrossEntropyLoss and Adam optimizer

This approach treats the pretrained CDBN as a **fixed feature extractor**.

## STEP 19 ‚Äî Softmax Classifier Head

We implement a simple linear classifier that takes the 256-dimensional latent representation from the FC-RBM and outputs class probabilities.

**Architecture:**
$$
\hat{y} = \text{softmax}(\mathbf{z} \mathbf{W}_{clf} + \mathbf{b}_{clf})
$$

where $\mathbf{z} \in \mathbb{R}^{256}$ is the latent vector from FC-RBM.

Note: PyTorch's `CrossEntropyLoss` applies softmax internally, so we only need a linear layer.

In [None]:
# =============================================================================
# STEP 19: Softmax Classifier Head
# =============================================================================

class CDBNClassifier(nn.Module):
    """
    Softmax Classifier Head for CDBN.
    
    ARCHITECTURE:
    -------------
    Input:  Latent vector z from FC-RBM [B, n_hidden]
    Output: Class logits [B, n_classes]
    
    The classifier is a single linear layer:
        logits = z @ W_clf + b_clf
    
    Softmax is applied implicitly by CrossEntropyLoss during training.
    For inference, use softmax to get probabilities.
    
    DESIGN CHOICES:
    ---------------
    1. Single linear layer (no hidden layers) - keeps the classifier simple
       and forces the CDBN to learn discriminative features
    2. No dropout - the latent representation is already compressed
    3. Xavier initialization - suitable for linear layers with softmax
    
    Args:
        n_latent: Dimensionality of latent vector (default: 256)
        n_classes: Number of output classes
    """
    
    def __init__(self, n_latent: int = 256, n_classes: int = 4):
        super(CDBNClassifier, self).__init__()
        
        self.n_latent = n_latent
        self.n_classes = n_classes
        
        # Single linear layer: z -> logits
        self.fc = nn.Linear(n_latent, n_classes)
        
        # Xavier initialization for better convergence
        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)
    
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through classifier.
        
        Args:
            z: Latent vector from FC-RBM [B, n_latent]
            
        Returns:
            logits: Raw class scores [B, n_classes]
                    (NOT probabilities - apply softmax for probabilities)
        """
        logits = self.fc(z)
        return logits
    
    def predict_proba(self, z: torch.Tensor) -> torch.Tensor:
        """
        Get class probabilities.
        
        Args:
            z: Latent vector from FC-RBM [B, n_latent]
            
        Returns:
            probs: Class probabilities [B, n_classes]
        """
        logits = self.forward(z)
        return torch.softmax(logits, dim=1)
    
    def predict(self, z: torch.Tensor) -> torch.Tensor:
        """
        Get predicted class indices.
        
        Args:
            z: Latent vector from FC-RBM [B, n_latent]
            
        Returns:
            predictions: Predicted class indices [B]
        """
        logits = self.forward(z)
        return torch.argmax(logits, dim=1)


# Determine number of classes from dataset
n_classes = len(train_dataset.classes)
class_names = train_dataset.classes

print("=" * 70)
print("CLASSIFIER CONFIGURATION")
print("=" * 70)
print(f"Number of classes: {n_classes}")
print(f"Class names: {class_names}")
print(f"Latent dimension: {fc_rbm.n_hidden}")
print("=" * 70)

# Create classifier
classifier = CDBNClassifier(
    n_latent=fc_rbm.n_hidden,
    n_classes=n_classes
).to(Config.DEVICE)

# Verify classifier
test_z = torch.randn(4, fc_rbm.n_hidden).to(Config.DEVICE)
test_logits = classifier(test_z)
test_probs = classifier.predict_proba(test_z)
test_preds = classifier.predict(test_z)

print(f"\n‚úì CDBNClassifier created!")
print(f"  Input shape:  {test_z.shape}")
print(f"  Logits shape: {test_logits.shape}")
print(f"  Probs shape:  {test_probs.shape}")
print(f"  Preds shape:  {test_preds.shape}")
print(f"  Prob sum (should be 1.0): {test_probs.sum(dim=1).mean():.4f}")

## STEP 20 ‚Äî Dataset Wrapper for Latent Features

To speed up training, we **cache** the latent representations extracted by the frozen CDBN. This avoids recomputing features every epoch.

**Process:**
1. Pass all images through the frozen CDBN once
2. Store (latent_vector, label) pairs
3. Create PyTorch Datasets for train/val/test splits

In [None]:
# =============================================================================
# STEP 20: Dataset Wrapper for Cached Latent Features (Disk-Backed)
# =============================================================================

class DiskBackedLatentDataset(Dataset):
    """
    Dataset that loads latent features from disk on-demand.
    
    For large datasets (5.5 GB+), this avoids loading all features into RAM.
    Features are stored as numpy arrays and loaded batch-by-batch.
    
    BENEFITS:
    ---------
    1. Memory efficient - only load what's needed for current batch
    2. Scales to very large datasets
    3. Faster subsequent runs (cached on disk)
    
    Args:
        cache_dir: Directory where latent vectors are cached
        split_name: Name of the split ('train', 'val', 'test')
    """
    
    def __init__(self, cache_dir: str, split_name: str):
        self.cache_dir = cache_dir
        self.split_name = split_name
        self.latent_file = os.path.join(cache_dir, f"{split_name}_latents.pt")
        self.labels_file = os.path.join(cache_dir, f"{split_name}_labels.pt")
        
        # Load metadata to get length
        if os.path.exists(self.latent_file):
            # Load the tensors
            self.latent_vectors = torch.load(self.latent_file, map_location='cpu')
            self.labels = torch.load(self.labels_file, map_location='cpu')
            print(f"  ‚úì Loaded cached {split_name} features: {len(self.labels)} samples")
        else:
            raise FileNotFoundError(f"Cache file not found: {self.latent_file}")
    
    def __len__(self) -> int:
        return len(self.labels)
    
    def __getitem__(self, idx: int) -> tuple:
        return self.latent_vectors[idx], self.labels[idx]


class LatentFeatureDataset(Dataset):
    """
    Dataset of cached latent features from the frozen CDBN (in-memory).
    
    This dataset stores precomputed latent vectors, eliminating the need
    to run the CDBN forward pass during each training epoch.
    
    For smaller datasets where RAM is available.
    
    Args:
        latent_vectors: Tensor of latent vectors [N, n_hidden]
        labels: Tensor of class labels [N]
    """
    
    def __init__(self, latent_vectors: torch.Tensor, labels: torch.Tensor):
        self.latent_vectors = latent_vectors
        self.labels = labels
        
        assert len(latent_vectors) == len(labels), \
            f"Mismatch: {len(latent_vectors)} vectors, {len(labels)} labels"
    
    def __len__(self) -> int:
        return len(self.labels)
    
    def __getitem__(self, idx: int) -> tuple:
        return self.latent_vectors[idx], self.labels[idx]


def extract_and_cache_features_to_disk(
    image_loader: DataLoader,
    feature_extractor: CDBNFeatureExtractor,
    device: torch.device,
    cache_dir: str,
    split_name: str,
    memory_cleanup_freq: int = 50
) -> tuple:
    """
    Extract latent features and save to disk (memory-efficient for large datasets).
    
    Processes data in batches, saves to disk, and only keeps final tensors.
    Uses periodic memory cleanup to avoid GPU memory spikes.
    
    Args:
        image_loader: DataLoader with (image, label) batches
        feature_extractor: Frozen CDBN feature extractor
        device: Computation device
        cache_dir: Directory to cache features
        split_name: Name of the split ('train', 'val', 'test')
        memory_cleanup_freq: How often to clean GPU cache
        
    Returns:
        latent_file: Path to cached latent vectors
        labels_file: Path to cached labels
    """
    os.makedirs(cache_dir, exist_ok=True)
    
    latent_file = os.path.join(cache_dir, f"{split_name}_latents.pt")
    labels_file = os.path.join(cache_dir, f"{split_name}_labels.pt")
    
    # Check if already cached
    if os.path.exists(latent_file) and os.path.exists(labels_file):
        print(f"  ‚úì {split_name} features already cached, loading...")
        latent_vectors = torch.load(latent_file, map_location='cpu')
        labels = torch.load(labels_file, map_location='cpu')
        return latent_vectors, labels
    
    all_latents = []
    all_labels = []
    
    print(f"Extracting {split_name} latent features (disk-backed)...")
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(image_loader):
            # Extract latent representations
            latents = feature_extractor.extract_latent(images)
            
            # Store on CPU immediately to free GPU memory
            all_latents.append(latents.cpu())
            all_labels.append(labels.cpu())
            
            # Explicit cleanup
            del images, latents
            
            # Periodic memory cleanup
            if torch.cuda.is_available() and (batch_idx + 1) % memory_cleanup_freq == 0:
                torch.cuda.empty_cache()
            
            if (batch_idx + 1) % 50 == 0:
                print(f"  Processed batch {batch_idx + 1}/{len(image_loader)}")
    
    # Concatenate all batches
    latent_vectors = torch.cat(all_latents, dim=0)
    labels_tensor = torch.cat(all_labels, dim=0)
    
    # Free the list memory
    del all_latents, all_labels
    
    # Save to disk
    torch.save(latent_vectors, latent_file)
    torch.save(labels_tensor, labels_file)
    
    print(f"‚úì Extracted and cached {len(latent_vectors)} latent vectors to disk")
    print(f"  Cache files: {latent_file}")
    
    # Clear GPU cache after extraction
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return latent_vectors, labels_tensor


# Legacy function for backward compatibility
def extract_and_cache_features(
    image_loader: DataLoader,
    feature_extractor: CDBNFeatureExtractor,
    device: torch.device
) -> tuple:
    """
    Extract latent features from all images in a dataloader (in-memory version).
    
    For smaller datasets. For large datasets (5.5 GB+), use 
    extract_and_cache_features_to_disk() instead.
    """
    all_latents = []
    all_labels = []
    
    print("Extracting latent features...")
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(image_loader):
            latents = feature_extractor.extract_latent(images)
            all_latents.append(latents.cpu())
            all_labels.append(labels)
            
            del images, latents
            
            if batch_idx % 20 == 0:
                print(f"  Processed batch {batch_idx}/{len(image_loader)}")
            
            if torch.cuda.is_available() and batch_idx % 100 == 0:
                torch.cuda.empty_cache()
    
    latent_vectors = torch.cat(all_latents, dim=0)
    labels = torch.cat(all_labels, dim=0)
    
    print(f"‚úì Extracted {len(latent_vectors)} latent vectors")
    return latent_vectors, labels


print("‚úì LatentFeatureDataset classes and extraction functions defined")
print(f"  Cache directory: {config.LATENT_CACHE_DIR}")

In [None]:
# =============================================================================
# Create Train/Val/Test Splits and Cache Features (Disk-Backed for 5.5GB Dataset)
# =============================================================================

from torch.utils.data import random_split

# Split training data into train and validation sets (80/20)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size

train_subset, val_subset = random_split(
    train_dataset, 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)  # Reproducibility
)

# Create DataLoaders for splitting (using streaming-safe settings)
train_split_loader = DataLoader(
    train_subset, 
    batch_size=config.BATCH_SIZE, 
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=config.PIN_MEMORY,
    persistent_workers=config.PERSISTENT_WORKERS if config.NUM_WORKERS > 0 else False,
    prefetch_factor=config.PREFETCH_FACTOR if config.NUM_WORKERS > 0 else None
)
val_split_loader = DataLoader(
    val_subset, 
    batch_size=config.BATCH_SIZE, 
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=config.PIN_MEMORY
)
test_loader_for_caching = DataLoader(
    test_dataset, 
    batch_size=config.BATCH_SIZE, 
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=config.PIN_MEMORY
)

print("=" * 70)
print("DATASET SPLITS")
print("=" * 70)
print(f"Training samples:   {len(train_subset)}")
print(f"Validation samples: {len(val_subset)}")
print(f"Test samples:       {len(test_dataset)}")
print(f"Feature cache dir:  {config.LATENT_CACHE_DIR}")
print("=" * 70)

# Extract and cache latent features to DISK for each split
# This avoids holding all features in RAM for the 5.5GB dataset

print("\n--- Caching Training Features to Disk ---")
train_latents, train_labels = extract_and_cache_features_to_disk(
    train_split_loader, cdbn_extractor, config.DEVICE,
    cache_dir=config.LATENT_CACHE_DIR,
    split_name='train',
    memory_cleanup_freq=50
)

print("\n--- Caching Validation Features to Disk ---")
val_latents, val_labels = extract_and_cache_features_to_disk(
    val_split_loader, cdbn_extractor, config.DEVICE,
    cache_dir=config.LATENT_CACHE_DIR,
    split_name='val',
    memory_cleanup_freq=50
)

print("\n--- Caching Test Features to Disk ---")
test_latents, test_labels = extract_and_cache_features_to_disk(
    test_loader_for_caching, cdbn_extractor, config.DEVICE,
    cache_dir=config.LATENT_CACHE_DIR,
    split_name='test',
    memory_cleanup_freq=50
)

# Clear GPU memory after feature extraction
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print_gpu_memory()

# Create LatentFeatureDatasets (now using cached tensors)
train_latent_dataset = LatentFeatureDataset(train_latents, train_labels)
val_latent_dataset = LatentFeatureDataset(val_latents, val_labels)
test_latent_dataset = LatentFeatureDataset(test_latents, test_labels)

# Create DataLoaders for latent features (larger batch size = CLASSIFIER_BATCH_SIZE)
train_latent_loader = DataLoader(
    train_latent_dataset, 
    batch_size=config.CLASSIFIER_BATCH_SIZE,  # 128 for classifier
    shuffle=True,
    num_workers=0,  # Features are small, no need for workers
    pin_memory=config.PIN_MEMORY
)
val_latent_loader = DataLoader(
    val_latent_dataset, 
    batch_size=config.CLASSIFIER_BATCH_SIZE, 
    shuffle=False
)
test_latent_loader = DataLoader(
    test_latent_dataset, 
    batch_size=config.CLASSIFIER_BATCH_SIZE, 
    shuffle=False
)

print("\n" + "=" * 70)
print("CACHED LATENT FEATURE DATASETS (Disk-Backed)")
print("=" * 70)
print(f"Train: {len(train_latent_dataset)} samples, {len(train_latent_loader)} batches")
print(f"Val:   {len(val_latent_dataset)} samples, {len(val_latent_loader)} batches")
print(f"Test:  {len(test_latent_dataset)} samples, {len(test_latent_loader)} batches")
print(f"Latent vector dim:    {train_latents.shape[1]}")
print(f"Classifier batch size: {config.CLASSIFIER_BATCH_SIZE}")
print(f"Cache location:        {config.LATENT_CACHE_DIR}")
print("=" * 70)

## STEP 21 ‚Äî Supervised Training (Frozen CDBN)

Train only the classifier head on cached latent features.

**Training Configuration:**
- Loss: CrossEntropyLoss (includes softmax internally)
- Optimizer: Adam with default parameters
- Epochs: 30
- Metrics: Training loss, Validation accuracy

The CDBN remains **completely frozen** ‚Äî only the classifier weights are updated.

In [None]:
# =============================================================================
# STEP 21: Supervised Training Loop for Classifier Head
# =============================================================================

def train_classifier(
    classifier: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    n_epochs: int = 30,
    learning_rate: float = 1e-3,
    device: torch.device = None
) -> dict:
    """
    Train the classifier head on cached latent features.
    
    TRAINING PROCEDURE:
    -------------------
    1. Forward pass: z -> logits
    2. Compute CrossEntropyLoss
    3. Backward pass (only classifier gradients)
    4. Adam optimizer step
    5. Evaluate on validation set
    
    Args:
        classifier: CDBNClassifier instance
        train_loader: DataLoader with (latent, label) pairs
        val_loader: Validation DataLoader
        n_epochs: Number of training epochs
        learning_rate: Learning rate for Adam
        device: Computation device
        
    Returns:
        history: Dictionary with training metrics
    """
    if device is None:
        device = torch.device('cpu')
    
    classifier = classifier.to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)
    
    # Learning rate scheduler (reduce on plateau)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5, verbose=True
    )
    
    # History tracking
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    best_val_acc = 0.0
    best_model_state = None
    
    print("=" * 70)
    print("CLASSIFIER TRAINING (FROZEN CDBN)")
    print("=" * 70)
    print(f"Epochs: {n_epochs}")
    print(f"Learning Rate: {learning_rate}")
    print(f"Optimizer: Adam")
    print(f"Loss: CrossEntropyLoss")
    print("=" * 70)
    
    for epoch in range(n_epochs):
        # =====================================================================
        # TRAINING PHASE
        # =====================================================================
        classifier.train()
        train_losses = []
        train_correct = 0
        train_total = 0
        
        for latents, labels in train_loader:
            latents = latents.to(device)
            labels = labels.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            logits = classifier(latents)
            loss = criterion(logits, labels)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Track metrics
            train_losses.append(loss.item())
            _, predicted = torch.max(logits, 1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
        
        train_loss = np.mean(train_losses)
        train_acc = train_correct / train_total
        
        # =====================================================================
        # VALIDATION PHASE
        # =====================================================================
        classifier.eval()
        val_losses = []
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for latents, labels in val_loader:
                latents = latents.to(device)
                labels = labels.to(device)
                
                logits = classifier(latents)
                loss = criterion(logits, labels)
                
                val_losses.append(loss.item())
                _, predicted = torch.max(logits, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)
        
        val_loss = np.mean(val_losses)
        val_acc = val_correct / val_total
        
        # Update scheduler
        scheduler.step(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = classifier.state_dict().copy()
        
        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Print progress
        print(f"Epoch [{epoch+1:2d}/{n_epochs}] | "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    
    # Restore best model
    if best_model_state is not None:
        classifier.load_state_dict(best_model_state)
        print(f"\n‚úì Restored best model with Val Acc: {best_val_acc:.4f}")
    
    print("=" * 70)
    print("TRAINING COMPLETE!")
    print(f"Best Validation Accuracy: {best_val_acc:.4f}")
    print("=" * 70)
    
    return history


print("‚úì train_classifier function defined")

In [None]:
# =============================================================================
# Train the Classifier Head
# =============================================================================

# Training configuration
CLASSIFIER_EPOCHS = 20  # Kaggle: reduced from 30 for time limits
CLASSIFIER_LR = 1e-3

# Train classifier on cached latent features
frozen_history = train_classifier(
    classifier=classifier,
    train_loader=train_latent_loader,
    val_loader=val_latent_loader,
    n_epochs=CLASSIFIER_EPOCHS,
    learning_rate=CLASSIFIER_LR,
    device=Config.DEVICE
)

In [None]:
# =============================================================================
# Visualize Training Progress (Frozen CDBN)
# =============================================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Loss curves
ax1 = axes[0]
epochs = range(1, len(frozen_history['train_loss']) + 1)
ax1.plot(epochs, frozen_history['train_loss'], 'b-', linewidth=2, label='Train Loss')
ax1.plot(epochs, frozen_history['val_loss'], 'r-', linewidth=2, label='Val Loss')
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Classifier Training Loss (Frozen CDBN)', fontsize=14)
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Accuracy curves
ax2 = axes[1]
ax2.plot(epochs, frozen_history['train_acc'], 'b-', linewidth=2, label='Train Acc')
ax2.plot(epochs, frozen_history['val_acc'], 'r-', linewidth=2, label='Val Acc')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy', fontsize=12)
ax2.set_title('Classifier Training Accuracy (Frozen CDBN)', fontsize=14)
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

plt.tight_layout()
plt.savefig(os.path.join(Config.OUTPUT_DIR, 'frozen_cdbn_classifier_training.png'), 
            dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úì Best Validation Accuracy: {max(frozen_history['val_acc']):.4f}")

# PART B ‚Äî End-to-End Fine-Tuning (Optional)

Now we wrap the entire CDBN + Classifier into a single `nn.Module` and fine-tune all layers with backpropagation.

**Key Differences from Frozen Training:**
1. All CDBN parameters are **unfrozen**
2. Use a **smaller learning rate** (1e-4) to avoid destroying pretrained features
3. Monitor for **overfitting** as the model has many more parameters

## STEP 22 ‚Äî End-to-End CDBN Model

We create a unified `nn.Module` that encapsulates the entire pipeline:

**Pipeline:**
$$
\text{Image} \xrightarrow{\text{Conv-RBM-1}} \xrightarrow{\text{Pool}} \xrightarrow{\text{Conv-RBM-2}} \xrightarrow{\text{Flatten}} \xrightarrow{\text{FC-RBM}} \xrightarrow{\text{Classifier}} \text{Logits}
$$

All weights are initialized from the pretrained RBMs.

In [None]:
# =============================================================================
# STEP 22: End-to-End CDBN Model with Classifier
# =============================================================================

class EndToEndCDBN(nn.Module):
    """
    End-to-End Convolutional Deep Belief Network for Classification.
    
    COMPLETE ARCHITECTURE:
    ----------------------
    
    Input: OCT Image [B, 1, 128, 128]
           ‚Üì
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  Conv-RBM-1 (Pretrained)             ‚îÇ
    ‚îÇ  W: [32, 1, 7, 7], uses only         ‚îÇ
    ‚îÇ  hidden_probabilities path           ‚îÇ
    ‚îÇ  Output: [B, 32, 122, 122]           ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
           ‚Üì
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  Probabilistic Pooling               ‚îÇ
    ‚îÇ  2√ó2 sum-based pooling               ‚îÇ
    ‚îÇ  Output: [B, 32, 61, 61]             ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
           ‚Üì
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  Conv-RBM-2 (Pretrained)             ‚îÇ
    ‚îÇ  W: [64, 32, 5, 5]                   ‚îÇ
    ‚îÇ  Output: [B, 64, 57, 57]             ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
           ‚Üì
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  Flatten                             ‚îÇ
    ‚îÇ  Output: [B, 207936]                 ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
           ‚Üì
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  FC-RBM (Pretrained)                 ‚îÇ
    ‚îÇ  Linear: 207936 ‚Üí 256                ‚îÇ
    ‚îÇ  Output: [B, 256]                    ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
           ‚Üì
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  Classifier (Trained)                ‚îÇ
    ‚îÇ  Linear: 256 ‚Üí n_classes             ‚îÇ
    ‚îÇ  Output: [B, n_classes]              ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
    
    FINE-TUNING CONSIDERATIONS:
    ---------------------------
    1. All pretrained weights are copied (not referenced)
    2. Pooling is non-parametric (no gradients)
    3. For fine-tuning, all RBM layers use their forward paths
       (hidden_probabilities) which are differentiable
    4. Use small learning rate to preserve pretrained features
    
    Args:
        conv_rbm_1: Pretrained ConvRBM (layer 1)
        pooling: ProbabilisticPooling module
        conv_rbm_2: Pretrained ConvRBM (layer 2)
        fc_rbm: Pretrained FCRBM
        classifier: Trained CDBNClassifier
    """
    
    def __init__(
        self,
        conv_rbm_1: ConvRBM,
        pooling: ProbabilisticPooling,
        conv_rbm_2: ConvRBM,
        fc_rbm: FCRBM,
        classifier: CDBNClassifier,
        flatten_dim: int
    ):
        super(EndToEndCDBN, self).__init__()
        
        # Store layer references (will copy weights)
        self.conv_rbm_1 = conv_rbm_1
        self.pooling = pooling
        self.conv_rbm_2 = conv_rbm_2
        self.fc_rbm = fc_rbm
        self.classifier = classifier
        self.flatten_dim = flatten_dim
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through entire CDBN + Classifier.
        
        Args:
            x: Input images [B, 1, 128, 128]
            
        Returns:
            logits: Class logits [B, n_classes]
        """
        # Layer 1: Conv-RBM-1 (differentiable path)
        h1 = self.conv_rbm_1.hidden_probabilities(x)
        
        # Layer 2: Probabilistic Pooling
        h1_pooled = self.pooling(h1)
        
        # Layer 3: Conv-RBM-2
        h2 = self.conv_rbm_2.hidden_probabilities(h1_pooled)
        
        # Layer 4: Flatten
        flat = h2.view(h2.size(0), -1)
        
        # Layer 5: FC-RBM
        z = self.fc_rbm.hidden_probabilities(flat)
        
        # Layer 6: Classifier
        logits = self.classifier(z)
        
        return logits
    
    def get_latent(self, x: torch.Tensor) -> torch.Tensor:
        """Get latent representation before classifier."""
        with torch.no_grad():
            h1 = self.conv_rbm_1.hidden_probabilities(x)
            h1_pooled = self.pooling(h1)
            h2 = self.conv_rbm_2.hidden_probabilities(h1_pooled)
            flat = h2.view(h2.size(0), -1)
            z = self.fc_rbm.hidden_probabilities(flat)
        return z
    
    def freeze_cdbn(self):
        """Freeze all CDBN layers (for feature-based training)."""
        for module in [self.conv_rbm_1, self.conv_rbm_2, self.fc_rbm]:
            for param in module.parameters():
                param.requires_grad = False
        print("‚úì CDBN layers frozen")
    
    def unfreeze_all(self):
        """Unfreeze all layers (for end-to-end fine-tuning)."""
        for param in self.parameters():
            param.requires_grad = True
        print("‚úì All layers unfrozen for fine-tuning")
    
    def count_parameters(self) -> dict:
        """Count trainable and total parameters."""
        total = sum(p.numel() for p in self.parameters())
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return {
            'total': total,
            'trainable': trainable,
            'frozen': total - trainable
        }


# Create end-to-end model from pretrained components
e2e_model = EndToEndCDBN(
    conv_rbm_1=conv_rbm_1,
    pooling=prob_pool,
    conv_rbm_2=conv_rbm_2,
    fc_rbm=fc_rbm,
    classifier=classifier,
    flatten_dim=FLAT_DIM
).to(Config.DEVICE)

# Verify the model
test_input = torch.randn(4, 1, 128, 128).to(Config.DEVICE)
test_output = e2e_model(test_input)

print("=" * 70)
print("END-TO-END CDBN MODEL")
print("=" * 70)
print(f"Input shape:  {test_input.shape}")
print(f"Output shape: {test_output.shape}")
params = e2e_model.count_parameters()
print(f"Total parameters:     {params['total']:,}")
print(f"Trainable parameters: {params['trainable']:,}")
print("=" * 70)

## STEP 23 ‚Äî Fine-Tuning with Backpropagation

Now we unfreeze all layers and fine-tune the entire network with a small learning rate.

**Fine-Tuning Strategy:**
- Learning rate: 1e-4 (10x smaller than classifier training)
- All parameters updated via backpropagation
- Monitor for overfitting (train acc >> val acc)
- Early stopping based on validation accuracy

In [None]:
# =============================================================================
# STEP 23: End-to-End Fine-Tuning Training Loop
# =============================================================================

def finetune_e2e_model(
    model: EndToEndCDBN,
    train_loader: DataLoader,
    val_loader: DataLoader,
    n_epochs: int = 20,
    learning_rate: float = 1e-4,
    device: torch.device = None
) -> dict:
    """
    Fine-tune the entire CDBN + Classifier end-to-end.
    
    FINE-TUNING PROCEDURE:
    ----------------------
    1. Unfreeze all parameters
    2. Use small learning rate to preserve pretrained features
    3. Standard backpropagation through entire network
    4. Monitor overfitting behavior
    
    Args:
        model: EndToEndCDBN instance
        train_loader: DataLoader with (image, label) pairs
        val_loader: Validation DataLoader
        n_epochs: Number of fine-tuning epochs
        learning_rate: Learning rate (should be small, e.g., 1e-4)
        device: Computation device
        
    Returns:
        history: Dictionary with training metrics
    """
    if device is None:
        device = torch.device('cpu')
    
    model = model.to(device)
    
    # Unfreeze all layers for fine-tuning
    model.unfreeze_all()
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    
    # Use different learning rates for different parts (optional refinement)
    # Here we use a single small LR for simplicity
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3, verbose=True
    )
    
    # History tracking
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    best_val_acc = 0.0
    best_model_state = None
    
    print("=" * 70)
    print("END-TO-END FINE-TUNING")
    print("=" * 70)
    params = model.count_parameters()
    print(f"Epochs: {n_epochs}")
    print(f"Learning Rate: {learning_rate}")
    print(f"Trainable Parameters: {params['trainable']:,}")
    print("=" * 70)
    
    for epoch in range(n_epochs):
        # =====================================================================
        # TRAINING PHASE
        # =====================================================================
        model.train()
        train_losses = []
        train_correct = 0
        train_total = 0
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass through entire network
            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            
            # Backward pass (gradients flow through all layers)
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            # Track metrics
            train_losses.append(loss.item())
            _, predicted = torch.max(logits, 1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
            
            # Progress update
            if batch_idx % 20 == 0:
                print(f"  Epoch [{epoch+1}/{n_epochs}] "
                      f"Batch [{batch_idx}/{len(train_loader)}] "
                      f"Loss: {loss.item():.4f}")
        
        train_loss = np.mean(train_losses)
        train_acc = train_correct / train_total
        
        # =====================================================================
        # VALIDATION PHASE
        # =====================================================================
        model.eval()
        val_losses = []
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                
                logits = model(images)
                loss = criterion(logits, labels)
                
                val_losses.append(loss.item())
                _, predicted = torch.max(logits, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)
        
        val_loss = np.mean(val_losses)
        val_acc = val_correct / val_total
        
        # Update scheduler
        scheduler.step(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = {k: v.clone() for k, v in model.state_dict().items()}
        
        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Print epoch summary
        overfit_gap = train_acc - val_acc
        overfit_warning = " ‚ö†Ô∏è OVERFITTING" if overfit_gap > 0.1 else ""
        
        print(f"‚ñ∂ Epoch [{epoch+1:2d}/{n_epochs}] | "
              f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | "
              f"Gap: {overfit_gap:.4f}{overfit_warning}")
        print("-" * 70)
    
    # Restore best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"\n‚úì Restored best model with Val Acc: {best_val_acc:.4f}")
    
    print("=" * 70)
    print("FINE-TUNING COMPLETE!")
    print(f"Best Validation Accuracy: {best_val_acc:.4f}")
    print("=" * 70)
    
    return history


print("‚úì finetune_e2e_model function defined")

In [None]:
# =============================================================================
# Create DataLoaders for Image-based Training (Fine-tuning)
# =============================================================================

# Need image-based loaders for fine-tuning (not cached features)
train_image_loader = DataLoader(train_subset, batch_size=Config.BATCH_SIZE, shuffle=True)
val_image_loader = DataLoader(val_subset, batch_size=Config.BATCH_SIZE, shuffle=False)

print("‚úì Image DataLoaders created for fine-tuning")
print(f"  Train batches: {len(train_image_loader)}")
print(f"  Val batches:   {len(val_image_loader)}")

In [None]:
# =============================================================================
# Fine-Tune the End-to-End Model
# =============================================================================

# Fine-tuning configuration
FINETUNE_EPOCHS = 10  # Kaggle: reduced from 20 for time limits
FINETUNE_LR = 1e-4  # Small learning rate to preserve pretrained features

# Fine-tune the model
finetune_history = finetune_e2e_model(
    model=e2e_model,
    train_loader=train_image_loader,
    val_loader=val_image_loader,
    n_epochs=FINETUNE_EPOCHS,
    learning_rate=FINETUNE_LR,
    device=Config.DEVICE
)

In [None]:
# =============================================================================
# Visualize Fine-Tuning Results
# =============================================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Fine-tuning Loss
ax1 = axes[0, 0]
epochs_ft = range(1, len(finetune_history['train_loss']) + 1)
ax1.plot(epochs_ft, finetune_history['train_loss'], 'b-', linewidth=2, label='Train Loss')
ax1.plot(epochs_ft, finetune_history['val_loss'], 'r-', linewidth=2, label='Val Loss')
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Fine-Tuning Loss', fontsize=14)
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Fine-tuning Accuracy
ax2 = axes[0, 1]
ax2.plot(epochs_ft, finetune_history['train_acc'], 'b-', linewidth=2, label='Train Acc')
ax2.plot(epochs_ft, finetune_history['val_acc'], 'r-', linewidth=2, label='Val Acc')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy', fontsize=12)
ax2.set_title('Fine-Tuning Accuracy', fontsize=14)
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

# Plot 3: Comparison - Frozen vs Fine-tuned (Validation Accuracy)
ax3 = axes[1, 0]
epochs_frozen = range(1, len(frozen_history['val_acc']) + 1)
ax3.plot(epochs_frozen, frozen_history['val_acc'], 'g-o', linewidth=2, 
         markersize=4, label='Frozen CDBN', alpha=0.7)
ax3.plot(epochs_ft, finetune_history['val_acc'], 'b-s', linewidth=2, 
         markersize=4, label='Fine-tuned', alpha=0.7)
ax3.set_xlabel('Epoch', fontsize=12)
ax3.set_ylabel('Validation Accuracy', fontsize=12)
ax3.set_title('Frozen vs Fine-tuned Comparison', fontsize=14)
ax3.legend()
ax3.grid(True, alpha=0.3)
ax3.set_ylim([0, 1])

# Plot 4: Overfitting Analysis
ax4 = axes[1, 1]
train_val_gap = [t - v for t, v in zip(finetune_history['train_acc'], finetune_history['val_acc'])]
ax4.bar(epochs_ft, train_val_gap, color='orange', alpha=0.7)
ax4.axhline(y=0.1, color='red', linestyle='--', label='Overfitting Threshold')
ax4.set_xlabel('Epoch', fontsize=12)
ax4.set_ylabel('Train - Val Accuracy Gap', fontsize=12)
ax4.set_title('Overfitting Analysis', fontsize=14)
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(Config.OUTPUT_DIR, 'finetune_results.png'), 
            dpi=150, bbox_inches='tight')
plt.show()

print(f"\n‚úì Fine-tuning Results:")
print(f"  Best Frozen Val Acc:    {max(frozen_history['val_acc']):.4f}")
print(f"  Best Fine-tuned Val Acc: {max(finetune_history['val_acc']):.4f}")

# PART C ‚Äî Evaluation

Final evaluation on the held-out test set with comprehensive metrics.

## STEP 24 ‚Äî Test Set Evaluation

Comprehensive evaluation of both models:
1. **Frozen CDBN + Classifier** (feature-based)
2. **Fine-tuned End-to-End CDBN**

Metrics computed:
- Overall accuracy
- Per-class accuracy
- Confusion matrix
- Classification report

In [None]:
# =============================================================================
# STEP 24: Test Set Evaluation Functions
# =============================================================================

from sklearn.metrics import confusion_matrix, classification_report, accuracy_score


def evaluate_frozen_model(
    classifier: CDBNClassifier,
    test_latent_loader: DataLoader,
    class_names: list,
    device: torch.device
) -> dict:
    """
    Evaluate the frozen CDBN + Classifier on test set.
    
    Args:
        classifier: Trained CDBNClassifier
        test_latent_loader: DataLoader with cached latent features
        class_names: List of class names
        device: Computation device
        
    Returns:
        metrics: Dictionary with evaluation metrics
    """
    classifier.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for latents, labels in test_latent_loader:
            latents = latents.to(device)
            preds = classifier.predict(latents)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    # Compute metrics
    overall_acc = accuracy_score(all_labels, all_preds)
    conf_matrix = confusion_matrix(all_labels, all_preds)
    class_report = classification_report(all_labels, all_preds, 
                                         target_names=class_names, 
                                         output_dict=True)
    
    # Per-class accuracy
    per_class_acc = {}
    for i, name in enumerate(class_names):
        mask = all_labels == i
        if mask.sum() > 0:
            per_class_acc[name] = (all_preds[mask] == i).mean()
    
    return {
        'overall_accuracy': overall_acc,
        'confusion_matrix': conf_matrix,
        'classification_report': class_report,
        'per_class_accuracy': per_class_acc,
        'predictions': all_preds,
        'labels': all_labels
    }


def evaluate_e2e_model(
    model: EndToEndCDBN,
    test_image_loader: DataLoader,
    class_names: list,
    device: torch.device
) -> dict:
    """
    Evaluate the end-to-end CDBN on test set.
    
    Args:
        model: EndToEndCDBN instance
        test_image_loader: DataLoader with test images
        class_names: List of class names
        device: Computation device
        
    Returns:
        metrics: Dictionary with evaluation metrics
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in test_image_loader:
            images = images.to(device)
            logits = model(images)
            preds = torch.argmax(logits, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    # Compute metrics
    overall_acc = accuracy_score(all_labels, all_preds)
    conf_matrix = confusion_matrix(all_labels, all_preds)
    class_report = classification_report(all_labels, all_preds, 
                                         target_names=class_names, 
                                         output_dict=True)
    
    # Per-class accuracy
    per_class_acc = {}
    for i, name in enumerate(class_names):
        mask = all_labels == i
        if mask.sum() > 0:
            per_class_acc[name] = (all_preds[mask] == i).mean()
    
    return {
        'overall_accuracy': overall_acc,
        'confusion_matrix': conf_matrix,
        'classification_report': class_report,
        'per_class_accuracy': per_class_acc,
        'predictions': all_preds,
        'labels': all_labels
    }


print("‚úì Evaluation functions defined")

In [None]:
# =============================================================================
# Evaluate Frozen CDBN + Classifier
# =============================================================================

print("=" * 70)
print("EVALUATING FROZEN CDBN + CLASSIFIER")
print("=" * 70)

frozen_metrics = evaluate_frozen_model(
    classifier=classifier,
    test_latent_loader=test_latent_loader,
    class_names=class_names,
    device=Config.DEVICE
)

print(f"\nOverall Test Accuracy: {frozen_metrics['overall_accuracy']:.4f}")
print("\nPer-Class Accuracy:")
for name, acc in frozen_metrics['per_class_accuracy'].items():
    print(f"  {name}: {acc:.4f}")

print("\nClassification Report:")
for name in class_names:
    metrics = frozen_metrics['classification_report'][name]
    print(f"  {name}:")
    print(f"    Precision: {metrics['precision']:.4f}")
    print(f"    Recall:    {metrics['recall']:.4f}")
    print(f"    F1-Score:  {metrics['f1-score']:.4f}")
    print(f"    Support:   {metrics['support']}")

In [None]:
# =============================================================================
# Evaluate Fine-Tuned End-to-End Model
# =============================================================================

# Create test image loader
test_image_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)

print("=" * 70)
print("EVALUATING FINE-TUNED END-TO-END CDBN")
print("=" * 70)

e2e_metrics = evaluate_e2e_model(
    model=e2e_model,
    test_image_loader=test_image_loader,
    class_names=class_names,
    device=Config.DEVICE
)

print(f"\nOverall Test Accuracy: {e2e_metrics['overall_accuracy']:.4f}")
print("\nPer-Class Accuracy:")
for name, acc in e2e_metrics['per_class_accuracy'].items():
    print(f"  {name}: {acc:.4f}")

print("\nClassification Report:")
for name in class_names:
    metrics = e2e_metrics['classification_report'][name]
    print(f"  {name}:")
    print(f"    Precision: {metrics['precision']:.4f}")
    print(f"    Recall:    {metrics['recall']:.4f}")
    print(f"    F1-Score:  {metrics['f1-score']:.4f}")
    print(f"    Support:   {metrics['support']}")

In [None]:
# =============================================================================
# Visualize Confusion Matrices
# =============================================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Plot 1: Frozen CDBN Confusion Matrix
ax1 = axes[0]
im1 = ax1.imshow(frozen_metrics['confusion_matrix'], interpolation='nearest', cmap='Blues')
ax1.set_title('Frozen CDBN + Classifier\nConfusion Matrix', fontsize=14)
ax1.set_xlabel('Predicted Label', fontsize=12)
ax1.set_ylabel('True Label', fontsize=12)

# Add text annotations
for i in range(len(class_names)):
    for j in range(len(class_names)):
        text = ax1.text(j, i, frozen_metrics['confusion_matrix'][i, j],
                       ha="center", va="center", color="white" 
                       if frozen_metrics['confusion_matrix'][i, j] > frozen_metrics['confusion_matrix'].max()/2 
                       else "black")

ax1.set_xticks(range(len(class_names)))
ax1.set_yticks(range(len(class_names)))
ax1.set_xticklabels(class_names, rotation=45, ha='right')
ax1.set_yticklabels(class_names)
plt.colorbar(im1, ax=ax1)

# Plot 2: Fine-tuned E2E Confusion Matrix
ax2 = axes[1]
im2 = ax2.imshow(e2e_metrics['confusion_matrix'], interpolation='nearest', cmap='Blues')
ax2.set_title('Fine-Tuned End-to-End CDBN\nConfusion Matrix', fontsize=14)
ax2.set_xlabel('Predicted Label', fontsize=12)
ax2.set_ylabel('True Label', fontsize=12)

# Add text annotations
for i in range(len(class_names)):
    for j in range(len(class_names)):
        text = ax2.text(j, i, e2e_metrics['confusion_matrix'][i, j],
                       ha="center", va="center", color="white" 
                       if e2e_metrics['confusion_matrix'][i, j] > e2e_metrics['confusion_matrix'].max()/2 
                       else "black")

ax2.set_xticks(range(len(class_names)))
ax2.set_yticks(range(len(class_names)))
ax2.set_xticklabels(class_names, rotation=45, ha='right')
ax2.set_yticklabels(class_names)
plt.colorbar(im2, ax=ax2)

plt.tight_layout()
plt.savefig(os.path.join(Config.OUTPUT_DIR, 'confusion_matrices.png'), 
            dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# =============================================================================
# Model Comparison Summary
# =============================================================================

print("=" * 70)
print("MODEL COMPARISON SUMMARY")
print("=" * 70)
print(f"\n{'Model':<35} {'Test Accuracy':<15}")
print("-" * 50)
print(f"{'Frozen CDBN + Classifier':<35} {frozen_metrics['overall_accuracy']:.4f}")
print(f"{'Fine-Tuned End-to-End CDBN':<35} {e2e_metrics['overall_accuracy']:.4f}")
print("-" * 50)

# Determine best model
if e2e_metrics['overall_accuracy'] > frozen_metrics['overall_accuracy']:
    improvement = e2e_metrics['overall_accuracy'] - frozen_metrics['overall_accuracy']
    print(f"\n‚úì Fine-tuning improved accuracy by {improvement:.4f} ({improvement*100:.2f}%)")
    best_model = "Fine-Tuned End-to-End CDBN"
    best_acc = e2e_metrics['overall_accuracy']
else:
    best_model = "Frozen CDBN + Classifier"
    best_acc = frozen_metrics['overall_accuracy']
    print(f"\n‚úì Frozen CDBN performs better (no fine-tuning needed)")

print(f"\nüèÜ Best Model: {best_model}")
print(f"   Test Accuracy: {best_acc:.4f}")
print("=" * 70)

# Per-class comparison
print("\nPER-CLASS ACCURACY COMPARISON:")
print("-" * 70)
print(f"{'Class':<20} {'Frozen CDBN':<15} {'Fine-Tuned':<15} {'Difference':<15}")
print("-" * 70)
for name in class_names:
    frozen_acc = frozen_metrics['per_class_accuracy'].get(name, 0)
    e2e_acc = e2e_metrics['per_class_accuracy'].get(name, 0)
    diff = e2e_acc - frozen_acc
    diff_str = f"+{diff:.4f}" if diff > 0 else f"{diff:.4f}"
    print(f"{name:<20} {frozen_acc:.4f}{'':>9} {e2e_acc:.4f}{'':>9} {diff_str}")
print("-" * 70)

In [None]:
# =============================================================================
# Save Models and Results
# =============================================================================

# Save model checkpoints
torch.save({
    'classifier_state_dict': classifier.state_dict(),
    'classifier_config': {
        'n_latent': classifier.n_latent,
        'n_classes': classifier.n_classes
    },
    'training_history': frozen_history,
    'test_metrics': frozen_metrics
}, os.path.join(Config.OUTPUT_DIR, 'frozen_cdbn_classifier.pt'))

torch.save({
    'model_state_dict': e2e_model.state_dict(),
    'training_history': finetune_history,
    'test_metrics': e2e_metrics
}, os.path.join(Config.OUTPUT_DIR, 'finetuned_e2e_cdbn.pt'))

# Save results summary to text file
with open(os.path.join(Config.OUTPUT_DIR, 'evaluation_results.txt'), 'w') as f:
    f.write("=" * 70 + "\n")
    f.write("CDBN CLASSIFICATION RESULTS\n")
    f.write("=" * 70 + "\n\n")
    
    f.write("FROZEN CDBN + CLASSIFIER\n")
    f.write("-" * 40 + "\n")
    f.write(f"Test Accuracy: {frozen_metrics['overall_accuracy']:.4f}\n\n")
    f.write("Per-Class Accuracy:\n")
    for name, acc in frozen_metrics['per_class_accuracy'].items():
        f.write(f"  {name}: {acc:.4f}\n")
    
    f.write("\n\nFINE-TUNED END-TO-END CDBN\n")
    f.write("-" * 40 + "\n")
    f.write(f"Test Accuracy: {e2e_metrics['overall_accuracy']:.4f}\n\n")
    f.write("Per-Class Accuracy:\n")
    for name, acc in e2e_metrics['per_class_accuracy'].items():
        f.write(f"  {name}: {acc:.4f}\n")

print("‚úì Models and results saved to output directory:")
print(f"  - frozen_cdbn_classifier.pt")
print(f"  - finetuned_e2e_cdbn.pt")
print(f"  - evaluation_results.txt")

---

# ‚úÖ CDBN Supervised Training and Evaluation Completed

## Summary

The complete CDBN pipeline for Eye OCT Classification has been implemented and evaluated.

### Architecture

| Component | Description |
|-----------|-------------|
| Conv-RBM-1 | 32 filters, 7√ó7, Gaussian-Bernoulli |
| Pooling | 2√ó2 probabilistic pooling |
| Conv-RBM-2 | 64 filters, 5√ó5, Bernoulli-Bernoulli |
| FC-RBM | 207,936 ‚Üí 256 units |
| Classifier | 256 ‚Üí n_classes linear layer |

### Training Phases

| Phase | Description | Optimizer |
|-------|-------------|-----------|
| **Unsupervised Pretraining** | Layer-wise CD training | Manual updates |
| **Feature-based Training** | Train classifier on frozen CDBN | Adam (lr=1e-3) |
| **End-to-End Fine-tuning** | Backprop through entire network | Adam (lr=1e-4) |

### Key Results

| Model | Test Accuracy |
|-------|---------------|
| Frozen CDBN + Classifier | See output above |
| Fine-tuned End-to-End | See output above |

### Saved Artifacts

- `frozen_cdbn_classifier.pt` - Classifier weights and metrics
- `finetuned_e2e_cdbn.pt` - Full model weights and metrics
- `evaluation_results.txt` - Text summary
- Training plots saved as PNG files

### Key Observations

1. **Unsupervised pretraining** provides a good initialization for classification
2. **Frozen CDBN** approach is fast and often sufficient for smaller datasets
3. **Fine-tuning** can improve performance but requires careful regularization
4. **Per-class metrics** reveal class-specific challenges

---

---

# üöÄ Notebook Hardened for Kaggle ‚Äî Ready for Full Execution

## Kaggle-Specific Modifications Applied

### ‚úÖ STEP A: GPU Detection & CUDA Settings
- GPU assertion at startup (fails fast if no GPU)
- `torch.backends.cudnn.benchmark = True` for auto-tuning
- `torch.backends.cudnn.deterministic = False` for performance

### ‚úÖ STEP B: Dataset Paths & Sanity Checks
- Dataset path: `/kaggle/input/<YOUR_DATASET_NAME>/OCT`
- Pre-training validation of GPU, paths, and data shapes
- RuntimeError if critical checks fail

### ‚úÖ STEP C: Conservative Epochs & Batch Size
| Component | Original | Kaggle | Comment |
|-----------|----------|--------|---------|
| Batch Size | 32 | **16** | Fits T4/P100 16GB |
| NUM_WORKERS | 4 | **2** | Kaggle CPU cores |
| Conv-RBM-1 | 10 epochs | **5** | Time limit |
| Conv-RBM-2 | 10 epochs | **5** | Time limit |
| FC-RBM | 10 epochs | **5** | Time limit |
| Classifier | 30 epochs | **20** | Time limit |
| Fine-tuning | 20 epochs | **10** | Time limit |

### ‚úÖ STEP D: Visualization Guards
- `DEBUG = False` disables heavy visualizations
- Set `DEBUG = True` for local debugging only

### ‚úÖ STEP E: Memory Cleanup
- `torch.cuda.empty_cache()` after each RBM training
- Trainer objects deleted to free velocity tensors

### ‚úÖ STEP F: Output Directory
- All outputs saved to `/kaggle/working`
- Subdirectories: `models/`, `plots/`, `logs/`

### ‚úÖ STEP G: Sanity Summary
- Pre-flight check cell validates environment before training

---

## Before Running on Kaggle

1. **Upload your dataset** and update `<YOUR_DATASET_NAME>` in Config
2. **Enable GPU** in Kaggle notebook settings
3. **Run All** cells sequentially
4. **Download** results from `/kaggle/working`

---

**Expected Runtime:** ~6-8 hours on Kaggle T4 GPU (9hr limit)

---

# ‚úÖ Local GPU + 5.5 GB Dataset Optimized ‚Äî Ready for Execution

## Optimization Summary

This notebook has been optimized for:

| Parameter | Value | Rationale |
|-----------|-------|-----------|
| **GPU Memory** | 24 GB | Full utilization with 95% memory cap |
| **Dataset Size** | ~5.5 GB | Streaming DataLoaders, no full RAM caching |
| **RBM Batch Size** | 32 | Optimal for Conv-RBM training |
| **FC-RBM Batch Size** | 64 | Larger batches for FC layer |
| **Classifier Batch Size** | 128 | Maximum throughput for small features |
| **Workers** | 4 | Persistent workers with prefetch |
| **Feature Caching** | Disk-backed | `./outputs_local_5gb/latent_cache/` |

## Key Optimizations Applied

1. **Streaming-Safe DataLoaders**
   - `persistent_workers=True` 
   - `prefetch_factor=2`
   - `pin_memory=True`

2. **Memory Management**
   - Periodic `torch.cuda.empty_cache()` every 100 batches
   - Explicit tensor deletion after each batch
   - GPU memory monitoring with `print_gpu_memory()`

3. **Disk-Backed Feature Caching**
   - Latent features saved to disk during extraction
   - Avoids holding ~5.5 GB of features in RAM
   - Reusable across runs

4. **Pre-Flight Validation**
   - Memory estimation before training starts
   - Warns if memory may exceed capacity

## Execution Checklist

- [ ] Verify GPU is detected (run cell 4)
- [ ] Check pre-flight memory estimation
- [ ] Confirm dataset path in Config class
- [ ] Run all cells sequentially

## Expected Runtime

| Stage | Estimated Time |
|-------|----------------|
| Conv-RBM-1 (5 epochs) | ~15-30 min |
| Conv-RBM-2 (3 epochs) | ~10-20 min |
| FC-RBM (3 epochs) | ~5-10 min |
| Classifier (30 epochs) | ~5-10 min |
| **Total** | **~35-70 min** |

---

**Ready to run!** Execute cells from top to bottom.