In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
# !pip install torch torchvision visdom mlflow scikit-image


In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms
from skimage.metrics import peak_signal_noise_ratio, structural_similarity, mean_squared_error
from tqdm import tqdm

In [4]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
# Custom Dataset Class
class CustomImageDataset(Dataset):
    def __init__(self, image_dir_raw, image_dir_ref, transform=None):
        self.image_dir_raw = image_dir_raw
        self.image_dir_ref = image_dir_ref
        self.transform = transform
        self.raw_image_filenames = sorted([f for f in os.listdir(image_dir_raw) if os.path.isfile(os.path.join(image_dir_raw, f))])
        self.ref_image_filenames = sorted([f for f in os.listdir(image_dir_ref) if os.path.isfile(os.path.join(image_dir_ref, f))])
        
    def __len__(self):
        return len(self.raw_image_filenames)

    def __getitem__(self, idx):
        raw_img_path = os.path.join(self.image_dir_raw, self.raw_image_filenames[idx])
        ref_img_path = os.path.join(self.image_dir_ref, self.ref_image_filenames[idx])
        
        raw_image = Image.open(raw_img_path).convert("RGB")
        ref_image = Image.open(ref_img_path).convert("RGB")
        
        if self.transform:
            raw_image = self.transform(raw_image)
            ref_image = self.transform(ref_image)
            
        return raw_image, ref_image

In [6]:
# Define the Variational AutoEncoder Model
class VariationalAutoEncoder(nn.Module):
    def __init__(self, features_d):
        super(VariationalAutoEncoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            self._enc_block(3, features_d, 3, 2, 1),
            self._enc_block(features_d, features_d * 2, 3, 2, 1),
            self._enc_block(features_d * 2, features_d * 4, 3, 2, 1),
        )
        # Latent space
        self.conv_mu = nn.Conv2d(features_d * 4, features_d * 4, kernel_size=3, stride=1, padding=1)
        self.conv_logvar = nn.Conv2d(features_d * 4, features_d * 4, kernel_size=3, stride=1, padding=1)
        # Decoder
        self.decoder = nn.Sequential(
            self._dec_block(features_d * 4, features_d * 2, 3, 1, 1),
            self._dec_block(features_d * 2, features_d, 3, 1, 1),
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(features_d, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid(),
        )

    def _enc_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def _dec_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, imgs):
        x = self.encoder(imgs)
        mu = self.conv_mu(x)
        logvar = self.conv_logvar(x)
        z = self.reparameterize(mu, logvar)
        reconstructed_imgs = self.decoder(z)
        return reconstructed_imgs


In [7]:
# Initialize the model
model = VariationalAutoEncoder(features_d=64).to(device)

In [8]:
# Load Dataset
train_dir = '/kaggle/input/wec-task-2/Train'
test_dir = '/kaggle/input/wec-task-2/Test'

train_raw_dir = os.path.join(train_dir, 'Raw')
train_ref_dir = os.path.join(train_dir, 'Reference')

test_raw_dir = os.path.join(test_dir, 'Raw')
test_ref_dir = os.path.join(test_dir, 'Reference')

In [9]:
import torchvision.transforms as transforms

# Define the transform with resizing to (480, 640)
transform = transforms.Compose([
    transforms.Resize((480, 640)),  # Resize to fixed size
    transforms.ToTensor()           # Convert image to tensor
])

In [10]:
train_dataset = CustomImageDataset(image_dir_raw=train_raw_dir, image_dir_ref=train_ref_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

test_dataset = CustomImageDataset(image_dir_raw=test_raw_dir, image_dir_ref=test_ref_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [11]:
# Training Function
def train_model(model, train_loader, test_loader, epochs=10, lr=0.001):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for noisy, clean in tqdm(train_loader):
            noisy, clean = noisy.to(device), clean.to(device)
            optimizer.zero_grad()
            reconstructed = model(noisy)
            loss = criterion(reconstructed, clean)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Evaluate on the test set
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for noisy, clean in test_loader:
                noisy, clean = noisy.to(device), clean.to(device)
                reconstructed = model(noisy)
                loss = criterion(reconstructed, clean)
                test_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss/len(train_loader)}, Test Loss: {test_loss/len(test_loader)}")


In [12]:
# Start training
train_model(model, train_loader, test_loader, epochs=100, lr=0.001)

100%|██████████| 88/88 [01:16<00:00,  1.15it/s]


Epoch [1/100], Train Loss: 0.03956297490830449, Test Loss: 0.03996192322423061


100%|██████████| 88/88 [01:00<00:00,  1.45it/s]


Epoch [2/100], Train Loss: 0.027369645339521496, Test Loss: 0.02728709625080228


100%|██████████| 88/88 [01:01<00:00,  1.42it/s]


Epoch [3/100], Train Loss: 0.023678678024390883, Test Loss: 0.029188112743819754


100%|██████████| 88/88 [01:01<00:00,  1.43it/s]


Epoch [4/100], Train Loss: 0.023284495861100204, Test Loss: 0.027300747848736744


100%|██████████| 88/88 [01:01<00:00,  1.43it/s]


Epoch [5/100], Train Loss: 0.02201411525972865, Test Loss: 0.026588474633172154


100%|██████████| 88/88 [01:00<00:00,  1.46it/s]


Epoch [6/100], Train Loss: 0.021266929688863456, Test Loss: 0.024316439405083656


100%|██████████| 88/88 [00:59<00:00,  1.47it/s]


Epoch [7/100], Train Loss: 0.020714619450948456, Test Loss: 0.023294739774428308


100%|██████████| 88/88 [00:59<00:00,  1.47it/s]


Epoch [8/100], Train Loss: 0.020652526650916447, Test Loss: 0.023287193849682808


100%|██████████| 88/88 [01:00<00:00,  1.46it/s]


Epoch [9/100], Train Loss: 0.020782027393579483, Test Loss: 0.02476969159518679


100%|██████████| 88/88 [00:59<00:00,  1.47it/s]


Epoch [10/100], Train Loss: 0.01998770177703012, Test Loss: 0.022542818255412083


100%|██████████| 88/88 [01:00<00:00,  1.46it/s]


Epoch [11/100], Train Loss: 0.019751833275553177, Test Loss: 0.02496666496153921


100%|██████████| 88/88 [00:59<00:00,  1.49it/s]


Epoch [12/100], Train Loss: 0.019598286736502567, Test Loss: 0.022656151986060042


100%|██████████| 88/88 [00:59<00:00,  1.49it/s]


Epoch [13/100], Train Loss: 0.019188034845600752, Test Loss: 0.022497140647222597


100%|██████████| 88/88 [00:58<00:00,  1.51it/s]


Epoch [14/100], Train Loss: 0.018839169018478555, Test Loss: 0.023266185230265062


100%|██████████| 88/88 [00:58<00:00,  1.50it/s]


Epoch [15/100], Train Loss: 0.018709992704151027, Test Loss: 0.022608997921148937


100%|██████████| 88/88 [00:57<00:00,  1.53it/s]


Epoch [16/100], Train Loss: 0.018735185307873922, Test Loss: 0.02294261734156559


100%|██████████| 88/88 [00:56<00:00,  1.54it/s]


Epoch [17/100], Train Loss: 0.018361072373491796, Test Loss: 0.022807437771310408


100%|██████████| 88/88 [00:56<00:00,  1.55it/s]


Epoch [18/100], Train Loss: 0.018803989091380077, Test Loss: 0.0216476006899029


100%|██████████| 88/88 [00:56<00:00,  1.56it/s]


Epoch [19/100], Train Loss: 0.018583242940208453, Test Loss: 0.02210529532749206


100%|██████████| 88/88 [00:57<00:00,  1.54it/s]


Epoch [20/100], Train Loss: 0.01814706283833154, Test Loss: 0.021839068504050374


100%|██████████| 88/88 [00:55<00:00,  1.58it/s]


Epoch [21/100], Train Loss: 0.018598697740923275, Test Loss: 0.021738841198384762


100%|██████████| 88/88 [00:55<00:00,  1.58it/s]


Epoch [22/100], Train Loss: 0.01813826826401055, Test Loss: 0.019897270714864135


100%|██████████| 88/88 [00:56<00:00,  1.55it/s]


Epoch [23/100], Train Loss: 0.018332867438650945, Test Loss: 0.02218797170401861


100%|██████████| 88/88 [00:56<00:00,  1.57it/s]


Epoch [24/100], Train Loss: 0.01801005678928711, Test Loss: 0.023657416478575517


100%|██████████| 88/88 [00:56<00:00,  1.56it/s]


Epoch [25/100], Train Loss: 0.017307989606210453, Test Loss: 0.02255436551058665


100%|██████████| 88/88 [00:58<00:00,  1.51it/s]


Epoch [26/100], Train Loss: 0.01763700352008031, Test Loss: 0.020460988045670092


100%|██████████| 88/88 [00:57<00:00,  1.53it/s]


Epoch [27/100], Train Loss: 0.018118651848371057, Test Loss: 0.02030041239534815


100%|██████████| 88/88 [00:56<00:00,  1.56it/s]


Epoch [28/100], Train Loss: 0.017381187581287868, Test Loss: 0.021464296073342364


100%|██████████| 88/88 [00:55<00:00,  1.58it/s]


Epoch [29/100], Train Loss: 0.017864830114624718, Test Loss: 0.020156571292318404


100%|██████████| 88/88 [00:57<00:00,  1.54it/s]


Epoch [30/100], Train Loss: 0.017563206302425402, Test Loss: 0.02149519188484798


100%|██████████| 88/88 [00:55<00:00,  1.58it/s]


Epoch [31/100], Train Loss: 0.01698417503344403, Test Loss: 0.02071305050048977


100%|██████████| 88/88 [00:56<00:00,  1.57it/s]


Epoch [32/100], Train Loss: 0.01719796198250895, Test Loss: 0.02276177133899182


100%|██████████| 88/88 [00:55<00:00,  1.59it/s]


Epoch [33/100], Train Loss: 0.017565026561814277, Test Loss: 0.02220271702390164


100%|██████████| 88/88 [00:55<00:00,  1.58it/s]


Epoch [34/100], Train Loss: 0.017391937657852064, Test Loss: 0.021002098743338138


100%|██████████| 88/88 [00:55<00:00,  1.59it/s]


Epoch [35/100], Train Loss: 0.016833311687646943, Test Loss: 0.02069986592202137


100%|██████████| 88/88 [00:58<00:00,  1.52it/s]


Epoch [36/100], Train Loss: 0.01712602845774117, Test Loss: 0.020731750254829723


100%|██████████| 88/88 [00:56<00:00,  1.55it/s]


Epoch [37/100], Train Loss: 0.01688800062137571, Test Loss: 0.02322856105941658


100%|██████████| 88/88 [00:56<00:00,  1.57it/s]


Epoch [38/100], Train Loss: 0.017007206661880693, Test Loss: 0.022048520389944315


100%|██████████| 88/88 [00:55<00:00,  1.59it/s]


Epoch [39/100], Train Loss: 0.01669989761219106, Test Loss: 0.02073410855761419


100%|██████████| 88/88 [00:55<00:00,  1.57it/s]


Epoch [40/100], Train Loss: 0.01681119128426706, Test Loss: 0.01976752238503347


100%|██████████| 88/88 [00:58<00:00,  1.52it/s]


Epoch [41/100], Train Loss: 0.016651157099245625, Test Loss: 0.021425598417408764


100%|██████████| 88/88 [00:59<00:00,  1.49it/s]


Epoch [42/100], Train Loss: 0.0165989286350933, Test Loss: 0.0239929566741921


100%|██████████| 88/88 [00:59<00:00,  1.48it/s]


Epoch [43/100], Train Loss: 0.017066187468696047, Test Loss: 0.02063224813900888


100%|██████████| 88/88 [00:59<00:00,  1.47it/s]


Epoch [44/100], Train Loss: 0.016667891113849528, Test Loss: 0.019903375961196918


100%|██████████| 88/88 [00:59<00:00,  1.48it/s]


Epoch [45/100], Train Loss: 0.016284595561129125, Test Loss: 0.021170393874247868


100%|██████████| 88/88 [00:59<00:00,  1.48it/s]


Epoch [46/100], Train Loss: 0.016102558467537165, Test Loss: 0.02154820431799938


100%|██████████| 88/88 [00:58<00:00,  1.50it/s]


Epoch [47/100], Train Loss: 0.016394910698925905, Test Loss: 0.02071932062972337


100%|██████████| 88/88 [00:58<00:00,  1.49it/s]


Epoch [48/100], Train Loss: 0.016847573114897717, Test Loss: 0.01994247210677713


100%|██████████| 88/88 [00:58<00:00,  1.50it/s]


Epoch [49/100], Train Loss: 0.016051062615588307, Test Loss: 0.019981655292212963


100%|██████████| 88/88 [00:57<00:00,  1.52it/s]


Epoch [50/100], Train Loss: 0.01620407891459763, Test Loss: 0.019856491145522643


100%|██████████| 88/88 [00:59<00:00,  1.48it/s]


Epoch [51/100], Train Loss: 0.016433994804339654, Test Loss: 0.021900242116923135


100%|██████████| 88/88 [01:00<00:00,  1.46it/s]


Epoch [52/100], Train Loss: 0.01615870208479464, Test Loss: 0.027649190626107156


100%|██████████| 88/88 [00:59<00:00,  1.47it/s]


Epoch [53/100], Train Loss: 0.016457242594862528, Test Loss: 0.019596230161065858


100%|██████████| 88/88 [00:57<00:00,  1.52it/s]


Epoch [54/100], Train Loss: 0.016109734858301552, Test Loss: 0.021581739963342745


100%|██████████| 88/88 [00:59<00:00,  1.49it/s]


Epoch [55/100], Train Loss: 0.015597341444597325, Test Loss: 0.02127932101332893


100%|██████████| 88/88 [00:58<00:00,  1.51it/s]


Epoch [56/100], Train Loss: 0.016030477693262088, Test Loss: 0.021169987992228318


100%|██████████| 88/88 [01:04<00:00,  1.36it/s]


Epoch [57/100], Train Loss: 0.015875669704242187, Test Loss: 0.022121607791632414


100%|██████████| 88/88 [01:00<00:00,  1.46it/s]


Epoch [58/100], Train Loss: 0.016057868932627818, Test Loss: 0.01922401797492057


100%|██████████| 88/88 [00:57<00:00,  1.53it/s]


Epoch [59/100], Train Loss: 0.016418840799650006, Test Loss: 0.020417111130276073


100%|██████████| 88/88 [00:57<00:00,  1.52it/s]


Epoch [60/100], Train Loss: 0.015622378580949524, Test Loss: 0.02130759177574267


100%|██████████| 88/88 [00:57<00:00,  1.52it/s]


Epoch [61/100], Train Loss: 0.015971953472630543, Test Loss: 0.0200077046174556


100%|██████████| 88/88 [00:58<00:00,  1.51it/s]


Epoch [62/100], Train Loss: 0.015846851916814394, Test Loss: 0.02074914313076685


100%|██████████| 88/88 [00:58<00:00,  1.49it/s]


Epoch [63/100], Train Loss: 0.015720657475123353, Test Loss: 0.019319056494471926


100%|██████████| 88/88 [00:58<00:00,  1.50it/s]


Epoch [64/100], Train Loss: 0.01571424774275246, Test Loss: 0.02537685465843727


100%|██████████| 88/88 [00:57<00:00,  1.53it/s]


Epoch [65/100], Train Loss: 0.015580277329056778, Test Loss: 0.022638481808826327


100%|██████████| 88/88 [00:57<00:00,  1.54it/s]


Epoch [66/100], Train Loss: 0.015675860615870493, Test Loss: 0.0212403714346389


100%|██████████| 88/88 [00:56<00:00,  1.55it/s]


Epoch [67/100], Train Loss: 0.01575519274708561, Test Loss: 0.0233716672907273


100%|██████████| 88/88 [00:57<00:00,  1.53it/s]


Epoch [68/100], Train Loss: 0.015642760903574526, Test Loss: 0.01983445797426005


100%|██████████| 88/88 [00:56<00:00,  1.56it/s]


Epoch [69/100], Train Loss: 0.01580701316495172, Test Loss: 0.022197866928763688


100%|██████████| 88/88 [00:56<00:00,  1.55it/s]


Epoch [70/100], Train Loss: 0.015649746796539563, Test Loss: 0.021972639641414087


100%|██████████| 88/88 [00:55<00:00,  1.57it/s]


Epoch [71/100], Train Loss: 0.01533682841214944, Test Loss: 0.021428313688375056


100%|██████████| 88/88 [00:55<00:00,  1.57it/s]


Epoch [72/100], Train Loss: 0.015568634368140589, Test Loss: 0.019147328566759825


100%|██████████| 88/88 [00:55<00:00,  1.59it/s]


Epoch [73/100], Train Loss: 0.015211138260466132, Test Loss: 0.020601277356036007


100%|██████████| 88/88 [00:55<00:00,  1.58it/s]


Epoch [74/100], Train Loss: 0.015380718501877378, Test Loss: 0.020542183929743867


100%|██████████| 88/88 [00:56<00:00,  1.57it/s]


Epoch [75/100], Train Loss: 0.015447062789462507, Test Loss: 0.019723078856865566


100%|██████████| 88/88 [00:56<00:00,  1.57it/s]


Epoch [76/100], Train Loss: 0.015124183349226687, Test Loss: 0.019900993017169338


100%|██████████| 88/88 [00:55<00:00,  1.58it/s]


Epoch [77/100], Train Loss: 0.015306308526884426, Test Loss: 0.021436112855250638


100%|██████████| 88/88 [00:56<00:00,  1.55it/s]


Epoch [78/100], Train Loss: 0.015435594102283094, Test Loss: 0.021591553503337007


100%|██████████| 88/88 [00:56<00:00,  1.56it/s]


Epoch [79/100], Train Loss: 0.015084975176829506, Test Loss: 0.020599810251345236


100%|██████████| 88/88 [00:55<00:00,  1.57it/s]


Epoch [80/100], Train Loss: 0.015067757398355752, Test Loss: 0.022106731582122546


100%|██████████| 88/88 [00:56<00:00,  1.56it/s]


Epoch [81/100], Train Loss: 0.014913941749413922, Test Loss: 0.020457586583991844


100%|██████████| 88/88 [00:55<00:00,  1.57it/s]


Epoch [82/100], Train Loss: 0.015194974035363306, Test Loss: 0.02088103264880677


100%|██████████| 88/88 [00:56<00:00,  1.57it/s]


Epoch [83/100], Train Loss: 0.01495309080928564, Test Loss: 0.019276451455273975


100%|██████████| 88/88 [00:55<00:00,  1.58it/s]


Epoch [84/100], Train Loss: 0.015284358661367813, Test Loss: 0.021022819234834362


100%|██████████| 88/88 [00:56<00:00,  1.55it/s]


Epoch [85/100], Train Loss: 0.014916108593090692, Test Loss: 0.020490937458816916


100%|██████████| 88/88 [00:59<00:00,  1.48it/s]


Epoch [86/100], Train Loss: 0.015072118501517583, Test Loss: 0.021111057663802058


100%|██████████| 88/88 [00:58<00:00,  1.51it/s]


Epoch [87/100], Train Loss: 0.014859267775054004, Test Loss: 0.019609239612085123


100%|██████████| 88/88 [00:56<00:00,  1.57it/s]


Epoch [88/100], Train Loss: 0.015073414479213005, Test Loss: 0.024972175325577457


100%|██████████| 88/88 [01:01<00:00,  1.42it/s]


Epoch [89/100], Train Loss: 0.015026693401688879, Test Loss: 0.01958139246562496


100%|██████████| 88/88 [00:56<00:00,  1.56it/s]


Epoch [90/100], Train Loss: 0.014863952105356888, Test Loss: 0.021863013699961204


100%|██████████| 88/88 [00:57<00:00,  1.54it/s]


Epoch [91/100], Train Loss: 0.014997934382832185, Test Loss: 0.023657252876243245


100%|██████████| 88/88 [00:57<00:00,  1.54it/s]


Epoch [92/100], Train Loss: 0.014676775158890947, Test Loss: 0.021342362587650616


100%|██████████| 88/88 [00:55<00:00,  1.58it/s]


Epoch [93/100], Train Loss: 0.014655528441918168, Test Loss: 0.02187836932716891


100%|██████████| 88/88 [00:55<00:00,  1.60it/s]


Epoch [94/100], Train Loss: 0.014746452353640714, Test Loss: 0.02019717579241842


100%|██████████| 88/88 [00:55<00:00,  1.59it/s]


Epoch [95/100], Train Loss: 0.014668978158045898, Test Loss: 0.020618005422875285


100%|██████████| 88/88 [00:57<00:00,  1.54it/s]


Epoch [96/100], Train Loss: 0.014471871937117116, Test Loss: 0.020249539520591497


100%|██████████| 88/88 [00:59<00:00,  1.49it/s]


Epoch [97/100], Train Loss: 0.014656295597722585, Test Loss: 0.022228010348044336


100%|██████████| 88/88 [00:58<00:00,  1.50it/s]


Epoch [98/100], Train Loss: 0.014466086619491265, Test Loss: 0.01949530156950156


100%|██████████| 88/88 [00:58<00:00,  1.51it/s]


Epoch [99/100], Train Loss: 0.014296139897355302, Test Loss: 0.020464008674025536


100%|██████████| 88/88 [00:58<00:00,  1.51it/s]


Epoch [100/100], Train Loss: 0.01480963271619244, Test Loss: 0.01877915777731687


In [13]:
import os
from PIL import Image

In [14]:
def evaluate_model(model, test_loader, output_dir="output_images"):
    model.eval()

    # Ensure the output directory exists
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Initialize lists to store results
    psnr_list = []
    ssim_list = []
    mse_list = []

    with torch.no_grad():
        for batch_idx, (noisy, clean) in enumerate(test_loader):
            noisy, clean = noisy.to(device), clean.to(device)
            reconstructed = model(noisy)
            
            # Iterate through the batch to process each image individually
            for i in range(reconstructed.size(0)):  # Loop over batch size
                output_np = reconstructed[i].cpu().numpy().transpose(1, 2, 0)
                clean_np = clean[i].cpu().numpy().transpose(1, 2, 0)

                # Save reconstructed image
                reconstructed_img = (output_np * 255).astype(np.uint8)  # Convert to uint8 for saving
                img = Image.fromarray(reconstructed_img)
                img_save_path = os.path.join(output_dir, f"enhanced_{batch_idx}_{i}.png")
                img.save(img_save_path)
                
                # Calculate PSNR
                psnr = peak_signal_noise_ratio(clean_np, output_np, data_range=1.0)
                
                # Calculate SSIM with explicit data_range and win_size
                ssim = structural_similarity(clean_np, output_np, multichannel=True, win_size=3, data_range=1.0)
                
                # Calculate MSE
                mse = mean_squared_error(clean_np, output_np)

                # Append the results
                psnr_list.append(psnr)
                ssim_list.append(ssim)
                mse_list.append(mse)

        print(f'Average PSNR: {sum(psnr_list)/len(psnr_list):.4f}')
        print(f'Average SSIM: {sum(ssim_list)/len(ssim_list):.4f}')
        print(f'Average MSE: {sum(mse_list)/len(mse_list):.4f}')


In [15]:
# Evaluate the model
evaluate_model(model, test_loader)

Average PSNR: 18.7001
Average SSIM: 0.6258
Average MSE: 0.0187
