# DHF-JSCC Training and Testing on KITTI Dataset üìö

**What is DHF-JSCC?** 
DHF-JSCC (Deep Hierarchical Feature Joint Source-Channel Coding) is a deep learning approach for image compression that jointly optimizes source coding (compression) and channel coding (transmission) using stereo image pairs.

**Learning Objectives:**
- üéØ Understand joint source-channel coding for image compression
- üß† Learn how deep neural networks compress stereo images
- üìä Explore rate-distortion optimization in deep learning
- üîß Practice training compression models from scratch
- üìà Analyze compression performance metrics (PSNR, MS-SSIM, BPP)

**What You'll Build:** A complete stereo image compression system that learns to:
1. Extract hierarchical features from stereo image pairs
2. Compress left images using side information from right images
3. Optimize the trade-off between compression ratio and image quality

## 1. Import Required Libraries üìö

**Why This Step Matters:**
Before building any deep learning model, we need to import the right tools. Each library serves a specific purpose in our compression pipeline:

- **PyTorch**: Our deep learning framework for building and training neural networks
- **Dataset Classes**: Custom loaders for KITTI stereo image pairs
- **Compression Modules**: Our DHF-JSCC model architecture
- **Loss Functions**: MS-SSIM for perceptual quality measurement
- **Visualization**: Tools to monitor training progress and results

In [1]:
import importlib
import subprocess
import sys
import os

# Add current directory to Python path for local imports
if os.getcwd() not in sys.path:
    sys.path.insert(0, os.getcwd())

def import_or_install(package, import_name=None, is_local=False):
    try:
        if import_name is None:
            import_name = package
        return importlib.import_module(import_name)
    except ImportError:
        if is_local:
            # For local modules, just re-raise the ImportError with helpful message
            print(f"Local module '{package}' not found. Please ensure the file exists in the current directory.")
            raise ImportError(f"Local module '{package}' not found")
        else:
            # For pip packages, try to install
            print(f"Package '{package}' not found. Installing...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])
            return importlib.import_module(import_name if import_name else package)

# Standard library and pip packages
os = import_or_install('os')
np = import_or_install('numpy', 'numpy')
pd = import_or_install('pandas', 'pandas')
torch = import_or_install('torch', 'torch')
yaml = import_or_install('pyyaml', 'yaml')
Image = import_or_install('Pillow', 'PIL.Image')
OrderedDict = import_or_install('collections', 'collections').OrderedDict
DataLoader = import_or_install('torch', 'torch.utils.data').DataLoader
ms_ssim = import_or_install('pytorch-msssim', 'pytorch_msssim').ms_ssim
math = import_or_install('math')
plt = import_or_install('matplotlib', 'matplotlib.pyplot')
display = import_or_install('IPython', 'IPython.display').display

# Local modules - use direct import with importlib
try:
    # Clear any cached imports
    if 'dataset' in sys.modules:
        importlib.reload(sys.modules['dataset'])
    if 'dataset.PairKitti' in sys.modules:
        importlib.reload(sys.modules['dataset.PairKitti'])
    
    # Import the module
    dataset_module = importlib.import_module('dataset.PairKitti')
    PairKitti = dataset_module.PairKitti
    print("‚úÖ PairKitti imported successfully")
except Exception as e:
    print(f"‚ùå Could not import PairKitti: {e}")
    print(f"Trying direct sys.path approach...")
    try:
        # Alternative: add dataset folder to path and import directly
        dataset_path = os.path.join(os.getcwd(), 'dataset')
        if dataset_path not in sys.path:
            sys.path.insert(0, dataset_path)
        from PairKitti import PairKitti
        print("‚úÖ PairKitti imported successfully (direct import)")
    except Exception as e2:
        print(f"‚ùå Direct import also failed: {e2}")
        PairKitti = None

try:
    dataset_module = importlib.import_module('dataset.InStereo2K')
    InStereo2K = dataset_module.InStereo2K
    print("‚úÖ InStereo2K imported successfully")
except Exception as e:
    print(f"‚ùå Could not import InStereo2K: {e}")
    try:
        from InStereo2K import InStereo2K
        print("‚úÖ InStereo2K imported successfully (direct import)")
    except:
        InStereo2K = None

try:
    import model_d_fusion2
    print("‚úÖ model_d_fusion2 imported successfully")
except ImportError as e:
    print(f"‚ùå Could not import model_d_fusion2: {e}")
    print("This might be due to missing dependencies in the local module.")
    model_d_fusion2 = None

print("\nLibrary import summary:")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA devices: {torch.cuda.device_count()}")
    print(f"Current device: {torch.cuda.get_device_name(0)}")
print("‚úÖ Core libraries imported successfully!")

‚úÖ PairKitti imported successfully
‚úÖ InStereo2K imported successfully
‚úÖ model_d_fusion2 imported successfully

Library import summary:
PyTorch version: 2.8.0+cu128
CUDA available: True
CUDA devices: 2
Current device: NVIDIA GeForce RTX 2080 Ti
‚úÖ Core libraries imported successfully!
‚úÖ model_d_fusion2 imported successfully

Library import summary:
PyTorch version: 2.8.0+cu128
CUDA available: True
CUDA devices: 2
Current device: NVIDIA GeForce RTX 2080 Ti
‚úÖ Core libraries imported successfully!


In [29]:
# Check CUDA availability and install PyTorch with CUDA if needed
import subprocess
import sys
import torch

def check_cuda_and_install():
    """Check CUDA availability and install appropriate PyTorch version"""
    print("=" * 60)
    print("CUDA AND PYTORCH INSTALLATION CHECK")
    print("=" * 60)
    
    # Check if CUDA is available with current PyTorch installation
    print(f"Current PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        print(f"CUDA version: {torch.version.cuda}")
        print(f"Number of CUDA devices: {torch.cuda.device_count()}")
        if torch.cuda.device_count() > 0:
            print(f"Current device: {torch.cuda.get_device_name()}")
        print("‚úÖ CUDA is properly configured!")
        return True
    else:
        print("‚ùå CUDA is not available with current PyTorch installation")
        
        # Check if NVIDIA GPU is available on system
        try:
            result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=10)
            if result.returncode == 0:
                print("‚úÖ NVIDIA GPU detected on system")
                print("Installing PyTorch with CUDA support...")
                
                # Install PyTorch with CUDA support
                install_commands = [
                    [sys.executable, "-m", "pip", "uninstall", "torch", "torchvision", "torchaudio", "-y"],
                    [sys.executable, "-m", "pip", "install", "torch", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cu121"]
                ]
                
                for cmd in install_commands:
                    print(f"Running: {' '.join(cmd)}")
                    subprocess.check_call(cmd)
                
                # Restart kernel notification
                print("\n‚ö†Ô∏è  IMPORTANT: You may need to restart the kernel for changes to take effect.")
                print("   Go to Kernel -> Restart Kernel to restart.")
                
                return True
            else:
                print("‚ùå No NVIDIA GPU detected on system")
                print("Installing CPU-only PyTorch...")
                
                # Install CPU-only version
                subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cpu"])
                return False
                
        except (subprocess.TimeoutExpired, FileNotFoundError):
            print("‚ùå nvidia-smi not found - no NVIDIA GPU available")
            print("Installing CPU-only PyTorch...")
            
            # Install CPU-only version
            subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cpu"])
            return False

# Run the check
cuda_available = check_cuda_and_install()

# Also install other missing dependencies
print("\n" + "=" * 60)
print("INSTALLING OTHER DEPENDENCIES")
print("=" * 60)

additional_packages = [
    "torchvision",  # For image transforms
    "pytorch-msssim",  # For MS-SSIM loss
    "Pillow",  # For image processing
    "pyyaml",  # For config files
    "pandas",  # For data handling
    "matplotlib",  # For visualization
    "numpy"  # For numerical operations
]

for package in additional_packages:
    try:
        __import__(package.replace('-', '_').lower())
        print(f"‚úÖ {package} already installed")
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

print("\n‚úÖ All dependencies check complete!")
print("You may now re-run the import cell above.")

CUDA AND PYTORCH INSTALLATION CHECK
Current PyTorch version: 2.8.0+cu128
CUDA available: True
CUDA version: 12.8
Number of CUDA devices: 2
Current device: NVIDIA GeForce RTX 2080 Ti
‚úÖ CUDA is properly configured!

INSTALLING OTHER DEPENDENCIES
‚úÖ torchvision already installed
‚úÖ pytorch-msssim already installed
Installing Pillow...
Installing pyyaml...
Installing pyyaml...
‚úÖ pandas already installed
‚úÖ matplotlib already installed
‚úÖ numpy already installed

‚úÖ All dependencies check complete!
You may now re-run the import cell above.
‚úÖ pandas already installed
‚úÖ matplotlib already installed
‚úÖ numpy already installed

‚úÖ All dependencies check complete!
You may now re-run the import cell above.


In [30]:
# Handle missing ops module - try to install CompressAI or create a basic ops module
import os
import torch
import torch.nn as nn

def setup_ops_module():
    """Setup the missing ops module needed by the local modules"""
    
    # First try to install CompressAI which contains common compression operations
    try:
        print("Trying to install CompressAI...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "compressai"])
        
        # Try to import operations from CompressAI
        try:
            from compressai.ops import LowerBound
            from compressai.layers import GDN
            
            # Create a simple ops.py file that imports from CompressAI
            ops_content = '''# Auto-generated ops module
import torch
import torch.nn as nn
from compressai.ops import LowerBound
from compressai.layers import GDN

class Low_bound(torch.autograd.Function):
    """Lower bound operation"""
    @staticmethod
    def forward(ctx, x, bound):
        return LowerBound.apply(x, bound)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

# Export GDN from CompressAI
GDN = GDN
'''
            
            with open('ops.py', 'w') as f:
                f.write(ops_content)
            
            print("‚úÖ ops.py created using CompressAI operations")
            return True
            
        except ImportError:
            print("‚ùå CompressAI installed but couldn't import required operations")
            
    except Exception as e:
        print(f"‚ùå CompressAI installation failed: {e}")
    
    # If CompressAI doesn't work, create a basic ops module
    print("Creating basic ops module...")
    
    basic_ops_content = '''# Basic ops module for DHF-JSCC
import torch
import torch.nn as nn

class Low_bound(torch.autograd.Function):
    """Lower bound operation - clamps values to minimum bound"""
    
    @staticmethod
    def forward(ctx, x, bound):
        ctx.save_for_backward(x)
        ctx.bound = bound
        return torch.clamp(x, min=bound)
    
    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[x < ctx.bound] = 0
        return grad_input, None

class GDN(nn.Module):
    """Generalized Divisive Normalization (GDN) layer"""
    
    def __init__(self, channels, inverse=False, beta_min=1e-6, gamma_init=0.1):
        super().__init__()
        self.inverse = inverse
        self.beta_min = beta_min
        
        self.beta = nn.Parameter(torch.ones(channels))
        self.gamma = nn.Parameter(gamma_init * torch.eye(channels))
        
    def forward(self, x):
        if self.inverse:
            # Inverse GDN
            beta = torch.clamp(self.beta, min=self.beta_min)
            gamma = torch.abs(self.gamma)
            norm = torch.tensordot(x**2, gamma, dims=([1], [0])) + beta
            return x * torch.sqrt(norm)
        else:
            # Forward GDN  
            beta = torch.clamp(self.beta, min=self.beta_min)
            gamma = torch.abs(self.gamma)
            norm = torch.tensordot(x**2, gamma, dims=([1], [0])) + beta
            return x / torch.sqrt(norm)
'''
    
    with open('ops.py', 'w') as f:
        f.write(basic_ops_content)
    
    print("‚úÖ Basic ops.py created with Low_bound and GDN operations")
    return True

# Setup ops module
setup_ops_module()

# Now try to import the local modules again
try:
    import model_d_fusion2
    print("‚úÖ model_d_fusion2 imported successfully after ops setup!")
except ImportError as e:
    print(f"‚ùå Still couldn't import model_d_fusion2: {e}")
    print("There might be other missing dependencies.")

print("ops module setup complete!")

Trying to install CompressAI...
‚úÖ ops.py created using CompressAI operations
‚úÖ model_d_fusion2 imported successfully after ops setup!
ops module setup complete!
‚úÖ ops.py created using CompressAI operations
‚úÖ model_d_fusion2 imported successfully after ops setup!
ops module setup complete!


## ‚úÖ CUDA and Dependencies Status

**CUDA Configuration:**
- ‚úÖ CUDA is properly installed and configured
- ‚úÖ PyTorch version: 2.8.0+cu128  
- ‚úÖ CUDA version: 12.8
- ‚úÖ GPU devices available: 2 (including NVIDIA GeForce RTX 2080 Ti)

**Dependencies Status:**
- ‚úÖ All core libraries imported successfully
- ‚úÖ Local dataset modules (PairKitti, InStereo2K) working
- ‚úÖ Model modules (model_d_fusion2) working after ops module creation
- ‚úÖ CompressAI or custom ops module installed for compression operations

**Ready for Training:**
The environment is now properly configured for DHF-JSCC training with CUDA acceleration on your RTX 2080 Ti GPUs.

## üéì Understanding the Complete DHF-JSCC Pipeline

**What Just Happened? A Step-by-Step Breakdown:**

### üß† **The Science Behind It:**
DHF-JSCC combines several advanced concepts:
1. **Deep Learning**: Neural networks learn compression patterns from data
2. **Stereo Vision**: Uses geometric relationships between camera views
3. **Rate-Distortion Theory**: Optimal trade-off between file size and quality
4. **Entropy Coding**: Efficient representation based on probability distributions

### üîÑ **The Training Process:**
```
Raw Stereo Images ‚Üí Neural Network ‚Üí Compressed Representation ‚Üí Reconstructed Images
      ‚Üë                    ‚Üì                                           ‚Üì
   Dataset            Feature Learning                            Quality Measurement
                           ‚Üì                                           ‚Üì
                    Weight Updates ‚Üê Backpropagation ‚Üê Loss Calculation
```

### üìä **Key Innovations:**
- **Joint Source-Channel Coding**: Optimizes compression AND transmission together
- **Side Information**: Right image helps compress left image more efficiently  
- **Hierarchical Features**: Multi-scale processing captures both details and structure
- **Learned Entropy Models**: Neural networks estimate probability better than traditional methods

### üéØ **Why This Matters:**
- **Practical**: Better compression for autonomous vehicles, VR/AR, streaming
- **Academic**: Advances our understanding of neural compression
- **Technical**: Shows how to combine computer vision + information theory

## 2. Configuration Settings ‚öôÔ∏è

**Understanding the Hyperparameters:**
Each setting controls a different aspect of our compression model. Let's understand why each matters:

**Dataset Parameters:**
- `resize`: Smaller images (128x128) train faster but may lose detail
- `dataset_name`: KITTI provides real-world stereo driving scenes

**Model Architecture:**
- `baseline_model`: 'bls17' refers to Ball√© 2017 entropy model
- `use_side_info`: Using right image to help compress left image
- `num_filters`: More filters = more capacity but slower training

**Training Strategy:**
- `lambda`: Controls rate-distortion trade-off (higher = more compression)
- `lr`: Learning rate (too high = unstable, too low = slow)
- `distortion_loss`: MSE vs MS-SSIM (perceptual quality)

In [3]:
# Configuration for KITTI dataset - Training from Scratch
config = {
    # ===========================================
    # üì∏ DATASET PARAMETERS
    # ===========================================
    'dataset_name': 'KITTI',  # Real-world stereo driving scenes
    'dataset_path': '.',      # Current directory contains dataset folder
    'resize': [128, 128],     # Smaller size for faster training (can increase later)
    
    # ===========================================
    # üß† MODEL ARCHITECTURE PARAMETERS  
    # ===========================================
    'baseline_model': 'bls17',    # Ball√© 2017 entropy model (proven architecture)
    'use_side_info': True,        # Use right image to help compress left image
    'num_filters': 256,           # Network capacity (more = better quality, slower training)
    'cuda': torch.cuda.is_available(),  # Auto-detect GPU availability
    'multi_gpu': True,            # üÜï Enable multi-GPU training (DataParallel)
    
    # ===========================================
    # üíæ PRETRAINED WEIGHTS (disabled for scratch training)
    # ===========================================
    'load_weight': False,         # Set to True if you have pretrained weights
    'weight_path': './pretrained_weights/ours+balle17_MS-SSIM_lambda3e-05.pt',
    
    # ===========================================
    # üéì TRAINING HYPERPARAMETERS
    # ===========================================
    'train': True,               # Enable training mode
    'epochs': 30000,             # üÜï Number of training iterations (30K EPOCHS!)
    'train_batch_size': 12,      # üÜï Images per batch (increased for 2 GPUs)
    'lr': 0.0001,                # Learning rate (Adam optimizer default)
    
    # Rate-Distortion Trade-off Parameters:
    'lambda': 0.00003,           # üéØ KEY PARAMETER: Controls compression vs quality
                                 # Higher Œª = more compression, lower quality
                                 # Lower Œª = less compression, higher quality
    'alpha': 1,                  # Weight for side information loss
    'beta': 1,                   # Weight for additional entropy terms
    
    # Loss Function Choice:
    'distortion_loss': 'MSE',    # MSE = fast but less perceptual
                                 # MS-SSIM = slower but more realistic
    'verbose_period': 50,        # üÜï Print progress every 50 epochs (more frequent for long training)
    
    # ===========================================
    # üíæ SAVING AND OUTPUT PARAMETERS
    # ===========================================
    'save_weights': True,        # Automatically save best model
    'save_output_path': './outputs',  # Where to save results
    'experiment_name': 'from_scratch_bls17_MSE_lambda3e-05',
    
    # ===========================================
    # üß™ TESTING PARAMETERS
    # ===========================================
    'test': True,                # Run evaluation after training
    'save_image': True           # Save reconstructed images for visual inspection
}

# üìä CONFIGURATION SUMMARY
print("Configuration for training from scratch:")
print(f"  Dataset: {config['dataset_name']} (stereo driving scenes)")
print(f"  Image size: {config['resize']} (resize for speed)")
print(f"  Using CUDA: {config['cuda']} ({'üöÄ GPU acceleration' if config['cuda'] else 'üêå CPU only'})")
if config['cuda']:
    gpu_count = torch.cuda.device_count()
    print(f"  GPUs available: {gpu_count}")
    for i in range(gpu_count):
        print(f"    ‚Ä¢ GPU {i}: {torch.cuda.get_device_name(i)}")
    if config['multi_gpu'] and gpu_count > 1:
        print(f"  Multi-GPU: ‚úÖ ENABLED (using {gpu_count} GPUs with DataParallel)")
    else:
        print(f"  Multi-GPU: ‚ùå Disabled (using single GPU)")
print(f"  Load pretrained weights: {config['load_weight']} (training from scratch)")
print(f"  Training epochs: {config['epochs']:,} iterations")
print(f"  Batch size: {config['train_batch_size']} images per batch")
print(f"  Learning rate: {config['lr']} (Adam optimizer)")
print(f"  Lambda (rate-distortion): {config['lambda']} (compression priority)")
print(f"  Distortion loss: {config['distortion_loss']} (quality metric)")
print(f"  Verbose period: {config['verbose_period']} epochs between updates")
print("‚úÖ Ready for training from scratch!")

Configuration for training from scratch:
  Dataset: KITTI (stereo driving scenes)
  Image size: [128, 128] (resize for speed)
  Using CUDA: True (üöÄ GPU acceleration)
  GPUs available: 2
    ‚Ä¢ GPU 0: NVIDIA GeForce RTX 2080 Ti
    ‚Ä¢ GPU 1: NVIDIA GeForce RTX 2080 Ti
  Multi-GPU: ‚úÖ ENABLED (using 2 GPUs with DataParallel)
  Load pretrained weights: False (training from scratch)
  Training epochs: 30,000 iterations
  Batch size: 12 images per batch
  Learning rate: 0.0001 (Adam optimizer)
  Lambda (rate-distortion): 3e-05 (compression priority)
  Distortion loss: MSE (quality metric)
  Verbose period: 50 epochs between updates
‚úÖ Ready for training from scratch!


## 3. Helper Functions üîß

**Core Compression Metrics:**
These functions implement the fundamental metrics used in image compression research:

**1. BPP (Bits Per Pixel):**
- Measures compression efficiency
- Lower BPP = better compression
- Calculated from entropy model likelihoods

**2. Distortion Functions:**
- **MSE**: Pixel-wise differences (simple but not perceptually aligned)
- **MS-SSIM**: Multi-scale structural similarity (matches human perception)
- Trade-off: MSE is faster, MS-SSIM is more realistic

**3. Rate-Distortion Optimization:**
- Balance between file size (rate) and image quality (distortion)
- Lambda parameter controls this trade-off
- Higher lambda = prioritize compression over quality

In [2]:
def get_bpp(model_out, config):
    """
    üìè BITS PER PIXEL (BPP) CALCULATION
    
    This is the "rate" in rate-distortion optimization.
    BPP measures compression efficiency - lower is better!
    
    How it works:
    1. Neural networks output probability distributions (likelihoods)
    2. We use entropy to estimate bit requirements
    3. Different model architectures have different output formats
    
    Args:
        model_out: Output from the compression model
        config: Configuration dictionary
    
    Returns:
        bpp: Total bits per pixel (rate)
        transmitted_bpp: Actually transmitted bits (excluding side info)
    """
    alpha = config['alpha']  # Weight for side information
    beta = config['beta']    # Weight for additional entropy terms
    
    # üèóÔ∏è Handle different model architectures
    if config['baseline_model'] == 'bmshj18':
        # Ball√© et al. 2018 - more complex entropy model
        if config['use_side_info']:
            x_recon, y_recon, likelihoods, y_likelihoods, z_likelihoods, z_likelihoods_cor, w_likelihoods = model_out
            size_est = (-np.log(2) * x_recon.numel() / 3)  # Estimate size in bits
            
            # Main stream bits (actually transmitted)
            bpp = (torch.sum(torch.log(likelihoods)) + torch.sum(torch.log(z_likelihoods))) / size_est
            transmitted_bpp = bpp.clone().detach()
            
            # Add side information costs (weighted)
            bpp += alpha * (torch.sum(torch.log(y_likelihoods)) + torch.sum(torch.log(z_likelihoods_cor))) / size_est
            bpp += beta * torch.sum(torch.log(w_likelihoods)) / size_est
            return bpp, transmitted_bpp
        else:
            # No side information - simpler calculation
            x_recon, likelihoods, z_likelihoods = model_out
            size_est = (-np.log(2) * x_recon.numel() / 3)
            bpp = (torch.sum(torch.log(likelihoods)) + torch.sum(torch.log(z_likelihoods))) / size_est
            return bpp, bpp
            
    elif config['baseline_model'] == 'bls17':
        # üéØ Ball√© et al. 2017 - our current model (simpler entropy model)
        if config['use_side_info']:
            # With side information (stereo compression)
            x_recon, y_recon, likelihoods, y_likelihoods, w_likelihoods = model_out
            size_est = (-np.log(2) * x_recon.numel() / 3)
            
            # Main compression stream
            bpp = torch.sum(torch.log(likelihoods)) / size_est
            transmitted_bpp = bpp.clone().detach()
            
            # Add weighted side information costs
            bpp += alpha * torch.sum(torch.log(y_likelihoods)) / size_est
            bpp += beta * torch.sum(torch.log(w_likelihoods)) / size_est
            return bpp, transmitted_bpp
        else:
            # Single image compression (no stereo)
            x_recon, likelihoods = model_out
            size_est = (-np.log(2) * x_recon.numel() / 3)
            bpp = torch.sum(torch.log(likelihoods)) / size_est
            return bpp, bpp
    return None


def get_distortion(config, out_l, out_r, img, cor_img, mse):
    """
    üìê DISTORTION MEASUREMENT
    
    This is the "distortion" in rate-distortion optimization.
    Measures how much the reconstructed image differs from original.
    
    Two main approaches:
    - MSE: Simple pixel differences (fast but not perceptually accurate)
    - MS-SSIM: Structural similarity (slower but matches human perception)
    
    Args:
        config: Configuration dictionary
        out_l, out_r: Reconstructed left and right images
        img, cor_img: Original left and right images
        mse: MSE loss function
    
    Returns:
        distortion: Total distortion value (lower is better quality)
    """
    distortion = None
    alpha = config['alpha']  # Weight for side information
    
    if config['use_side_info']:
        # üë´ Stereo compression - measure distortion for both images
        x_recon, y_recon = out_l, out_r
        
        if config['distortion_loss'] == 'MS-SSIM':
            # üëÅÔ∏è Perceptually-motivated loss (better for human vision)
            # MS-SSIM ranges from 0 (worst) to 1 (perfect)
            # We use (1 - MS-SSIM) so lower is better
            distortion = (1 - ms_ssim(img.cpu(), x_recon.cpu(), data_range=1.0, size_average=True,
                                      win_size=7))
            distortion += alpha * (1 - ms_ssim(cor_img.cpu(), y_recon.cpu(), data_range=1.0, size_average=True,
                                               win_size=7))
        elif config['distortion_loss'] == 'MSE':
            # üìä Simple pixel-wise differences (faster but less realistic)
            distortion = mse(img, x_recon)
            distortion += alpha * mse(cor_img, y_recon)
    else:
        # üë§ Single image compression
        x_recon = out_l
        if config['distortion_loss'] == 'MS-SSIM':
            distortion = (1 - ms_ssim(img.cpu(), x_recon.cpu(), data_range=1.0, size_average=True,
                                      win_size=7))
        elif config['distortion_loss'] == 'MSE':
            distortion = mse(img, x_recon)
    return distortion


def save_image(x_recon, x, path, name):
    """
    üíæ SAVE RECONSTRUCTED IMAGES
    
    Saves original and reconstructed images side-by-side for visual comparison.
    This helps you see the compression quality visually.
    
    Args:
        x_recon: Reconstructed image tensor
        x: Original image tensor  
        path: Directory to save images
        name: Filename (without extension)
    """
    # Convert tensors to numpy arrays and scale to [0, 255]
    img_recon = np.clip((x_recon * 255).squeeze().cpu().numpy(), 0, 255)
    img = np.clip((x * 255).squeeze().cpu().numpy(), 0, 255)
    
    # Rearrange dimensions from (C, H, W) to (H, W, C) for PIL
    img_recon = np.transpose(img_recon, (1, 2, 0)).astype('uint8')
    img = np.transpose(img, (1, 2, 0)).astype('uint8')
    
    # Concatenate original and reconstructed side-by-side
    img_final = Image.fromarray(np.concatenate((img, img_recon), axis=1), 'RGB')
    
    # Create directory if it doesn't exist
    if not os.path.exists(path):
        os.makedirs(path)
    
    # Save the comparison image
    img_final.save(os.path.join(path, name + '.png'))


def map_layers(weight):
    """
    üîÑ LAYER NAME MAPPING
    
    Sometimes pretrained weights have different layer names.
    This function maps old names to new names for compatibility.
    
    Args:
        weight: Dictionary of model weights
    
    Returns:
        OrderedDict with mapped layer names
    """
    return OrderedDict([(k.replace('z', 'w'), v) if 'z' in k else (k, v) for k, v in weight.items()])


print("üìö Helper functions defined successfully!")
print("üîß These functions handle:")
print("   ‚Ä¢ BPP calculation (compression efficiency)")
print("   ‚Ä¢ Distortion measurement (image quality)")
print("   ‚Ä¢ Image saving (visual results)")
print("   ‚Ä¢ Weight mapping (model compatibility)")

üìö Helper functions defined successfully!
üîß These functions handle:
   ‚Ä¢ BPP calculation (compression efficiency)
   ‚Ä¢ Distortion measurement (image quality)
   ‚Ä¢ Image saving (visual results)
   ‚Ä¢ Weight mapping (model compatibility)


## 4. Dataset Initialization üì∏

**Understanding KITTI Stereo Dataset:**
- **What**: Real driving scenes from Karlsruhe, Germany
- **Why**: Realistic stereo pairs with known geometry
- **Structure**: Left/right camera pairs with calibrated disparity
- **Use Case**: Perfect for stereo compression research

**Data Loading Strategy:**
- **Training**: Learn compression patterns (largest split)
- **Validation**: Monitor overfitting during training  
- **Testing**: Final performance evaluation (never seen during training)

**Stereo Compression Advantage:**
Using right image as "side information" helps compress left image more efficiently because:
1. Stereo images share similar content (same scene)
2. Geometric relationships provide predictable correlations
3. Joint compression exploits redundancy between views

### ‚ö†Ô∏è Dataset Setup Check

**Before running the training, we need to verify the KITTI dataset is properly downloaded.**

The dataset should contain:
- `data_scene_flow_multiview/training/image_2/` (left images)
- `data_scene_flow_multiview/training/image_3/` (right images)

If you see a FileNotFoundError, it means the actual image files need to be downloaded.

In [33]:
# üîç DATASET VERIFICATION - CHECK IF IMAGES EXIST

print("üîç Checking KITTI dataset structure...")
print("=" * 60)

dataset_base = './dataset'
required_paths = [
    'data_scene_flow_multiview/training/image_2',  # Left images
    'data_scene_flow_multiview/training/image_3',  # Right images  
]

all_exist = True
for rel_path in required_paths:
    full_path = os.path.join(dataset_base, rel_path)
    exists = os.path.exists(full_path)
    
    if exists:
        # Count images in directory
        try:
            image_files = [f for f in os.listdir(full_path) if f.endswith(('.png', '.jpg', '.jpeg'))]
            print(f"‚úÖ {rel_path}")
            print(f"   Found {len(image_files)} images")
        except Exception as e:
            print(f"‚ö†Ô∏è  {rel_path} exists but can't read: {e}")
            all_exist = False
    else:
        print(f"‚ùå {rel_path}")
        print(f"   Path does not exist!")
        all_exist = False

print("=" * 60)

if not all_exist:
    print("\nüö® DATASET NOT FOUND!")
    print("\nüì• To fix this, you need to download the KITTI Stereo 2015 dataset:")
    print("\n**Option 1: Download from Official KITTI Website**")
    print("1. Visit: http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php")
    print("2. Download 'data_scene_flow_multiview' (left + right stereo images)")
    print("3. Extract to ./dataset/ folder")
    print("\n**Option 2: Use a Smaller Test Dataset**")
    print("For quick testing, you can:")
    print("‚Ä¢ Create dummy data (see next cell)")
    print("‚Ä¢ Use a smaller public stereo dataset")
    print("‚Ä¢ Download a subset of KITTI images")
    
    print("\nüí° Would you like to create a small dummy dataset for testing?")
    print("   Run the next cell to generate synthetic test data.")
else:
    print("\n‚úÖ DATASET READY!")
    print("You can proceed with training.")

üîç Checking KITTI dataset structure...
‚úÖ data_scene_flow_multiview/training/image_2
   Found 400 images
‚úÖ data_scene_flow_multiview/training/image_3
   Found 400 images

‚úÖ DATASET READY!
You can proceed with training.


In [34]:
# üé® OPTIONAL: CREATE DUMMY DATASET FOR TESTING
# Run this cell ONLY if you don't have the real KITTI dataset and want to test the pipeline

def create_dummy_kitti_dataset(num_images=50):
    """
    Creates a small dummy stereo dataset for testing the pipeline.
    This is NOT real data - just for testing that the code works!
    """
    print("üé® Creating dummy KITTI-style dataset for testing...")
    
    # Create directory structure
    base_path = './dataset/data_scene_flow_multiview/training'
    left_dir = os.path.join(base_path, 'image_2')
    right_dir = os.path.join(base_path, 'image_3')
    
    os.makedirs(left_dir, exist_ok=True)
    os.makedirs(right_dir, exist_ok=True)
    
    # Generate dummy images
    from PIL import Image
    import numpy as np
    
    print(f"Generating {num_images} dummy stereo pairs...")
    
    for i in range(num_images):
        # Create random images (not real stereo pairs, just for testing!)
        # Left image: Random colored pattern
        left_img = np.random.randint(0, 255, (375, 1242, 3), dtype=np.uint8)
        # Right image: Similar but slightly shifted (fake stereo effect)
        right_img = np.roll(left_img, shift=10, axis=1)
        right_img = np.clip(right_img + np.random.randint(-20, 20, right_img.shape), 0, 255).astype(np.uint8)
        
        # Save images in KITTI format
        left_path = os.path.join(left_dir, f'{i:06d}_10.png')
        right_path = os.path.join(right_dir, f'{i:06d}_10.png')
        
        Image.fromarray(left_img).save(left_path)
        Image.fromarray(right_img).save(right_path)
        
        if (i + 1) % 10 == 0:
            print(f"   Generated {i + 1}/{num_images} pairs...")
    
    print(f"\n‚úÖ Dummy dataset created!")
    print(f"   üìÅ Left images: {left_dir}")
    print(f"   üìÅ Right images: {right_dir}")
    print(f"\n‚ö†Ô∏è  IMPORTANT: This is synthetic data!")
    print("   ‚Ä¢ Use only for testing the pipeline")
    print("   ‚Ä¢ Download real KITTI data for actual training")
    print("   ‚Ä¢ Results won't be meaningful with dummy data")

# Uncomment the line below to create dummy data
# create_dummy_kitti_dataset(num_images=100)

print("üí° To create dummy test data, uncomment the last line and run this cell")
print("   This will let you test the training pipeline without downloading real data")

üí° To create dummy test data, uncomment the last line and run this cell
   This will let you test the training pipeline without downloading real data


### üì• Download Real KITTI Dataset

**KITTI Stereo 2015 Dataset:**
This will download the official KITTI Scene Flow dataset which contains:
- **Left stereo images** (image_2): ~15 GB
- **Right stereo images** (image_3): ~15 GB
- Real driving scenes from Karlsruhe, Germany
- High-quality calibrated stereo pairs

**Note:** The download is large (~30 GB total) and may take time depending on your connection.

In [35]:
# üì• DOWNLOAD KITTI STEREO 2015 DATASET

import urllib.request
import zipfile
import os
from tqdm import tqdm

class DownloadProgressBar(tqdm):
    """Progress bar for downloads"""
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)

def download_url(url, output_path):
    """Download file with progress bar"""
    with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t:
        urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)

def download_kitti_dataset():
    """
    Download and extract KITTI Stereo 2015 dataset
    """
    print("üöó KITTI Stereo 2015 Dataset Download")
    print("=" * 70)
    
    # Dataset URLs (official KITTI website)
    datasets = {
        'left_images': {
            'url': 'https://s3.eu-central-1.amazonaws.com/avg-kitti/data_scene_flow.zip',
            'filename': 'data_scene_flow.zip',
            'description': 'Scene Flow Multi-view (includes left and right stereo images)'
        }
    }
    
    # Create dataset directory
    dataset_dir = './dataset'
    download_dir = os.path.join(dataset_dir, 'downloads')
    os.makedirs(download_dir, exist_ok=True)
    
    print(f"üìÅ Dataset directory: {dataset_dir}")
    print(f"üìÅ Download directory: {download_dir}")
    print()
    
    # Check if already downloaded
    extracted_path = os.path.join(dataset_dir, 'data_scene_flow_multiview')
    if os.path.exists(extracted_path):
        print("‚úÖ Dataset already exists!")
        
        # Verify images exist
        left_imgs = os.path.join(extracted_path, 'training/image_2')
        right_imgs = os.path.join(extracted_path, 'training/image_3')
        
        if os.path.exists(left_imgs) and os.path.exists(right_imgs):
            left_count = len([f for f in os.listdir(left_imgs) if f.endswith('.png')])
            right_count = len([f for f in os.listdir(right_imgs) if f.endswith('.png')])
            print(f"   üì∏ Left images: {left_count}")
            print(f"   üì∏ Right images: {right_count}")
            print("\n‚úÖ Dataset is ready to use!")
            return
        else:
            print("‚ö†Ô∏è  Dataset folder exists but images are missing. Re-downloading...")
    
    # Download dataset
    for name, info in datasets.items():
        print(f"\nüì• Downloading {info['description']}...")
        print(f"   URL: {info['url']}")
        
        zip_path = os.path.join(download_dir, info['filename'])
        
        # Check if zip already exists
        if os.path.exists(zip_path):
            print(f"   ‚ÑπÔ∏è  Zip file already exists: {zip_path}")
            print("   Skipping download...")
        else:
            print(f"   üíæ Downloading to: {zip_path}")
            print(f"   ‚ö†Ô∏è  This is a large file (~15-30 GB), please be patient...")
            print()
            
            try:
                download_url(info['url'], zip_path)
                print(f"\n   ‚úÖ Download complete!")
            except Exception as e:
                print(f"\n   ‚ùå Download failed: {e}")
                print("\n   üîÑ Alternative: Manual Download")
                print(f"   1. Visit: http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php")
                print(f"   2. Download 'data_scene_flow_multiview.zip'")
                print(f"   3. Place it in: {download_dir}")
                print(f"   4. Re-run this cell to extract")
                return
        
        # Extract dataset
        print(f"\nüì¶ Extracting {info['filename']}...")
        print(f"   This may take several minutes...")
        
        try:
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                # Extract to dataset directory
                zip_ref.extractall(dataset_dir)
            
            print(f"   ‚úÖ Extraction complete!")
            
            # Verify extraction
            if os.path.exists(extracted_path):
                print(f"\n‚úÖ Dataset successfully set up!")
                
                # Count images
                left_imgs = os.path.join(extracted_path, 'training/image_2')
                right_imgs = os.path.join(extracted_path, 'training/image_3')
                
                if os.path.exists(left_imgs):
                    left_count = len([f for f in os.listdir(left_imgs) if f.endswith('.png')])
                    print(f"   üì∏ Left images: {left_count}")
                
                if os.path.exists(right_imgs):
                    right_count = len([f for f in os.listdir(right_imgs) if f.endswith('.png')])
                    print(f"   üì∏ Right images: {right_count}")
                
                print(f"\nüéâ KITTI dataset is ready for training!")
                print(f"   You can now proceed with the training cells.")
            else:
                print(f"   ‚ö†Ô∏è  Extraction completed but expected path not found")
                
        except Exception as e:
            print(f"   ‚ùå Extraction failed: {e}")
            print(f"   You may need to extract manually")
            return

# Run the download
print("üöÄ Starting KITTI dataset download...")
print("‚è±Ô∏è  This will take some time depending on your internet speed")
print()

try:
    # First, try to install tqdm if not available
    try:
        from tqdm import tqdm
    except ImportError:
        print("Installing tqdm for progress bars...")
        import subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "tqdm"])
        from tqdm import tqdm
    
    download_kitti_dataset()
    
except KeyboardInterrupt:
    print("\n\n‚ö†Ô∏è  Download interrupted by user")
    print("You can re-run this cell to resume")
except Exception as e:
    print(f"\n‚ùå Error: {e}")
    print("\nüîÑ Manual download instructions:")
    print("1. Visit: http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php")
    print("2. Download 'data_scene_flow.zip' or 'data_scene_flow_multiview.zip'")
    print("3. Extract to ./dataset/ folder")
    print("4. Verify the structure:")
    print("   ./dataset/data_scene_flow_multiview/training/image_2/ (left images)")
    print("   ./dataset/data_scene_flow_multiview/training/image_3/ (right images)")

üöÄ Starting KITTI dataset download...
‚è±Ô∏è  This will take some time depending on your internet speed

üöó KITTI Stereo 2015 Dataset Download
üìÅ Dataset directory: ./dataset
üìÅ Download directory: ./dataset/downloads

‚úÖ Dataset already exists!
   üì∏ Left images: 400
   üì∏ Right images: 400

‚úÖ Dataset is ready to use!
   üì∏ Left images: 400
   üì∏ Right images: 400

‚úÖ Dataset is ready to use!


### ‚úÖ Dataset Status Summary

**KITTI Dataset Successfully Downloaded and Configured!**

üìä **Dataset Statistics:**
- **Left stereo images (image_2)**: 400 training images
- **Right stereo images (image_3)**: 400 training images
- **Total size**: ~1.68 GB downloaded
- **Location**: `./dataset/data_scene_flow_multiview/`

üéØ **Ready for Training:**
The dataset is now properly configured and you can proceed with:
1. Running the dataset initialization cell (already done)
2. Starting the training loop
3. Evaluating on test images

üí° **Note:** This is the KITTI 2015 Scene Flow dataset, perfect for stereo compression research!

In [4]:
# üóÇÔ∏è INITIALIZE KITTI STEREO DATASET
path = './dataset'          # Path to dataset folder
resize = tuple(config['resize'])  # Convert list to tuple for transforms

print(f"üöó Loading {config['dataset_name']} dataset...")
print(f"üìê Image resize: {resize} (smaller = faster training)")

# üìö CREATE DATASET SPLITS
# Each dataset handles loading stereo pairs and applying transforms
train_dataset = PairKitti(path=path, set_type='train', resize=resize)
val_dataset = PairKitti(path=path, set_type='val', resize=resize)  
test_dataset = PairKitti(path=path, set_type='test', resize=resize)

print(f"üìà Training samples: {len(train_dataset)} (learn compression patterns)")
print(f"üéØ Validation samples: {len(val_dataset)} (monitor overfitting)")  
print(f"üß™ Test samples: {len(test_dataset)} (final evaluation)")

# üîÑ CREATE DATA LOADERS
# DataLoaders handle batching, shuffling, and parallel loading
batch_size = config['train_batch_size']

# Training: Shuffle for better learning, multiple workers for speed
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 
                         shuffle=True, num_workers=3)

# Validation: Shuffle for variety, same batch size  
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, 
                       shuffle=True, num_workers=3)

# Testing: No shuffle (reproducible), batch size 1 for individual analysis
test_loader = DataLoader(dataset=test_dataset, batch_size=1, 
                        shuffle=False, num_workers=3)

print(f"\n‚öôÔ∏è Data loaders created with batch size: {batch_size}")
print(f"üîÑ Training batches: {len(train_loader)} ({len(train_dataset)} images √∑ {batch_size})")
print(f"üìä Validation batches: {len(val_loader)} ({len(val_dataset)} images √∑ {batch_size})")  
print(f"üß™ Test batches: {len(test_loader)} (1 image per batch for analysis)")

print(f"\nüí° **What happens in each batch:**")
print(f"   ‚Ä¢ Load {batch_size} stereo pairs (left + right images)")
print(f"   ‚Ä¢ Apply transforms (resize, normalize to [0,1])")
print(f"   ‚Ä¢ Stack into tensors for GPU processing")
print(f"   ‚Ä¢ Feed to model for compression/reconstruction")

üöó Loading KITTI dataset...
üìê Image resize: (128, 128) (smaller = faster training)
üìà Training samples: 1576 (learn compression patterns)
üéØ Validation samples: 790 (monitor overfitting)
üß™ Test samples: 790 (final evaluation)

‚öôÔ∏è Data loaders created with batch size: 12
üîÑ Training batches: 132 (1576 images √∑ 12)
üìä Validation batches: 66 (790 images √∑ 12)
üß™ Test batches: 790 (1 image per batch for analysis)

üí° **What happens in each batch:**
   ‚Ä¢ Load 12 stereo pairs (left + right images)
   ‚Ä¢ Apply transforms (resize, normalize to [0,1])
   ‚Ä¢ Stack into tensors for GPU processing
   ‚Ä¢ Feed to model for compression/reconstruction


## 5. Model Initialization üß†

**DHF-JSCC Architecture Overview:**
The model has several key components working together:

1. **Encoder Network**: Converts images to compressed latent representations
2. **Entropy Model**: Estimates probability distributions for efficient coding
3. **Decoder Network**: Reconstructs images from compressed representations
4. **Side Information**: Uses right image to help compress left image

**Training from Scratch vs Pretrained:**
- **From Scratch**: Random weights, longer training, full learning experience
- **Pretrained**: Pre-learned weights, faster convergence, less exploration

**Why 21M Parameters?**
Deep compression needs capacity to:
- Learn complex image features at multiple scales
- Model correlations between stereo pairs  
- Optimize rate-distortion trade-offs across diverse scenes

In [5]:
# üß† MODEL INITIALIZATION - BUILDING OUR COMPRESSION NETWORK

# üîß Configuration check
with_side_info = config['use_side_info']
print(f"üéØ Model type: {'Stereo compression (with side info)' if with_side_info else 'Single image compression'}")

# üèóÔ∏è CREATE NEW MODEL FOR TRAINING FROM SCRATCH
print("üöÄ Creating new model for training from scratch...")
if model_d_fusion2 is not None:
    # Initialize model with specified number of filters (network capacity)
    model = model_d_fusion2.Image_coding(M=config['num_filters'], N2=25)
    print(f"‚úÖ New DHF-JSCC model created!")
    print(f"   ‚Ä¢ Filters: {config['num_filters']} (network capacity)")
    print(f"   ‚Ä¢ Architecture: Encoder ‚Üí Entropy Model ‚Üí Decoder")
    print(f"   ‚Ä¢ Side info: {'Enabled (stereo)' if with_side_info else 'Disabled (single)'}")
else:
    raise ImportError("‚ùå model_d_fusion2 module required but not available. Check imports!")

# üöÄ MOVE MODEL TO GPU (CUDA ACCELERATION)
model = model.cuda() if config['cuda'] else model
device = 'CUDA (GPU)' if config['cuda'] else 'CPU'
print(f"üìç Model moved to: {device}")

# üéÆ MULTI-GPU SETUP (DataParallel)
if config['cuda'] and config.get('multi_gpu', False) and torch.cuda.device_count() > 1:
    gpu_count = torch.cuda.device_count()
    print(f"\nüéÆ MULTI-GPU TRAINING ENABLED!")
    print(f"   ‚Ä¢ Available GPUs: {gpu_count}")
    for i in range(gpu_count):
        print(f"   ‚Ä¢ GPU {i}: {torch.cuda.get_device_name(i)}")
    
    # Wrap model with DataParallel
    model = torch.nn.DataParallel(model)
    print(f"‚úÖ Model wrapped with DataParallel")
    print(f"   ‚Ä¢ Batch will be split across {gpu_count} GPUs")
    print(f"   ‚Ä¢ Effective batch size per GPU: {config['train_batch_size'] // gpu_count}")
    print(f"   ‚Ä¢ Gradients will be synchronized after each batch")
elif config['cuda']:
    print(f"   ‚Ä¢ GPU Memory: More efficient for large models")
    print(f"   ‚Ä¢ Parallel Processing: Faster matrix operations")
    print(f"   ‚Ä¢ Using single GPU: {torch.cuda.get_device_name(0)}")

# üìä ANALYZE MODEL COMPLEXITY
if hasattr(model, 'parameters'):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nüî¢ Model Statistics:")
    print(f"   ‚Ä¢ Total parameters: {total_params:,}")
    print(f"   ‚Ä¢ Trainable parameters: {trainable_params:,}")
    print(f"   ‚Ä¢ Memory footprint: ~{total_params * 4 / 1e6:.1f} MB (float32)")

# üéì INITIALIZE OPTIMIZER (ADAM - ADAPTIVE LEARNING)
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], amsgrad=True)
print(f"\n‚öôÔ∏è Optimizer: Adam with learning rate {config['lr']}")
print(f"   ‚Ä¢ Adam: Adapts learning rate per parameter")
print(f"   ‚Ä¢ AMSGrad: Improved convergence stability")
print(f"   ‚Ä¢ Learning rate: {config['lr']} (balanced for deep networks)")

# üíæ WEIGHT LOADING LOGIC (CURRENTLY DISABLED FOR SCRATCH TRAINING)
if config['load_weight'] and os.path.exists(config['weight_path']):
    print(f"\nüì• Loading pretrained weights from: {config['weight_path']}")
    try:
        checkpoint = torch.load(config['weight_path'], 
                              map_location=torch.device('cuda' if config['cuda'] else 'cpu'))
        
        # Handle potential layer name differences
        if config['baseline_model'] == 'bls17' and with_side_info:
            checkpoint['model_state_dict'] = map_layers(checkpoint['model_state_dict'])
            
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("‚úÖ Pretrained weights loaded successfully!")
    except Exception as e:
        print(f"‚ùå Failed to load weights: {e}")
        print("üîÑ Continuing with randomly initialized weights...")
elif config['load_weight']:
    print(f"\n‚ùå Weight file not found: {config['weight_path']}")
    print("üîÑ Continuing with randomly initialized weights...")
else:
    print(f"\nüé≤ Training from scratch with randomly initialized weights")
    print("   ‚Ä¢ All weights start random (Gaussian/Xavier initialization)")
    print("   ‚Ä¢ Model will learn compression from ground up")
    print("   ‚Ä¢ Longer training but full learning experience")

# üìâ LEARNING RATE SCHEDULER (ADAPTIVE LEARNING)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=10, min_lr=1e-7
)
print(f"\nüìà Learning Rate Scheduler: ReduceLROnPlateau")
print(f"   ‚Ä¢ Reduces LR when validation loss plateaus")
print(f"   ‚Ä¢ Factor: 0.1 (10x reduction)")
print(f"   ‚Ä¢ Patience: 10 epochs before reduction")
print(f"   ‚Ä¢ Min LR: 1e-7 (prevents going too small)")

# üè∑Ô∏è CREATE EXPERIMENT NAME FOR TRACKING
experiment_name = str(train_dataset) + '_' + config['distortion_loss'] + '_lambda:' + str(config['lambda'])
print(f"\nüè∑Ô∏è Experiment identifier: {experiment_name}")
print("‚úÖ Model initialization complete - ready for training!")


üéØ Model type: Stereo compression (with side info)
üöÄ Creating new model for training from scratch...
‚úÖ New DHF-JSCC model created!
   ‚Ä¢ Filters: 256 (network capacity)
   ‚Ä¢ Architecture: Encoder ‚Üí Entropy Model ‚Üí Decoder
   ‚Ä¢ Side info: Enabled (stereo)
üìç Model moved to: CUDA (GPU)

üéÆ MULTI-GPU TRAINING ENABLED!
   ‚Ä¢ Available GPUs: 2
   ‚Ä¢ GPU 0: NVIDIA GeForce RTX 2080 Ti
   ‚Ä¢ GPU 1: NVIDIA GeForce RTX 2080 Ti
‚úÖ Model wrapped with DataParallel
   ‚Ä¢ Batch will be split across 2 GPUs
   ‚Ä¢ Effective batch size per GPU: 6
   ‚Ä¢ Gradients will be synchronized after each batch

üî¢ Model Statistics:
   ‚Ä¢ Total parameters: 21,380,152
   ‚Ä¢ Trainable parameters: 21,380,152
   ‚Ä¢ Memory footprint: ~85.5 MB (float32)

‚öôÔ∏è Optimizer: Adam with learning rate 0.0001
   ‚Ä¢ Adam: Adapts learning rate per parameter
   ‚Ä¢ AMSGrad: Improved convergence stability
   ‚Ä¢ Learning rate: 0.0001 (balanced for deep networks)

üé≤ Training from scratch with rando

## üöÄ Multi-GPU Setup (DataParallel)

**Using Both RTX 2080 Ti GPUs:**
PyTorch's DataParallel allows us to use multiple GPUs simultaneously:

**How It Works:**
- **Model Replication**: Copy model to each GPU
- **Batch Splitting**: Divide each batch across GPUs (batch_size=8 ‚Üí 4 per GPU)
- **Parallel Forward**: Each GPU processes its portion simultaneously
- **Gradient Aggregation**: Combine gradients from all GPUs
- **Single Update**: Update model weights once with combined gradients

**Performance Benefits:**
- ~2x faster training with 2 GPUs
- Better GPU memory utilization
- Same final model quality

**Note:** The batch size will be automatically split across GPUs!

In [38]:
# üöÄ ENABLE MULTI-GPU TRAINING WITH DATAPARALLEL

# Check GPU availability
num_gpus = torch.cuda.device_count()
print(f"üîç GPU Detection:")
print(f"   ‚Ä¢ Available GPUs: {num_gpus}")

if num_gpus > 1:
    print(f"\nüöÄ Enabling DataParallel across {num_gpus} GPUs:")
    for i in range(num_gpus):
        print(f"   ‚Ä¢ GPU {i}: {torch.cuda.get_device_name(i)}")
    
    # Wrap model with DataParallel
    model = torch.nn.DataParallel(model)
    print(f"\n‚úÖ Model wrapped with DataParallel!")
    print(f"   ‚Ä¢ Batch size: {config['train_batch_size']} (split {config['train_batch_size']//num_gpus} per GPU)")
    print(f"   ‚Ä¢ Both GPUs will be utilized automatically")
    print(f"   ‚Ä¢ Expected speedup: ~{num_gpus}x faster training")
    
elif num_gpus == 1:
    print(f"\n‚ö†Ô∏è  Only 1 GPU detected - using single GPU training")
    print(f"   ‚Ä¢ GPU: {torch.cuda.get_device_name(0)}")
else:
    print(f"\n‚ùå No GPUs detected - using CPU (very slow!)")

print(f"\nüí° Monitor GPU usage with: nvidia-smi -l 1")
print(f"   You should see both GPUs active during training!")

üîç GPU Detection:
   ‚Ä¢ Available GPUs: 2

üöÄ Enabling DataParallel across 2 GPUs:
   ‚Ä¢ GPU 0: NVIDIA GeForce RTX 2080 Ti
   ‚Ä¢ GPU 1: NVIDIA GeForce RTX 2080 Ti

‚úÖ Model wrapped with DataParallel!
   ‚Ä¢ Batch size: 8 (split 4 per GPU)
   ‚Ä¢ Both GPUs will be utilized automatically
   ‚Ä¢ Expected speedup: ~2x faster training

üí° Monitor GPU usage with: nvidia-smi -l 1
   You should see both GPUs active during training!


### üöÄ Training from Scratch - Ready!

**Model Configuration:**
- New model created with random weights
- Architecture: DHF-JSCC with side information
- Parameters will be shown after model creation

**Training Strategy:**
- Starting with MSE loss (more stable for initial training)
- Can switch to MS-SSIM loss later for better perceptual quality
- Learning rate: 1e-4 (good starting point)
- Batch size: 8 (conservative for stability)

**Monitoring:**
- Verbose output every 10 epochs
- Model checkpoints saved when validation loss improves
- Both PSNR and MS-SSIM metrics tracked

In [6]:
# Training Progress Tracker and Tips
print("üöÄ TRAINING FROM SCRATCH - READY TO START!")
print("=" * 60)
print(f"‚úÖ Model: {total_params:,} parameters on {'CUDA' if config['cuda'] else 'CPU'}")
print(f"‚úÖ Dataset: {len(train_dataset)} training, {len(val_dataset)} validation samples")
print(f"‚úÖ Training setup: {config['epochs']} epochs, batch size {config['train_batch_size']}")
print("=" * 60)

print("\nüìä TRAINING EXPECTATIONS:")
print("‚Ä¢ Initial epochs may have high loss - this is normal")
print("‚Ä¢ Look for consistent loss decrease over time")
print("‚Ä¢ PSNR should gradually improve (higher is better)")
print("‚Ä¢ MS-SSIM should approach 1.0 (closer to 1.0 is better)")

print("\n‚öôÔ∏è TRAINING TIPS:")
print("‚Ä¢ Monitor GPU memory usage - reduce batch size if OOM occurs")
print("‚Ä¢ Early epochs focus on convergence, later epochs on fine-tuning")
print("‚Ä¢ If loss plateaus, the scheduler will reduce learning rate")
print("‚Ä¢ Best model weights are saved automatically")

print("\nüéØ TYPICAL BENCHMARKS FOR KITTI:")
print("‚Ä¢ Good PSNR: >25 dB")
print("‚Ä¢ Excellent PSNR: >30 dB") 
print("‚Ä¢ Good MS-SSIM: >0.85")
print("‚Ä¢ Excellent MS-SSIM: >0.90")

print("\n‚ñ∂Ô∏è  Ready to run the training loop!")
print("   Execute the training cell to start training from scratch.")
print("=" * 60)

üöÄ TRAINING FROM SCRATCH - READY TO START!
‚úÖ Model: 21,380,152 parameters on CUDA
‚úÖ Dataset: 1576 training, 790 validation samples
‚úÖ Training setup: 30000 epochs, batch size 12

üìä TRAINING EXPECTATIONS:
‚Ä¢ Initial epochs may have high loss - this is normal
‚Ä¢ Look for consistent loss decrease over time
‚Ä¢ PSNR should gradually improve (higher is better)
‚Ä¢ MS-SSIM should approach 1.0 (closer to 1.0 is better)

‚öôÔ∏è TRAINING TIPS:
‚Ä¢ Monitor GPU memory usage - reduce batch size if OOM occurs
‚Ä¢ Early epochs focus on convergence, later epochs on fine-tuning
‚Ä¢ If loss plateaus, the scheduler will reduce learning rate
‚Ä¢ Best model weights are saved automatically

üéØ TYPICAL BENCHMARKS FOR KITTI:
‚Ä¢ Good PSNR: >25 dB
‚Ä¢ Excellent PSNR: >30 dB
‚Ä¢ Good MS-SSIM: >0.85
‚Ä¢ Excellent MS-SSIM: >0.90

‚ñ∂Ô∏è  Ready to run the training loop!
   Execute the training cell to start training from scratch.


## 6. Training Setup üõ†Ô∏è

**Loss Function Deep Dive:**
The loss function determines what the model optimizes for:

**Rate-Distortion Loss = Œª √ó Distortion + Rate**

- **Distortion**: How different reconstructed images are from originals
- **Rate**: How many bits needed to store compressed representation  
- **Lambda (Œª)**: Trade-off parameter
  - High Œª: Prioritize compression (smaller files, lower quality)
  - Low Œª: Prioritize quality (larger files, better images)

**Why MSE vs MS-SSIM?**
- **MSE**: Simple pixel differences, fast computation, but ignores human perception
- **MS-SSIM**: Structural similarity, matches human vision, but slower to compute

**Directory Structure:**
Organized saving helps track experiments and results across multiple runs.

In [None]:
# üõ†Ô∏è TRAINING SETUP - PREPARE DIRECTORIES AND LOSS FUNCTIONS

# üìÅ CREATE ORGANIZED OUTPUT DIRECTORIES
import datetime

# Create date-based checkpoint structure (same as main22.py)
date_folder = datetime.datetime.now().strftime("%m_%d")  # e.g., "10_07"
pkl_dir = os.path.join('.', 'checkpoints', date_folder, 'pkl')
pth_dir = os.path.join('.', 'checkpoints', date_folder, 'pth')

# Create directories if they don't exist
os.makedirs(pkl_dir, exist_ok=True)
os.makedirs(pth_dir, exist_ok=True)

# Add checkpoint directories to config for use in training loop
config['pkl_dir'] = pkl_dir
config['pth_dir'] = pth_dir

print(f"? Checkpoint directories created:")
print(f"   ‚Ä¢ .pkl files: {pkl_dir}")
print(f"   ‚Ä¢ .pth files: {pth_dir}")

üìÅ Checkpoint directories created:
   ‚Ä¢ .pkl files: ./checkpoints/10_07/pkl
   ‚Ä¢ .pth files: ./checkpoints/10_07/pth

üìÇ Training run directory structure:
   ‚Ä¢ Run folder: ./models/KITTI_MSE_lambda3e-05_20251007_201536
   ‚Ä¢ Results: ./models/KITTI_MSE_lambda3e-05_20251007_201536/results

üíæ Model checkpoints will be saved to:
   ‚Ä¢ PKL format: ./checkpoints/10_07/pkl/epoch_XXXX_psnr_YY.YYdB.pkl
   ‚Ä¢ PTH format: ./checkpoints/10_07/pth/epoch_XXXX_psnr_YY.YYdB.pth (recommended)

üéØ Loss Function Configuration:
   ‚Ä¢ Distortion metric: MSE
   ‚Ä¢ MSE function: GPU accelerated

‚öñÔ∏è Rate-Distortion Trade-off:
   ‚Ä¢ Lambda (Œª): 3e-05
   ‚Ä¢ Higher Œª ‚Üí More compression, lower quality
   ‚Ä¢ Lower Œª ‚Üí Less compression, higher quality
   ‚Ä¢ Current setting: Balanced

üìä Training Loss Formula:
   Total Loss = Œª √ó Distortion + Rate
   Where:
   ‚Ä¢ Distortion = MSE between original and reconstructed
   ‚Ä¢ Rate = Estimated bits per pixel (BPP)
   ‚Ä¢ Œª = 3e-05

## 7. Training Loop üéì

**The Heart of Deep Learning:**
This is where the magic happens! The training loop teaches our model to compress images by:

**Training Phase (Each Epoch):**
1. **Forward Pass**: Feed images through model ‚Üí get reconstruction
2. **Loss Calculation**: Measure how good/bad the reconstruction is
3. **Backpropagation**: Calculate gradients (how to improve)
4. **Optimization**: Update model weights using gradients

**Validation Phase:**
- Test on unseen data to check if model is learning generally (not just memorizing)
- No weight updates - just monitoring performance

**Key Metrics to Watch:**
- **Loss**: Should decrease over time (lower = better)
- **PSNR**: Peak Signal-to-Noise Ratio (higher = better quality)
- **MS-SSIM**: Structural similarity (closer to 1.0 = better)

**Training Tips:**
- First few epochs may have erratic loss - this is normal
- Look for consistent downward trend over many epochs
- Validation metrics more important than training metrics

In [8]:
# Install ipywidgets for notebook progress bars
import subprocess
import sys

try:
    import ipywidgets
    print("ipywidgets is already installed")
except ImportError:
    print("Installing ipywidgets...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ipywidgets"])
    print("ipywidgets installed successfully!")
    print("Note: You may need to restart the kernel if progress bars don't show correctly")

ipywidgets is already installed


In [10]:
import time
from tqdm.notebook import tqdm

# Ensure config has checkpoint directories (in case cell above wasn't re-run)
if 'pkl_dir' not in config:
    config['pkl_dir'] = pkl_dir
if 'pth_dir' not in config:
    config['pth_dir'] = pth_dir

# TRAINING SETUP
print("=== STARTING TRAINING ===")
print(f"Total Epochs: {config['epochs']}")
print(f"Batch Size: {config['batch_size']}")
print(f"Learning Rate: {config['lr']}")
print(f"Lambda: {config['lambda']}")
print(f"Checkpoint dirs: pkl={config.get('pkl_dir', 'NOT SET')}, pth={config.get('pth_dir', 'NOT SET')}")
print(f"Using {'notebook' if IN_NOTEBOOK else 'terminal'} progress bars")
print(f"Using {torch.cuda.device_count()} GPU(s)")

# Training variables
best_psnr = 0.0
min_val_loss = None

# Create progress bar for epochs
if IN_NOTEBOOK:
    epoch_pbar = tqdm(range(config['epochs']), desc='Training Progress', position=0)
else:
    epoch_pbar = range(config['epochs'])

# MAIN TRAINING LOOP
for epoch in epoch_pbar:
    epoch_start_time = time.time()
    
    # === TRAINING PHASE ===
    model.train()
    train_loss_sum = 0
    
    # Training progress bar
    if IN_NOTEBOOK:
        train_pbar = tqdm(enumerate(train_loader), total=len(train_loader), 
                         desc=f'Epoch {epoch+1}/{config["epochs"]} [Train]', 
                         leave=False, position=1)
    else:
        train_pbar = enumerate(train_loader)
    
    for batch_idx, (input_left, input_right) in train_pbar:
        # Move data to GPU
        input_left = input_left.to(device)
        input_right = input_right.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        output_left, output_right, num_bits_left, num_bits_right = model(input_left, input_right)
        
        # Calculate loss (distortion only)
        mse_loss_left = criterion(output_left, input_left)
        mse_loss_right = criterion(output_right, input_right)
        distortion_loss = mse_loss_left + mse_loss_right
        
        # Total loss (rate-distortion)
        total_loss = distortion_loss + config['lambda'] * (num_bits_left + num_bits_right)
        
        # Backward pass
        total_loss.backward()
        optimizer.step()
        
        # Track loss
        train_loss_sum += total_loss.item()
        
        # Update progress bar
        if IN_NOTEBOOK:
            train_pbar.set_postfix({'loss': total_loss.item()})
    
    # Calculate average training loss
    avg_train_loss = train_loss_sum / len(train_loader)
    
    # === VALIDATION PHASE ===
    model.eval()
    val_loss_sum = 0
    psnr_sum = 0
    msssim_db_sum = 0
    distortion_sum = 0
    
    # Validation progress bar
    if IN_NOTEBOOK:
        val_pbar = tqdm(enumerate(val_loader), total=len(val_loader),
                       desc=f'Epoch {epoch+1}/{config["epochs"]} [Val]',
                       leave=False, position=1)
    else:
        val_pbar = enumerate(val_loader)
    
    with torch.no_grad():
        for batch_idx, (input_left, input_right) in val_pbar:
            # Move data to GPU
            input_left = input_left.to(device)
            input_right = input_right.to(device)
            
            # Forward pass
            output_left, output_right, num_bits_left, num_bits_right = model(input_left, input_right)
            
            # Calculate loss
            mse_loss_left = criterion(output_left, input_left)
            mse_loss_right = criterion(output_right, input_right)
            distortion_loss = mse_loss_left + mse_loss_right
            total_loss = distortion_loss + config['lambda'] * (num_bits_left + num_bits_right)
            
            # Track metrics
            val_loss_sum += total_loss.item()
            distortion_sum += distortion_loss.item()
            
            # Calculate PSNR
            psnr_left = 10 * torch.log10(1 / mse_loss_left)
            psnr_right = 10 * torch.log10(1 / mse_loss_right)
            avg_psnr_batch = (psnr_left + psnr_right) / 2
            psnr_sum += avg_psnr_batch.item()
            
            # Calculate MS-SSIM in dB
            msssim_left = ms_ssim(output_left, input_left, data_range=1.0, size_average=True)
            msssim_right = ms_ssim(output_right, input_right, data_range=1.0, size_average=True)
            msssim_db_batch = -10 * torch.log10(1 - (msssim_left + msssim_right) / 2)
            msssim_db_sum += msssim_db_batch.item()
            
            # Update progress bar
            if IN_NOTEBOOK:
                val_pbar.set_postfix({'loss': total_loss.item(), 'psnr': avg_psnr_batch.item()})
    
    # Calculate average validation metrics
    avg_val_loss = val_loss_sum / len(val_loader)
    avg_psnr = psnr_sum / len(val_loader)
    avg_msssim_db = msssim_db_sum / len(val_loader)
    avg_distortion = distortion_sum / len(val_loader)
    
    # Calculate epoch time
    epoch_time = time.time() - epoch_start_time
    
    # Track loss for scheduler
    val_loss_to_track = avg_val_loss
    
    # PRINT EPOCH SUMMARY
    if IN_NOTEBOOK:
        tqdm.write(f"\n[Epoch {epoch+1}/{config['epochs']}] "
                  f"Time: {epoch_time:.1f}s | "
                  f"Train Loss: {avg_train_loss:.4f} | "
                  f"Val Loss: {avg_val_loss:.4f} | "
                  f"PSNR: {avg_psnr:.4f} dB | "
                  f"MS-SSIM: {avg_msssim_db:.4f} dB")
        
        # Print quality indicators
        if avg_psnr > 25:
            tqdm.write("  >> Good quality achieved (PSNR > 25 dB)")
        if avg_psnr > 30:
            tqdm.write("  >> Excellent quality achieved (PSNR > 30 dB)")
    else:
        print(f"\n[Epoch {epoch+1}/{config['epochs']}] "
              f"Time: {epoch_time:.1f}s | "
              f"Train Loss: {avg_train_loss:.4f} | "
              f"Val Loss: {avg_val_loss:.4f} | "
              f"PSNR: {avg_psnr:.4f} dB | "
              f"MS-SSIM: {avg_msssim_db:.4f} dB")
        
        # Print quality indicators
        if avg_psnr > 25:
            print("  >> Good quality achieved (PSNR > 25 dB)")
        if avg_psnr > 30:
            print("  >> Excellent quality achieved (PSNR > 30 dB)")
    
    # SAVE MODEL CHECKPOINTS
    if config['save_weights']:
        # Get the actual model (unwrap DataParallel if needed)
        model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else model
        
        # Create checkpoint dictionary
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model_to_save.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss_to_track,
            'psnr': avg_psnr,
            'train_loss': avg_train_loss,
            'config': config
        }
        
        # SAVE PICKLE FILE FOR EVERY EPOCH
        import pickle
        epoch_pkl_path = os.path.join(config['pkl_dir'], f'epoch_{epoch+1:04d}.pkl')
        epoch_data = {
            'epoch': epoch + 1,
            'train_loss': avg_train_loss,
            'val_loss': val_loss_to_track,
            'psnr': avg_psnr,
            'msssim_db': avg_msssim_db,
            'distortion': avg_distortion,
            'model_state': model_to_save.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict()
        }
        with open(epoch_pkl_path, 'wb') as f:
            pickle.dump(epoch_data, f)
        
        # Save periodic checkpoint every 10 epochs as .pth
        if (epoch + 1) % 10 == 0:
            checkpoint_path = os.path.join(config['pth_dir'], 
                                          f'epoch_{epoch+1:04d}_psnr_{avg_psnr:.2f}dB.pth')
            torch.save(checkpoint, checkpoint_path)
            if IN_NOTEBOOK:
                tqdm.write(f"  [SAVED] Checkpoint: epoch_{epoch+1:04d}_psnr_{avg_psnr:.2f}dB.pth")
            else:
                print(f"  [SAVED] Checkpoint: epoch_{epoch+1:04d}_psnr_{avg_psnr:.2f}dB.pth")
        
        # Save best model based on validation loss
        if min_val_loss is None or min_val_loss > val_loss_to_track:
            min_val_loss = val_loss_to_track
            best_model_path = os.path.join(config['pth_dir'], 'best_model_loss.pth')
            torch.save(checkpoint, best_model_path)
            if IN_NOTEBOOK:
                tqdm.write(f"  [BEST LOSS] best_model_loss.pth - Val Loss: {val_loss_to_track:.4f}")
            else:
                print(f"  [BEST LOSS] best_model_loss.pth - Val Loss: {val_loss_to_track:.4f}")
        
        # Save best model based on PSNR
        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            best_model_path = os.path.join(config['pth_dir'], 
                                          f'best_model_psnr_{avg_psnr:.2f}dB.pth')
            torch.save(checkpoint, best_model_path)
            if IN_NOTEBOOK:
                tqdm.write(f"  [BEST PSNR] best_model_psnr_{avg_psnr:.2f}dB.pth - PSNR: {avg_psnr:.4f} dB")
            else:
                print(f"  [BEST PSNR] best_model_psnr_{avg_psnr:.2f}dB.pth - PSNR: {avg_psnr:.4f} dB")
    
    # LEARNING RATE SCHEDULING
    scheduler.step(val_loss_to_track)  # Reduce LR if validation plateaus
    current_lr = optimizer.param_groups[0]['lr']
    if current_lr < config['lr']:
        if IN_NOTEBOOK:
            tqdm.write(f"  [LR REDUCED] Learning rate reduced to: {current_lr:.2e}")
        else:
            print(f"  [LR REDUCED] Learning rate reduced to: {current_lr:.2e}")

# Close progress bar
if IN_NOTEBOOK:
    epoch_pbar.close()

# SAVE FINAL TRAINING HISTORY AND MODEL
print("\n=== SAVING FINAL RESULTS ===")

# Save complete training history as pickle
import pickle
training_history = {
    'final_epoch': config['epochs'],
    'best_val_loss': min_val_loss,
    'best_psnr': best_psnr,
    'config': config
}
history_path = os.path.join(config['pkl_dir'], 'training_history.pkl')
with open(history_path, 'wb') as f:
    pickle.dump(training_history, f)
print(f"Training history saved to: {history_path}")

print("\n=== TRAINING COMPLETE ===")
print(f"Best Validation Loss: {min_val_loss:.4f}")
print(f"Best PSNR: {best_psnr:.4f} dB")

=== STARTING TRAINING ===
Total Epochs: 30000


KeyError: 'batch_size'

## 8. Testing Phase üß™

**Why Testing Matters:**
After training, we need to evaluate our model's performance on completely unseen data:

**Testing vs Training/Validation:**
- **Training**: Model sees this data and learns from it
- **Validation**: Model sees this during training but doesn't learn from it
- **Testing**: Model has NEVER seen this data - true performance measure

**Key Metrics for Image Compression:**
1. **PSNR (Peak Signal-to-Noise Ratio)**:
   - Measures signal quality in decibels (dB)
   - Higher = better (30+ dB is excellent)
   - Most common metric in compression research

2. **MS-SSIM (Multi-Scale Structural Similarity)**:
   - Measures perceptual similarity (0 to 1)
   - Closer to 1.0 = better perceptual quality
   - Better aligned with human visual system than PSNR

3. **BPP (Bits Per Pixel)** - if measured:
   - Compression efficiency 
   - Lower = more compressed files

**What the Results Tell Us:**
- Individual image performance (some compress better than others)
- Average performance across diverse scenes
- Visual quality through saved image comparisons

In [None]:
if config['test']:
    results_path = os.path.join(config['save_output_path'], 'results')
    if not os.path.exists(results_path):
        os.makedirs(results_path)
    
    names = ["Image Number", "PSNR", "MS-SSIM"]
    cols = dict()
    model.eval()
    mse_test = []
    
    print(f"Testing on {len(test_loader)} images...")
    
    with torch.no_grad():
        for i, data in enumerate(iter(test_loader)):
            img, cor_img, _, _ = data
            img = img.cuda().float() if config['cuda'] else img.float()
            cor_img = cor_img.cuda().float() if config['cuda'] else cor_img.float()
            
            # Forward pass
            out_l, out_r = model(img, cor_img)
            
            x_recon = out_l
            mse_dist = mse(img, x_recon)
            mse_test.append(mse_dist.item())
            msssim = 1 - ms_ssim(img.clone().cpu(), x_recon.clone().cpu(), data_range=1.0, 
                                 size_average=True, win_size=7)
            msssim_db = -10 * np.log10(msssim)
            
            vals = [str(i)] + ['{:.8f}'.format(x) for x in [
                10 * np.log10(1 / mse_dist.item()),
                msssim.item()
            ]]
            
            # Store results
            for (name, val) in zip(names, vals):
                if name not in cols:
                    cols[name] = []
                cols[name].append(val)
            
            # Print progress
            if (i + 1) % 10 == 0:
                print(f"Processed {i + 1}/{len(test_loader)} images")
            
            # Save images if requested
            if config['save_image']:
                save_image(x_recon[0], img[0], os.path.join(results_path, '{}_images'.format(1)), str(i))
    
    # Save results to CSV
    df = pd.DataFrame.from_dict(cols)
    csv_path = os.path.join(results_path, experiment_name + '.csv')
    df.to_csv(csv_path)
    
    # Calculate and display average metrics
    avg_psnr = np.mean([float(x) for x in cols['PSNR']])
    avg_msssim = np.mean([float(x) for x in cols['MS-SSIM']])
    
    print(f"\nTest Results:")
    print(f"Average PSNR: {avg_psnr:.4f} dB")
    print(f"Average MS-SSIM: {avg_msssim:.6f}")
    print(f"Results saved to: {csv_path}")
    
    # Display the results dataframe
    display(df.head(10))
else:
    print("Testing skipped (config['test'] = False)")

## 9. Visualize Sample Results

Display some sample reconstructed images from the test set.

In [None]:
# Visualize some sample results
if config['test']:
    model.eval()
    num_samples = 3  # Number of samples to visualize
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    
    with torch.no_grad():
        for idx in range(num_samples):
            # Get a sample from test dataset
            img, cor_img, _, _ = test_dataset[idx]
            img_batch = img.unsqueeze(0).cuda().float() if config['cuda'] else img.unsqueeze(0).float()
            cor_img_batch = cor_img.unsqueeze(0).cuda().float() if config['cuda'] else cor_img.unsqueeze(0).float()
            
            # Get reconstruction
            out_l, out_r = model(img_batch, cor_img_batch)
            
            # Convert to numpy for visualization
            img_np = img.cpu().numpy().transpose(1, 2, 0)
            cor_img_np = cor_img.cpu().numpy().transpose(1, 2, 0)
            recon_np = out_l[0].cpu().numpy().transpose(1, 2, 0)
            
            # Clip values to [0, 1]
            img_np = np.clip(img_np, 0, 1)
            cor_img_np = np.clip(cor_img_np, 0, 1)
            recon_np = np.clip(recon_np, 0, 1)
            
            # Plot
            axes[idx, 0].imshow(img_np)
            axes[idx, 0].set_title(f'Original Left Image {idx+1}')
            axes[idx, 0].axis('off')
            
            axes[idx, 1].imshow(cor_img_np)
            axes[idx, 1].set_title(f'Original Right Image {idx+1}')
            axes[idx, 1].axis('off')
            
            axes[idx, 2].imshow(recon_np)
            axes[idx, 2].set_title(f'Reconstructed Left Image {idx+1}')
            axes[idx, 2].axis('off')
    
    plt.tight_layout()
    plt.show()
    print("Sample visualizations complete!")
else:
    print("Testing was not performed. Set config['test'] = True to visualize results.")

## 10. Summary

Review the complete workflow and results.

## üéØ Learning Outcomes & Next Steps

**What You've Learned:**
‚úÖ **Deep Compression Theory**: Rate-distortion optimization in neural networks  
‚úÖ **Stereo Processing**: Using geometric relationships for better compression  
‚úÖ **Training Pipeline**: Complete deep learning workflow from scratch  
‚úÖ **Performance Metrics**: PSNR, MS-SSIM, and BPP interpretation  
‚úÖ **Practical Implementation**: Real-world compression system on KITTI dataset  

**Key Takeaways:**
- **Neural compression** learns patterns that traditional methods miss
- **Stereo correlation** significantly improves compression efficiency  
- **Rate-distortion trade-off** is controlled by lambda parameter
- **Training from scratch** requires patience but provides deep understanding

**Potential Improvements:**
1. **Architecture**: Try different encoder/decoder designs
2. **Loss Functions**: Experiment with perceptual losses
3. **Dataset**: Train on more diverse image types
4. **Optimization**: Advanced techniques like progressive training
5. **Evaluation**: Add BPP measurement for complete analysis

**Real-World Applications:**
- üöó **Autonomous Vehicles**: Efficient stereo data transmission
- üéÆ **VR/AR**: Real-time stereo compression for immersive experiences  
- üì± **Mobile Devices**: Bandwidth-efficient video calling
- üõ∞Ô∏è **Satellite Imaging**: Space-constrained data transmission

In [None]:
print("=" * 80)
print("EXPERIMENT SUMMARY")
print("=" * 80)
print(f"\nExperiment Name: {experiment_name}")
print(f"\nDataset: {config['dataset_name']}")
print(f"  - Training samples: {len(train_dataset)}")
print(f"  - Validation samples: {len(val_dataset)}")
print(f"  - Test samples: {len(test_dataset)}")
print(f"\nModel Configuration:")
print(f"  - Baseline model: {config['baseline_model']}")
print(f"  - Use side info: {config['use_side_info']}")
print(f"  - Image size: {config['resize']}")
print(f"\nTraining Configuration:")
print(f"  - Training enabled: {config['train']}")
print(f"  - Lambda: {config['lambda']}")
print(f"  - Learning rate: {config['lr']}")
print(f"  - Batch size: {config['train_batch_size']}")
print(f"  - Distortion loss: {config['distortion_loss']}")
print(f"\nTesting Configuration:")
print(f"  - Testing enabled: {config['test']}")
print(f"  - Save images: {config['save_image']}")
print(f"\nOutput:")
print(f"  - Save path: {config['save_output_path']}")
print(f"  - Save weights: {config['save_weights']}")
print("=" * 80)