In [None]:
pip install rasterio

Collecting rasterio
  Downloading rasterio-1.3.9-cp310-cp310-manylinux2014_x86_64.whl (20.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.6/20.6 MB[0m [31m68.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Collecting snuggs>=1.4.1 (from rasterio)
  Downloading snuggs-1.4.7-py3-none-any.whl (5.4 kB)
Installing collected packages: snuggs, affine, rasterio
Successfully installed affine-2.4.0 rasterio-1.3.9 snuggs-1.4.7


In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import models
import torchvision.transforms as transforms
from torchvision.transforms import Compose, Resize, v2
from torchvision.transforms.functional import to_tensor, hflip, vflip, rotate, adjust_gamma
from torchvision.models.segmentation.deeplabv3 import DeepLabHead

import os
import time
import random

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support

from PIL import Image
import rasterio

In [None]:
class HarveyData(Dataset):
    #dataset_dir: Provide a path to either "./dataset/training" or "./dataset/testing"
    #transforms: Any transformations that should be performed on the image when retrieved.
    def __init__(self, dataset_dir, image_size = 224, augment_data=True, verbose_logging=False):
        super(HarveyData, self).__init__()
        self.dataset_dir = dataset_dir
        self.image_size = image_size
        self.augment_data = augment_data

        self.pre_image_paths = sorted(os.listdir(os.path.join(dataset_dir, 'pre_img')))
        self.post_image_paths = sorted(os.listdir(os.path.join(dataset_dir, 'post_img')))
        self.mask_paths = sorted(os.listdir(os.path.join(dataset_dir, 'PDE_labels')))
        self.elevation_paths = sorted(os.listdir(os.path.join(dataset_dir, 'elevation')))
        self.hand_paths = sorted(os.listdir(os.path.join(dataset_dir, 'hand')))
        self.imperviousness_paths = sorted(os.listdir(os.path.join(dataset_dir, 'imperviousness')))
        self.distance_coast_paths = sorted(os.listdir(os.path.join(dataset_dir, 'distance_to_coast')))
        self.distance_stream_paths = sorted(os.listdir(os.path.join(dataset_dir, 'distance_to_stream')))

        self.rain_824_paths = sorted(os.listdir(os.path.join(dataset_dir, 'rain/824')))
        self.rain_825_paths = sorted(os.listdir(os.path.join(dataset_dir, 'rain/825')))
        self.rain_826_paths = sorted(os.listdir(os.path.join(dataset_dir, 'rain/826')))
        self.rain_827_paths = sorted(os.listdir(os.path.join(dataset_dir, 'rain/827')))
        self.rain_828_paths = sorted(os.listdir(os.path.join(dataset_dir, 'rain/828')))
        self.rain_829_paths = sorted(os.listdir(os.path.join(dataset_dir, 'rain/829')))
        self.rain_830_paths = sorted(os.listdir(os.path.join(dataset_dir, 'rain/830')))

        self.stream_elev_824_paths = sorted(os.listdir(os.path.join(dataset_dir, 'stream_elev/824')))
        self.stream_elev_825_paths = sorted(os.listdir(os.path.join(dataset_dir, 'stream_elev/825')))
        self.stream_elev_826_paths = sorted(os.listdir(os.path.join(dataset_dir, 'stream_elev/826')))
        self.stream_elev_827_paths = sorted(os.listdir(os.path.join(dataset_dir, 'stream_elev/827')))
        self.stream_elev_828_paths = sorted(os.listdir(os.path.join(dataset_dir, 'stream_elev/828')))
        self.stream_elev_829_paths = sorted(os.listdir(os.path.join(dataset_dir, 'stream_elev/829')))
        self.stream_elev_830_paths = sorted(os.listdir(os.path.join(dataset_dir, 'stream_elev/830')))

        self.pre_images = []
        self.post_images = []
        self.masks = []

        self.elevation = []
        self.hand = []
        self.imperviousness = []
        self.distance_coast = []
        self.distance_stream = []

        self.rain_824 = []
        self.rain_825 = []
        self.rain_826 = []
        self.rain_827 = []
        self.rain_828 = []
        self.rain_829 = []
        self.rain_830 = []

        self.stream_elev_824 = []
        self.stream_elev_825 = []
        self.stream_elev_826 = []
        self.stream_elev_827 = []
        self.stream_elev_828 = []
        self.stream_elev_829 = []
        self.stream_elev_830 = []

        self.num_images = len(self.pre_image_paths)

        for i in range(self.num_images):
            with rasterio.open(os.path.join(dataset_dir, 'pre_img', self.pre_image_paths[i])) as src:
                pre_image = src.read()
                pre_image = torch.tensor(pre_image, dtype=torch.float32)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'pre_img', self.pre_image_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'post_img', self.post_image_paths[i])) as src:
                post_image = src.read()
                post_image = torch.tensor(post_image, dtype=torch.float32)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'post_img', self.post_image_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'PDE_labels', self.mask_paths[i])) as src:
                mask = src.read(1)
                mask = torch.tensor(mask, dtype=torch.int64).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'PDE_labels', self.mask_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'elevation', self.elevation_paths[i])) as src:
                elevation = src.read(1)
                elevation = torch.tensor(elevation, dtype=torch.float32).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'elevation', self.elevation_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'hand', self.hand_paths[i])) as src:
                hand = src.read(1)
                hand = torch.tensor(hand, dtype=torch.int16).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'hand', self.hand_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'imperviousness', self.imperviousness_paths[i])) as src:
                imperviousness = src.read(1)
                imperviousness = torch.tensor(imperviousness, dtype=torch.float32).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'imperviousness', self.imperviousness_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'distance_to_coast', self.distance_coast_paths[i])) as src:
                distance_coast = src.read(1)
                distance_coast = torch.tensor(distance_coast, dtype=torch.float32).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'distance_to_coast', self.distance_coast_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'distance_to_stream', self.distance_stream_paths[i])) as src:
                distance_stream = src.read(1)
                distance_stream = torch.tensor(distance_stream, dtype=torch.float32).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'distance_to_stream', self.distance_stream_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'rain/824', self.rain_824_paths[i])) as src:
                rain_824 = src.read(1)
                rain_824 = torch.tensor(rain_824).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'rain/824', self.rain_824_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'rain/825', self.rain_825_paths[i])) as src:
                rain_825 = src.read(1)
                rain_825 = torch.tensor(rain_825).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'rain/825', self.rain_825_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'rain/826', self.rain_826_paths[i])) as src:
                rain_826 = src.read(1)
                rain_826 = torch.tensor(rain_826).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'rain/826', self.rain_826_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'rain/827', self.rain_827_paths[i])) as src:
                rain_827 = src.read(1)
                rain_827 = torch.tensor(rain_827).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'rain/827', self.rain_827_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'rain/828', self.rain_828_paths[i])) as src:
                rain_828 = src.read(1)
                rain_828 = torch.tensor(rain_828).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'rain/828', self.rain_828_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'rain/829', self.rain_829_paths[i])) as src:
                rain_829 = src.read(1)
                rain_829 = torch.tensor(rain_829).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'rain/829', self.rain_829_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'rain/830', self.rain_830_paths[i])) as src:
                rain_830 = src.read(1)
                rain_830 = torch.tensor(rain_830).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'rain/830', self.rain_830_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'stream_elev/824', self.stream_elev_824_paths[i])) as src:
                stream_elev_824 = src.read(1)
                stream_elev_824 = torch.tensor(stream_elev_824).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'stream_elev/824', self.stream_elev_824_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'stream_elev/825', self.stream_elev_825_paths[i])) as src:
                stream_elev_825 = src.read(1)
                stream_elev_825 = torch.tensor(stream_elev_825).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'stream_elev/824', self.stream_elev_825_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'stream_elev/826', self.stream_elev_826_paths[i])) as src:
                stream_elev_826 = src.read(1)
                stream_elev_826 = torch.tensor(stream_elev_826).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'stream_elev/824', self.stream_elev_826_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'stream_elev/827', self.stream_elev_827_paths[i])) as src:
                stream_elev_827 = src.read(1)
                stream_elev_827 = torch.tensor(stream_elev_827).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'stream_elev/824', self.stream_elev_827_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'stream_elev/828', self.stream_elev_828_paths[i])) as src:
                stream_elev_828 = src.read(1)
                stream_elev_828 = torch.tensor(stream_elev_828).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'stream_elev/824', self.stream_elev_828_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'stream_elev/829', self.stream_elev_829_paths[i])) as src:
                stream_elev_829 = src.read(1)
                stream_elev_829 = torch.tensor(stream_elev_829).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'stream_elev/824', self.stream_elev_829_paths[i]))

            with rasterio.open(os.path.join(dataset_dir, 'stream_elev/830', self.stream_elev_830_paths[i])) as src:
                stream_elev_830 = src.read(1)
                stream_elev_830 = torch.tensor(stream_elev_830).unsqueeze(0)
                if verbose_logging: print("Loading " + os.path.join(dataset_dir, 'stream_elev/824', self.stream_elev_830_paths[i]))

            self.pre_images.append(pre_image)
            self.post_images.append(post_image)
            self.masks.append(mask)

            self.elevation.append(elevation)
            self.hand.append(hand)
            self.imperviousness.append(imperviousness)
            self.distance_coast.append(distance_coast)
            self.distance_stream.append(distance_stream)

            self.rain_824.append(rain_824)
            self.rain_825.append(rain_825)
            self.rain_826.append(rain_826)
            self.rain_827.append(rain_827)
            self.rain_828.append(rain_828)
            self.rain_829.append(rain_829)
            self.rain_830.append(rain_830)

            self.stream_elev_824.append(stream_elev_824)
            self.stream_elev_825.append(stream_elev_825)
            self.stream_elev_826.append(stream_elev_826)
            self.stream_elev_827.append(stream_elev_827)
            self.stream_elev_828.append(stream_elev_828)
            self.stream_elev_829.append(stream_elev_829)
            self.stream_elev_830.append(stream_elev_830)

    def normalize_image(self, image):
        min = torch.min(image)
        max = torch.max(image)
        return (image - min) / (max - min)

    def __getitem__(self, idx):
        #Get pre and post image, and the mask, for the current index.
        pre_image = self.pre_images[idx]
        post_image = self.post_images[idx]
        mask = self.masks[idx]

        elevation = self.elevation[idx]
        hand = self.hand[idx]
        imperviousness = self.imperviousness[idx]
        distance_coast = self.distance_coast[idx]
        distance_stream = self.distance_stream[idx]

        rain_824 = self.rain_824[idx]
        rain_825 = self.rain_825[idx]
        rain_826 = self.rain_826[idx]
        rain_827 = self.rain_827[idx]
        rain_828 = self.rain_828[idx]
        rain_829 = self.rain_829[idx]
        rain_830 = self.rain_830[idx]

        stream_elev_824 = self.stream_elev_824[idx]
        stream_elev_825 = self.stream_elev_825[idx]
        stream_elev_826 = self.stream_elev_826[idx]
        stream_elev_827 = self.stream_elev_827[idx]
        stream_elev_828 = self.stream_elev_828[idx]
        stream_elev_829 = self.stream_elev_829[idx]
        stream_elev_830 = self.stream_elev_830[idx]

        #elevation = elevation.repeat(3, 1, 1)
        #hand = hand.repeat(3, 1, 1)
        #imperviousness = imperviousness.repeat(3, 1, 1)
        #distance_coast = distance_coast.repeat(3, 1, 1)
        #distance_stream = distance_stream.repeat(3, 1, 1)

        #These are the normalization values used by the pretrained weights in DeepLabv3
        mean_normalize_rgb_channels = [0.485, 0.456, 0.406]
        std_normalize_rgb_channels = [0.229, 0.224, 0.225]

        #Average values for single channel inputs
        #mean_normalize_grayscale_channels = [sum(mean_normalize_rgb_channels) / len(mean_normalize_rgb_channels)]
        #std_normalize_grayscale_channels = [sum(std_normalize_rgb_channels) / len(std_normalize_rgb_channels)]

        image_transforms = v2.Compose([
                           v2.ToImage(),
                           v2.ToDtype(torch.float32, scale=True),
                           v2.Resize((self.image_size, self.image_size), antialias=True),
                           v2.Normalize(mean=mean_normalize_rgb_channels, std=std_normalize_rgb_channels)
        ])
        mask_transforms = v2.Compose([
                          #v2.ToImage(),
                          #v2.ToDtype(torch.int64, scale=False),
                          v2.Resize((self.image_size, self.image_size), antialias=True)
        ])
        meta_transforms = v2.Compose([
                          v2.ToImage(),
                          v2.ToDtype(torch.float32, scale=True),
                          v2.Resize((self.image_size, self.image_size), antialias=True)
        ])
        int_meta_transforms = v2.Compose([
                              v2.ToImage(),
                              v2.ToDtype(torch.int16, scale=True),
                              v2.Resize((self.image_size, self.image_size), antialias=True)
        ])

        pre_image = image_transforms(pre_image)
        post_image = image_transforms(post_image)
        mask = mask_transforms(mask)

        elevation = meta_transforms(elevation)
        hand = int_meta_transforms(hand)
        imperviousness = meta_transforms(imperviousness)
        distance_coast = meta_transforms(distance_coast)
        distance_stream = meta_transforms(distance_stream)

        rain_824 = meta_transforms(rain_824)
        rain_825 = meta_transforms(rain_825)
        rain_826 = meta_transforms(rain_826)
        rain_827 = meta_transforms(rain_827)
        rain_828 = meta_transforms(rain_828)
        rain_829 = meta_transforms(rain_829)
        rain_830 = meta_transforms(rain_830)

        stream_elev_824 = meta_transforms(stream_elev_824)
        stream_elev_825 = meta_transforms(stream_elev_825)
        stream_elev_826 = meta_transforms(stream_elev_826)
        stream_elev_827 = meta_transforms(stream_elev_827)
        stream_elev_828 = meta_transforms(stream_elev_828)
        stream_elev_829 = meta_transforms(stream_elev_829)
        stream_elev_830 = meta_transforms(stream_elev_830)

        #distance_coast = self.normalize_image(distance_coast)
        #distance_stream = self.normalize_image(distance_stream)

        if self.augment_data:
            augmentation_switches = {0, 1, 2, 3}
            augment_mode_1 = np.random.choice(list(augmentation_switches))
            augmentation_switches.remove(augment_mode_1)

            additional_augment_chance = np.random.random()
            augment_mode_2 = -1
            augment_mode_3 = -1

            if (additional_augment_chance > 0.5):
                augment_mode_2 = np.random.choice(list(augmentation_switches))
                augmentation_switches.remove(augment_mode_2)
            #if (additional_augment_chance > 0.8):
                #augment_mode_3 = np.random.choice(list(augmentation_switches))
                #augmentation_switches.remove(augment_mode_3)

            if 0 in [augment_mode_1 or augment_mode_2 or augment_mode_3]:
                # flip image vertically
                pre_image = vflip(pre_image)
                post_image = vflip(post_image)

                elevation = vflip(elevation)
                hand = vflip(hand)
                imperviousness = vflip(imperviousness)
                distance_coast = vflip(distance_coast)
                distance_stream = vflip(distance_stream)

                rain_824 = vflip(rain_824)
                rain_825 = vflip(rain_825)
                rain_826 = vflip(rain_826)
                rain_827 = vflip(rain_827)
                rain_828 = vflip(rain_828)
                rain_829 = vflip(rain_829)
                rain_830 = vflip(rain_830)

                stream_elev_824 = vflip(stream_elev_824)
                stream_elev_825 = vflip(stream_elev_825)
                stream_elev_826 = vflip(stream_elev_826)
                stream_elev_827 = vflip(stream_elev_827)
                stream_elev_828 = vflip(stream_elev_828)
                stream_elev_829 = vflip(stream_elev_829)
                stream_elev_830 = vflip(stream_elev_830)

                mask = vflip(mask)
            elif 1 in [augment_mode_1 or augment_mode_2 or augment_mode_3]:
                # flip image horizontally
                pre_image = hflip(pre_image)
                post_image = hflip(post_image)

                elevation = hflip(elevation)
                hand = hflip(hand)
                imperviousness = hflip(imperviousness)
                distance_coast = hflip(distance_coast)
                distance_stream = hflip(distance_stream)

                rain_824 = hflip(rain_824)
                rain_825 = hflip(rain_825)
                rain_826 = hflip(rain_826)
                rain_827 = hflip(rain_827)
                rain_828 = hflip(rain_828)
                rain_829 = hflip(rain_829)
                rain_830 = hflip(rain_830)

                stream_elev_824 = hflip(stream_elev_824)
                stream_elev_825 = hflip(stream_elev_825)
                stream_elev_826 = hflip(stream_elev_826)
                stream_elev_827 = hflip(stream_elev_827)
                stream_elev_828 = hflip(stream_elev_828)
                stream_elev_829 = hflip(stream_elev_829)
                stream_elev_830 = hflip(stream_elev_830)

                mask = hflip(mask)
            elif 2 in [augment_mode_1 or augment_mode_2 or augment_mode_3]:
                # crop image
                crop = v2.RandomResizedCrop(self.image_size, antialias=True)

                pre_image = crop(pre_image)
                post_image = crop(post_image)

                elevation = crop(elevation)
                hand = crop(hand)
                imperviousness = crop(imperviousness)
                distance_coast = crop(distance_coast)
                distance_stream = crop(distance_stream)

                rain_824 = crop(rain_824)
                rain_825 = crop(rain_825)
                rain_826 = crop(rain_826)
                rain_827 = crop(rain_827)
                rain_828 = crop(rain_828)
                rain_829 = crop(rain_829)
                rain_830 = crop(rain_830)

                stream_elev_824 = crop(stream_elev_824)
                stream_elev_825 = crop(stream_elev_825)
                stream_elev_826 = crop(stream_elev_826)
                stream_elev_827 = crop(stream_elev_827)
                stream_elev_828 = crop(stream_elev_828)
                stream_elev_829 = crop(stream_elev_829)
                stream_elev_830 = crop(stream_elev_830)

                mask = crop(mask)
            elif 3 in [augment_mode_1 or augment_mode_2 or augment_mode_3]:
                # rotate image
                random_degree = random.randint(1, 359)

                pre_image = rotate(pre_image, random_degree)
                post_image = rotate(post_image, random_degree)

                elevation = rotate(elevation, random_degree)
                hand = rotate(hand, random_degree)
                imperviousness = rotate(imperviousness, random_degree)
                distance_coast = rotate(distance_coast, random_degree)
                distance_stream = rotate(distance_stream, random_degree)

                rain_824 = rotate(rain_824, random_degree)
                rain_825 = rotate(rain_825, random_degree)
                rain_826 = rotate(rain_826, random_degree)
                rain_827 = rotate(rain_827, random_degree)
                rain_828 = rotate(rain_828, random_degree)
                rain_829 = rotate(rain_829, random_degree)

                stream_elev_824 = rotate(stream_elev_824, random_degree)
                stream_elev_825 = rotate(stream_elev_825, random_degree)
                stream_elev_826 = rotate(stream_elev_826, random_degree)
                stream_elev_827 = rotate(stream_elev_827, random_degree)
                stream_elev_828 = rotate(stream_elev_828, random_degree)
                stream_elev_829 = rotate(stream_elev_829, random_degree)
                stream_elev_830 = rotate(stream_elev_830, random_degree)

                mask = rotate(mask, random_degree)

        #Concatenate the pre and post disaster images, as well as the meta-attributes, together along the channel dimension.
        combined_image = torch.cat([pre_image, post_image, elevation, hand, imperviousness, distance_coast, distance_stream,
                                    rain_824, rain_825, rain_826, rain_827, rain_828, rain_829, rain_830,
                                    stream_elev_824, stream_elev_825, stream_elev_826, stream_elev_827, stream_elev_828, stream_elev_829, stream_elev_830], dim=0)
        #combined_image = torch.cat([pre_image, post_image], dim=0)
        return combined_image, mask

    def get_item_resize_only(self, idx, image_size):
        #Get pre and post image, and the mask, for the current index.
        pre_image = self.pre_images[idx]
        post_image = self.post_images[idx]
        mask = self.masks[idx]

        elevation = self.elevation[idx]
        hand = self.hand[idx]
        imperviousness = self.imperviousness[idx]

        #Convert image to normalized tensor.
        pre_image = to_tensor(pre_image)
        post_image = to_tensor(post_image)

        mask = to_tensor(mask)
        mask *= 255  # Manually adjust the label values back to the original values after the normalization of to_tensor()

        elevation = to_tensor(elevation)
        hand = to_tensor(hand)
        imperviousness = to_tensor(imperviousness)

        #Resize the images to the same size as was used during training.
        resize = v2.Compose([v2.Resize((image_size, image_size), antialias=True)])
        pre_image = resize(pre_image)
        post_image = resize(post_image)
        mask = resize(mask)

        elevation = resize(elevation)
        hand = resize(hand)
        imperviousness = resize(imperviousness)

        #Concatenate the pre and post disaster images, as well as the meta attributes, together along the channel dimension.
        combined_image = torch.cat([pre_image, post_image, elevation, imperviousness], dim=0)
        return combined_image, mask

    def __len__(self):
        return self.num_images

In [None]:
class DeepLabV3(nn.Module):
    def __init__(self, num_input_channels, num_classes):
        super(DeepLabV3, self).__init__()
        self.deeplabv3_weights = torchvision.models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT
        self.resnet50_weights = models.ResNet50_Weights.DEFAULT
        self.deeplabv3 = torchvision.models.segmentation.deeplabv3_resnet50(weights=self.deeplabv3_weights, weights_backbone=self.resnet50_weights)

        #Replaces the first convolution of the backbone of the model to accept 6-channel input.
        self.deeplabv3.backbone.conv1 = nn.Conv2d(num_input_channels, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)

        #Replaces the final classifier to change the number of output classes to 4.
        self.deeplabv3.classifier[-1] = torch.nn.Conv2d(in_channels=256, out_channels=num_classes, kernel_size=1, stride=1)

    def forward(self, x):
        x = self.deeplabv3.forward(x)
        return x

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, input, target):
        ce_loss = F.cross_entropy(input, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        if self.alpha is not None:
            focal_loss = self.alpha * focal_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [None]:
def visualize_results(num_results, predictions, images=None, masks=None, randomize_images=False):
    fig, axes = plt.subplots(num_results, 3, figsize=(32, 32))

    predictions_flat = [item for sublist in predictions for item in sublist]
    if (images != None):
        images_flat = [item for sublist in images for item in sublist]
    if (masks != None):
        masks_flat = [item for sublist in masks for item in sublist]

    if (randomize_images):
        # Choose num_results number of images at random from the results.
        image_idxs = random.sample(range(0, len(predictions_flat) - 1), num_results)
    else:
        image_idxs = [i for i in range(1, num_results + 2)]

    for i in range(num_results):
        # Plot the input image and ground truth mask
        if (images == None or masks == None):
            image, mask = test_dataset.get_item_resize_only(image_idxs[i], image_size)

            #Reorder the channels for matplotlib.
            image = torch.permute(image, (1, 2, 0))
            mask = torch.permute(mask, (1, 2, 0))

            axes[i, 0].imshow(image.numpy()[:, :, 0:3], aspect='equal')
            axes[i, 0].imshow(image.numpy()[:, :, 3:6], alpha=0.5, aspect='equal')
            #axes[i, 0].imshow(image.numpy()[:, :, 6:7], alpha=0.5, aspect='equal')
            #axes[i, 0].imshow(image.numpy()[:, :, 7:8], alpha=0.5, aspect='equal')
            axes[i, 2].imshow(mask.numpy(), cmap="viridis", aspect='equal')
        else:
            image = images_flat[image_idxs[i]]
            mask = masks_flat[image_idxs[i]]

            #Reorder the channels for matplotlib.
            image = np.transpose(image, (1, 2, 0))
            #mask = np.transpose(mask, (1, 2, 0))

            axes[i, 0].imshow(image[:, :, 0:3], aspect='equal')
            axes[i, 0].imshow(image[:, :, 3:6], alpha=0.5, aspect='equal')
            #axes[i, 0].imshow(image[:, :, 6:7], alpha=0.5, aspect='equal')
            #axes[i, 0].imshow(image[:, :, 7:8], alpha=0.5, aspect='equal')
            axes[i, 2].imshow(mask, cmap="viridis", aspect='equal')

        axes[i, 0].set_title("Combined Image")
        axes[i, 0].axis('off')

        axes[i, 2].set_title("Ground Truth Mask")
        axes[i, 2].axis('off')

        # Plot the predicted image
        axes[i, 1].imshow(predictions_flat[image_idxs[i]], cmap="viridis", aspect='equal')
        axes[i, 1].set_title("Predicted Image")
        axes[i, 1].axis('off')

    plt.show()

batch_size = 3
num_input_channels = 25
num_classes = 5
lr = 1e-4
image_size = 750  #520x520 is the image size used by the pretrained weights in DeepLabv3

train = True
test = True

# Whether the models parameters should be saved following the completion of a run.
save = True
#Whether an existing models parameters should be loaded before the run.
load = False
#Tracking the highest current F1 score and the epoch it occurred on for the model to know when to best save the model.
highest_f1 = (0, 0)

cwd = os.getcwd()

if train:
    print("Loading train images")
    train_dataset = HarveyData(os.path.join(cwd, 'drive/MyDrive/Flood Damage Extent Detection/dataset/training'), image_size=image_size, augment_data=True)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

if test:
    print("Loading test images")
    test_dataset = HarveyData(os.path.join(cwd, 'drive/MyDrive/Flood Damage Extent Detection/dataset/testing'), image_size=image_size, augment_data=False)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
print("Finished loading images")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = DeepLabV3(num_input_channels, num_classes)
if (load):
    if (os.path.exists('drive/MyDrive/Flood Damage Extent Detection/DeepLabv3.pt')):
        print("Loading model.")
        model.load_state_dict(torch.load('drive/MyDrive/Flood Damage Extent Detection/DeepLabv3.pt'))
    else:
        print('Could not load model. File does not exist.')
model.to(device)
#model_preprocess = model.deeplabv3_weights.transforms()

criterion = FocalLoss(reduction='mean')#torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0005)

#softmax = nn.Softmax(dim=1)

num_epochs = 50

images = []
masks = []
predicted_images = []

#Training
start_time = time.time()
for epoch in range(num_epochs):
    if train:
        model.train()
        epoch_loss = 0
        for i, data in enumerate(train_dataloader):
            image, mask = data

            image = image.to(device)
            mask = mask.squeeze().to(device)

            outputs = model(image)['out']

            loss = criterion(outputs, mask)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            epoch_loss += loss.item()

            print('Batch %d --- Loss: %.4f' % (i, loss.item() / batch_size))
        print('Epoch %d / %d --- Average Loss: %.4f' % (epoch + 1, num_epochs, epoch_loss / train_dataset.__len__()))

    if test:
        total_loss = 0.0

        total_macro_precision = 0.0
        total_macro_recall = 0.0
        total_macro_f1 = 0.0

        total_class_precision = [0.0, 0.0, 0.0, 0.0, 0.0]
        total_class_recall = [0.0, 0.0, 0.0, 0.0, 0.0]
        total_class_f1 = [0.0, 0.0, 0.0, 0.0, 0.0]

    #Testing
        model.eval()
        with torch.no_grad():
            for i, data in enumerate(test_dataloader):
                image, mask = data

                image = image.to(device)
                mask = mask.squeeze().to(device)

                outputs = model(image)['out']

                loss = criterion(outputs, mask)
                total_loss += loss.item()

                predicted = torch.argmax(outputs, dim=1, keepdim=False)

                image = image.cpu().numpy()
                mask = mask.cpu().numpy()
                predicted = predicted.cpu().numpy()

                for i in range(len(mask)):
                    # Calculate scores globally.
                    precision, recall, f1, _ = precision_recall_fscore_support(mask[i].flatten(), predicted[i].flatten(), average='macro', zero_division=0.0)
                    total_macro_precision += precision
                    total_macro_recall += recall
                    total_macro_f1 += f1

                    # Calculate scores by class.
                    precision, recall, f1, _ = precision_recall_fscore_support(mask[i].flatten(), predicted[i].flatten(), labels=[0, 1, 2, 3, 4], average=None, zero_division=0.0)
                    total_class_precision += precision
                    total_class_recall += recall
                    total_class_f1 += f1

                if (epoch + 1 == num_epochs):
                    images.append(image)
                    masks.append(mask)
                    predicted_images.append(predicted)

        average_loss = total_loss / len(test_dataset)

        average_macro_precision = total_macro_precision / len(test_dataset)
        average_macro_recall = total_macro_recall / len(test_dataset)
        average_macro_f1 = total_macro_f1 / len(test_dataset)

        average_class_precision = total_class_precision / len(test_dataset)
        average_class_recall = total_class_recall / len(test_dataset)
        average_class_f1 = total_class_f1 / len(test_dataset)

        print('Average Macro Precision: %.4f ---- Average Macro Recall: %.4f ---- Average F1 Score: %.4f ---- Average Loss: %.4f' % (average_macro_precision, average_macro_recall, average_macro_f1, average_loss))
        print('Average No Damage Precision: %.4f ---- Average No Damage Recall: %.4f ---- Average No Damage F1: %.4f' % (average_class_precision[0], average_class_recall[0], average_class_f1[0]))
        print('Average Minor Precision: %.4f ---- Average Minor Recall: %.4f ---- Average Minor F1: %.4f' % (average_class_precision[1], average_class_recall[1], average_class_f1[1]))
        print('Average Moderate Precision: %.4f ---- Average Moderate Recall: %.4f ---- Average Moderate F1: %.4f' % (average_class_precision[2], average_class_recall[2], average_class_f1[2]))
        print('Average Major Precision: %.4f ---- Average Major Recall: %.4f ---- Average Major F1: %.4f' % (average_class_precision[3], average_class_recall[3], average_class_f1[3]))
        print('Average Background Precision: %.4f ---- Average Background Recall: %.4f ---- Average Background F1: %.4f' % (average_class_precision[4], average_class_recall[4], average_class_f1[4]))

        #Save model
        if (save and average_macro_f1 > highest_f1[0]):
            highest_f1 = (average_macro_f1, epoch + 1)
            save_file = f'drive/MyDrive/Flood Damage Extent Detection/DeepLabv3_epoch_{highest_f1[1]}_{highest_f1[0]}.pt'
            torch.save(model.state_dict(), save_file)
            print(f'Saved Model to {save_file}')

        if (epoch + 1 == num_epochs):
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Elapsed Time at Epoch {epoch + 1} : {elapsed_time} seconds")

            if save:
                save_file = f'drive/MyDrive/Flood Damage Extent Detection/DeepLabv3_epoch_{epoch + 1}_{average_macro_f1}.pt'
                torch.save(model.state_dict(), save_file)
                print(f'Saved Model to {save_file}')

            visualize_results(6, predicted_images, images, masks)

            images.clear()
            masks.clear()
            predicted_images.clear()

            start_time = time.time()

Loading train images
Loading test images
Finished loading images


Downloading: "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth" to /root/.cache/torch/hub/checkpoints/deeplabv3_resnet50_coco-cd0a2569.pth
100%|██████████| 161M/161M [00:10<00:00, 16.8MB/s]


Batch 0 --- Loss: 0.4661
Batch 1 --- Loss: 0.4612
Batch 2 --- Loss: 0.4696
Batch 3 --- Loss: 0.4633
Batch 4 --- Loss: 0.4602
Batch 5 --- Loss: 0.4466
Batch 6 --- Loss: 0.4604
Batch 7 --- Loss: 0.4559
Batch 8 --- Loss: 0.4503
Batch 9 --- Loss: 0.4552
Batch 10 --- Loss: 0.4224
Batch 11 --- Loss: 0.4316
Batch 12 --- Loss: 0.4193
Batch 13 --- Loss: 0.4167
Batch 14 --- Loss: 0.4045
Batch 15 --- Loss: 0.4425
Batch 16 --- Loss: 0.4129
Batch 17 --- Loss: 0.4040
Batch 18 --- Loss: 0.4190
Batch 19 --- Loss: 0.3882
Batch 20 --- Loss: 0.4112
Batch 21 --- Loss: 0.3968
Batch 22 --- Loss: 0.3903
Batch 23 --- Loss: 0.3845
Batch 24 --- Loss: 0.3852
Batch 25 --- Loss: 0.3893
Batch 26 --- Loss: 0.3700
Batch 27 --- Loss: 0.3843
Batch 28 --- Loss: 0.3811
Batch 29 --- Loss: 0.3697
Batch 30 --- Loss: 0.3654
Batch 31 --- Loss: 0.3341
Batch 32 --- Loss: 0.3606
Batch 33 --- Loss: 0.3487
Batch 34 --- Loss: 0.3490
Batch 35 --- Loss: 0.3131
Batch 36 --- Loss: 0.3380
Batch 37 --- Loss: 0.3342
Batch 38 --- Loss: 0.2