In [1]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install albumentations
!pip install opencv-python-headless matplotlib
!pip install cityscapesscripts

!pip install tqdm

In [2]:
from google.colab import drive
drive.mount('/content/drive')

In [3]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torchvision import datasets, utils, transforms
from PIL import Image
from tqdm import tqdm

import torch.optim as optim
import torch.nn.functional as F
from cityscapesscripts.helpers.labels import trainId2label as t2l
import matplotlib.pyplot as plt
import random


if torch.cuda.is_available():
    DEVICE = 'cuda:0'
    print('Running on the GPU')
else:
    DEVICE = 'cpu'
    print('Running on the CPU')

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


In [4]:
class UNET(nn.Module):

    def __init__(self, in_channels=3, classes=19):
        super(UNET, self).__init__()
        self.layers = [in_channels, 64, 128, 256, 512, 1024]

        self.double_conv_downs = nn.ModuleList(
            [self.__double_conv(layer, layer_n) for layer, layer_n in zip(self.layers[:-1], self.layers[1:])])

        self.up_trans = nn.ModuleList(
            [nn.ConvTranspose2d(layer, layer_n, kernel_size=2, stride=2)
             for layer, layer_n in zip(self.layers[::-1][:-2], self.layers[::-1][1:-1])])

        self.double_conv_ups = nn.ModuleList(
        [self.__double_conv(layer, layer//2) for layer in self.layers[::-1][:-2]])

        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.final_conv = nn.Conv2d(64, classes, kernel_size=1)


    def __double_conv(self, in_channels, out_channels):
        conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        return conv

    def forward(self, x):
        concat_layers = []

        for down in self.double_conv_downs:
            x = down(x)
            if down != self.double_conv_downs[-1]:
                concat_layers.append(x)
                x = self.max_pool_2x2(x)

        concat_layers = concat_layers[::-1]

        for up_trans, double_conv_up, concat_layer  in zip(self.up_trans, self.double_conv_ups, concat_layers):
            x = up_trans(x)
            if x.shape != concat_layer.shape:
                x = TF.resize(x, concat_layer.shape[2:])

            concatenated = torch.cat((concat_layer, x), dim=1)
            x = double_conv_up(concatenated)

        x = self.final_conv(x)

        return x

In [5]:
class CityscapesDataset(Dataset):
    def __init__(self, split, root_dir, target_type='semantic', mode='fine', transform=None, eval=False):
        self.transform = transform
        self.split = split
        self.eval = eval

        if mode == 'fine':
            self.mode = 'gtFine'
        elif mode == 'coarse':
            self.mode = 'gtCoarse'

        self.label_path = os.path.join(root_dir, 'gtFine_trainvaltest', self.mode, self.split)
        self.rgb_path = os.path.join(root_dir, 'leftImg8bit_trainvaltest', 'leftImg8bit', self.split)

        self.XImg_list = []
        self.yLabel_list = []

        city_list = os.listdir(self.rgb_path)

        for city in city_list:
            rgb_city_path = os.path.join(self.rgb_path, city)
            rgb_images = os.listdir(rgb_city_path)

            for img in rgb_images:
                if img.endswith('_leftImg8bit.png'):

                    self.XImg_list.append(os.path.join(city, img))


                    label_img = img.replace('leftImg8bit', 'gtFine').replace('.png', '_labelIds.png')
                    label_img_path = os.path.join(city, label_img)


                    if os.path.exists(os.path.join(self.label_path, label_img_path)):
                        self.yLabel_list.append(label_img_path)
                    else:
                        print(f"Warning: Missing label for image {img}")

        print(f"Number of images: {len(self.XImg_list)}, Number of labels: {len(self.yLabel_list)}")
        print(f"First 5 images: {self.XImg_list[:5]}")
        print(f"First 5 labels: {self.yLabel_list[:5]}")

        self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
        self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]

        self.class_mapping = {original: new for new, original in enumerate(self.valid_classes)}


    def __len__(self):
        return len(self.XImg_list)

    def __getitem__(self, index):
        max_attempts = 10
        attempts = 0
        image = None
        y = None

        while attempts < max_attempts:
            try:
                image = Image.open(os.path.join(self.rgb_path, self.XImg_list[index]))
                break
            except Exception as e:
                print(f"Attempt {attempts + 1} failed to load image {self.XImg_list[index]}: {e}")
                attempts += 1

        if image is None:
            print(f"Skipping image {self.XImg_list[index]} after {max_attempts} attempts.")
            return None

        attempts = 0

        while attempts < max_attempts:
            try:
                y = Image.open(os.path.join(self.label_path, self.yLabel_list[index]))
                break
            except Exception as e:
                print(f"Attempt {attempts + 1} failed to load label {self.yLabel_list[index]}: {e}")
                attempts += 1

        if y is None:
            print(f"Skipping label {self.yLabel_list[index]} after {max_attempts} attempts.")
            return None

        image = transforms.ToTensor()(image)

        y = np.array(y)
        y_remapped = np.copy(y)

        for void_class in self.void_classes:
            y_remapped[y == void_class] = 255

        for original_class, new_class in self.class_mapping.items():
            y_remapped[y == original_class] = new_class

        y = torch.from_numpy(y_remapped).type(torch.LongTensor)

        if self.transform is not None:
            image = self.transform(image)

        if self.eval:
            return image, y, self.XImg_list[index]
        else:
            return image, y

In [6]:
def get_cityscapes_data(
    mode,
    split,
    relabelled,
    root_dir='/content/drive/MyDrive/Cityscapes/',
    target_type="semantic",
    transforms=None,
    batch_size=1,
    eval=False,
    shuffle=True,
    pin_memory=True,

):
    data = CityscapesDataset(
        mode=mode, split=split, target_type=target_type, transform=transforms, root_dir=root_dir, eval=eval)

    data_loaded = torch.utils.data.DataLoader(
        data, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)

    return data_loaded

def save_as_images(tensor_pred, folder, image_name):
    tensor_pred = transforms.ToPILImage()(tensor_pred.byte())
    filename = f"{folder}\{image_name}.png"
    tensor_pred.save(filename)

In [7]:
LOAD_MODEL = True
ROOT_DIR = '/content/drive/MyDrive/Cityscapes/'
MODEL_DIR = os.path.join(ROOT_DIR, 'MODEL/')
IMG_HEIGHT = 1024
IMG_WIDTH = 2048
BATCH_SIZE = 1
LEARNING_RATE = 0.001
EPOCHS = 20

os.makedirs(MODEL_DIR, exist_ok=True)

def train_function(data, model, optimizer, loss_fn, device):
    print('Entering into train function')
    loss_values = []
    correct = 0
    total = 0
    data = tqdm(data)

    for index, batch in enumerate(data):
        X, y = batch
        X, y = X.to(device), y.to(device)
        preds = model(X)

        preds = preds.permute(0, 2, 3, 1)
        preds = preds.reshape(-1, preds.shape[-1])
        y = y.reshape(-1)

        loss = loss_fn(preds, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(preds, 1)

        total += y.size(0)
        correct += (predicted == y).sum().item()

        loss_values.append(loss.item())

    avg_loss = sum(loss_values) / len(loss_values)
    accuracy = correct / total * 100
    return avg_loss, accuracy

def main():
    global epoch
    epoch = 15

    if LOAD_MODEL:
        if not epoch:
            MODEL_NAME = f'model.pth'
            MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)
        else:
            MODEL_NAME = f'model_epoch{epoch-1}.pth'
            MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)

    LOSS_VALS = []
    ACCURACY_VALS = []

    transform = transforms.Compose([
        transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=Image.NEAREST),
    ])

    train_set = get_cityscapes_data(
        split='train',
        mode='fine',
        relabelled=True,
        root_dir=ROOT_DIR,
        transforms=transform,
        batch_size=BATCH_SIZE,
    )

    print('Data Loaded Successfully!')

    unet = UNET(in_channels=3, classes=19).to(DEVICE).train()
    optimizer = optim.Adam(unet.parameters(), lr=LEARNING_RATE)
    loss_function = nn.CrossEntropyLoss(ignore_index=255)

    if LOAD_MODEL:
        print(f'Loading Model {MODEL_PATH}...')
        checkpoint = torch.load(MODEL_PATH)
        unet.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optim_state_dict'])
        epoch = checkpoint['epoch'] + 1
        LOSS_VALS = checkpoint['loss_values']
        print("Model successfully loaded!")

    for e in range(epoch, EPOCHS):
        print(f'Epoch: {e}')
        avg_loss, accuracy = train_function(train_set, unet, optimizer, loss_function, DEVICE)
        LOSS_VALS.append(avg_loss)
        ACCURACY_VALS.append(accuracy)
        MODEL_NAME = f'model_epoch{e}.pth'
        MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)

        torch.save({
            'model_state_dict': unet.state_dict(),
            'optim_state_dict': optimizer.state_dict(),
            'epoch': e,
            'loss_values': LOSS_VALS,
            'accuracy_values': ACCURACY_VALS
        }, MODEL_PATH)
        print(f"Epoch {e} completed! Avg Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
        print("Model successfully saved!")

if __name__ == '__main__':
    main()


In [8]:
ROOT_DIR_CITYSCAPES = '/content/drive/MyDrive/Cityscapes/'
IMAGE_HEIGHT = 1024
IMAGE_WIDTH = 2048

MODEL_PATH = ROOT_DIR_CITYSCAPES + "MODEL/model_epoch14.pth"

EVAL = True
PLOT_LOSS = True

def map_labels_to_colors(label):

    if isinstance(label, Image.Image):
        label = np.array(label)


    color_mapped = np.zeros((label.shape[0], label.shape[1], 3), dtype=np.uint8)


    for trainId, label_info in t2l.items():
        color_mapped[label == trainId] = label_info.color

    return color_mapped

def plot_images(original_image, true_label, predicted_label, name):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    axes[0].imshow(original_image)
    axes[0].set_title('Original Image')


    axes[1].imshow(map_labels_to_colors(true_label))
    axes[1].set_title('Ground Truth')

    axes[2].imshow(map_labels_to_colors(predicted_label))
    axes[2].set_title('Prediction')

    plt.tight_layout()
    plt.savefig(f'saved_images/multiclass_1/{name}_comparison.png')
    plt.show()

def save_predictions(data, model):
    model.eval()
    count = 0
    total_pixels = 0
    correct_pixels = 0
    all_results = []

    with torch.no_grad():
        for idx, batch in enumerate(tqdm(data)):
            if count == 5:
                break
            count += 1
            X, y, s = batch
            X, y = X.to(DEVICE), y.to(DEVICE)
            predictions = model(X)

            predictions = torch.nn.functional.softmax(predictions, dim=1)
            pred_labels = torch.argmax(predictions, dim=1)

            pred_labels = pred_labels.squeeze().to('cpu').numpy()
            y = y.squeeze().to('cpu').numpy()

            unique_true_labels = np.unique(y)
            unique_pred_labels = np.unique(pred_labels)

            total_pixels += y.size
            correct_pixels += np.sum(pred_labels == y)

            pred_labels = np.vectorize(lambda x: t2l[x].id)(pred_labels)

            pred_labels_resized = transforms.Resize((1024, 2048))(Image.fromarray(pred_labels.astype(np.uint8)))

            original_image_resized = transforms.Resize((1024, 2048))(transforms.ToPILImage()(X.squeeze().cpu()))
            true_label_resized = transforms.Resize((1024, 2048))(transforms.ToPILImage()(y.squeeze().astype(np.uint8)))

            s = str(s)
            pos = s.rfind('/', 0, len(s))
            name = s[pos+1:-18]
            location = 'saved_images/multiclass_1'
            os.makedirs(location, exist_ok=True)

            pred_image_path = os.path.join(location, f"{name}_prediction.png")
            pred_labels_resized.save(pred_image_path)

            all_results.append((original_image_resized, true_label_resized, pred_labels_resized, s))

    accuracy = correct_pixels / total_pixels
    print(f"Pixel-wise Accuracy: {accuracy:.4f}")

    random.shuffle(all_results)
    for i in range(5):
        original_image, true_label, pred_label, s = all_results[i]
        s = str(s)
        pos = s.rfind('/', 0, len(s))
        name = s[pos+1:-18]

        plot_images(original_image, true_label, pred_label, name)

def evaluate(path):
    T = transforms.Compose([
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH), interpolation=Image.NEAREST)
    ])

    test_set = get_cityscapes_data(
        root_dir=ROOT_DIR_CITYSCAPES,
        split='train',
        mode='fine',
        relabelled=True,
        transforms=T,
        shuffle=True,
        eval=True
    )

    print('Data has been loaded!')

    net = UNET(in_channels=3, classes=19).to(DEVICE)
    checkpoint = torch.load(path)
    net.load_state_dict(checkpoint['model_state_dict'])
    net.eval()
    print(f'{path} has been loaded and initialized')
    save_predictions(test_set, net)

def plot_metrics(path):
    checkpoint = torch.load(path)
    losses = checkpoint['loss_values']

    accuracies = checkpoint['accuracy_values']
    epoch = checkpoint['epoch']

    epoch_list = list(range(len(losses)))


    if len(epoch_list) != len(losses):
        raise ValueError(f"Epochs ({len(epoch_list)}) and Loss values ({len(losses)}) must have the same length.")

    if len(epoch_list) != len(accuracies):
        raise ValueError(f"Epochs ({len(epoch_list)}) and Accuracy values ({len(accuracies)}) must have the same length.")

    fig, ax1 = plt.subplots()

    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color='tab:red')
    ax1.plot(epoch_list, losses, color='tab:red', label='Loss')
    ax1.tick_params(axis='y', labelcolor='tab:red')

    ax2 = ax1.twinx()
    ax2.set_ylabel('Accuracy', color='tab:blue')
    ax2.plot(epoch_list, accuracies, color='tab:blue', label='Accuracy')
    ax2.tick_params(axis='y', labelcolor='tab:blue')

    plt.title(f"Loss and Accuracy over {epoch+1} epoch/s")
    fig.tight_layout()
    plt.show()

if __name__ == '__main__':
    if EVAL:
        evaluate(MODEL_PATH)
    if PLOT_LOSS:
        plot_metrics(MODEL_PATH)

In [9]:
from google.colab import runtime
runtime.unassign()