In [2]:
import math
import numpy as np
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import structural_similarity as ssim

from PIL import Image
import glob

import matplotlib.pyplot as plt

In [5]:
import os 
from torch.utils.data import Dataset
import torchvision.transforms as transforms

In [14]:
class FoggyCityscape(Dataset):
    def __init__(self, data_dir, transform = None):
        self.data_dir = data_dir
        self.transform = transform
        self.haze_dir = os.path.join(data_dir, "hazy")
        self.clean_dir = os.path.join(data_dir, "clean")

        # Sort for consistency
        self.haze_images = sorted(os.listdir(self.haze_dir))  
        self.clean_images = sorted(os.listdir(self.clean_dir))  

        # Error handling if number of haze and clear images don't match
        if len(self.haze_images) != len(self.clean_images):
            raise ValueError("Number of haze images does not match number of clear images.")

    def __len__(self):
        return len(self.haze_images)

    def __getitem__(self, idx):
        haze_image_name = self.haze_images[idx]
        clean_image_name = self.clean_images[idx]
        haze_image_path = os.path.join(self.haze_dir, haze_image_name)
        clean_image_path = os.path.join(self.clean_dir, clean_image_name)

        haze_image = Image.open(haze_image_path)
        clean_image = Image.open(clean_image_path)

        if self.transform:
            haze_image = self.transform(haze_image)
            clean_image = self.transform(clean_image)
            
        return haze_image, clean_image

In [15]:
os.getcwd()

'/root/dehaze-involution/models'

In [16]:
# Define transformation (if any)
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])
data_dir = "../dataset/foggy_cityscape/train"
dataset = FoggyCityscape(data_dir, transform)

In [17]:
# Access an item from the dataset
index = 0  # Change this to test different images
haze_image, clear_image= dataset[index]

# Visualize the images
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(haze_image.permute(1, 2, 0))  # Permute dimensions for visualization
plt.title("Haze Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(clear_image.permute(1, 2, 0))  # Permute dimensions for visualization
plt.title("Clear Image")
plt.axis("off")

plt.show()

# Optionally, print image paths
print("Haze Image Path:")
print("Clear Image Path:")


UnboundLocalError: local variable 'clean_image_path' referenced before assignment