In [1]:
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from torchvision.transforms.functional import to_pil_image
from sklearn.metrics import f1_score, accuracy_score

# personal coded classes
from unet import UNet
from apply_augmentation import DataAugmentation
from dynamic_augmentation_pipeline import RoadSegmentationDataset
from test_model import test_model
from train_workflow import train_workflow

In [2]:
# Set paths
testing_dir = "test_set_images/"
training_dir = "training/"
image_dir = os.path.join(training_dir, "images/")
gt_dir = os.path.join(training_dir, "groundtruth/")

In [3]:
def load_images_and_masks(image_dir, gt_dir, valid_extensions=(".png", ".jpg", ".jpeg")):
    """
    Load images and their corresponding ground truth masks.

    Args:
        image_dir (str): Directory containing the input images.
        gt_dir (str): Directory containing the ground truth masks.
        valid_extensions (tuple): Tuple of valid image file extensions.

    Returns:
        list: List of images as PIL.Image objects.
        list: List of ground truth masks as PIL.Image objects.
    """
    try:
        image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(valid_extensions)])
        gt_files = sorted([f for f in os.listdir(gt_dir) if f.endswith(valid_extensions)])
        
        # Check if the number of images matches the number of masks
        if len(image_files) != len(gt_files):
            raise ValueError("Mismatch between number of images and ground truth masks.")

        images = [Image.open(os.path.join(image_dir, f)) for f in image_files]
        gt_masks = [Image.open(os.path.join(gt_dir, f)) for f in gt_files]
        
        print(f"Loaded {len(images)} images and {len(gt_masks)} ground truth masks.")
        return images, gt_masks
    
    except Exception as e:
        print(f"Error loading images and masks: {e}")
        return [], []

In [4]:
images, gt_masks = load_images_and_masks(image_dir, gt_dir)
print(f"Loaded {len(images)} images and {len(gt_masks)} ground truth masks.")

Loaded 100 images and 100 ground truth masks.
Loaded 100 images and 100 ground truth masks.


In [5]:
# Key model parameters
IMG_PATCH_SIZE = 16  # Patch size, must align with image dimensions and be a factor of image width/height

# Data parameters
TRAINING_SIZE = 0.8
VALIDATION_SIZE = 0.2
BATCH_SIZE = 32  # Larger batch sizes are preferred if memory allows; smaller for fine-tuning

# Training parameters
NUM_EPOCHS = 50  # Start with 50 and adjust based on observed convergence behavior
LEARNING_RATE = 1e-4  # starting point for Adam optimizer
WEIGHT_DECAY = 1e-4  # L2 regularization to prevent overfitting
PATIENCE = 10  # Stop training if no improvement for 10 epochs

# Randomization
SEED = 66478  # Fixed seed for reproducibility (set to None for random seed)

# Model restoration
RESTORE_MODEL = False  # Set to True if you want to resume training from a previous checkpoint

In [6]:
# Extract patches from a given image
def img_crop(im, w, h):
    list_patches = []
    imgwidth = im.shape[0]
    imgheight = im.shape[1]
    is_2d = len(im.shape) < 3
    for i in range(0, imgheight, h):
        for j in range(0, imgwidth, w):
            if is_2d:
                im_patch = im[j : j + w, i : i + h]
            else:
                im_patch = im[j : j + w, i : i + h, :]
            list_patches.append(im_patch)
    return list_patches

def extract_data(filename, num_images):
    """Extract the images into a 4D tensor [image index, y, x, channels].
    Values are rescaled from [0, 255] down to [-0.5, 0.5].
    """
    imgs = []
    for i in range(1, num_images + 1):
        imageid = "satImage_%.3d" % i
        image_filename = filename + imageid + ".png"
        if os.path.isfile(image_filename):
            print("Loading " + image_filename)
            img = mpimg.imread(image_filename)
            imgs.append(img)
        else:
            print("File " + image_filename + " does not exist")


    img_patches = [
        img_crop(imgs[i], IMG_PATCH_SIZE, IMG_PATCH_SIZE) for i in range(len(imgs))
    ]
    data = [
        img_patches[i][j]
        for i in range(len(img_patches))
        for j in range(len(img_patches[i]))
    ]

    return np.asarray(data)

# Assign a label to a patch v
def value_to_class(v):
    """
    Assign a label to a patch based on the number of road pixels.
    """
    foreground_threshold = 0.25  # Percentage of road pixels to classify as foreground
    df = np.mean(v)  # Calculate the proportion of road pixels
    if df > foreground_threshold:  # More road pixels than the threshold
        return [0, 1]  # road
    else:  # Less road pixels
        return [1, 0]  # background

# Extract label images
def extract_labels(filename, num_images):
    """Extract the labels into a 2D array of class indices [image index, label index]."""
    gt_imgs = []
    for i in range(1, num_images + 1):
        imageid = "satImage_%.3d" % i
        image_filename = filename + imageid + ".png"
        if os.path.isfile(image_filename):
            print("Loading " + image_filename)
            img = mpimg.imread(image_filename)
            gt_imgs.append(img)
        else:
            print("File " + image_filename + " does not exist")

    num_images = len(gt_imgs)
    gt_patches = [
        img_crop(gt_imgs[i], IMG_PATCH_SIZE, IMG_PATCH_SIZE) for i in range(num_images)
    ]

    # Convert each patch to a single class label (0 for background, 1 for road)
    data = np.asarray(
        [gt_patches[i][j] for i in range(len(gt_patches)) for j in range(len(gt_patches[i]))]
    )
    labels = np.asarray([1 if np.mean(data[i]) > 0.25 else 0 for i in range(len(data))])  # Class indices (not one-hot)

    # Return as integer class indices
    return labels.astype(np.int64)

In [7]:
# Extract patches and labels
image_patches = extract_data(image_dir, len(images))
label_patches = extract_labels(gt_dir, len(gt_masks))

print(f"Extracted {image_patches.shape[0]} image patches of size {IMG_PATCH_SIZE}x{IMG_PATCH_SIZE}")
print(f"Extracted {label_patches.shape[0]} label patches of size {IMG_PATCH_SIZE}x{IMG_PATCH_SIZE}")

Loading training/images/satImage_001.png
Loading training/images/satImage_002.png
Loading training/images/satImage_003.png
Loading training/images/satImage_004.png
Loading training/images/satImage_005.png
Loading training/images/satImage_006.png
Loading training/images/satImage_007.png
Loading training/images/satImage_008.png
Loading training/images/satImage_009.png
Loading training/images/satImage_010.png
Loading training/images/satImage_011.png
Loading training/images/satImage_012.png
Loading training/images/satImage_013.png
Loading training/images/satImage_014.png
Loading training/images/satImage_015.png
Loading training/images/satImage_016.png
Loading training/images/satImage_017.png
Loading training/images/satImage_018.png
Loading training/images/satImage_019.png
Loading training/images/satImage_020.png
Loading training/images/satImage_021.png
Loading training/images/satImage_022.png
Loading training/images/satImage_023.png
Loading training/images/satImage_024.png
Loading training

In [8]:
# Convert image and label patches to PyTorch tensors
image_patches = torch.from_numpy(image_patches).float().permute(0, 3, 1, 2)  # Shape: [batch_size, num_channels, patch_size, patch_size]
label_patches = torch.from_numpy(label_patches).long() # Shape: [batch_size]

# Print the shapes of tensors
print(f"Image Patches Tensor Shape: {image_patches.shape}")
print(f"Label Patches Tensor Shape: {label_patches.shape}")

Image Patches Tensor Shape: torch.Size([62500, 3, 16, 16])
Label Patches Tensor Shape: torch.Size([62500])


In [9]:
# Split into training and validation datasets
train_images, val_images, train_labels, val_labels = train_test_split(
    image_patches, label_patches, test_size=VALIDATION_SIZE, random_state=SEED
)

# Print the sizes of training and validation sets
print(f"Training Images: {train_images.shape[0]}, Training Labels: {train_labels.shape[0]}")
print(f"Validation Images: {val_images.shape[0]}, Validation Labels: {val_labels.shape[0]}")

Training Images: 50000, Training Labels: 50000
Validation Images: 12500, Validation Labels: 12500


In [10]:
# Set the computation device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [11]:
# Instantiate the U-Net model and move it to the device
model = UNet().to(device)
print("U-Net model instantiated and moved to:", device)

U-Net model instantiated and moved to: cuda


In [12]:
class_weights = torch.tensor([0.668, 1.987]).to(device) # Class weights for imbalance
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

In [13]:
# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Initialize ReduceLROnPlateau scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",              # Maximize the F1 Score
    factor=0.5,              # Reduce LR by a factor of 0.5
    patience=3,              # Wait for 3 epochs without improvement
    threshold=1e-4,          # Minimal change to qualify as improvement
)

In [14]:
# Initialize data augmentation
data_augmentation = DataAugmentation(img_width = 400, img_height = 400, patch_size = IMG_PATCH_SIZE)

# Create training and validation datasets
train_dataset = RoadSegmentationDataset(
    images=train_images, 
    labels=train_labels, 
    augmentations=data_augmentation, 
)
val_dataset = RoadSegmentationDataset(
    images=val_images, 
    labels=val_labels, 
    augmentations=None, 
)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [15]:
# Run the training workflow
history = train_workflow(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    device=device,
    num_epochs=NUM_EPOCHS,
    patience=PATIENCE,
    save_path="road_segmentation_best_model.pth",
)

IndexError: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number

In [16]:
def load_test_images(testing_dir, valid_extensions=(".png", ".jpg", ".jpeg")):
    """
    Load test images from the directory.

    Args:
        testing_dir (str): Path to the directory containing test image folders.
        valid_extensions (tuple): Valid image file extensions.

    Returns:
        list: List of test images as PIL.Image objects.
        list: Corresponding file names for identification.
    """
    test_images = []
    test_filenames = []
    
    for folder in sorted(os.listdir(testing_dir)):
        folder_path = os.path.join(testing_dir, folder)
        if os.path.isdir(folder_path):
            image_filename = f"{folder}.png"
            image_path = os.path.join(folder_path, image_filename)
            if os.path.isfile(image_path) and image_filename.endswith(valid_extensions):
                try:
                    img = Image.open(image_path).convert("RGB")  # Ensure images are RGB
                    test_images.append(img)
                    test_filenames.append(folder)  # Save folder name for identification
                except Exception as e:
                    print(f"Error loading image {image_path}: {e}")
    
    print(f"Loaded {len(test_images)} test images.")
    return test_images, test_filenames

In [32]:
# Convert array of labels to an image
def label_to_img(imgwidth, imgheight, w, h, labels):
    array_labels = np.zeros([imgwidth, imgheight])
    idx = 0
    for i in range(0, imgheight, h):
        for j in range(0, imgwidth, w):
            # Assign class label based on the first channel
            if labels[idx][0] > 0.5:  # background
                l = 0
            else:  # road
                l = 1
            # Fill the region in the image with the assigned label
            array_labels[j:j + w, i:i + h] = l
            idx += 1
    return array_labels

In [None]:
def img_float_to_uint8(img):
    rimg = img - np.min(img)
    rimg = (rimg / np.max(rimg) * PIXEL_DEPTH).round().astype(np.uint8)
    return rimg

In [None]:
def concatenate_images(img, gt_img):
    n_channels = len(gt_img.shape)
    w = gt_img.shape[0]
    h = gt_img.shape[1]
    if n_channels == 3:
        cimg = np.concatenate((img, gt_img), axis=1)
    else:
        gt_img_3c = np.zeros((w, h, 3), dtype=np.uint8)
        gt_img8 = img_float_to_uint8(gt_img)
        gt_img_3c[:, :, 0] = gt_img8
        gt_img_3c[:, :, 1] = gt_img8
        gt_img_3c[:, :, 2] = gt_img8
        img8 = img_float_to_uint8(img)
        cimg = np.concatenate((img8, gt_img_3c), axis=1)
    return cimg

In [None]:
def make_img_overlay(img, predicted_img):
    """
    Create an overlay of predictions on the input image.
    """
    w, h = img.shape[:2]
    color_mask = np.zeros((w, h, 3), dtype=np.uint8)
    color_mask[:, :, 0] = predicted_img.astype(np.uint8) * 255  # Red channel for predictions

    img8 = img_float_to_uint8(img)
    background = Image.fromarray(img8, "RGB").convert("RGBA")
    overlay = Image.fromarray(color_mask, "RGB").convert("RGBA")
    return Image.blend(background, overlay, 0.5)