Encoder

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#Residual function block
def resnet_block(x, channels):
    shortcut = x
    x = F.conv2d(x, weight1, bias=bias1, stride=1, padding=1)
    x = F.batch_norm(x, running_mean1, running_var1, weight_bn1, bias_bn1, training=True)
    x = F.relu(x)
    x = F.conv2d(x, weight2, bias=bias2, stride=1, padding=1)
    x = F.batch_norm(x, running_mean2, running_var2, weight_bn2, bias_bn2, training=True)
    return F.relu(x + shortcut)


def build_encoder():
    layers = []

    # First downsampling
    layers.append(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3))
    layers.append(nn.BatchNorm2d(64))
    layers.append(nn.ReLU(inplace=True))

    # Residual block 1
    layers.append(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1))
    layers.append(nn.BatchNorm2d(64))
    layers.append(nn.ReLU(inplace=True))
    layers.append(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1))
    layers.append(nn.BatchNorm2d(64))
    layers.append(nn.ReLU(inplace=True))

    # Second downsampling
    layers.append(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1))
    layers.append(nn.BatchNorm2d(128))
    layers.append(nn.ReLU(inplace=True))

    # Residual block 2
    layers.append(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1))
    layers.append(nn.BatchNorm2d(128))
    layers.append(nn.ReLU(inplace=True))
    layers.append(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1))
    layers.append(nn.BatchNorm2d(128))
    layers.append(nn.ReLU(inplace=True))

    # Third downsampling
    layers.append(nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1))
    layers.append(nn.BatchNorm2d(128))
    layers.append(nn.ReLU(inplace=True))

    # Residual block 3
    layers.append(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1))
    layers.append(nn.BatchNorm2d(128))
    layers.append(nn.ReLU(inplace=True))
    layers.append(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1))
    layers.append(nn.BatchNorm2d(128))
    layers.append(nn.ReLU(inplace=True))

    return nn.Sequential(*layers)


encoder = build_encoder()
x = torch.randn(2, 3, 256, 256)
z = encoder(x)



In [None]:
def encoder_forward(x, encoder):
    skips = []
    idx = 0


    # First Downsampling
    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1


    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1
    skips.append(x)  # Save after first Residual block

    # Second Downsampling
    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1


    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1
    skips.append(x)  # Save after second Residual block

    # Third Downsampling
    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1


    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1
    x = encoder[idx](x); idx += 1
    skips.append(x)  # Save after third Residual block

    return x, skips


Decoder

In [None]:
# Upsampling Block (upsample spatial size and reduce channels)
def up_block(in_channels, out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

# U-Net Decoder
def build_unet(latent_channels=128, base_channels=64):
    encoder_channels = [64, 128, 128]  # Channels from the encoder at different levels

    decoder_layers = nn.ModuleList([
        up_block(latent_channels, encoder_channels[-1]),               # 128 → 128 (for skip3)
        up_block(encoder_channels[-1]*2, encoder_channels[-2]),         # 128+128 → 128 (for skip2)
        up_block(encoder_channels[-2]*2, encoder_channels[-3]),
        nn.Conv2d(encoder_channels[-3]*2, 3, kernel_size=3, stride=1, padding=1)  # 64+64 → 3 (final output)
    ])

    return decoder_layers

# Forward Function for Decoder
def forward_unet(x, skips, decoder_layers):
    for idx, layer in enumerate(decoder_layers[:-1]):  # Except the final conv
        x = layer(x)

        if idx < len(skips):
            skip = skips[-(idx + 1)]  # Take deepest skip first
            skip = F.interpolate(skip, size=x.shape[-2:], mode='nearest')  # Match spatial size
            x = torch.cat([x, skip], dim=1)  # Concatenate along channels

    # Final convolution to 3-channel output
    x = decoder_layers[-1](x)
    return x

Diffusion

In [None]:
# Beta schedule
def get_beta_schedule(n_timesteps, beta_start=1e-4, beta_end=0.02):
    return torch.linspace(beta_start, beta_end, n_timesteps)

# Compute alpha and cumulative alpha
def compute_alpha_terms(betas):
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return alphas, alphas_cumprod

# Sampling function
def q_sample(x0, t, alphas_cumprod):
    B, C, H, W = x0.shape
    noise = torch.randn_like(x0)
    alpha_t = alphas_cumprod[t].view(B, 1, 1, 1)
    x_n = torch.sqrt(alpha_t) * x0 + torch.sqrt(1 - alpha_t) * noise
    return x_n, noise


Loss Function

In [None]:
def loss_function(predicted_x0, true_x0):
    return F.mse_loss(predicted_x0, true_x0)


Load Datasets

In [None]:
import os
import random
from PIL import Image
import torch
import torchvision.transforms.functional as TF

# 1. Get all image file paths from a directory
def get_image_paths(directory):
    image_paths = []
    for file in os.listdir(directory):
        if file.endswith('.png'):
            full_path = os.path.join(directory, file)
            image_paths.append(full_path)
    image_paths = sorted(image_paths)
    return image_paths

# 2. Load and process one image
def process_image(image_path, crop_size=256):
    image = Image.open(image_path).convert('RGB')
    w, h = image.size

    # Crop if large enough, else resize
    if w >= crop_size and h >= crop_size:
        top = random.randint(0, h - crop_size)
        left = random.randint(0, w - crop_size)
        image = TF.crop(image, top, left, crop_size, crop_size)
    else:
        image = TF.resize(image, (crop_size, crop_size))

    return TF.to_tensor(image)

# 3. Create a batch (Single simplified line)
def create_batch(image_paths, batch_size=4, crop_size=256):
    selected_paths = random.sample(image_paths, batch_size)
    batch = [process_image(p, crop_size) for p in selected_paths]
    return torch.stack(batch)


In [None]:
# Set your dataset paths
train_dir = '/content/drive/MyDrive/datasets/DIV2K_train_HR'
val_dir = '/content/drive/MyDrive/datasets/DIV2K_valid_HR'

# Get image file paths
train_image_paths = get_image_paths(train_dir)
val_image_paths = get_image_paths(val_dir)

# Create a batch
train_batch = create_batch(train_image_paths, batch_size=4, crop_size=256)

print(f"Train batch shape: {train_batch.shape}")

Training

In [None]:
import torch
import torch.optim as optim
import os

# 1. Initialize model parts
encoder = build_encoder()
decoder_layers = build_unet(latent_channels=128, base_channels=64)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = encoder.to(device)
decoder_layers = decoder_layers.to(device)

optimizer = optim.Adam(list(encoder.parameters()) + list(decoder_layers.parameters()), lr=5e-4)

n_timesteps = 1000
betas = get_beta_schedule(n_timesteps)
alphas, alphas_cumprod = compute_alpha_terms(betas)
alphas_cumprod = alphas_cumprod.to(device)

train_image_paths = get_image_paths('/dgxa_home/se22uari031/DIV2K_train_HR')
val_image_paths = get_image_paths('/dgxa_home/se22uari031/DIV2K_valid_HR')
checkpoint_dir = '/dgxa_home/se22uari031/model_checkpoints'
best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')

os.makedirs(checkpoint_dir, exist_ok=True)

def compute_psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
    if mse == 0: return 100
    return 20 * torch.log10(1.0 / torch.sqrt(mse)).item()

def validate_model(encoder, decoder_layers, val_image_paths, batch_size=4, crop_size=256):
    encoder.eval()
    decoder_layers.eval()
    total_loss = 0.0
    total_psnr = 0.0
    count = 0

    with torch.no_grad():
        for _ in range(len(val_image_paths) // batch_size):
            real_images = create_batch(val_image_paths, batch_size=batch_size, crop_size=crop_size).to(device)
            latent, skips = encoder_forward(real_images, encoder)
            predicted_x0 = forward_unet(latent, skips, decoder_layers)
            loss = loss_function(predicted_x0, real_images)
            psnr_val = compute_psnr(predicted_x0, real_images)
            total_loss += loss.item()
            total_psnr += psnr_val
            count += 1

    avg_loss = total_loss / count
    avg_psnr = total_psnr / count



    print(f"[Validation] Loss: {avg_loss:.6f}, PSNR: {avg_psnr:.2f} dB")
    encoder.train()
    decoder_layers.train()
    return avg_loss, avg_psnr


num_epochs = 300
batch_size = 32

best_val_loss = float('inf')
best_val_psnr = 0.0



for epoch in range(num_epochs):
    running_loss = 0.0

    for _ in range(len(train_image_paths) // batch_size):
        real_images = create_batch(train_image_paths, batch_size=batch_size, crop_size=256).to(device)
        latent, skips = encoder_forward(real_images, encoder)
        t = torch.randint(0, n_timesteps, (real_images.shape[0],)).to(device)
        noisy_latent, noise = q_sample(latent, t, alphas_cumprod)
        predicted_x0 = forward_unet(noisy_latent, skips, decoder_layers)
        loss = loss_function(predicted_x0, real_images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / (len(train_image_paths) // batch_size)
    print(f"[Epoch {epoch+1}/{num_epochs}] Average Loss: {avg_loss:.6f}")

    # Validate every 10 epochs
    if (epoch + 1) % 10 == 0:
        val_loss, val_psnr = validate_model(encoder, decoder_layers, val_image_paths, batch_size=4)

        # SIMPLE direct check: save if ANY metric improves
        if (val_loss < best_val_loss) or (val_psnr > best_val_psnr):
            best_val_loss = min(best_val_loss, val_loss)
            best_val_psnr = max(best_val_psnr, val_psnr)

            torch.save({
                'encoder_state_dict': encoder.state_dict(),
                'decoder_state_dict': decoder_layers.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'metrics': {'val_loss': val_loss, 'val_psnr': val_psnr}
            }, best_model_path)
            print(f"[Best Model Saved] Epoch {epoch+1}")

    # Checkpoint every 100 epochs
    if (epoch + 1) % 100 == 0:
        checkpoint_path = os.path.join(checkpoint_dir, f"model_checkpoint_epoch_{epoch+1}.pth")
        torch.save({
            'encoder_state_dict': encoder.state_dict(),
            'decoder_state_dict': decoder_layers.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch
        }, checkpoint_path)
        print(f"[Checkpoint Saved] at: {checkpoint_path}")


Load The Trained Model

In [None]:
import torch

# 1. Rebuild the encoder and decoder (same architecture as during training)
encoder = build_encoder()
decoder_layers = build_unet(latent_channels=128, base_channels=64)

# 2. Move models to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = encoder.to(device)
decoder_layers = decoder_layers.to(device)

# 3. Load saved checkpoint
checkpoint_path = '/dgxa_home/se22uari031/model_checkpoints/best_model.pth'
checkpoint = torch.load(checkpoint_path, map_location=device)

# 4. Load the model parameters
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder_layers.load_state_dict(checkpoint['decoder_state_dict'])

# 5. Set models to eval
encoder.eval()
decoder_layers.eval()

Validation on Loss




In [None]:
import torch.nn.functional as F
import math

# Function to compute PSNR
def compute_psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
    if mse == 0:
        return 100
    psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr.item()

# Validation Function
def validate_model(encoder, decoder_layers, val_image_paths, batch_size=4, crop_size=256):
    encoder.eval()
    decoder_layers.eval()

    total_loss = 0.0
    total_psnr = 0.0
    count = 0

    with torch.no_grad():
        for _ in range(len(val_image_paths) // batch_size):
            real_images = create_batch(val_image_paths, batch_size=batch_size, crop_size=crop_size)
            real_images = real_images.to(device)

            latent, skips = encoder_forward(real_images, encoder)
            predicted_x0 = forward_unet(latent, skips, decoder_layers)

            loss = loss_function(predicted_x0, real_images)
            psnr = compute_psnr(predicted_x0, real_images)

            total_loss += loss.item()
            total_psnr += psnr
            count += 1

    avg_loss = total_loss / count
    avg_psnr = total_psnr / count

    print(f"Validation - Average Loss: {avg_loss:.6f}, Average PSNR: {avg_psnr:.2f} dB")

    encoder.train()
    decoder_layers.train()


In [None]:
val_image_paths = get_image_paths('/content/drive/MyDrive/datasets/DIV2K_valid_HR')

validate_model(encoder, decoder_layers, val_image_paths, batch_size=4, crop_size=256)


Display Images

In [None]:
import matplotlib.pyplot as plt
import torch

# Function to generate reconstructed images from real images
def reconstruct_images(encoder, decoder_layers, real_images):
    encoder.eval()
    decoder_layers.eval()

    with torch.no_grad():
        latent, skips = encoder_forward(real_images, encoder)
        predicted_x0 = forward_unet(latent, skips, decoder_layers)

    return predicted_x0

# Function to display input vs output side-by-side
def display_images(real_images, reconstructed_images, idx=0):
    """
    real_images: batch of real images [B, 3, H, W]
    reconstructed_images: batch of predicted images [B, 3, H, W]
    idx: which image in the batch to display
    """

    # Move to CPU and prepare for display
    real_img = real_images[idx].detach().cpu().numpy()
    real_img = real_img.transpose(1, 2, 0)  # C, H, W -> H, W, C
    real_img = real_img.clip(0, 1)

    recon_img = reconstructed_images[idx].detach().cpu().numpy()
    recon_img = recon_img.transpose(1, 2, 0)
    recon_img = recon_img.clip(0, 1)

    # Plot
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))

    axs[0].imshow(real_img)
    axs[0].set_title("Original Image")
    axs[0].axis('off')

    axs[1].imshow(recon_img)
    axs[1].set_title("Reconstructed Image")
    axs[1].axis('off')

    plt.show()


In [None]:

val_image_paths = get_image_paths('/content/drive/MyDrive/datasets/DIV2K_valid_HR')

# Create a small batch
real_images = create_batch(val_image_paths, batch_size=4, crop_size=256)
real_images = real_images.to(device)

# Generate reconstructed images
reconstructed_images = reconstruct_images(encoder, decoder_layers, real_images)


display_images(real_images, reconstructed_images, idx=0)
display_images(real_images, reconstructed_images, idx=1)
display_images(real_images, reconstructed_images, idx=2)

Calculating Compression Ratio

In [1]:
#pixel based
input_image_size = real_images.shape[1] * real_images.shape[2] * real_images.shape[3]
latent_size = latent.shape[1] * latent.shape[2] * latent.shape[3]

print(f"Input image values{input_image_size}")
print(f"Latent values: {latent_size}")
compression_ratio = input_image_size / latent_size
print(f"Compression Ratio: {compression_ratio:.2f}x")


In [None]:
#memory based
def tensor_size_bytes(tensor):
    return tensor.nelement() * tensor.element_size()

input_size = tensor_size_bytes(real_images)
latent_size = tensor_size_bytes(latent)

compression_ratio = input_size / latent_size

print(f"Input image size: {input_size / 1e6:.4f} MB")
print(f"Latent tensor size: {latent_size / 1e6:.4f} MB")
print(f"Compression ratio: {compression_ratio:.2f}x")


Test on an Image

In [None]:
import torch
from torchvision.utils import save_image


encoder = build_encoder()
decoder_layers = build_unet(latent_channels=128, base_channels=64)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = encoder.to(device)
decoder_layers = decoder_layers.to(device)

checkpoint = torch.load('/content/drive/MyDrive/datasets/best_model.pth', map_location=device)
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder_layers.load_state_dict(checkpoint['decoder_state_dict'])

encoder.eval()
decoder_layers.eval()


test_image_path = '/content/sample_photos/FjikPptEbZg.jpg'

def process_image(image_path, crop_size=256):
    from PIL import Image
    import torchvision.transforms as T
    image = Image.open(image_path).convert('RGB')
    transform = T.Compose([
        T.Resize((crop_size, crop_size)),
        T.ToTensor()
    ])
    return transform(image)

real_image = process_image(test_image_path).unsqueeze(0).to(device)


latent1, skips = encoder_forward(real_image, encoder)


reconstructed_image = forward_unet(latent1, skips, decoder_layers)


reconstructed_image = torch.clamp(reconstructed_image, 0, 1)


def compute_psnr(img1, img2):
    import torch.nn.functional as F
    mse = F.mse_loss(img1, img2)
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

psnr_val = compute_psnr(reconstructed_image, real_image)
print(f"PSNR on {test_image_path}: {psnr_val:.2f} dB")


save_image(reconstructed_image, f'reconstructed_{test_image_path.split("/")[-1]}')


Plot the image

In [None]:
import matplotlib.pyplot as plt

# Move tensors to CPU and convert to numpy for visualization
real_image_np = real_image.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
reconstructed_np = reconstructed_image.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()

# Plot
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(real_image_np)
plt.title('Original Image')
plt.axis('off')

plt.subplot(1,2,2)
plt.imshow(reconstructed_np)
plt.title('Reconstructed Image')
plt.axis('off')

plt.show()


Load Images from a folder

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  # Normalized to [0, 1]
])

def load_test_images(folder_path):
    image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith(('png', 'jpg', 'jpeg'))]
    images = []

    for img_path in image_paths[:5]:  # Limit to first 5 images
        img = Image.open(img_path).convert('RGB')
        img = transform(img)
        images.append(img)

    return torch.stack(images), image_paths[:5]

# Example: load from './test_data'
test_images, image_paths = load_test_images('/content/sample_photos')
print(f"Loaded {len(test_images)} test images")


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = encoder.to(device).eval()
decoder_layers = decoder_layers.to(device).eval()

test_images = test_images.to(device)


latent, skips = encoder_forward(test_images, encoder)
reconstructed_images = forward_unet(latent, skips, decoder_layers)



SSIM and MS SSIM

In [None]:
def compute_ssim(img1, img2):
    img1_np = img1.permute(1, 2, 0).detach().cpu().numpy()
    img2_np = img2.permute(1, 2, 0).detach().cpu().numpy()

    score = ssim(img1_np, img2_np, channel_axis=-1, data_range=1.0)
    return score


for i in range(test_images.shape[0]):
    ssim_val = compute_ssim(test_images[i], reconstructed_images[i])
    msssim_val = ms_ssim(test_images[i].unsqueeze(0), reconstructed_images[i].unsqueeze(0), data_range=1.0).item()
    print(f"Image: {image_paths[i]} | SSIM: {ssim_val:.4f} | MS-SSIM: {msssim_val:.4f}")

Load images from folder and calculate metrics on sample images

In [None]:
import torch.nn.functional as F

# Resize reconstructed images to match test_images shape
reconstructed_images_resized = F.interpolate(reconstructed_images, size=test_images.shape[-2:], mode='bilinear', align_corners=False)

# Ensure we loop only over valid image pairs
num_images = min(test_images.shape[0], reconstructed_images_resized.shape[0])

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4 * num_images))

for i in range(num_images):
    # Original image
    orig_img = test_images[i].permute(1, 2, 0).detach().cpu().numpy()
    # Resized Reconstructed image
    recon_img = reconstructed_images_resized[i].permute(1, 2, 0).detach().cpu().numpy()

    # Clamp to valid range [0, 1]
    orig_img = orig_img.clip(0, 1)
    recon_img = recon_img.clip(0, 1)

    # Compute SSIM & MS-SSIM
    ssim_val = compute_ssim(test_images[i], reconstructed_images_resized[i])
    msssim_val = ms_ssim(test_images[i].unsqueeze(0), reconstructed_images_resized[i].unsqueeze(0), data_range=1.0).item()

    # Plot Original
    plt.subplot(num_images, 2, i * 2 + 1)
    plt.imshow(orig_img)
    plt.axis('off')
    plt.title(f"Original: {os.path.basename(image_paths[i])}", fontsize=10)

    # Plot Reconstructed
    plt.subplot(num_images, 2, i * 2 + 2)
    plt.imshow(recon_img)
    plt.axis('off')
    plt.title(f"Reconstructed\nSSIM: {ssim_val:.4f} | MS-SSIM: {msssim_val:.4f}", fontsize=10)

plt.tight_layout()
plt.show()