In [None]:
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image

dataset_directory = '/home/maith/Desktop/cityscapes'
train_images_dir = os.path.join(dataset_directory, 'leftImg8bit/train/')
train_labels_dir = os.path.join(dataset_directory, 'gtFine/train/')
val_images_dir = os.path.join(dataset_directory, 'leftImg8bit/val/')
val_labels_dir = os.path.join(dataset_directory, 'gtFine/val/')

transform = transforms.Compose([
    transforms.Resize((256, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

from torchvision.transforms import functional as TF

class CityscapesDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None, target_transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.target_transform = target_transform
        self.image_paths = []
        self.label_paths = []

        for city in sorted(os.listdir(image_dir)):
            city_images_dir = os.path.join(image_dir, city)
            city_labels_dir = os.path.join(label_dir, city)
            for file_name in sorted(os.listdir(city_images_dir)):
                if 'leftImg8bit' in file_name:
                    self.image_paths.append(os.path.join(city_images_dir, file_name))
                    label_name = file_name.replace('leftImg8bit', 'gtFine_labelIds')
                    self.label_paths.append(os.path.join(city_labels_dir, label_name))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        label_path = self.label_paths[index]
        image = Image.open(img_path).convert('RGB')
        label = Image.open(label_path).convert('L')

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return image, label

image_transform = transforms.Compose([
    transforms.Resize((256, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

label_transform = transforms.Compose([
    transforms.Resize((256, 512), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.ToTensor()
])

train_dataset = CityscapesDataset(train_images_dir, train_labels_dir, transform=image_transform, target_transform=label_transform)
val_dataset = CityscapesDataset(val_images_dir, val_labels_dir, transform=image_transform, target_transform=label_transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

In [None]:
def explore_dataset(loader, num_batches=1):
    unique_labels = set()
    for i, (images, labels) in enumerate(loader):
        if i >= num_batches:  # Only process a limited number of batches
            break
        print("Shape of the images:", images.shape)  # Shape of images
        labels = labels.squeeze(1)  # Remove channel dim if it exists
        unique_labels.update(torch.unique(labels).numpy().tolist())  # Update unique labels

    print("Unique label values:", unique_labels)
    print("Number of classes:", len(unique_labels))

# Explore the dataset
print("Exploring training data:")
explore_dataset(train_loader, num_batches=3)  # Adjust num_batches as needed for balance between speed and thoroughness