In [1]:
pip install rasterio



In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import models
import torchvision.transforms as transforms
from torchvision.transforms import Compose, Resize, v2
from torchvision.transforms.functional import to_tensor, hflip, vflip, rotate, adjust_gamma
from torchvision.models.segmentation.deeplabv3 import DeepLabHead

import os
import time
import random

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support

from PIL import Image
import rasterio

In [3]:
class HarveyData(Dataset):
    #dataset_dir: Provide a path to either "./dataset/training" or "./dataset/testing"
    #transforms: Any transformations that should be performed on the image when retrieved.
    def __init__(self, dataset_dir, image_size = 224, augment_data=True):
        super(HarveyData, self).__init__()
        self.dataset_dir = dataset_dir
        self.image_size = image_size
        self.augment_data = augment_data

        self.pre_image_paths = sorted(os.listdir(os.path.join(dataset_dir, 'pre_img')))
        self.post_image_paths = sorted(os.listdir(os.path.join(dataset_dir, 'post_img')))
        self.mask_paths = sorted(os.listdir(os.path.join(dataset_dir, 'PDE_labels')))
        self.elevation_paths = sorted(os.listdir(os.path.join(dataset_dir, 'elevation')))
        self.hand_paths = sorted(os.listdir(os.path.join(dataset_dir, 'hand')))
        self.imperviousness_paths = sorted(os.listdir(os.path.join(dataset_dir, 'imperviousness')))

        self.pre_images = []
        self.post_images = []
        self.masks = []

        self.elevation = []
        self.hand = []
        self.imperviousness = []

        self.num_images = len(self.pre_image_paths)

        for i in range(self.num_images):
            pre_image = Image.open(os.path.join(dataset_dir, 'pre_img', self.pre_image_paths[i]))
            post_image = Image.open(os.path.join(dataset_dir, 'post_img', self.post_image_paths[i]))
            mask = Image.open(os.path.join(dataset_dir, 'PDE_labels', self.mask_paths[i])).convert('L')

            with rasterio.open(os.path.join(dataset_dir, 'elevation', self.elevation_paths[i])) as src:
                elevation = src.read(1)
                elevation = torch.tensor(elevation).unsqueeze(0)
            with rasterio.open(os.path.join(dataset_dir, 'hand', self.hand_paths[i])) as src:
                hand = src.read(1)
                hand = torch.tensor(hand).unsqueeze(0)
            with rasterio.open(os.path.join(dataset_dir, 'imperviousness', self.imperviousness_paths[i])) as src:
                imperviousness = src.read(1)
                imperviousness = torch.tensor(imperviousness).unsqueeze(0)

            self.pre_images.append(pre_image)
            self.post_images.append(post_image)
            self.masks.append(mask)

            self.elevation.append(elevation)
            self.hand.append(hand)
            self.imperviousness.append(imperviousness)

    def __getitem__(self, idx):
        #Get pre and post image, and the mask, for the current index.
        pre_image = self.pre_images[idx]
        post_image = self.post_images[idx]
        mask = self.masks[idx]

        elevation = self.elevation[idx]
        hand = self.hand[idx]
        imperviousness = self.imperviousness[idx]

        image_transforms = v2.Compose([
                           v2.ToImage(),
                           v2.ToDtype(torch.float32, scale=True),
                           v2.Resize((self.image_size, self.image_size), antialias=True),
                           v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))  #These are the normalization values used by the pretrained weights in DeepLabv3
        ])
        mask_transforms = v2.Compose([
                          v2.ToImage(),
                          v2.ToDtype(torch.int64, scale=False),
                          v2.Resize((self.image_size, self.image_size), antialias=True)
        ])
        meta_transforms = v2.Compose([
                          v2.ToImage(),
                          v2.ToDtype(torch.float32, scale=True),
                          v2.Resize((self.image_size, self.image_size), antialias=True),
                          v2.Grayscale()
        ])

        pre_image = image_transforms(pre_image)
        post_image = image_transforms(post_image)
        mask = mask_transforms(mask)

        elevation = meta_transforms(elevation)
        hand = meta_transforms(hand)
        imperviousness = meta_transforms(imperviousness)

        if self.augment_data:
            augmentation_switches = {0, 1, 2, 3, 4, 5, 6}
            augment_mode_1 = np.random.choice(list(augmentation_switches))
            augmentation_switches.remove(augment_mode_1)

            additional_augment_chance = np.random.random()
            augment_mode_2 = -1
            augment_mode_3 = -1

            if (additional_augment_chance > 0.5):
                augment_mode_2 = np.random.choice(list(augmentation_switches))
                augmentation_switches.remove(augment_mode_2)
            #if (additional_augment_chance > 0.8):
                #augment_mode_3 = np.random.choice(list(augmentation_switches))
                #augmentation_switches.remove(augment_mode_3)

            if augment_mode_1 or augment_mode_2 or augment_mode_3 == 0:
                # flip image vertically
                pre_image = vflip(pre_image)
                post_image = vflip(post_image)

                elevation = vflip(elevation)
                hand = vflip(hand)
                imperviousness = vflip(imperviousness)

                mask = vflip(mask)
            elif augment_mode_1 or augment_mode_2 or augment_mode_3 == 1:
                # flip image horizontally
                pre_image = hflip(pre_image)
                post_image = hflip(post_image)

                elevation = hflip(elevation)
                hand = hflip(hand)
                imperviousness = hflip(imperviousness)

                mask = hflip(mask)
            elif augment_mode_1 or augment_mode_2 or augment_mode_3 == 2:
                # zoom image
                zoom = v2.RandomResizedCrop(self.size, antialias=True)

                pre_image = zoom(pre_image)
                post_image = zoom(post_image)

                elevation = zoom(elevation)
                hand = zoom(hand)
                imperviousness = zoom(imperviousness)

                mask = zoom(mask)
            elif augment_mode_1 or augment_mode_2 or augment_mode_3 == 3:
                # modify gamma
                min_gamma = 0.25
                gamma_range = 2.25
                gamma = gamma_range * np.random.random() + min_gamma

                pre_image = adjust_gamma(pre_image, gamma)
                post_image = adjust_gamma(post_image, gamma)

                elevation = adjust_gamma(elevation)
                hand = adjust_gamma(hand)
                imperviousness = adjust_gamma(imperviousness)

                mask = adjust_gamma(mask, gamma)
            elif augment_mode_1 or augment_mode_2 or augment_mode_3 == 4:
                # perform elastic transformation
                elastic = v2.ElasticTransform(sigma=10)

                pre_image = elastic(pre_image)
                post_image = elastic(post_image)

                elevation = elastic(elevation)
                hand = elastic(hand)
                imperviousness = elastic(imperviousness)

                mask = elastic(mask)
            elif augment_mode_1 or augment_mode_2 or augment_mode_3 == 5:
                # modify brightness/contrast/saturation/hue
                jitter = v2.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25)

                pre_image = jitter(pre_image)
                post_image = jitter(post_image)

                elevation = jitter(elevation)
                hand = jitter(hand)
                imperviousness = jitter(imperviousness)

                mask = jitter(mask)
            elif augment_mode_1 or augment_mode_2 or augment_mode_3 == 6:
                # rotate image
                random_degree = random.randint(1, 359)

                pre_image = rotate(pre_image, random_degree)
                post_image = rotate(post_image, random_degree)

                elevation = rotate(elevation, random_degree)
                hand = rotate(hand, random_degree)
                imperviousness = rotate(imperviousness, random_degree)

                mask = rotate(mask, random_degree)

        #Concatenate the pre and post disaster images, as well as the meta-attributes, together along the channel dimension.
        #combined_image = torch.cat([pre_image, post_image, elevation, imperviousness], dim=0)
        combined_image = torch.cat([pre_image, post_image], dim=0)
        return combined_image, mask

    def get_item_resize_only(self, idx, image_size):
        #Get pre and post image, and the mask, for the current index.
        pre_image = self.pre_images[idx]
        post_image = self.post_images[idx]
        mask = self.masks[idx]

        elevation = self.elevation[idx]
        hand = self.hand[idx]
        imperviousness = self.imperviousness[idx]

        #Convert image to normalized tensor.
        pre_image = to_tensor(pre_image)
        post_image = to_tensor(post_image)

        mask = to_tensor(mask)
        mask *= 255  # Manually adjust the label values back to the original values after the normalization of to_tensor()

        elevation = to_tensor(elevation)
        hand = to_tensor(hand)
        imperviousness = to_tensor(imperviousness)

        #Resize the images to the same size as was used during training.
        resize = v2.Compose([v2.Resize((image_size, image_size), antialias=True)])
        pre_image = resize(pre_image)
        post_image = resize(post_image)
        mask = resize(mask)

        elevation = resize(elevation)
        hand = resize(hand)
        imperviousness = resize(imperviousness)

        #Concatenate the pre and post disaster images, as well as the meta attributes, together along the channel dimension.
        combined_image = torch.cat([pre_image, post_image, elevation, imperviousness], dim=0)
        return combined_image, mask

    def __len__(self):
        return self.num_images

In [4]:
class DeepLabV3(nn.Module):
    def __init__(self, num_input_channels, num_classes):
        super(DeepLabV3, self).__init__()
        self.deeplabv3_weights = torchvision.models.segmentation.DeepLabV3_ResNet101_Weights.DEFAULT
        self.resnet101_weights = models.ResNet101_Weights.DEFAULT
        self.deeplabv3 = torchvision.models.segmentation.deeplabv3_resnet101(weights=self.deeplabv3_weights, weights_backbone=self.resnet101_weights)

        #Replaces the first convolution of the backbone of the model to accept 6-channel input.
        self.deeplabv3.backbone.conv1 = nn.Conv2d(num_input_channels, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)

        #Replaces the final classifier to change the number of output classes to 4.
        self.deeplabv3.classifier[-1] = torch.nn.Conv2d(in_channels=256, out_channels=num_classes, kernel_size=1, stride=1)

    def forward(self, x):
        x = self.deeplabv3.forward(x)
        return x

In [5]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, input, target):
        ce_loss = F.cross_entropy(input, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        if self.alpha is not None:
            focal_loss = self.alpha * focal_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [6]:
def visualize_results(num_results, predictions, images=None, masks=None, randomize_images=False):
    fig, axes = plt.subplots(num_results, 3, figsize=(32, 32))

    predictions_flat = [item for sublist in predictions for item in sublist]
    if (images != None):
        images_flat = [item for sublist in images for item in sublist]
    if (masks != None):
        masks_flat = [item for sublist in masks for item in sublist]

    if (randomize_images):
        # Choose num_results number of images at random from the results.
        image_idxs = random.sample(range(0, len(predictions_flat) - 1), num_results)
    else:
        image_idxs = [i for i in range(1, num_results + 2)]

    for i in range(num_results):
        # Plot the input image and ground truth mask
        if (images == None or masks == None):
            image, mask = test_dataset.get_item_resize_only(image_idxs[i], image_size)

            #Reorder the channels for matplotlib.
            image = torch.permute(image, (1, 2, 0))
            mask = torch.permute(mask, (1, 2, 0))

            axes[i, 0].imshow(image.numpy()[:, :, 0:3], aspect='equal')
            axes[i, 0].imshow(image.numpy()[:, :, 3:6], alpha=0.5, aspect='equal')
            #axes[i, 0].imshow(image.numpy()[:, :, 6:7], alpha=0.5, aspect='equal')
            #axes[i, 0].imshow(image.numpy()[:, :, 7:8], alpha=0.5, aspect='equal')
            axes[i, 2].imshow(mask.numpy(), cmap="viridis", aspect='equal')
        else:
            image = images_flat[image_idxs[i]]
            mask = masks_flat[image_idxs[i]]

            #Reorder the channels for matplotlib.
            image = np.transpose(image, (1, 2, 0))
            #mask = np.transpose(mask, (1, 2, 0))

            axes[i, 0].imshow(image[:, :, 0:3], aspect='equal')
            axes[i, 0].imshow(image[:, :, 3:6], alpha=0.5, aspect='equal')
            #axes[i, 0].imshow(image[:, :, 6:7], alpha=0.5, aspect='equal')
            #axes[i, 0].imshow(image[:, :, 7:8], alpha=0.5, aspect='equal')
            axes[i, 2].imshow(mask, cmap="viridis", aspect='equal')

        axes[i, 0].set_title("Combined Image")
        axes[i, 0].axis('off')

        axes[i, 2].set_title("Ground Truth Mask")
        axes[i, 2].axis('off')

        # Plot the predicted image
        axes[i, 1].imshow(predictions_flat[image_idxs[i]], cmap="viridis", aspect='equal')
        axes[i, 1].set_title("Predicted Image")
        axes[i, 1].axis('off')

    plt.show()

batch_size = 2
num_input_channels = 6
num_classes = 5
lr = 1e-4
image_size = 800  #520x520 is the image size used by the pretrained weights in DeepLabv3
# Whether the models parameters should be saved following the completion of a run.
save = True
#Whether an existing models parameters should be loaded before the run.
load = False

cwd = os.getcwd()

print("Loading train and test images")
train_dataset = HarveyData(os.path.join(cwd, 'drive/MyDrive/Flood Damage Extent Detection/dataset/training'), image_size=image_size)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = HarveyData(os.path.join(cwd, 'drive/MyDrive/Flood Damage Extent Detection/dataset/testing'), image_size=image_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
print("Finished loading images")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = DeepLabV3(num_input_channels, num_classes)
if (load):
    if (os.path.exists('drive/MyDrive/Flood Damage Extent Detection/DeepLabv3.pt')):
        print("Loading model.")
        model.load_state_dict(torch.load('drive/MyDrive/Flood Damage Extent Detection/DeepLabv3.pt'))
    else:
        print('Could not load model. File does not exist.')
model.to(device)
#model_preprocess = model.deeplabv3_weights.transforms()

criterion = FocalLoss(reduction='sum')#torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0005)

softmax = nn.Softmax(dim=1)

num_epochs = 50

images = []
masks = []
predicted_images = []

#Training
start_time = time.time()
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for i, data in enumerate(train_dataloader):
        image, mask = data

        image = image.to(device)
        mask = mask.squeeze().to(device)

        outputs = softmax(model(image)['out'])

        loss = criterion(outputs, mask)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.item()

        print('Batch %d --- Loss: %.4f' % (i, loss.item() / batch_size))
    print('Epoch %d / %d --- Average Loss: %.4f' % (epoch + 1, num_epochs, epoch_loss / train_dataset.__len__()))

    total_loss = 0.0

    total_macro_precision = 0.0
    total_macro_recall = 0.0
    total_macro_f1 = 0.0

    total_class_precision = [0.0, 0.0, 0.0, 0.0, 0.0]
    total_class_recall = [0.0, 0.0, 0.0, 0.0, 0.0]
    total_class_f1 = [0.0, 0.0, 0.0, 0.0, 0.0]

#Testing
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(test_dataloader):
            image, mask = data

            image = image.to(device)
            mask = mask.squeeze().to(device)

            outputs = softmax(model(image)['out'])

            loss = criterion(outputs, mask)
            total_loss += loss.item()

            predicted = torch.argmax(outputs, dim=1, keepdim=False)

            image = image.cpu().numpy()
            mask = mask.cpu().numpy()
            predicted = predicted.cpu().numpy()

            for i in range(len(mask)):
                # Calculate scores globally.
                precision, recall, f1, _ = precision_recall_fscore_support(mask[i].flatten(), predicted[i].flatten(), average='macro', zero_division=0.0)
                total_macro_precision += precision
                total_macro_recall += recall
                total_macro_f1 += f1

                # Calculate scores by class.
                precision, recall, f1, _ = precision_recall_fscore_support(mask[i].flatten(), predicted[i].flatten(), labels=[0, 1, 2, 3, 4], average=None, zero_division=0.0)
                total_class_precision += precision
                total_class_recall += recall
                total_class_f1 += f1

            if (epoch + 1 == num_epochs):
                images.append(image)
                masks.append(mask)
                predicted_images.append(predicted)

    average_loss = total_loss / len(test_dataset)

    average_macro_precision = total_macro_precision / len(test_dataset)
    average_macro_recall = total_macro_recall / len(test_dataset)
    average_macro_f1 = total_macro_f1 / len(test_dataset)

    average_class_precision = total_class_precision / len(test_dataset)
    average_class_recall = total_class_recall / len(test_dataset)
    average_class_f1 = total_class_f1 / len(test_dataset)

    print('Average Macro Precision: %.4f ---- Average Macro Recall: %.4f ---- Average F1 Score: %.4f ---- Average Loss: %.4f' % (average_macro_precision, average_macro_recall, average_macro_f1, average_loss))
    print('Average No Damage Precision: %.4f ---- Average No Damage Recall: %.4f ---- Average No Damage F1: %.4f' % (average_class_precision[0], average_class_recall[0], average_class_f1[0]))
    print('Average Minor Precision: %.4f ---- Average Minor Recall: %.4f ---- Average Minor F1: %.4f' % (average_class_precision[1], average_class_recall[1], average_class_f1[1]))
    print('Average Moderate Precision: %.4f ---- Average Moderate Recall: %.4f ---- Average Moderate F1: %.4f' % (average_class_precision[2], average_class_recall[2], average_class_f1[2]))
    print('Average Major Precision: %.4f ---- Average Major Recall: %.4f ---- Average Major F1: %.4f' % (average_class_precision[3], average_class_recall[3], average_class_f1[3]))
    print('Average Background Precision: %.4f ---- Average Background Recall: %.4f ---- Average Background F1: %.4f' % (average_class_precision[4], average_class_recall[4], average_class_f1[4]))

    if (epoch + 1 == num_epochs):
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Elapsed Time at Epoch {epoch + 1} : {elapsed_time} seconds")

        if save:
            torch.save(model.state_dict(), 'drive/MyDrive/Flood Damage Extent Detection/DeepLabv3.pt')

        visualize_results(6, predicted_images, images, masks)

        images.clear()
        masks.clear()
        predicted_images.clear()

        start_time = time.time()

Output hidden; open in https://colab.research.google.com to view.