# Run

In [None]:
!pip install diffusers==0.8.0
!pip uninstall jax jaxlib -y

Collecting diffusers==0.8.0
  Downloading diffusers-0.8.0-py3-none-any.whl.metadata (31 kB)
Downloading diffusers-0.8.0-py3-none-any.whl (433 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m433.8/433.8 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: diffusers
Successfully installed diffusers-0.8.0
Found existing installation: jax 0.4.33
Uninstalling jax-0.4.33:
  Successfully uninstalled jax-0.4.33
Found existing installation: jaxlib 0.4.33
Uninstalling jaxlib-0.4.33:
  Successfully uninstalled jaxlib-0.4.33


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#Much better vae module
# Load VAE model
from diffusers import AutoencoderKL
device = "cuda:0"
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
vae.eval()
vae.half()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  return torch.load(checkpoint_file, map_location="cpu")


AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (c

#Pixel Level Self Attention Model

In [None]:
# Install necessary libraries (if not already installed)
!pip install ftfy pycocotools diffusers optuna

# Import Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.cuda.amp import autocast
from diffusers import AutoencoderKL
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.utils.checkpoint


class SelfAttention(nn.Module):
    def __init__(self, in_dim, downsample_factor=2):
        super(SelfAttention, self).__init__()
        self.in_dim = in_dim
        self.downsample_factor = downsample_factor
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch, C, width, height = x.size()

        if not self.training:
            # Downsample the feature maps to reduce memory usage during inference
            x_down = F.interpolate(x, scale_factor=1 / self.downsample_factor, mode='bilinear', align_corners=False)
            width_down, height_down = x_down.size(2), x_down.size(3)
        else:
            x_down = x
            width_down, height_down = width, height

        # Compute projections
        proj_query = self.query_conv(x_down).view(batch, -1, width_down * height_down).permute(0, 2, 1)
        proj_key = self.key_conv(x_down).view(batch, -1, width_down * height_down)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x_down).view(batch, -1, width_down * height_down)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch, C, width_down, height_down)

        if not self.training:
            # Upsample back to the original size
            out = F.interpolate(out, size=(width, height), mode='bilinear', align_corners=False)

        out = self.gamma * out + x
        return out
class UNetWithParams(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNetWithParams, self).__init__()
        self.in_channels = in_channels

        self.param_fc = nn.Linear(1, 16)

        self.enc1 = self.conv_block(in_channels + 16, 64)
        self.enc2 = self.conv_block(64, 128)
        self.att1 = SelfAttention(128, downsample_factor=2)  # Adjusted attention
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)

        self.bottleneck = self.conv_block(512, 1024)

        self.dec4 = self.up_conv(1024, 512)
        self.dec4_conv = self.conv_block(1024, 512)

        self.dec3 = self.up_conv(512, 256)
        self.dec3_conv = self.conv_block(512, 256)
        self.att2 = SelfAttention(256, downsample_factor=2)  # Adjusted attention

        self.dec2 = self.up_conv(256, 128)
        self.dec2_conv = self.conv_block(256, 128)

        self.dec1 = self.up_conv(128, 64)
        self.dec1_conv = self.conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        return block

    def up_conv(self, in_channels, out_channels):
        up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        return up

    def forward(self, x, params):
        # Embed parameters
        param_embedding = self.param_fc(params)
        param_embedding = param_embedding.view(-1, 16, 1, 1)
        param_embedding = param_embedding.expand(-1, -1, x.size(2), x.size(3))
        x = torch.cat((x, param_embedding), dim=1)

        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(F.max_pool2d(e1, 2))
        e2 = self.att1(e2)
        e3 = self.enc3(F.max_pool2d(e2, 2))
        e4 = self.enc4(F.max_pool2d(e3, 2))

        # Bottleneck
        b = self.bottleneck(F.max_pool2d(e4, 2))

        # Decoder
        d4 = self.dec4(b)
        d4 = torch.cat((d4, e4), dim=1)
        d4 = self.dec4_conv(d4)

        d3 = self.dec3(d4)
        d3 = torch.cat((d3, e3), dim=1)
        d3 = self.dec3_conv(d3)
        d3 = self.att2(d3)

        d2 = self.dec2(d3)
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.dec2_conv(d2)

        d1 = self.dec1(d2)
        d1 = torch.cat((d1, e1), dim=1)
        d1 = self.dec1_conv(d1)

        out = self.final_conv(d1)
        return out




# Load Model

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
import torch

#model=UNetWithConditioning()
model=UNetWithParams()
# Path to your saved model (assuming it's in Google Drive as previously discussed)
model_path = '/content/drive/MyDrive/Models/PixelLevelRestorationVAECont2.1/unet_best_checkpoint.pth'

# Load the entire checkpoint (which contains more than just the model's state_dict)
checkpoint = torch.load(model_path)

# Extract only the model's state_dict from the checkpoint
model.load_state_dict(checkpoint['model_state_dict'])


# Set the model to evaluation mode if you’re doing inference
model.eval()


# model2=UNetWithConditioningResidual()

# # Path to your saved model (assuming it's in Google Drive as previously discussed)
# model_path = '/content/drive/MyDrive/Models/Model_Train3.1/unet_final_checkpoint.pth'

# # Load the entire checkpoint (which contains more than just the model's state_dict)
# checkpoint = torch.load(model_path)

# # Extract only the model's state_dict from the checkpoint
# model2.load_state_dict(checkpoint['model_state_dict'])


# # Set the model to evaluation mode if you’re doing inference
# model2.eval()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


  checkpoint = torch.load(model_path)


UNetWithParams(
  (param_fc): Linear(in_features=1, out_features=16, bias=True)
  (enc1): Sequential(
    (0): Conv2d(19, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (enc2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (att1): SelfAttention(
    (query_conv): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
    (key_conv): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
    (value_conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
    (softmax): Softmax(dim=-1)
  )
  (enc3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1),

#Quantize

## Load Files

In [None]:
def load_5125(image, left=0, right=0, top=0, bottom=0):
    # Convert image to NumPy array if necessary
    if isinstance(image, Image.Image):
        image = np.array(image)
    elif isinstance(image, torch.Tensor):
        image = image.permute(1, 2, 0).cpu().numpy()

    # Process image cropping
    h, w, _ = image.shape
    image = image[top:h - bottom, left:w - right]

    # Compute new dimensions while maintaining aspect ratio
    target_size = (768, 768)  # width, height
    scale = min(target_size[0] / w, target_size[1] / h)
    new_w = int(w * scale)
    new_h = int(h * scale)

    # Resize the image
    image_resized = Image.fromarray(image.astype(np.uint8)).resize((new_w, new_h), Image.BICUBIC)
    image_resized = np.array(image_resized)

    # Create a new image and paste the resized image into it
    image_padded = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8)
    pad_left = (target_size[0] - new_w) // 2
    pad_top = (target_size[1] - new_h) // 2
    image_padded[pad_top:pad_top + new_h, pad_left:pad_left + new_w, :] = image_resized

    # Convert to tensor and normalize
    image = torch.from_numpy(image_padded).float() / 255.0
    image = (image * 2.0) - 1.0
    image = image.permute(2, 0, 1)

    return image

def load_5123_CenterCropping(image, left=0, right=0, top=0, bottom=0):
    # If image is a PIL Image, convert to NumPy array
    if isinstance(image, Image.Image):
        image = np.array(image)
    elif isinstance(image, torch.Tensor):
        # Convert the tensor to NumPy array
        image = image.permute(1, 2, 0).cpu().numpy()

    # Process image cropping
    h, w, _ = image.shape
    left = min(left, w - 1)
    right = min(right, w - left - 1)
    top = min(top, h - 1)
    bottom = min(bottom, h - top - 1)
    image = image[top:h - bottom, left:w - right]

    # Handle aspect ratio and resize
    h, w, _ = image.shape
    if h < w:
        offset = (w - h) // 2
        image = image[:, offset:offset + h]
    elif w < h:
        offset = (h - w) // 2
        image = image[offset:offset + w]

    # Resize to 512x512
    image = Image.fromarray(image.astype(np.uint8)).resize((768, 768))
    image = np.array(image)

    # Convert back to PyTorch tensor and normalize
    image = torch.from_numpy(image).float() / 255.0  # Normalize to [0, 1]
    image = (image * 2.0) - 1.0  # Normalize to [-1, 1]
    image = image.permute(2, 0, 1)  # Adjust dimensions for PyTorch [C, H, W]

    return image  # Do not move to device here

import numpy as np
import torch
from PIL import Image



def simple_load(img, target_size=(512, 512)):
    # Convert image to PIL Image if it's not already
    if not isinstance(img, Image.Image):
        img = Image.fromarray(img)

    # Resize the image directly to the target size without maintaining aspect ratio
    img_resized = img.resize(target_size, Image.LANCZOS)

    # Convert the resized image to a PyTorch tensor and normalize
    image_tensor = torch.from_numpy(np.array(img_resized)).float() / 255.0  # Normalize to [0, 1]
    image_tensor = (image_tensor * 2.0) - 1.0  # Normalize to [-1, 1]
    image_tensor = image_tensor.permute(2, 0, 1)  # Adjust dimensions to [C, H, W]

    # Return both the image tensor and the resized PIL Image
    return image_tensor, img_resized






## Quantization Function

In [None]:
from utils import image_grid
!pip install brotli
import brotli
import zlib
import numpy as np
import torch
from PIL import Image
from torch.cuda.amp import autocast

def quantize(latents, parameter=0.58215):
    quantized_latents = (latents / (255 * parameter) + 0.5).clamp(0, 1)
    quantized = quantized_latents.cpu().permute(0, 2, 3, 1).detach().numpy()[0]
    quantized = (quantized * 255.0 + 0.5).astype(np.uint8)
    return quantized

def unquantize(quantized, parameter=0.58215):
    unquantized = quantized.astype(np.float32) / 255.0
    unquantized = unquantized[None].transpose(0, 3, 1, 2)
    unquantized_latents = (unquantized - 0.5) * (255 * parameter)
    unquantized_latents = torch.from_numpy(unquantized_latents)
    return unquantized_latents.to(device)

def quantize_function(image_path, model, parameter = 0.58215, target_bpp=0.15, tolerance=0.005, max_iterations=20):
    offsets = (0, 0, 0, 0)  # Offset for image loading
    image_gt_path = image_path
    img = Image.open(image_gt_path).convert('RGB')
    x0 = load_5123_CenterCropping(img)  # Load and preprocess image
    x0 = x0.half().to(device)  # Ensure the input is in half precision
    x0 = x0.unsqueeze(0)

    with autocast(enabled=True):
        w0 = (vae.encode(x0).latent_dist.mode() * 0.18215).float()


    best_parameter = parameter
    best_bpp = float('inf')
    best_quantized = None

    for iteration in range(max_iterations):
        quantized = quantize(w0, parameter)

        # Move tensor to CPU (if on GPU) and convert to numpy array
        latents_numpy = quantized

        # Convert numpy array to byte stream
        latents_bytes = latents_numpy.tobytes()

        # Compress the byte stream using brotli (you could use zlib or both as needed)
        compressed_bytes = brotli.compress(latents_bytes, quality=11)

        # Calculate bpp
        image_pixels = 768 * 768  # Get image size (width * height)
        bpp = (len(compressed_bytes) * 8) / image_pixels  # Convert bytes to bits and divide by number of pixels

        print(f"Iteration {iteration + 1}: Parameter = {parameter}, Brotli Compressed bpp = {bpp:.6f}")

        # Check if this is the closest bpp we've achieved
        if abs(bpp - target_bpp) < abs(best_bpp - target_bpp):
            best_bpp = bpp
            best_parameter = parameter
            best_quantized = quantized

        # Check if we've reached the target bpp within the tolerance
        if abs(bpp - target_bpp) <= tolerance:
            print("Target bpp reached!")
            break

        # Adjust the parameter to get closer to the target bpp
        if bpp > target_bpp:
            parameter *= 1.1  # Increase parameter to lower bpp
        else:
            parameter *= 0.9  # Decrease parameter to raise bpp

    # Save the best quantized latents
    latents_numpy = best_quantized

    # Store the dtype and shape of the array
    latents_dtype = latents_numpy.dtype
    latents_shape = latents_numpy.shape

    # Convert numpy array to byte stream and compress
    latents_bytes = latents_numpy.tobytes()
    compressed_bytes = brotli.compress(latents_bytes, quality=11)

    # Save the compressed latents
    with open("compressed_latents.bin", "wb") as f:
        f.write(compressed_bytes)

    # Save the dtype and shape for later use
    np.savez("latents_metadata.npz", dtype=latents_dtype.name, shape=latents_shape)

    # Load the compressed data and metadata
    with open("compressed_latents.bin", "rb") as f:
        compressed_latents = f.read()

    # Load the metadata
    metadata = np.load("latents_metadata.npz")
    latents_dtype_name = metadata['dtype'].item()
    latents_shape = tuple(metadata['shape'])
    latents_dtype = np.dtype(latents_dtype_name)

    # Decompress the data
    decompressed_bytes = brotli.decompress(compressed_latents)
    decompressed_numpy = np.frombuffer(decompressed_bytes, dtype=latents_dtype)
    decompressed_numpy = decompressed_numpy.reshape(latents_shape)

    latents1 = unquantize(decompressed_numpy, best_parameter)
    model = model.to(device)
    noise_level = torch.tensor([best_parameter], dtype=torch.float32).to(device)  # Single value for noise level
    latents1=latents1.half()
    # Decode the image using VAE
    with torch.no_grad():
       # Decode the unquantized latents to images

        decoded_unquantized_images = vae.decode(1 / 0.18215 * (latents1 )).sample

        # Model prediction
        residual_tensor = model(decoded_unquantized_images, noise_level)

        # Reconstructed images
        x0_dec = decoded_unquantized_images + residual_tensor

    img = image_grid(x0_dec)  # Create an image grid for the decoded image

    return img




#PSNR

In [None]:
from skimage.metrics import structural_similarity as get_ssim
from skimage.metrics import peak_signal_noise_ratio as get_psnr
from PIL import Image

def psnr_function(gt, img):
    gt_array = np.array(gt)
    img_array = np.array(img)

    # Debugging dimensions and channels
    print(f"GT Image Shape: {gt_array.shape}, Processed Image Shape: {img_array.shape}")

    # Check dimensions and channels
    if gt_array.ndim != 3 or img_array.ndim != 3:
        print("Error: Images must have three dimensions (height, width, channels).")
        return
    if gt_array.shape[0] < 7 or gt_array.shape[1] < 7 or img_array.shape[0] < 7 or img_array.shape[1] < 7:
        print("Warning: Images are too small for SSIM calculation. Ensure images are at least 7x7 pixels.")
        return

    # Ensure images have three channels
    if gt_array.shape[2] != 3 or img_array.shape[2] != 3:
        print("Error: Images must have three channels.")
        return

    # Calculate metrics with explicitly set parameters
    try:
        psnr_value = get_psnr(gt_array, img_array)
        ssim_value = get_ssim(gt_array, img_array, win_size=7, multichannel=True, channel_axis=2, data_range=img_array.max() - img_array.min())
        return psnr_value, ssim_value

    except Exception as e:
        print("Failed to calculate metrics:", str(e))


# Average

In [None]:
import os
from PIL import Image
import numpy as np

def load_with_aspect_ratio2(img, target_size=(512, 512)):
    image = np.array(img)
    h, w, _ = image.shape

    # Calculate the scaling factor
    scale = min(target_size[0] / h, target_size[1] / w)
    new_size = (int(w * scale), int(h * scale))

    # Resize the image
    image = Image.fromarray(image).resize(new_size, Image.LANCZOS)
    image = np.array(image)

    # Create a new image with the target size and paste the resized image
    new_image = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8)
    offset_y = (target_size[1] - new_size[1]) // 2
    offset_x = (target_size[0] - new_size[0]) // 2
    new_image[offset_y:offset_y + new_size[1], offset_x:offset_x + new_size[0]] = image

    # Convert to PyTorch tensor and normalize
    image = Image.fromarray(new_image)

    return image

# Custom Preprocessing Function
def load_5123_CenterCropping2(image_path, left=0, right=0, top=0, bottom=0):
    image=Image.open(image_path).convert('RGB')
    # If image is a PIL Image, convert to NumPy array
    if isinstance(image, Image.Image):
        image = np.array(image)
    elif isinstance(image, torch.Tensor):
        # Convert the tensor to NumPy array
        image = image.permute(1, 2, 0).cpu().numpy()

    # Process image cropping
    h, w, _ = image.shape
    left = min(left, w - 1)
    right = min(right, w - left - 1)
    top = min(top, h - 1)
    bottom = min(bottom, h - top - 1)
    image = image[top:h - bottom, left:w - right]

    # Handle aspect ratio and resize
    h, w, _ = image.shape
    if h < w:
        offset = (w - h) // 2
        image = image[:, offset:offset + h]
    elif w < h:
        offset = (h - w) // 2
        image = image[offset:offset + w]

    # Resize to 512x512
    image = Image.fromarray(image.astype(np.uint8)).resize((768, 768))

    return image  # Do not move to device here


def load_5122(image_path, left=0, right=0, top=0, bottom=0):
    # Load image
    if isinstance(image_path, str):
        image = Image.open(image_path).convert('RGB')
    elif isinstance(image_path, Image.Image):
        image = image_path
    else:
        raise ValueError("Input must be a file path or PIL Image.")

    # Convert to NumPy array
    image = np.array(image)

    # Process image cropping
    h, w, c = image.shape
    left = min(left, w - 1)
    right = min(right, w - left - 1)
    top = min(top, h - 1)
    bottom = min(bottom, h - top - 1)
    image = image[top:h - bottom, left:w - right]

    # Compute new dimensions while maintaining aspect ratio
    target_size = (768, 768)  # width, height
    scale = min(target_size[0] / w, target_size[1] / h)
    new_w = int(w * scale)
    new_h = int(h * scale)

    # Resize the image
    image_resized = Image.fromarray(image.astype(np.uint8)).resize((new_w, new_h), Image.BICUBIC)

    # Create a new image and paste the resized image into it
    image_padded = Image.new('RGB', target_size, (0, 0, 0))  # You can change the fill color if needed
    pad_left = (target_size[0] - new_w) // 2
    pad_top = (target_size[1] - new_h) // 2
    image_padded.paste(image_resized, (pad_left, pad_top))

    return image_padded  # Returns a PIL Image object



# Custom Preprocessing Function
def load_5124(image, left=0, right=0, top=0, bottom=0):

    image = np.array(image)[:, :, :3]


    h, w, c = image.shape
    left = min(left, w-1)
    right = min(right, w - left - 1)
    top = min(top, h - left - 1)
    bottom = min(bottom, h - top - 1)
    image = image[top:h-bottom, left:w-right]
    h, w, c = image.shape
    if h < w:
        offset = (w - h) // 2
        image = image[:, offset:offset + h]
    elif w < h:
        offset = (h - w) // 2
        image = image[offset:offset + w]
    image = Image.fromarray(image).resize((768, 768))
    return image

def calculate_average_psnr(folder_path):
    """
    Calculate the average PSNR after quantizing all images in a folder.

    Parameters:
    folder_path (str): The path to the folder containing the images.
    quantize_function (function): The function to quantize the images.
    psnr_function (function): The function to calculate PSNR between two images.

    Returns:
    float: The average PSNR across all images.
    """
    psnr_sum = 0
    ssim_sum = 0
    image_count = 0

    # Iterate over all files in the folder
    for filename in os.listdir(folder_path):
        if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
            # Load the image
            image_path = os.path.join(folder_path, filename)
            image=Image.open(image_path).convert('RGB')
            original_image = load_5123_CenterCropping2(image_path)


            # Quantize the imag
            quantized_image =  quantize_function(image_path, model, target_bpp=0.25, tolerance=0.0005, max_iterations=20)
            #quantized_image=load_5124(quantized_image)
            # Calculate PSNR
            psnr_value, ssim_value = psnr_function(original_image, quantized_image)
            psnr_sum += psnr_value
            ssim_sum += ssim_value
            image_count += 1
    display(original_image)
    display(quantized_image)
    # Calculate the average PSNR
    if image_count == 0:
        return 0.0  # Avoid division by zero if no images were found

    average_psnr = psnr_sum / image_count
    average_ssim = ssim_sum / image_count
    print(f"Average PSNR: {average_psnr}")
    print(f"Average SSIM: {average_ssim}")
    return average_psnr
calculate_average_psnr("/content/")

  with autocast(enabled=True):


Iteration 1: Parameter = 0.58215, Brotli Compressed bpp = 0.147664
Iteration 2: Parameter = 0.5239349999999999, Brotli Compressed bpp = 0.156250
Iteration 3: Parameter = 0.47154149999999995, Brotli Compressed bpp = 0.166368
Iteration 4: Parameter = 0.42438734999999994, Brotli Compressed bpp = 0.175944
Iteration 5: Parameter = 0.38194861499999994, Brotli Compressed bpp = 0.185371
Iteration 6: Parameter = 0.3437537534999999, Brotli Compressed bpp = 0.194906
Iteration 7: Parameter = 0.30937837814999997, Brotli Compressed bpp = 0.204983
Iteration 8: Parameter = 0.278440540335, Brotli Compressed bpp = 0.214966
Iteration 9: Parameter = 0.2505964863015, Brotli Compressed bpp = 0.225532
Iteration 10: Parameter = 0.22553683767135, Brotli Compressed bpp = 0.236382
Iteration 11: Parameter = 0.202983153904215, Brotli Compressed bpp = 0.246908
Iteration 12: Parameter = 0.18268483851379352, Brotli Compressed bpp = 0.256022
Iteration 13: Parameter = 0.20095332236517288, Brotli Compressed bpp = 0.2476