In [None]:
# ==================== INSTALLATIONS & IMPORTS ====================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import PIL.Image as Image
import io
import requests
from torchvision import transforms
from torchvision.models import vgg19
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import time
import os
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import cv2
from PIL import Image as PILImage
from io import BytesIO
import warnings
import glob
warnings.filterwarnings('ignore')

print("âœ… All imports successful!")

In [None]:
# ==================== OPTIMIZED CONFIGURATION ====================

class Config:
    # Model parameters - optimized for compression
    levels = 3
    hidden_channels = 96  # Reduced for efficiency
    num_mixtures = 5      # Increased for better entropy modeling
    
    # Training parameters
    batch_size = 4
    learning_rate = 1e-4
    num_epochs = 50
    
    # Image parameters
    image_size = 256
    
    # Compression optimization
    rate_weight = 5e-4    # Increased for better compression
    subband_importance = [1.0, 1.3, 1.3, 1.8]  # LL, LH, HL, HH
    
    # Paths
    model_dir = "/kaggle/working/iwavev3_models/"
    sample_images_dir = "/kaggle/working/sample_images/"
    train_data_path = "/kaggle/input/kodim-shivam/*"

config = Config()

# Create directories
os.makedirs(config.model_dir, exist_ok=True)
os.makedirs(config.sample_images_dir, exist_ok=True)

print("âœ… Configuration setup complete!")

In [None]:
# ==================== OPTIMIZED DATASET CLASS ====================

class ImageCompressionDataset(Dataset):
    def __init__(self, image_paths, image_size=256):
        self.image_paths = image_paths
        self.image_size = image_size
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        try:
            image = PILImage.open(image_path).convert('RGB')
            image_tensor = self.transform(image)
            return image_tensor, image_path
        except Exception as e:
            print(f" Error loading {image_path}: {e}")
            return torch.zeros(3, self.image_size, self.image_size), image_path
def custom_collate_fn(batch):
    images, paths = zip(*batch)
    images = torch.stack(images, 0)
    return images, paths

print(" Dataset class defined!")

# ==================== OPTIMIZED MODEL COMPONENTS ====================

class ResidualBlock(nn.Module):
    def __init__(self, channels=96):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        residual = x
        x = self.relu(self.conv1(x))
        x = self.conv2(x)
        return x + residual

In [None]:
#TRANSFORMS
class AdditiveTransformUnit(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=96):
        super(AdditiveTransformUnit, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, 5, padding=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_channels, hidden_channels, 5, padding=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_channels, in_channels, 5, padding=2)
        )
    
    def forward(self, x):
        return self.net(x)

class AffineTransformUnit(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=96):
        super(AffineTransformUnit, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, 5, padding=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_channels, hidden_channels, 5, padding=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_channels, in_channels * 2, 5, padding=2)
        )
    
    def forward(self, x):
        output = self.net(x)
        shift, scale = output.chunk(2, dim=1)
        scale = torch.sigmoid(scale) * 1.5 + 0.5
        return shift, scale

In [None]:
class WaveletLikeTransform(nn.Module):
    def __init__(self, levels=3, transform_type='affine'):
        super(WaveletLikeTransform, self).__init__()
        self.levels = levels
        self.transform_type = transform_type
        
        if transform_type == 'additive':
            self.P_units = nn.ModuleList([AdditiveTransformUnit(3, config.hidden_channels) for _ in range(levels)])
            self.U_units = nn.ModuleList([AdditiveTransformUnit(3, config.hidden_channels) for _ in range(levels)])
        elif transform_type == 'affine':
            self.P_units = nn.ModuleList([AffineTransformUnit(3, config.hidden_channels) for _ in range(levels)])
            self.U_units = nn.ModuleList([AffineTransformUnit(3, config.hidden_channels) for _ in range(levels)])
    
    def ensure_divisible(self, x, divisor=32):
        h, w = x.shape[2], x.shape[3]
        new_h = ((h + divisor - 1) // divisor) * divisor
        new_w = ((w + divisor - 1) // divisor) * divisor
        
        if new_h != h or new_w != w:
            x = F.interpolate(x, size=(new_h, new_w), mode='bilinear', align_corners=False)
        return x
    
    def forward_single_level(self, x, level):
        batch, channels, height, width = x.shape
        
        if height % 2 != 0:
            x = F.pad(x, (0, 0, 0, 1), mode='reflect')
        if width % 2 != 0:
            x = F.pad(x, (0, 1, 0, 0), mode='reflect')
        height, width = x.shape[2], x.shape[3]
        
        x_ll = x[:, :, 0::2, 0::2]
        x_lh = x[:, :, 0::2, 1::2]
        x_hl = x[:, :, 1::2, 0::2]
        x_hh = x[:, :, 1::2, 1::2]
        
        if self.transform_type == 'additive':
            h_temp = x_hh - self.P_units[level](x_ll)
            l_temp = x_ll + self.U_units[level](h_temp)
            
            hl_temp = x_hl - self.P_units[level](l_temp)
            lh_temp = x_lh - self.P_units[level](l_temp)
            
            ll = l_temp + self.U_units[level](hl_temp) + self.U_units[level](lh_temp)
            lh = lh_temp
            hl = hl_temp
            hh = h_temp
        else:
            shift_p, scale_p = self.P_units[level](x_ll)
            h_temp = scale_p * (x_hh - shift_p)
            
            shift_u, scale_u = self.U_units[level](h_temp)
            l_temp = scale_u * (x_ll + shift_u)
            
            shift_p2, scale_p2 = self.P_units[level](l_temp)
            hl_temp = scale_p2 * (x_hl - shift_p2)
            lh_temp = scale_p2 * (x_lh - shift_p2)
            
            shift_u2, scale_u2 = self.U_units[level](hl_temp)
            ll = scale_u2 * (l_temp + shift_u2)
            
            lh = lh_temp
            hl = hl_temp
            hh = h_temp
        
        return ll, lh, hl, hh
    def forward(self, x):
        divisor = 2 ** self.levels
        x = self.ensure_divisible(x, divisor)
        
        subbands = []
        current = x
        
        for level in range(self.levels):
            ll, lh, hl, hh = self.forward_single_level(current, level)
            subbands.extend([lh, hl, hh])
            current = ll
        
        subbands.append(current)
        return subbands
    
    def inverse_single_level(self, ll, lh, hl, hh, level):
        if self.transform_type == 'additive':
            l_temp = ll - self.U_units[level](hl) - self.U_units[level](lh)
            x_lh = lh + self.P_units[level](l_temp)
            x_hl = hl + self.P_units[level](l_temp)
            
            x_ll = l_temp - self.U_units[level](hh)
            x_hh = hh + self.P_units[level](x_ll)
         else:
            shift_u2, scale_u2 = self.U_units[level](hl)
            l_temp = (ll / scale_u2) - shift_u2
            
            shift_p2, scale_p2 = self.P_units[level](l_temp)
            x_lh = (lh / scale_p2) + shift_p2
            x_hl = (hl / scale_p2) + shift_p2
            
            shift_u, scale_u = self.U_units[level](hh)
            x_ll = (l_temp / scale_u) - shift_u
            
            shift_p, scale_p = self.P_units[level](x_ll)
            x_hh = (hh / scale_p) + shift_p
        
        batch, channels, height, width = ll.shape
        reconstructed = torch.zeros(batch, channels, height * 2, width * 2, device=ll.device)
        
        reconstructed[:, :, 0::2, 0::2] = x_ll
        reconstructed[:, :, 0::2, 1::2] = x_lh
        reconstructed[:, :, 1::2, 0::2] = x_hl
        reconstructed[:, :, 1::2, 1::2] = x_hh
        
        return reconstructed
    
                
    def inverse(self, subbands):
        current = subbands[-1]
        
        for level in reversed(range(self.levels)):
            lh = subbands[level * 3]
            hl = subbands[level * 3 + 1]
            hh = subbands[level * 3 + 2]
            current = self.inverse_single_level(current, lh, hl, hh, level)
        
        return current

In [None]:
#IMPROVED ENTROPY MODEL CLASS
class ImprovedEntropyModel(nn.Module):
    def __init__(self, num_mixtures=5):
        super(ImprovedEntropyModel, self).__init__()
        self.num_mixtures = num_mixtures
        self.context_net = nn.Sequential(
            nn.Conv2d(3, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, num_mixtures * 3, 3, padding=1)
        )
    
    def forward(self, subbands):
        entropy_params = []
        for subband in subbands:
            params = self.context_net(subband)
            weights = F.softmax(params[:, :self.num_mixtures], dim=1)
            means = torch.tanh(params[:, self.num_mixtures:2*self.num_mixtures]) * 1.0  # Tighter range
            scales = torch.exp(torch.clamp(params[:, 2*self.num_mixtures:3*self.num_mixtures], -3, 3))  # Tighter scales
            entropy_params.append((weights, means, scales))
        return entropy_params

In [None]:
#DequantizationModule
class DequantizationModule(nn.Module):
    def __init__(self, channels=96):
        super(DequantizationModule, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, channels, 3, padding=1),
            nn.ReLU(),
            ResidualBlock(channels),
            ResidualBlock(channels),
            nn.Conv2d(channels, 3, 3, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.net(x)

In [None]:
#PERCEPTUAL POST PROCESSING
class PerceptualPostProcessing(nn.Module):
    def __init__(self, channels=96):
        super(PerceptualPostProcessing, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, channels, 3, padding=1),
            nn.ReLU(),
            ResidualBlock(channels),
            nn.Conv2d(channels, 3, 3, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.net(x) + x

In [None]:
class iWaveV3_Base(nn.Module):
    def __init__(self, levels=3, transform_type='affine'):
        super(iWaveV3_Base, self).__init__()
        self.levels = levels
        self.transform_type = transform_type
        
        self.transform = WaveletLikeTransform(levels, transform_type)
        self.entropy_model = ImprovedEntropyModel(config.num_mixtures)
        self.qstep = nn.Parameter(torch.tensor(0.03))  # Slightly higher for compression
    
    def adaptive_quantize(self, x, subband_idx):
        """Adaptive quantization based on subband importance"""
        importance = config.subband_importance[subband_idx % len(config.subband_importance)]
        effective_qstep = self.qstep * importance
        
        if self.training:
            noise = (torch.rand_like(x) - 0.5) * effective_qstep
            return x + noise
        else:
            return torch.round(x / effective_qstep) * effective_qstep
            
    def channel_aware_quantize(self, x, subband_idx):
        """Different quantization for luminance vs chrominance"""
        if self.training:
            # Convert to YCbCr for better compression
            if x.shape[1] == 3:  # RGB image
                ycbcr = rgb_to_ycbcr(x)
                # Luminance (Y) gets finer quantization
                qsteps = torch.tensor([0.02, 0.04, 0.04]).to(x.device)
                noise = (torch.rand_like(ycbcr) - 0.5) * qsteps.view(1, 3, 1, 1)
                quantized = ycbcr + noise
                return ycbcr_to_rgb(quantized)
            else:
                return self.adaptive_quantize(x, subband_idx)
        else:
            if x.shape[1] == 3:
                ycbcr = rgb_to_ycbcr(x)
                qsteps = torch.tensor([0.02, 0.04, 0.04]).to(x.device)
                quantized = torch.round(ycbcr / qsteps.view(1, 3, 1, 1)) * qsteps.view(1, 3, 1, 1)
                return ycbcr_to_rgb(quantized)
            else:
                return self.adaptive_quantize(x, subband_idx)
                
    def calculate_rate_improved(self, quantized_subbands, entropy_params):
        """Improved rate calculation with better probability modeling"""
        total_rate = 0
        for i, (subband, (weights, means, scales)) in enumerate(zip(quantized_subbands, entropy_params)):
            subband_prob = 0
            for k in range(config.num_mixtures):
                # Use Laplace distribution for heavier tails (better for compression)
                scale = torch.clamp(scales[:, k:k+1], min=1e-6)
                log_prob = -torch.abs(subband - means[:, k:k+1]) / scale - torch.log(2 * scale)
                prob_component = weights[:, k:k+1] * torch.exp(log_prob)
                subband_prob += prob_component
            
            # Add small epsilon and use log1p for numerical stability
            rate = -torch.log(torch.clamp(subband_prob, min=1e-10))
            total_rate += rate.mean()
        
        return total_rate / len(quantized_subbands)

In [None]:
#Color conversion

def rgb_to_ycbcr(x):
    """Convert RGB to YCbCr"""
    matrix = torch.tensor([[0.299, 0.587, 0.114],
                          [-0.168736, -0.331264, 0.5],
                          [0.5, -0.418688, -0.081312]]).to(x.device)
    ycbcr = torch.einsum('ij,bjhw->bihw', matrix, x)
    ycbcr[:, 1:] += 0.5  # Center Cb, Cr
    return ycbcr

def ycbcr_to_rgb(x):
    """Convert YCbCr to RGB"""
    matrix = torch.tensor([[1.0, 0.0, 1.402],
                          [1.0, -0.344136, -0.714136],
                          [1.0, 1.772, 0.0]]).to(x.device)
    rgb = x.clone()
    rgb[:, 1:] -= 0.5  # Uncenter Cb, Cr
    rgb = torch.einsum('ij,bjhw->bihw', matrix, rgb)
    return torch.clamp(rgb, 0, 1)

In [None]:
#IWAVEV3 OBJ
class iWaveV3_Obj(iWaveV3_Base):
    def __init__(self, levels=3, transform_type='affine'):
        super(iWaveV3_Obj, self).__init__(levels, transform_type)
        self.dequantization = DequantizationModule()
        
    def forward(self, x, training=True):
        # Forward transform
        subbands = self.transform(x)
        
        # Adaptive quantization
        if training:
            quantized_subbands = [self.channel_aware_quantize(sb, i) for i, sb in enumerate(subbands)]
        else:
            quantized_subbands = [self.channel_aware_quantize(sb, i) for i, sb in enumerate(subbands)]
        
        # Entropy coding
        entropy_params = self.entropy_model(quantized_subbands)
        
        # Inverse transform
        reconstructed = self.transform.inverse(quantized_subbands)
        
        # Dequantization
        reconstructed = self.dequantization(reconstructed)
        
        # Calculate rate with improved method
        rate = self.calculate_rate_improved(quantized_subbands, entropy_params) if training else 0
        
        return reconstructed, quantized_subbands, rate

In [None]:
#IwaveV3 perp
class iWaveV3_Perp(iWaveV3_Base):
    def __init__(self, levels=3, transform_type='affine'):
        super(iWaveV3_Perp, self).__init__(levels, transform_type)
        self.dequantization = DequantizationModule()
        self.perceptual_pp = PerceptualPostProcessing()
        
    def forward(self, x, training=True, use_perceptual=True):
        # Forward transform
        subbands = self.transform(x)
        
        # Adaptive quantization
        if training:
            quantized_subbands = [self.channel_aware_quantize(sb, i) for i, sb in enumerate(subbands)]
        else:
            quantized_subbands = [self.channel_aware_quantize(sb, i) for i, sb in enumerate(subbands)]
        
        # Entropy coding
        entropy_params = self.entropy_model(quantized_subbands)
        
        # Inverse transform
        reconstructed = self.transform.inverse(quantized_subbands)
        
        # Dequantization
        reconstructed = self.dequantization(reconstructed)
        
        # Perceptual post-processing
        if use_perceptual:
            reconstructed = self.perceptual_pp(reconstructed)
        
        # Calculate rate with improved method
        rate = self.calculate_rate_improved(quantized_subbands, entropy_params) if training else 0
        
        return reconstructed, quantized_subbands, rate
print("Optimized model components defined!")


In [None]:
# ==================== OPTIMIZED METRICS & UTILITIES ====================

class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        vgg = vgg19(pretrained=True).features
        self.slice1 = nn.Sequential(*list(vgg)[:2])
        
        for param in self.parameters():
            param.requires_grad = False
    
    def forward(self, x, target):
        x = (x - 0.5) * 2
        target = (target - 0.5) * 2
        
        x_feat = self.slice1(x)
        target_feat = self.slice1(target)
        
        return F.mse_loss(x_feat, target_feat)

def calculate_metrics(original, reconstructed):
    # Move tensors to CPU and convert to numpy
    original_np = original.squeeze().permute(1, 2, 0).cpu().numpy()
    reconstructed_np = reconstructed.squeeze().permute(1, 2, 0).cpu().numpy()
    
    original_np = np.clip(original_np, 0, 1)
    reconstructed_np = np.clip(reconstructed_np, 0, 1)
    
    psnr = peak_signal_noise_ratio(original_np, reconstructed_np, data_range=1.0)
    
    min_dim = min(original_np.shape[0], original_np.shape[1])
    win_size = min(7, min_dim)
    if win_size % 2 == 0:
        win_size -= 1
    
    try:
        ssim = structural_similarity(original_np, reconstructed_np, 
                                   win_size=win_size, channel_axis=2, data_range=1.0)
    except:
        ssim = 0.5
    
    return psnr, ssim

def calculate_bpp_improved(quantized_subbands, image_size):
    """Improved BPP calculation considering entropy"""
    total_bits = 0
    for subband in quantized_subbands:
        # Move to CPU for numpy operations
        subband_cpu = subband.cpu()
        # Estimate bits using entropy
        unique_vals = torch.unique(subband_cpu)
        hist = torch.histc(subband_cpu, bins=len(unique_vals))
        prob = hist / hist.sum()
        entropy = -torch.sum(prob * torch.log2(prob + 1e-8))
        total_bits += entropy * subband_cpu.numel()
    
    total_pixels = image_size[0] * image_size[1]
    return total_bits / total_pixels

def calculate_bpp_simple(quantized_subbands, image_size):
    """Simple BPP calculation as fallback"""
    total_elements = sum(sb.numel() for sb in quantized_subbands)
    total_pixels = image_size[0] * image_size[1]
    bpp = (total_elements * 2) / total_pixels
    return bpp

def load_and_preprocess_image(image_path, target_size=None):
    """
    Enhanced image loading with better preprocessing
    """
    if target_size is None:
        target_size = (config.image_size, config.image_size)
    
    try:
        image = PILImage.open(image_path).convert('RGB')
        original_size = image.size  # (width, height)
        
        # Convert to tensor and resize
        transform = transforms.Compose([
            transforms.Resize(target_size),
            transforms.ToTensor()
        ])
        
        image_tensor = transform(image)
        processed_size = (image_tensor.shape[2], image_tensor.shape[1])  # (height, width)
        
        return image_tensor.unsqueeze(0), original_size, processed_size
        
    except Exception as e:
        print(f" Error loading image {image_path}: {e}")
        # Return dummy tensor
        dummy_tensor = torch.zeros(1, 3, target_size[0], target_size[1])
        return dummy_tensor, target_size, target_size

print(" Optimized metrics and utilities defined!")



In [None]:
# ==================== CLEAN OUTPUT VISUALIZATION ====================

def display_clean_results(original_tensor, iwave_obj, iwave_perp, device, image_path):
    """
    Display clean results: Original, Compressed (both methods), Reconstructed
    """
    print("\n" + "="*60)
    print("  CLEAN COMPRESSION RESULTS")
    print("="*60)
    
    # Set quantization step for testing
    qstep = 0.025
    iwave_obj.qstep.data = torch.tensor(qstep).to(device)
    iwave_perp.qstep.data = torch.tensor(qstep).to(device)
    
    with torch.no_grad():
        # Compress with both methods
        compressed_obj, quantized_obj, _ = iwave_obj(original_tensor, training=False)
        compressed_perp, quantized_perp, _ = iwave_perp(original_tensor, training=False, use_perceptual=True)
        
        # Reconstruct from compressed data (simulating loading from file)
        reconstructed_obj, _, _ = iwave_obj(compressed_obj, training=False)
        reconstructed_perp, _, _ = iwave_perp(compressed_perp, training=False, use_perceptual=True)
        
        # Calculate metrics for compressed images
        psnr_obj, ssim_obj = calculate_metrics(original_tensor, compressed_obj)
        psnr_perp, ssim_perp = calculate_metrics(original_tensor, compressed_perp)
        
        # Calculate metrics for reconstructed images
        psnr_recon_obj, ssim_recon_obj = calculate_metrics(original_tensor, reconstructed_obj)
        psnr_recon_perp, ssim_recon_perp = calculate_metrics(original_tensor, reconstructed_perp)
        # Calculate BPP
        try:
            bpp_obj = calculate_bpp_improved(quantized_obj, (original_tensor.shape[2], original_tensor.shape[3]))
            bpp_perp = calculate_bpp_improved(quantized_perp, (original_tensor.shape[2], original_tensor.shape[3]))
        except:
            bpp_obj = calculate_bpp_simple(quantized_obj, (original_tensor.shape[2], original_tensor.shape[3]))
            bpp_perp = calculate_bpp_simple(quantized_perp, (original_tensor.shape[2], original_tensor.shape[3]))

    # Convert tensors for display
    original_display = original_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
    compressed_obj_display = compressed_obj.squeeze().permute(1, 2, 0).cpu().numpy()
    compressed_perp_display = compressed_perp.squeeze().permute(1, 2, 0).cpu().numpy()
    recon_obj_display = reconstructed_obj.squeeze().permute(1, 2, 0).cpu().numpy()
    recon_perp_display = reconstructed_perp.squeeze().permute(1, 2, 0).cpu().numpy()

    # Create clean visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Row 1: Original and Compressed Images
    axes[0, 0].imshow(np.clip(original_display, 0, 1))
    axes[0, 0].set_title('1. Original Image', fontsize=16, fontweight='bold', pad=20)
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(np.clip(compressed_obj_display, 0, 1))
    axes[0, 1].set_title('2. iWaveV3-Obj Compressed', fontsize=16, fontweight='bold', pad=20)
    axes[0, 1].axis('off')
    
    axes[0, 2].imshow(np.clip(compressed_perp_display, 0, 1))
    axes[0, 2].set_title('3. iWaveV3-Perp Compressed', fontsize=16, fontweight='bold', pad=20)
    axes[0, 2].axis('off')

        # Row 2: Reconstructed Images
    axes[1, 0].imshow(np.clip(original_display, 0, 1))
    axes[1, 0].set_title('Reference (Original)', fontsize=14, fontweight='bold', pad=20)
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(np.clip(recon_obj_display, 0, 1))
    axes[1, 1].set_title('4. iWaveV3-Obj Reconstructed', fontsize=16, fontweight='bold', pad=20)
    axes[1, 1].axis('off')
    
    axes[1, 2].imshow(np.clip(recon_perp_display, 0, 1))
    axes[1, 2].set_title('5. iWaveV3-Perp Reconstructed', fontsize=16, fontweight='bold', pad=20)
    axes[1, 2].axis('off')
    
    # Add text annotations with metrics
    fig.text(0.15, 0.48, f'PSNR: {psnr_obj:.2f} dB\nSSIM: {ssim_obj:.4f}\nBPP: {bpp_obj:.3f}', 
             fontsize=12, ha='center', va='center', bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue"))
    fig.text(0.5, 0.48, f'PSNR: {psnr_perp:.2f} dB\nSSIM: {ssim_perp:.4f}\nBPP: {bpp_perp:.3f}', 
             fontsize=12, ha='center', va='center', bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen"))
    fig.text(0.15, 0.02, f'PSNR: {psnr_recon_obj:.2f} dB\nSSIM: {ssim_recon_obj:.4f}', 
             fontsize=12, ha='center', va='center', bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"))
    fig.text(0.5, 0.02, f'PSNR: {psnr_recon_perp:.2f} dB\nSSIM: {ssim_recon_perp:.4f}', 
             fontsize=12, ha='center', va='center', bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow"))
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.1, top=0.9)
    plt.savefig(f"{config.model_dir}clean_compression_results.png", dpi=300, bbox_inches='tight')
    plt.show()
    # Print detailed metrics table
    print("\n" + "="*80)
    print(" DETAILED COMPRESSION METRICS")
    print("="*80)
    print(f"{'Stage':<25} {'Model':<15} {'PSNR(dB)':<10} {'SSIM':<8} {'BPP':<8}")
    print("-"*80)
    print(f"{'Compressed':<25} {'iWaveV3-Obj':<15} {psnr_obj:<10.2f} {ssim_obj:<8.4f} {bpp_obj:<8.3f}")
    print(f"{'Compressed':<25} {'iWaveV3-Perp':<15} {psnr_perp:<10.2f} {ssim_perp:<8.4f} {bpp_perp:<8.3f}")
    print(f"{'Reconstructed':<25} {'iWaveV3-Obj':<15} {psnr_recon_obj:<10.2f} {ssim_recon_obj:<8.4f} {'-':<8}")
    print(f"{'Reconstructed':<25} {'iWaveV3-Perp':<15} {psnr_recon_perp:<10.2f} {ssim_recon_perp:<8.4f} {'-':<8}")
    print("-"*80)
    
    # Save individual images
    save_individual_images(original_tensor, compressed_obj, compressed_perp, 
                          reconstructed_obj, reconstructed_perp)
    
    return {
        'compressed_metrics': {
            'obj_psnr': psnr_obj, 'obj_ssim': ssim_obj, 'obj_bpp': bpp_obj,
            'perp_psnr': psnr_perp, 'perp_ssim': ssim_perp, 'perp_bpp': bpp_perp
        },
        'reconstructed_metrics': {
            'obj_psnr': psnr_recon_obj, 'obj_ssim': ssim_recon_obj,
            'perp_psnr': psnr_recon_perp, 'perp_ssim': ssim_recon_perp
        }
    }


    
        

In [None]:
#Saving individual images
def save_individual_images(original, compressed_obj, compressed_perp, recon_obj, recon_perp):
    """Save all individual images"""
    # Convert to PIL and save
    original_pil = transforms.ToPILImage()(original.squeeze().cpu())
    compressed_obj_pil = transforms.ToPILImage()(compressed_obj.squeeze().cpu())
    compressed_perp_pil = transforms.ToPILImage()(compressed_perp.squeeze().cpu())
    recon_obj_pil = transforms.ToPILImage()(recon_obj.squeeze().cpu())
    recon_perp_pil = transforms.ToPILImage()(recon_perp.squeeze().cpu())
    
    original_pil.save(f"{config.model_dir}1_original.png")
    compressed_obj_pil.save(f"{config.model_dir}2_compressed_obj.png")
    compressed_perp_pil.save(f"{config.model_dir}3_compressed_perp.png")
    recon_obj_pil.save(f"{config.model_dir}4_reconstructed_obj.png")
    recon_perp_pil.save(f"{config.model_dir}5_reconstructed_perp.png")
    
    print(f"\n Individual images saved:")
    print(f"   1. {config.model_dir}1_original.png")
    print(f"   2. {config.model_dir}2_compressed_obj.png")
    print(f"   3. {config.model_dir}3_compressed_perp.png")
    print(f"   4. {config.model_dir}4_reconstructed_obj.png")
    print(f"   5. {config.model_dir}5_reconstructed_perp.png")

In [None]:
# ==================== OPTIMIZED TRAINING ====================

def initialize_models(device):
    iwave_obj = iWaveV3_Obj(levels=config.levels, transform_type='affine').to(device)
    iwave_perp = iWaveV3_Perp(levels=config.levels, transform_type='affine').to(device)
    
    def init_weights(m):
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight, gain=0.5)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
    
    iwave_obj.apply(init_weights)
    iwave_perp.apply(init_weights)
    
    return iwave_obj, iwave_perp
def progressive_training(model, dataloader, device, model_name):
    """Progressive training for better compression"""
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=1e-5)
    mse_loss = nn.MSELoss()
    perceptual_loss = VGGPerceptualLoss().to(device) if "Perp" in model_name else None
    
    model.train()
    
    for epoch in range(config.num_epochs):
        # Progressive quantization - start easy, get harder
        if epoch < 15:
            model.qstep.data = torch.tensor(0.02)  # Easy
        elif epoch < 35:
            model.qstep.data = torch.tensor(0.03)  # Medium
        else:
            model.qstep.data = torch.tensor(0.04)  # Hard (more compression)
        
        total_loss = 0
        num_batches = 0
        
        for batch_imgs, paths in dataloader:
            batch_imgs = batch_imgs.to(device)
            optimizer.zero_grad()
            
            if "Perp" in model_name:
                reconstructed, _, rate = model(batch_imgs, training=True, use_perceptual=True)
                distortion = mse_loss(reconstructed, batch_imgs)
                percep = perceptual_loss(reconstructed, batch_imgs)
                loss = distortion + 0.01 * percep + config.rate_weight * rate
                
            else:
                reconstructed, _, rate = model(batch_imgs, training=True)
                distortion = mse_loss(reconstructed, batch_imgs)
                loss = distortion + config.rate_weight * rate
            
            if not torch.isnan(loss) and not torch.isinf(loss):
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()
                
                total_loss += loss.item()
                num_batches += 1
        
        if num_batches > 0 and (epoch + 1) % 10 == 0:
            avg_loss = total_loss / num_batches
            current_qstep = model.qstep.item()
            print(f'   Epoch {epoch+1}/{config.num_epochs}, Loss: {avg_loss:.6f}, Qstep: {current_qstep:.4f}')
    
    model.eval()
    return model

def train_models_on_dataset():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f" Using device: {device}")
    
    image_paths = glob.glob(config.train_data_path)
    print(f" Found {len(image_paths)} images for training")
    
    dataset = ImageCompressionDataset(image_paths, image_size=config.image_size)
    dataloader = DataLoader(
        dataset, 
        batch_size=config.batch_size, 
        shuffle=True, 
        num_workers=0,
        collate_fn=custom_collate_fn
    )
    
    iwave_obj, iwave_perp = initialize_models(device)
    
    print(f" Starting progressive training for {config.num_epochs} epochs...")
    
    print("\n Training iWaveV3-Obj with progressive compression...")
    iwave_obj = progressive_training(iwave_obj, dataloader, device, "iWaveV3-Obj")
    
    print("\n Training iWaveV3-Perp with progressive compression...")
    iwave_perp = progressive_training(iwave_perp, dataloader, device, "iWaveV3-Perp")
    
    print(" Training completed!")
    return iwave_obj, iwave_perp, device

In [None]:
# ==================== TESTING INTERFACE ====================

def test_single_image_clean(iwave_obj, iwave_perp, device):
    """
    Clean testing function that shows exactly what you need
    """
    print("\n" + "="*60)
    print(" CLEAN TESTING INTERFACE")
    print("="*60)
    
    test_image_path = input(" Enter the path to your test image: ").strip()
    
    if not os.path.exists(test_image_path):
        print(" Image path not found! Using default image...")
        image_paths = glob.glob(config.train_data_path)
        test_image_path = image_paths[0] if image_paths else "/kaggle/input/archive/kodim03.png"
    
    print(f" Processing: {test_image_path}")
    
    image_tensor, original_size, processed_size = load_and_preprocess_image(test_image_path)
    image_tensor = image_tensor.to(device)
    
    print(f" Original size: {original_size}, Processed size: {processed_size}")

# Test different compression levels
    compression_levels = [
        {'name': 'High Quality', 'qstep': 0.015},
        {'name': 'Balanced', 'qstep': 0.025},
        {'name': 'High Compression', 'qstep': 0.035},
    ]
    
    print(f"\n Testing {len(compression_levels)} compression levels...")
    
    for level in compression_levels:
        print(f"\n {level['name']} (qstep={level['qstep']}):")
        
        iwave_obj.qstep.data = torch.tensor(level['qstep']).to(device)
        iwave_perp.qstep.data = torch.tensor(level['qstep']).to(device)
        
        with torch.no_grad():
            # Compress
            compressed_obj, quantized_obj, _ = iwave_obj(image_tensor, training=False)
            compressed_perp, quantized_perp, _ = iwave_perp(image_tensor, training=False, use_perceptual=True)
            
            # Calculate metrics
            psnr_obj, ssim_obj = calculate_metrics(image_tensor, compressed_obj)
            psnr_perp, ssim_perp = calculate_metrics(image_tensor, compressed_perp)

            try:
                bpp_obj = calculate_bpp_improved(quantized_obj, processed_size)
                bpp_perp = calculate_bpp_improved(quantized_perp, processed_size)
            except:
                bpp_obj = calculate_bpp_simple(quantized_obj, processed_size)
                bpp_perp = calculate_bpp_simple(quantized_perp, processed_size)
            
            print(f"   iWaveV3-Obj:  PSNR: {psnr_obj:.2f}dB, SSIM: {ssim_obj:.4f}, BPP: {bpp_obj:.3f}")
            print(f"   iWaveV3-Perp: PSNR: {psnr_perp:.2f}dB, SSIM: {ssim_perp:.4f}, BPP: {bpp_perp:.3f}")
    
    # Display clean results with the balanced setting
    print(f"\nðŸŽ¯ Displaying detailed results with Balanced compression...")
    iwave_obj.qstep.data = torch.tensor(0.025).to(device)
    iwave_perp.qstep.data = torch.tensor(0.025).to(device)
    
    results = display_clean_results(image_tensor, iwave_obj, iwave_perp, device, test_image_path)
    
    return results
    

In [None]:
# ==================== MAIN EXECUTION ====================

if __name__ == "__main__":
    print(" iWaveV3 - Clean Compression Pipeline")
    print("=" * 60)
    
    start_time = time.time()
    
    try:
        # Train models
        iwave_obj, iwave_perp, device = train_models_on_dataset()
        
        # Test with clean interface
        test_results = test_single_image_clean(iwave_obj, iwave_perp, device)
        
        total_time = time.time() - start_time
        print(f"\n  Total execution time: {total_time:.2f} seconds")
        print(" Clean compression pipeline completed successfully!")
        
        print(f"\n All results saved in: {config.model_dir}")
        print("   - clean_compression_results.png (Complete visualization)")
        print("   - 1_original.png to 5_reconstructed_perp.png (Individual images)")
        
    except Exception as e:
        print(f" Error during execution: {e}")
        import traceback
        traceback.print_exc()

print(" iWaveV3 Clean Output Pipeline Ready!")