In [None]:
from cityscapesscripts.download import downloader
session = downloader.login()
downloader.get_available_packages(session=session)
package_list =['gtFine_trainvaltest.zip', 'leftImg8bit_trainvaltest.zip']
downloader.download_packages(session=session, package_names=package_list, destination_path='.')


In [None]:
zip_path1 = '/content/gtFine_trainvaltest.zip'
extract_path1 = '/content'

In [None]:
zip_path2 = '/content/leftImg8bit_trainvaltest.zip'
extract_path2 = '/content'

In [None]:
import zipfile
with zipfile.ZipFile(zip_path1, 'r') as zip_ref:
    zip_ref.extractall(extract_path1)
with zipfile.ZipFile(zip_path2, 'r') as zip_ref:
    zip_ref.extractall(extract_path2)

In [None]:
# Paths to the extracted directories
gt_fine_dir = '/content/gtFine'
left_img_dir = '/content/leftImg8bit'

In [None]:
import os
import json
import numpy as np
from PIL import Image

def extract_unique_colors(gt_fine_dir, split='train'):
    unique_colors = {}
    labels_dir = os.path.join(gt_fine_dir, split)
    cities = os.listdir(labels_dir)

    for city in cities:
        city_dir = os.path.join(labels_dir, city)
        label_files = [f for f in os.listdir(city_dir) if f.endswith('_gtFine_color.png')]

        for file in label_files:
            label_path = os.path.join(city_dir, file)
            label_image = Image.open(label_path).convert('RGB')
            label_array = np.array(label_image)
            colors = np.unique(label_array.reshape(-1, label_array.shape[2]), axis=0)
            for color in colors:
                unique_colors[tuple(color)] = len(unique_colors)
            print('file '+ label_path+' extracted')

    # Cache the class2color and color2class mappings
    with open('class2color.json', 'w') as f:
        json.dump({str(k): v for k, v in unique_colors.items()}, f)

    return unique_colors

def load_class2color_mapping():
    with open('class2color.json', 'r') as f:
        class2color = json.load(f)
    color2class = {tuple(map(int, k.strip("()").split(','))): v for k, v in class2color.items()}
    return color2class

def color_to_class(label, color2class):
    label_class = np.zeros((label.shape[0], label.shape[1]), dtype=np.int64)
    for color, class_idx in color2class.items():
        mask = np.all(label == np.array(color), axis=-1)
        label_class[mask] = class_idx
    return label_class


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CityscapesFineDataset(Dataset):
    def __init__(self, gt_fine_dir, left_img_dir, class2color, split='train', image_transform=None, label_transform=None):
        self.split = split
        self.class2color = class2color
        self.color2class = {v: k for k, v in class2color.items()}
        self.image_transform = image_transform
        self.label_transform = label_transform
        self.images_dir = os.path.join(left_img_dir, split)
        self.labels_dir = os.path.join(gt_fine_dir, split)
        self.cities = os.listdir(self.images_dir)
        self.files = []

        for city in self.cities:
            img_dir = os.path.join(self.images_dir, city)
            label_dir = os.path.join(self.labels_dir, city)
            img_files = os.listdir(img_dir)

            for file in img_files:
                if file.endswith('_leftImg8bit.png'):
                    img_path = os.path.join(img_dir, file)
                    label_file = file.replace('_leftImg8bit.png', '_gtFine_color.png')
                    label_path = os.path.join(label_dir, label_file)
                    self.files.append((img_path, label_path))

        print(f"Found {len(self.files)} image-label pairs in the {split} split.")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        img_path, label_path = self.files[index]
        image = Image.open(img_path).convert('RGB')
        label = Image.open(label_path).convert('RGB')  # Load the color annotations

        if self.image_transform:
            image = self.image_transform(image)

        if self.label_transform:
            label = self.label_transform(label)

        label = torch.from_numpy(color_to_class(np.array(label), self.color2class)).long()  # Convert label to LongTensor

        return image, label

    @staticmethod
    def collate_fn(batch):
        images, labels = zip(*batch)
        images = torch.stack(images, dim=0)
        labels = torch.stack(labels, dim=0)
        return images, labels


In [None]:
def load_data(gt_fine_dir, left_img_dir, class2color, split, batch_size=8, num_workers=4):
    image_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    label_transform = transforms.Compose([
        transforms.Resize((256, 256), interpolation=Image.NEAREST)
    ])

    dataset = CityscapesFineDataset(
        gt_fine_dir=gt_fine_dir,
        left_img_dir=left_img_dir,
        class2color=class2color,
        split=split,
        image_transform=image_transform,
        label_transform=label_transform
    )

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=CityscapesFineDataset.collate_fn)
    return dataloader

# Check if the class2color.json file exists
if not os.path.exists('class2color.json'):
    print("Extracting unique colors from the dataset and creating class2color.json...")
    extract_unique_colors(gt_fine_dir)

# Load the class-to-color mapping
class2color = load_class2color_mapping()

# Load the data
dataloader = load_data(gt_fine_dir, left_img_dir, class2color, split='train', batch_size=8, num_workers=4)


In [None]:
import os
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CityscapesFineDataset(Dataset):
    def __init__(self, gt_fine_dir, left_img_dir, class2color, split='train', image_transform=None, label_transform=None):
        self.split = split
        self.class2color = class2color
        self.color2class = {v: k for k, v in class2color.items()}
        self.image_transform = image_transform
        self.label_transform = label_transform
        self.images_dir = os.path.join(left_img_dir, split)
        self.labels_dir = os.path.join(gt_fine_dir, split)
        self.cities = os.listdir(self.images_dir)
        self.files = []

        for city in self.cities:
            img_dir = os.path.join(self.images_dir, city)
            label_dir = os.path.join(self.labels_dir, city)
            img_files = os.listdir(img_dir)

            for file in img_files:
                if file.endswith('_leftImg8bit.png'):
                    img_path = os.path.join(img_dir, file)
                    label_file = file.replace('_leftImg8bit.png', '_gtFine_color.png')
                    label_path = os.path.join(label_dir, label_file)
                    self.files.append((img_path, label_path))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        img_path, label_path = self.files[index]
        image = Image.open(img_path).convert('RGB')
        label = Image.open(label_path).convert('RGB')  # Load the color annotations

        if self.image_transform:
            image = self.image_transform(image)

        if self.label_transform:
            label = self.label_transform(label)

        label = torch.from_numpy(color_to_class(np.array(label), self.color2class)).long()  # Convert label to LongTensor

        return image, label

image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

label_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=Image.NEAREST)
])

# Create dataset and dataloader instances
dataset = CityscapesFineDataset(
    gt_fine_dir=gt_fine_dir,
    left_img_dir=left_img_dir,
    class2color=class2color,
    split='train',
    image_transform=image_transform,
    label_transform=label_transform
)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.upconv3 = self.upconv_block(512, 256)
        self.dec3 = self.conv_block(512, 256)

        self.upconv2 = self.upconv_block(256, 128)
        self.dec2 = self.conv_block(256, 128)

        self.upconv1 = self.upconv_block(128, 64)
        self.dec1 = self.conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        dec3 = self.crop_and_concat(enc3, self.upconv3(enc4))
        dec3 = self.dec3(dec3)

        dec2 = self.crop_and_concat(enc2, self.upconv2(dec3))
        dec2 = self.dec2(dec2)

        dec1 = self.crop_and_concat(enc1, self.upconv1(dec2))
        dec1 = self.dec1(dec1)

        out = self.final_conv(dec1)

        return out

    def crop_and_concat(self, enc, dec):
        enc_cropped = self.center_crop(enc, dec.size()[2], dec.size()[3])
        return torch.cat([enc_cropped, dec], dim=1)

    def center_crop(self, layer, max_height, max_width):
        _, _, h, w = layer.size()
        xy1 = (h - max_height) // 2
        xy2 = (w - max_width) // 2
        return layer[:, :, xy1:xy1 + max_height, xy2:xy2 + max_width]

model = UNet(in_channels=3, out_channels=34)  # Assuming 34 classes for Cityscapes


In [None]:
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)

        # Ensure outputs have correct shape for CrossEntropyLoss
        outputs = outputs.permute(0, 2, 3, 1).contiguous()
        outputs = outputs.view(-1, outputs.shape[-1])
        labels = labels.view(-1)

        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")

print("Training complete.")
