# VAE-GAN Architectures for Image Compression - Demo

This notebook demonstrates how to use the three implemented VAE-GAN architectures for image compression:
1. β-VAE-GAN
2. VQ-VAE-GAN
3. Hierarchical VAE-GAN

## Setup

First, let's import the necessary libraries and set up the environment.

In [None]:
import os
import sys
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from torchvision import transforms
from IPython.display import display

# Add parent directory to path for imports
sys.path.append('..')
# For loading images
transform = transforms.Compose([
    transforms.ToTensor()
])

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Load Pre-trained Models

Let's load the pre-trained models for each architecture. Change the paths to where your trained models are located.

In [None]:
# Load β-VAE-GAN
from beta_vae_gan.model import BetaVAEGAN
from beta_vae_gan.config import Config as BetaConfig

beta_config = BetaConfig()
beta_model = BetaVAEGAN(beta_config)
beta_model_path = './output/beta_vae_gan/checkpoints/best_model.pth'
if os.path.exists(beta_model_path):
    beta_model.load_state_dict(torch.load(beta_model_path, map_location=device))
    beta_model = beta_model.to(device)
    beta_model.eval()
    print("β-VAE-GAN model loaded successfully")
else:
    print(f"β-VAE-GAN model not found at {beta_model_path}")
    beta_model = None

In [None]:
# Load VQ-VAE-GAN
from vq_vae_gan.model import VQVAEGAN, compress_image
from vq_vae_gan.config import Config as VQConfig

vq_config = VQConfig()
vq_model = VQVAEGAN(vq_config)
vq_model_path = './output/vq_vae_gan/checkpoints/best_model.pth'
if os.path.exists(vq_model_path):
    vq_model.load_state_dict(torch.load(vq_model_path, map_location=device))
    vq_model = vq_model.to(device)
    vq_model.eval()
    print("VQ-VAE-GAN model loaded successfully")
else:
    print(f"VQ-VAE-GAN model not found at {vq_model_path}")
    vq_model = None

In [None]:
# Load Hierarchical VAE-GAN
from hierarchical_vae_gan.model import HierarchicalVAEGAN, compress_hierarchical_model
from hierarchical_vae_gan.config import Config as HierarchicalConfig

hier_config = HierarchicalConfig()
hier_model = HierarchicalVAEGAN(hier_config)
hier_model_path = './output/hierarchical_vae_gan/checkpoints/best_model.pth'
if os.path.exists(hier_model_path):
    hier_model.load_state_dict(torch.load(hier_model_path, map_location=device))
    hier_model = hier_model.to(device)
    hier_model.eval()
    print("Hierarchical VAE-GAN model loaded successfully")
else:
    print(f"Hierarchical VAE-GAN model not found at {hier_model_path}")
    hier_model = None

## Load Test Images

Let's load some test images from the Kodak dataset.

In [None]:
def load_image(path):
    """Load an image and convert to tensor"""
    img = Image.open(path).convert('RGB')
    tensor = transform(img).unsqueeze(0)
    return img, tensor

# Load a test image from Kodak dataset
kodak_dir = './data/kodak'
test_images = []
test_tensors = []

# Load first 4 images
for i in range(1, 5):
    img_path = os.path.join(kodak_dir, f'kodim{i:02d}.png')
    if os.path.exists(img_path):
        img, tensor = load_image(img_path)
        test_images.append(img)
        test_tensors.append(tensor)
        print(f"Loaded image {i} with shape {tensor.shape}")
    else:
        print(f"Image {i} not found at {img_path}")

# Display the test images
if test_images:
    plt.figure(figsize=(15, 5))
    for i, img in enumerate(test_images):
        plt.subplot(1, len(test_images), i+1)
        plt.imshow(img)
        plt.title(f"Test Image {i+1}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

## Compression and Reconstruction with β-VAE-GAN

Now let's compress and reconstruct the test images using the β-VAE-GAN model.

In [None]:
def compress_with_beta_vae_gan(model, image_tensor):
    """Compress and reconstruct an image using β-VAE-GAN"""
    if model is None:
        print("β-VAE-GAN model not loaded")
        return None, None, 0, 0
    
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        # Forward pass
        recon_x, z, mu, logvar = model(image_tensor)
        
        # Calculate BPP and compression ratio
        num_elements = z.numel()
        bits_per_element = 32  # assuming float32
        total_bits = num_elements * bits_per_element
        total_pixels = image_tensor.numel() / 3  # divide by 3 channels
        bpp = total_bits / total_pixels
        
        original_size = image_tensor.numel() * 8  # 8 bits per channel value
        compressed_size = total_bits
        compression_ratio = original_size / compressed_size
        
        # Convert to numpy for visualization
        original_np = image_tensor.cpu().squeeze().permute(1, 2, 0).numpy()
        recon_np = recon_x.cpu().squeeze().permute(1, 2, 0).numpy()
        
        return original_np, recon_np, bpp, compression_ratio

# Compress and reconstruct test images
if beta_model is not None and test_tensors:
    plt.figure(figsize=(15, 10))
    for i, tensor in enumerate(test_tensors):
        original_np, recon_np, bpp, compression_ratio = compress_with_beta_vae_gan(beta_model, tensor)
        
        if original_np is not None and recon_np is not None:
            # Calculate PSNR
            mse = np.mean((original_np - recon_np) ** 2)
            psnr = 20 * np.log10(1.0 / np.sqrt(mse))
            
            # Plot original and reconstructed
            plt.subplot(len(test_tensors), 2, 2*i+1)
            plt.imshow(np.clip(original_np, 0, 1))
            plt.title(f"Original Image {i+1}")
            plt.axis('off')
            
            plt.subplot(len(test_tensors), 2, 2*i+2)
            plt.imshow(np.clip(recon_np, 0, 1))
            plt.title(f"β-VAE-GAN: PSNR={psnr:.2f}dB, BPP={bpp:.4f}, CR={compression_ratio:.2f}x")
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()

## Compression and Reconstruction with VQ-VAE-GAN

Now let's compress and reconstruct the test images using the VQ-VAE-GAN model.

In [None]:
def compress_with_vq_vae_gan(model, image_tensor):
    """Compress and reconstruct an image using VQ-VAE-GAN"""
    if model is None:
        print("VQ-VAE-GAN model not loaded")
        return None, None, 0, 0
    
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        # Compress image
        compressed, metadata = compress_image(model, image_tensor, device)
        
        # Decompress image
        reconstructed = model.decompress_image(compressed, metadata, device)
        
        # Get BPP and compression ratio
        bpp = metadata['bpp']
        compression_ratio = metadata['compression_ratio']
        
        # Convert to numpy for visualization
        original_np = image_tensor.cpu().squeeze().permute(1, 2, 0).numpy()
        recon_np = reconstructed.cpu().squeeze().permute(1, 2, 0).numpy()
        
        return original_np, recon_np, bpp, compression_ratio

# Compress and reconstruct test images
if vq_model is not None and test_tensors:
    plt.figure(figsize=(15, 10))
    for i, tensor in enumerate(test_tensors):
        original_np, recon_np, bpp, compression_ratio = compress_with_vq_vae_gan(vq_model, tensor)
        
        if original_np is not None and recon_np is not None:
            # Calculate PSNR
            mse = np.mean((original_np - recon_np) ** 2)
            psnr = 20 * np.log10(1.0 / np.sqrt(mse))
            
            # Plot original and reconstructed
            plt.subplot(len(test_tensors), 2, 2*i+1)
            plt.imshow(np.clip(original_np, 0, 1))
            plt.title(f"Original Image {i+1}")
            plt.axis('off')
            
            plt.subplot(len(test_tensors), 2, 2*i+2)
            plt.imshow(np.clip(recon_np, 0, 1))
            plt.title(f"VQ-VAE-GAN: PSNR={psnr:.2f}dB, BPP={bpp:.4f}, CR={compression_ratio:.2f}x")
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()

## Compression and Reconstruction with Hierarchical VAE-GAN

Finally, let's compress and reconstruct the test images using the Hierarchical VAE-GAN model.

In [None]:
def compress_with_hierarchical_vae_gan(model, image_tensor, bit_allocation=[8, 6, 4]):
    """Compress and reconstruct an image using Hierarchical VAE-GAN"""
    if model is None:
        print("Hierarchical VAE-GAN model not loaded")
        return None, None, 0, 0
    
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        # Simulate compression with different bit allocations
        metrics, reconstructed, _ = compress_hierarchical_model(
            model, image_tensor, bit_allocation=bit_allocation, device=device
        )
        
        # Get BPP and compression ratio
        bpp = metrics['bpp']
        compression_ratio = metrics['compression_ratio']
        
        # Convert to numpy for visualization
        original_np = image_tensor.cpu().squeeze().permute(1, 2, 0).numpy()
        recon_np = reconstructed.cpu().squeeze().permute(1, 2, 0).numpy()
        
        return original_np, recon_np, bpp, compression_ratio

# Compress and reconstruct test images
if hier_model is not None and test_tensors:
    plt.figure(figsize=(15, 10))
    for i, tensor in enumerate(test_tensors):
        original_np, recon_np, bpp, compression_ratio = compress_with_hierarchical_vae_gan(
            hier_model, tensor, bit_allocation=[8, 6, 4]
        )
        
        if original_np is not None and recon_np is not None:
            # Calculate PSNR
            mse = np.mean((original_np - recon_np) ** 2)
            psnr = 20 * np.log10(1.0 / np.sqrt(mse))
            
            # Plot original and reconstructed
            plt.subplot(len(test_tensors), 2, 2*i+1)
            plt.imshow(np.clip(original_np, 0, 1))
            plt.title(f"Original Image {i+1}")
            plt.axis('off')
            
            plt.subplot(len(test_tensors), 2, 2*i+2)
            plt.imshow(np.clip(recon_np, 0, 1))
            plt.title(f"Hierarchical VAE-GAN: PSNR={psnr:.2f}dB, BPP={bpp:.4f}, CR={compression_ratio:.2f}x")
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()

## Comparing Different Bit Allocations for Hierarchical VAE-GAN

Let's compare the effect of different bit allocations for the hierarchical model.

In [None]:
# Different bit allocations to try
bit_allocations = [
    [8, 6, 4],  # Default - more bits for higher levels
    [6, 6, 6],  # Equal allocation
    [4, 6, 8],  # More bits for deeper levels
    [10, 5, 2], # Heavy focus on highest level
    [2, 5, 10]  # Heavy focus on deepest level
]

# Compress one test image with different bit allocations
if hier_model is not None and test_tensors:
    test_tensor = test_tensors[0]  # Use the first test image
    
    plt.figure(figsize=(20, 5 * len(bit_allocations)))
    
    # Plot original image
    original_img = test_tensor.cpu().squeeze().permute(1, 2, 0).numpy()
    plt.subplot(len(bit_allocations)+1, 3, 2)
    plt.imshow(np.clip(original_img, 0, 1))
    plt.title("Original Image")
    plt.axis('off')
    
    results = []
    
    # Compress with each bit allocation
    for i, bits in enumerate(bit_allocations):
        bits_str = ", ".join(map(str, bits))
        
        original_np, recon_np, bpp, compression_ratio = compress_with_hierarchical_vae_gan(
            hier_model, test_tensor, bit_allocation=bits
        )
        
        if original_np is not None and recon_np is not None:
            # Calculate PSNR
            mse = np.mean((original_np - recon_np) ** 2)
            psnr = 20 * np.log10(1.0 / np.sqrt(mse))
            
            # Store results
            results.append({
                'bit_allocation': bits,
                'psnr': psnr,
                'bpp': bpp,
                'compression_ratio': compression_ratio,
                'recon_np': recon_np
            })
            
            # Plot reconstructed image
            plt.subplot(len(bit_allocations)+1, 3, (i+1)*3 + 2)
            plt.imshow(np.clip(recon_np, 0, 1))
            plt.title(f"Bits=[{bits_str}]\nPSNR={psnr:.2f}dB, BPP={bpp:.4f}")
            plt.axis('off')
    
    # Plot rate-distortion curve
    plt.subplot(1, 3, 3)
    for result in results:
        plt.scatter(result['bpp'], result['psnr'], label=f"Bits={result['bit_allocation']}")
    plt.xlabel('Bits per Pixel (BPP)')
    plt.ylabel('PSNR (dB)')
    plt.title('Rate-Distortion for Different Bit Allocations')
    plt.grid(True)
    plt.legend()
    
    plt.tight_layout()
    plt.show()

## Side-by-Side Comparison of All Models

Finally, let's compare all models side by side on the same test image.

In [None]:
# Compare all models on the same test image
if test_tensors and (beta_model is not None or vq_model is not None or hier_model is not None):
    test_tensor = test_tensors[0]  # Use the first test image
    
    plt.figure(figsize=(20, 15))
    
    # Plot original image
    original_img = test_tensor.cpu().squeeze().permute(1, 2, 0).numpy()
    plt.subplot(2, 2, 1)
    plt.imshow(np.clip(original_img, 0, 1))
    plt.title("Original Image")
    plt.axis('off')
    
    # β-VAE-GAN
    if beta_model is not None:
        original_np, recon_np, bpp, compression_ratio = compress_with_beta_vae_gan(beta_model, test_tensor)
        if original_np is not None and recon_np is not None:
            mse = np.mean((original_np - recon_np) ** 2)
            psnr = 20 * np.log10(1.0 / np.sqrt(mse))
            
            plt.subplot(2, 2, 2)
            plt.imshow(np.clip(recon_np, 0, 1))
            plt.title(f"β-VAE-GAN\nPSNR={psnr:.2f}dB, BPP={bpp:.4f}, CR={compression_ratio:.2f}x")
            plt.axis('off')
    
    # VQ-VAE-GAN
    if vq_model is not None:
        original_np, recon_np, bpp, compression_ratio = compress_with_vq_vae_gan(vq_model, test_tensor)
        if original_np is not None and recon_np is not None:
            mse = np.mean((original_np - recon_np) ** 2)
            psnr = 20 * np.log10(1.0 / np.sqrt(mse))
            
            plt.subplot(2, 2, 3)
            plt.imshow(np.clip(recon_np, 0, 1))
            plt.title(f"VQ-VAE-GAN\nPSNR={psnr:.2f}dB, BPP={bpp:.4f}, CR={compression_ratio:.2f}x")
            plt.axis('off')
    
    # Hierarchical VAE-GAN
    if hier_model is not None:
        original_np, recon_np, bpp, compression_ratio = compress_with_hierarchical_vae_gan(
            hier_model, test_tensor, bit_allocation=[8, 6, 4]
        )
        if original_np is not None and recon_np is not None:
            mse = np.mean((original_np - recon_np) ** 2)
            psnr = 20 * np.log10(1.0 / np.sqrt(mse))
            
            plt.subplot(2, 2, 4)
            plt.imshow(np.clip(recon_np, 0, 1))
            plt.title(f"Hierarchical VAE-GAN\nPSNR={psnr:.2f}dB, BPP={bpp:.4f}, CR={compression_ratio:.2f}x")
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()

## Conclusion

In this notebook, we demonstrated how to use three different VAE-GAN architectures for image compression:

1. **β-VAE-GAN**: A basic VAE-GAN with a continuous latent space, controlled by the β parameter.
2. **VQ-VAE-GAN**: A Vector-Quantized VAE-GAN with a discrete latent space using a learned codebook.
3. **Hierarchical VAE-GAN**: A hierarchical model with multiple levels of latent representations at different scales.

Each architecture has its strengths and trade-offs in terms of reconstruction quality and compression ratio. The hierarchical model typically provides the best perceptual quality, while the VQ-VAE-GAN often achieves better compression rates.

For practical applications, the choice of architecture would depend on the specific requirements of the task, including the desired balance between compression ratio, reconstruction quality, and computational resources.