In [33]:
%matplotlib inline
import matplotlib.image as mpimg
import numpy as np
import matplotlib.pyplot as plt
import os, sys
from PIL import Image
import torch

In [34]:
# Helper functions

def load_image(infilename):
    data = mpimg.imread(infilename)
    return data


def img_float_to_uint8(img):
    rimg = img - np.min(img)
    rimg = (rimg / np.max(rimg) * 255).round().astype(np.uint8)
    return rimg


# Concatenate an image and its groundtruth
def concatenate_images(img, gt_img):
    nChannels = len(gt_img.shape)
    w = gt_img.shape[0]
    h = gt_img.shape[1]
    if nChannels == 3:
        cimg = np.concatenate((img, gt_img), axis=1)
    else:
        gt_img_3c = np.zeros((w, h, 3), dtype=np.uint8)
        gt_img8 = img_float_to_uint8(gt_img)
        gt_img_3c[:, :, 0] = gt_img8
        gt_img_3c[:, :, 1] = gt_img8
        gt_img_3c[:, :, 2] = gt_img8
        img8 = img_float_to_uint8(img)
        cimg = np.concatenate((img8, gt_img_3c), axis=1)
    return cimg

In [35]:
# print("Image size = " + str(imgs[0].shape[0]) + "," + str(imgs[0].shape[1]))

# # Show first image and its groundtruth image
# cimg = concatenate_images(imgs[0], gt_imgs[0])
# fig1 = plt.figure(figsize=(10, 10))
# plt.imshow(cimg, cmap="Greys_r")

In [36]:
NUM_CHANNELS = 3  # RGB images
PIXEL_DEPTH = 255
NUM_LABELS = 2
TRAINING_SIZE = 20
VALIDATION_SIZE = 5  # Size of the validation set.
SEED = 66478  # Set to None for random seed.
BATCH_SIZE = 16  # 64
NUM_EPOCHS = 100
RESTORE_MODEL = False  # If True, restore existing model instead of training a new one
RECORDING_STEP = 0
IMG_PATCH_SIZE = 16

out_channels = 64  # Number of filters in the first convolutional layer
kernel_size = 3  # Size of the convolutional kernel
input_size = 16  # Define input size based on your data
hidden_size = 256  # Size of the hidden layer
output_size = 2  # Number of output classes

In [56]:
def extract_data(filename, num_images):
    """Extract the images into a 4D tensor [image index, y, x, channels].
    Values are rescaled from [0, 255] down to [-0.5, 0.5].
    """
    imgs = []
    for i in range(1, num_images + 1):
        imageid = "satImage_%.3d" % i
        image_filename = filename + imageid + ".png"
        if os.path.isfile(image_filename):
            print("Loading " + image_filename)
            img = mpimg.imread(image_filename)
            imgs.append(img)
        else:
            print("File " + image_filename + " does not exist")

    num_images = len(imgs)
    IMG_WIDTH = imgs[0].shape[0]
    IMG_HEIGHT = imgs[0].shape[1]
    N_PATCHES_PER_IMAGE = (IMG_WIDTH / IMG_PATCH_SIZE) * (IMG_HEIGHT / IMG_PATCH_SIZE)

    img_patches = [
        img_crop(imgs[i], IMG_PATCH_SIZE, IMG_PATCH_SIZE) for i in range(num_images)
    ]
    data = [
        img_patches[i][j]
        for i in range(len(img_patches))
        for j in range(len(img_patches[i]))
    ]

    return np.asarray(data)

def extract_labels(filename, num_images):
    """Extract the labels into a 1-hot matrix [image index, label index]."""
    gt_imgs = []
    for i in range(1, num_images + 1):
        imageid = "satImage_%.3d" % i
        image_filename = filename + imageid + ".png"
        if os.path.isfile(image_filename):
            print("Loading " + image_filename)
            img = mpimg.imread(image_filename)
            gt_imgs.append(img)
        else:
            print("File " + image_filename + " does not exist")

    num_images = len(gt_imgs)
    gt_patches = [
        img_crop(gt_imgs[i], IMG_PATCH_SIZE, IMG_PATCH_SIZE) for i in range(num_images)
    ]
    data = np.asarray(
        [
            gt_patches[i][j]
            for i in range(len(gt_patches))
            for j in range(len(gt_patches[i]))
        ]
    )
    labels = np.asarray(
        [value_to_class(np.mean(data[i])) for i in range(len(data))]
    )

    # Convert to dense 1-hot representation.
    return labels.astype(np.float32)

In [57]:
from torch.utils.data import Dataset

class RoadSegmentationDataset(Dataset):
    
    SPLITS = {
      'train': list(range(0, 60)),
      'val':   list(range(61, 70)),
      'test':  list(range(71, 100))
    }
    
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.image_filenames = [f for f in os.listdir(os.path.join(data_dir, "images")) if f.endswith('.png')]
        self.label_filenames = [f for f in os.listdir(os.path.join(data_dir, "groundtruth")) if f.endswith('.png')]

        
    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.data_dir, "images", self.image_filenames[idx])
        label_path = os.path.join(self.data_dir, "groundtruth", self.label_filenames[idx])

        train_data = extract_data(img_path, TRAINING_SIZE)
        train_labels = extract_labels(label_path, TRAINING_SIZE)

        # Convert to PyTorch tensors
        img_tensor = torch.from_numpy(img).permute(2, 0, 1).float()
        label_tensor = torch.from_numpy(label).float()

        return img_tensor, label_tensor

In [58]:
# Adjust number of samples
data_dir = "../data/training/"
train_data_filename = data_dir + "images/"
train_labels_filename = data_dir + "groundtruth/"

train_data = extract_data(train_data_filename, TRAINING_SIZE)
train_labels = extract_labels(train_labels_filename, TRAINING_SIZE)

num_epochs = NUM_EPOCHS

c0 = np.sum(train_labels[:, 0] == 1)
c1 = np.sum(train_labels[:, 1] == 1)

print(f"Number of data points per class: c0 = {c0} c1 = {c1}")

min_c = min(c0, c1)
idx0 = np.where(train_labels[:, 0] == 1)[0][:min_c]
idx1 = np.where(train_labels[:, 1] == 1)[0][:min_c]

new_indices = np.concatenate((idx0, idx1))
print(len(new_indices))

train_data = train_data[new_indices]
train_labels = train_labels[new_indices]

train_size = train_labels.shape[0]

c0 = np.sum(train_labels[:, 0] == 1)
c1 = np.sum(train_labels[:, 1] == 1)
print(f"Number of data points per class after balancing: c0 = {c0} c1 = {c1}")

Loading ../data/training/images/satImage_001.png
Loading ../data/training/images/satImage_002.png
Loading ../data/training/images/satImage_003.png
Loading ../data/training/images/satImage_004.png
Loading ../data/training/images/satImage_005.png
Loading ../data/training/images/satImage_006.png
Loading ../data/training/images/satImage_007.png
Loading ../data/training/images/satImage_008.png
Loading ../data/training/images/satImage_009.png
Loading ../data/training/images/satImage_010.png
Loading ../data/training/images/satImage_011.png
Loading ../data/training/images/satImage_012.png
Loading ../data/training/images/satImage_013.png
Loading ../data/training/images/satImage_014.png
Loading ../data/training/images/satImage_015.png
Loading ../data/training/images/satImage_016.png
Loading ../data/training/images/satImage_017.png
Loading ../data/training/images/satImage_018.png
Loading ../data/training/images/satImage_019.png
Loading ../data/training/images/satImage_020.png
Loading ../data/trai

In [59]:
dataset = RoadSegmentationDataset(data_dir)

# Test the dataset initialization
sample_img, sample_label = dataset[0]
print(f"Sample Image Shape: {sample_img.shape}")
print(f"Sample Label Shape: {sample_label.shape}")

File ../data/training/images\satImage_001.pngsatImage_001.png does not exist
File ../data/training/images\satImage_001.pngsatImage_002.png does not exist
File ../data/training/images\satImage_001.pngsatImage_003.png does not exist
File ../data/training/images\satImage_001.pngsatImage_004.png does not exist
File ../data/training/images\satImage_001.pngsatImage_005.png does not exist
File ../data/training/images\satImage_001.pngsatImage_006.png does not exist
File ../data/training/images\satImage_001.pngsatImage_007.png does not exist
File ../data/training/images\satImage_001.pngsatImage_008.png does not exist
File ../data/training/images\satImage_001.pngsatImage_009.png does not exist
File ../data/training/images\satImage_001.pngsatImage_010.png does not exist
File ../data/training/images\satImage_001.pngsatImage_011.png does not exist
File ../data/training/images\satImage_001.pngsatImage_012.png does not exist
File ../data/training/images\satImage_001.pngsatImage_013.png does not exist

IndexError: list index out of range