# IMPORTS

In [None]:
from osgeo import gdal
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import traceback
from torch.utils.data import Dataset, DataLoader
import skimage as sk
import rasterio
import gc
import os
from PIL import Image
from numpy import random
import cv2

import torchvision
from torchvision import transforms

from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import CyclicLR
from torch.optim.lr_scheduler import OneCycleLR

import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian

from sklearn.metrics import precision_recall_fscore_support, cohen_kappa_score, jaccard_score, accuracy_score

In [None]:
TRAINING_IMAGE_PATH = ""
LOADED_IMAGE_SAVE_PATH = "loaded_ims.npy"
LOADED_LABEL_SAVE_PATH = "loaded_labels.npy"
MODEL_SAVE_PATH = ""
MODEL_LOAD_PATH = ""
TEST_IMAGE_SAVE_PATH = ""

# 1. Data preparation

Loads .TIF images and their classified label from a specified directory given in `root_directory`. The image and the label must be in the same folder and the if the image's name is `image.tif`, then the label should have the name `image_classified.tif`.

In [None]:


def load_tif_images_from_directory(root_directory):
    tif_images = []
    tif_labels = []

    for root, dirs, files in os.walk(root_directory):
        for file in files:
            if file.lower().endswith('.tif') and not file.lower().endswith('classified.tif'):
                file_path = os.path.join(root, file)
                try:

                    ds = gdal.Open(file_path, gdal.GA_ReadOnly)

                    rows = ds.RasterYSize
                    cols = ds.RasterXSize
                    bands = ds.RasterCount
                    array = ds.ReadAsArray().astype(dtype="float32")

                    array = np.stack(array, axis=2)

                    array = np.reshape(array, [rows, cols, bands])
                    array = np.transpose(array, (2,0,1))
                    max = array.max(axis=(1,2), keepdims=True)
                    array = array/max                  
                    label_path = file_path[:-4] + "_classified.tif"                   
                    label_img = rasterio.open(label_path)
                    label_img = label_img.read(1).astype('f4')
                    label_img = np.asarray(label_img)
                     
                    tif_labels.append(label_img)
                    tif_images.append(array)
                    
                except Exception as e:
                    print(e)
                    ""           
    return tif_images, tif_labels




Pads the input `array` to the specified `target_shape`. If the given `array` is a label image (it has only two dimensions e.g (128,128)), then it applies the padding to every dimension, if `array` is 3-dimensional, then the we don't have to apply the padding to the channel dimension.

In [None]:
def pad_to_shape(array, target_shape):
    if len(array.shape) == 2:
        padding = []
        for current_size, target_size in zip(array.shape, target_shape):
            pad_total = max(target_size - current_size, 0)
            pad_before = pad_total // 2
            pad_after = pad_total - pad_before
            padding.append((pad_before, pad_after))
        
    elif len(array.shape) == 3: 
        padding = [(0, 0)]
        for current_size, target_size in zip(array.shape[1:], target_shape):
            pad_total = max(target_size - current_size, 0)
            pad_before = pad_total // 2
            pad_after = pad_total - pad_before
            padding.append((pad_before, pad_after))
    
    else:
        print("invalid input")

    padded_array = np.pad(array, padding, mode='constant', constant_values=0)
    
    return padded_array


Using `sk.measure.regionprops` (https://scikit-image.org/docs/stable/auto_examples/segmentation/plot_regionprops.html) we cut out a 128x128 shaped patch from the original image, centered around the center of the bounding box of a labeled object. If this patch would be smaller than the desired size, we use the previous `pad_to_shape` function. The number of patches which contain no waste can be specified using `no_waste_counter`

In [None]:
ims, labels = load_tif_images_from_directory(TRAINING_IMAGE_PATH)
cut_ims = []
cut_labels = []
no_waste_counter = 0
for i in range(len(labels)):
    labeled_img = sk.measure.label(labels[i])
    regions = sk.measure.regionprops(labeled_img)

    for props in regions:
        minr, minc, maxr, maxc = props.bbox
        center = np.array(((minr + maxr)//2,(minc + maxc)//2))
        start_x = center[0]-64
        start_y = center[1]-64
        end_x = (start_x + 128)
        end_y = (start_y + 128)
        if start_x < 0:
            start_x = 0
        if start_y < 0:
            start_y = 0
        if end_x > ims[i].shape[1]:
            end_x = ims[i].shape[1]
        if end_y > ims[i].shape[2]:
            end_y = ims[i].shape[2]



        cut = ims[i][:,start_x:end_x, start_y:end_y]
        cut_label = labels[i][start_x:end_x, start_y:end_y]
        if cut.shape != (4,128,128):
            cut = pad_to_shape(cut, (128, 128))
            cut_label = pad_to_shape(cut_label, (128,128))
        cut_label_filtered = np.copy(cut_label)
        cut_label_filtered[cut_label != 100] = 0
        cut_label_filtered[cut_label == 100] = 1
        if not np.isin(1, cut_label_filtered) and no_waste_counter < 850:
                cut_ims.append(cut)
                cut_labels.append(cut_label_filtered)
                no_waste_counter += 1
        elif np.isin(1, cut_label_filtered):
            cut_ims.append(cut)
            cut_labels.append(cut_label_filtered)


cut_ims = np.stack(cut_ims,axis=0)
cut_labels = np.stack(cut_labels, axis=0)



We randomly permutate the images and labels so that the images with no waste, and the images with waste are more evenly distributed

In [None]:
perm = np.random.permutation(cut_ims.shape[0])
cut_ims = cut_ims[perm,:,:,:]
cut_labels = cut_labels[perm,:,:]
np.save(LOADED_IMAGE_SAVE_PATH,cut_ims)
np.save(LOADED_LABEL_SAVE_PATH, cut_labels)

In [None]:
cut_ims = np.load(LOADED_IMAGE_SAVE_PATH)
cut_labels = np.load(LOADED_LABEL_SAVE_PATH)


The `WasteTransforms` class transforms the images and the labels, using the same randomly generated seed to get the same results. Some color distortion can be applied to the images, but not the labels.

In [None]:
class WasteTransforms:
    def __init__(self):
        self.image_transform = transforms.Compose([
            transforms.RandomAffine(degrees=12, translate=(0.2, 0.2), scale=(0.85, 1.15)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            # transforms.ColorJitter(brightness=0.1, hue=0.1, contrast=0.1, saturation=0.1)
        ])

        self.label_transform = transforms.Compose([
            transforms.RandomAffine(degrees=12, translate=(0.2, 0.2), scale=(0.85, 1.15)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
        ])

    def __call__(self, image, label):
        seed = torch.randint(0, 2**32, (1,)).item()
        
        torch.manual_seed(seed)
        transformed_image = self.image_transform(image)
        
        torch.manual_seed(seed)
        transformed_label = self.label_transform(label)
        
        return transformed_image, transformed_label


The `WasteDataSet` class must define the functions `__init__`, `__getitem__`, and `__lenn__` from the `torch.utils.data.Dataset` class. We convert the `numpy.ndarrays` to `torch.tensors` so that the neural networks can work with them. In `__getitem__`, we apply the transformations to the images and labels (only on the training dataset). During training, we iterate through the image using `torch.utils.data.DataLoader`

In [None]:
class WasteDataSet(Dataset):
    def __init__(self, ims, labels ,transforms):
        self.ims = torch.tensor(ims.astype(np.float32), dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
        self.transforms = transforms
    
    def __len__(self):
        return len(self.ims)
    
    def __getitem__(self, idx):
        
        image = self.ims[idx]
        label = self.labels[idx]
        label = label.unsqueeze(0)
        if self.transforms is not None:
            image, label = self.transforms(image, label)
        label = label.squeeze()
        return image, label
    
split1 = round(len(cut_ims)*0.6)
split2 = round(len(cut_ims)*.99)

train_transforms = WasteTransforms()

train_set = WasteDataSet(cut_ims[:split1], cut_labels[:split1], transforms= train_transforms)
val_set = WasteDataSet(cut_ims[split1:split2], cut_labels[split1:split2], transforms=None)
test_set = WasteDataSet(cut_ims[split2:], cut_labels[split2:],  transforms=None)

train_dataloader = DataLoader(train_set, batch_size=24, shuffle=True)
val_dataloader = DataLoader(val_set, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_set, batch_size= 8, shuffle=True, drop_last=True)
        

## 2. Architecture

These functions are used to align the output of different layers, due to the `//2` sometimes they are 1 pixel smaller/larger than another layer's output.

In [None]:
def align_layers(layers):
    
    layer_heights = []
    layer_widths = []
    rtn = []
    for layer in layers:
        layer_heights.append(layer.size(2))
        layer_widths.append(layer.size(3))
    target_height = min(layer_heights)
    target_width = min(layer_widths)
    for layer, layer_height, layer_width in zip(layers, layer_heights, layer_widths):
        diff_y = (layer_height - target_height) // 2
        diff_x = (layer_width - target_width) // 2
        rtn.append(layer[:, :, diff_y:(diff_y + target_height), diff_x:(diff_x + target_width)])
    torch.cuda.empty_cache()
    gc.collect()
    return torch.cat(rtn, dim = 1)

def center_crop(layer, target_height, target_width):
    _, _, layer_height, layer_width = layer.size()
    diff_y = (layer_height - target_height) // 2
    diff_x = (layer_width - target_width) // 2
    return layer[:, :, diff_y:(diff_y + target_height), diff_x:(diff_x + target_width)]

The building blocks of the `UNET` and `UNETPP` networks.

In [None]:
class ConvolutionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
        )
    
    def forward(self, x):
        return self.conv_layers(x)
        

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = ConvolutionBlock(in_channels, out_channels)
        self.mpool = nn.MaxPool2d((2,2))
    
    def forward(self, x):
        
        skip = self.conv(x)
        out = self.mpool(skip)
        return skip, out
    
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
        self.conv = ConvolutionBlock(out_channels + out_channels, out_channels)
    
    
    
    def forward(self, x, skip):
        out = self.up_conv(x)
        if out.size(2) != skip.size(2) or out.size(3) != skip.size(3):
            skip = center_crop(skip, out.size(2), out.size(3))
        
        out = torch.cat([out, skip], axis=1)
        out = self.conv(out)
        return out

class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels): 
        super().__init__()               
        self.up_conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride= 2),
            nn.Dropout(0.2),
            )
        
    def forward(self, x):
        return self.up_conv(x)

The `UNET` class defines the architecture that is described in the paper https://arxiv.org/pdf/1505.04597. Number of in and out channels can be given.

In [None]:
class UNET(nn.Module):
    def __init__(self, in_channels = 4,out_channels = 1):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.e1 = EncoderBlock(in_channels, 64)
        self.e2 = EncoderBlock(64, 128)
        self.e3 = EncoderBlock(128,256)
        self.e4 = EncoderBlock(256, 512)
        
        self.b = ConvolutionBlock(512, 1024)
        
        self.d1 = DecoderBlock(1024, 512)
        self.d2 = DecoderBlock(512,256)
        self.d3 = DecoderBlock(256, 128)
        self.d4 = DecoderBlock(128,64)
        self.output = nn.Conv2d(64, out_channels, kernel_size=1, padding=0)
        
    def forward(self, x):
        skip1, out = self.e1(x)
        skip2, out = self.e2(out)
        skip3, out = self.e3(out)
        skip4, out = self.e4(out)
        
        out = self.b(out)
        out = self.d1(out, skip4)
        out = self.d2(out, skip3)
        out = self.d3(out, skip2)
        out = self.d4(out, skip1)
        out = torch.squeeze(out)
        out = self.output(out)
        
        return out
    def predict(self, x):
        out = self.forward(x)
        if self.out_channels == 1:
            out = torch.sigmoid(out)  
        elif self.out_channels > 1:
            out = torch.softmax(out) 
        return out

The `UNETPP` class defines the architecture that is described in the paper https://arxiv.org/pdf/1912.05074. It can be given a pretrained `UNET` as a backbone, with an option to freeze the pretrained weights. Deep supervision can also be enabled/disabled 

In [None]:
class UNETPP(nn.Module):
    def __init__(self,pretrained_unet = None, freeze_weights = False ,deep_vision=False, in_channels = 4, out_channels = 1):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pretrained_unet = pretrained_unet
        self.deep_vision = deep_vision
        self.upsamp = nn.Upsample(scale_factor=2,mode='bilinear' , align_corners=True)
        if self.pretrained_unet is not None:    
            self.conv0_0 = self.pretrained_unet.e1
            self.conv1_0 = self.pretrained_unet.e2
            self.conv2_0 = self.pretrained_unet.e3
            self.conv3_0 = self.pretrained_unet.e4
            self.conv4_0 = EncoderBlock(512, 1024)
            if freeze_weights:
                for param in self.conv0_0.parameters():
                    param.requires_grad = False
                for param in self.conv1_0.parameters():
                    param.requires_grad = False
                for param in self.conv2_0.parameters():
                    param.requires_grad = False
                for param in self.conv3_0.parameters():
                    param.requires_grad = False
        else:
            self.conv0_0 = EncoderBlock(in_channels, 64)
            self.conv1_0 = EncoderBlock(64, 128)
            self.conv2_0 = EncoderBlock(128, 256)
            self.conv3_0 = EncoderBlock(256, 512)
            self.conv4_0 = EncoderBlock(512, 1024)

        self.conv0_1 = ConvolutionBlock(64 + 128, 64)
        self.conv0_2 = ConvolutionBlock(2*64 + 128, 64)
        self.conv0_3 = ConvolutionBlock(3*64 + 128, 64)
        self.conv0_4 = ConvolutionBlock(4*64 + 128, 64)
        
        self.conv1_1 = ConvolutionBlock(128 + 256, 128)
        self.conv1_2 = ConvolutionBlock(2*128 + 256, 128)
        self.conv1_3 = ConvolutionBlock(3*128 + 256, 128)
        
        self.conv2_1 = ConvolutionBlock(256 + 512, 256)
        self.conv2_2 = ConvolutionBlock(2*256 + 512, 256)
        
        self.conv3_1 = ConvolutionBlock(512 + 1024, 512)

        self.up_conv1_0 = UpConv(128, 128)
        self.up_conv2_0 = UpConv(256, 256)
        self.up_conv3_0 = UpConv(512, 512)
        
        self.up_conv1_1 = UpConv(128, 128)
        self.up_conv2_1 = UpConv(256, 256)
        self.up_conv1_2 = UpConv(128, 128)

        self.up_conv4_0 = UpConv(1024, 1024)
        self.up_conv3_1 = UpConv(512, 512)
        self.up_conv2_2 = UpConv(256, 256)
        self.up_conv1_3 = UpConv(128, 128)
        
        if self.deep_vision:
            self.deep1 = nn.Conv2d(64, out_channels, kernel_size=1)
            self.deep2 = nn.Conv2d(64, out_channels, kernel_size=1)
            self.deep3 = nn.Conv2d(64, out_channels, kernel_size=1)
            self.deep4 = nn.Conv2d(64, out_channels, kernel_size=1)
        else:
            self.deep = nn.Conv2d(64, out_channels, kernel_size=1)
    

    
    def forward(self, x):
        skip0_0, x0_0 = self.conv0_0(x)
        skip1_0, x1_0 = self.conv1_0(x0_0)
        skip2_0, x2_0 = self.conv2_0(x1_0)
        skip3_0, x3_0 = self.conv3_0(x2_0)
        skip4_0, x4_0 = self.conv4_0(x3_0)
        
    
        x0_1 = self.conv0_1(torch.cat([skip0_0, self.up_conv1_0(skip1_0)], dim=1))
        x1_1 = self.conv1_1(torch.cat([skip1_0, self.up_conv2_0(skip2_0)], dim=1))
        x0_2 = self.conv0_2(torch.cat([skip0_0, x0_1, self.up_conv1_1(x1_1)], dim=1))
        
        x2_1 = self.conv2_1(torch.cat([skip2_0, self.up_conv3_0(skip3_0)], dim=1))
        x1_2 = self.conv1_2(torch.cat([skip1_0, x1_1, self.up_conv2_1(x2_1)], dim=1))
        x0_3 = self.conv0_3(torch.cat([skip0_0, x0_1, x0_2, self.up_conv1_2(x1_2)], dim=1))
        
        x3_1 = self.conv3_1(torch.cat([skip3_0, self.up_conv4_0(skip4_0)], dim=1))
        x2_2 = self.conv2_2(torch.cat([skip2_0, x2_1, self.up_conv3_1(x3_1)], dim=1))
        x1_3 = self.conv1_3(torch.cat([skip1_0, x1_1, x1_2, self.up_conv2_2(x2_2)], dim=1))
        x0_4 = self.conv0_4(torch.cat([skip0_0, x0_1, x0_2, x0_3, self.up_conv1_3(x1_3)], dim=1))
        
             
        
        if self.deep_vision:
            out1 = self.deep1(x0_1)
            out2 = self.deep2(x0_2)
            out3 = self.deep3(x0_3)
            out4 = self.deep4(x0_4)
            return [out1, out2, out3, out4]
        else:
            out = self.deep(x0_4)    
        return out
    
    def predict(self, x):
        skip0_0, x0_0 = self.conv0_0(x)
        skip1_0, x1_0 = self.conv1_0(x0_0)
        skip2_0, x2_0 = self.conv2_0(x1_0)
        skip3_0, x3_0 = self.conv3_0(x2_0)
        skip4_0, x4_0 = self.conv4_0(x3_0)
        
        x0_1 = self.conv0_1(align_layers([skip0_0, self.up_conv1_0(skip1_0)]))
        x1_1 = self.conv1_1(align_layers([skip1_0, self.up_conv2_0(skip2_0)]))
        x0_2 = self.conv0_2(align_layers([skip0_0, x0_1, self.up_conv1_1(x1_1)]))
        x2_1 = self.conv2_1(align_layers([skip2_0, self.up_conv3_0(skip3_0)]))
        x1_2 = self.conv1_2(align_layers([skip1_0, x1_1, self.up_conv2_1(x2_1)]))
        x0_3 = self.conv0_3(align_layers([skip0_0, x0_1, x0_2, self.up_conv1_2(x1_2)]))
        x3_1 = self.conv3_1(align_layers([skip3_0, self.up_conv4_0(skip4_0)]))
        x2_2 = self.conv2_2(align_layers([skip2_0, x2_1, self.up_conv3_1(x3_1)]))
        x1_3 = self.conv1_3(align_layers([skip1_0, x1_1, x1_2, self.up_conv2_2(x2_2)]))
        x0_4 = self.conv0_4(align_layers([skip0_0, x0_1, x0_2, x0_3, self.up_conv1_3(x1_3)]))
        
             
        
        if self.deep_vision:
            out1 = self.deep1(x0_1)
            out2 = self.deep2(x0_2)
            out3 = self.deep3(x0_3)
            out4 = self.deep4(x0_4)
            x = out4.size(2)
            y = out4.size(3)
            out1 = center_crop(out1, x, y)
            out2 = center_crop(out2, x, y)
            out3 = center_crop(out3, x, y)
            out = out1 + out2 + out3 + out4
            out = out / 4
        else:
            out = self.deep(x0_4)
            
        if self.out_channels == 1:
            out = torch.sigmoid(out)  
        elif self.out_channels > 1:
            out = torch.softmax(out)
        return out  


Classes for `R2UNET` and `R2AttentionUNET` (https://arxiv.org/pdf/1802.06955, https://arxiv.org/pdf/1804.03999), might not be the best suited for this task, but interesting nonetheless.

In [None]:
class RecurrentBlock(nn.Module):
    def __init__(self, out_channel, t = 2):
        super().__init__()
        self.t = t
        self.out_channel = out_channel
        self.conv_layer = nn.Sequential(
            nn.Conv2d(self.out_channel, self.out_channel, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),

        )
        
    def forward(self, x):
        for i in range(self.t):
            if i == 0:
                x1 = self.conv_layer(x)
                
            x1 = self.conv_layer(x+x1)
        return x1 
class RecurrentResidualBlock(nn.Module):
    def __init__(self,in_channels, out_channels, t = 2, dropout = 0.1):
        super().__init__()
        self.rec_res_layer = nn.Sequential(
            RecurrentBlock(out_channel=out_channels, t = t),
            RecurrentBlock(out_channel=out_channels, t = t),
            nn.Dropout2d(dropout)
        )
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0)
        
    def forward(self, x):
        x = self.conv(x)
        out = self.rec_res_layer(x)
        return x + out
    
class UpSampleConv(nn.Module):
    def __init__(self,in_channels, out_channels):
        super().__init__()
        self.up_sample = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        out = self.up_sample(x)
        return out
        
        
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, padding=0, bias= True),
            nn.BatchNorm2d(F_int),
            nn.Dropout2d(0.1)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int),
            nn.Dropout2d(0.1)
            
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid(),
            
        )
        self.relu = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Dropout2d(0.5)  
            )
        
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        x1 = center_crop(x1, g1.size(2), g1.size(3))
        psi = self.psi(g1 + x1)
        psi = self.relu(psi)
        x = center_crop(x, psi.size(2), psi.size(3))
        return x*psi
    
class R2UNET(nn.Module):
    def __init__(self, in_channels = 4, out_channels = 1, t = 2):
        super().__init__()
        
        self.rec_res_layer1 = RecurrentResidualBlock(in_channels, 64, t=t)
        self.rec_res_layer2 = RecurrentResidualBlock(64, 128, t=t)
        self.rec_res_layer3 = RecurrentResidualBlock(128, 256, t=t)
        self.rec_res_layer4 = RecurrentResidualBlock(256, 512, t=t, dropout=0.3)
        self.rec_res_layer5 = RecurrentResidualBlock(512, 1024, t=t, dropout=0.3)
        
        self.up1 = UpSampleConv(1024, 512)
        self.up2 = UpSampleConv(512, 256)
        self.up3 = UpSampleConv(256, 128)
        self.up4 = UpSampleConv(128, 64)
        
        self.up_rec_res_layer1 = RecurrentResidualBlock(1024, 512, t=t, dropout=0.3)
        self.up_rec_res_layer2 = RecurrentResidualBlock(512, 256, t=t, dropout=0.3)
        self.up_rec_res_layer3 = RecurrentResidualBlock(256, 128, t=t)
        self.up_rec_res_layer4 = RecurrentResidualBlock(128, 64, t=t)
        
        self.out = nn.Conv2d(64, out_channels=out_channels, kernel_size=1) 
        
        self.max = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsample = nn.Upsample(scale_factor=2)
        
        
    def forward(self, x):
        gc.collect()
        torch.cuda.empty_cache()
        x1 = self.rec_res_layer1(x)
        x2 = self.max(x1)
        x2 = self.rec_res_layer2(x2)
        x3 = self.max(x2)
        x3 = self.rec_res_layer3(x3)
        x4 = self.max(x3)
        x4 = self.rec_res_layer4(x4)
        x5 =self.max(x4)
        x5 = self.rec_res_layer5(x5)
        
        d1 = self.up1(x5)
        d1 = torch.cat((x4, d1), dim = 1)
        d1 = self.up_rec_res_layer1(d1)
        
        d2 = self.up2(d1)
        d2 = torch.cat((x3, d2), dim = 1)
        d2 = self.up_rec_res_layer2(d2)
        
        d3 = self.up3(d2)
        d3 = torch.cat((x2, d3), dim = 1)
        d3 = self.up_rec_res_layer3(d3)
        
        d4 = self.up4(d3)
        d4 = torch.cat((x1, d4), dim = 1)
        d4 = self.up_rec_res_layer4(d4)
        
        out = self.out(d4)
        
        return out
    
    

In [None]:
class R2AttentionUNET(nn.Module):
    def __init__(self, in_channels= 4, out_channels = 1, t = 2):
        super().__init__()
        self.max = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsamp = nn.Upsample(scale_factor=2)
        
        self.rec_res_layer1 = RecurrentResidualBlock(in_channels,64,t)
        self.rec_res_layer2 = RecurrentResidualBlock(64, 128, t)
        self.rec_res_layer3 = RecurrentResidualBlock(128, 256, t)
        self.rec_res_layer4 = RecurrentResidualBlock(256, 512, t, dropout=0.3)
        self.rec_res_layer5 = RecurrentResidualBlock(512, 1024, t, dropout=0.3)
    
        self.up1 = UpSampleConv(1024, 512)
        self.up2 = UpSampleConv(512, 256)
        self.up3 = UpSampleConv(256, 128)
        self.up4 = UpSampleConv(128, 64)

        self.att1 = AttentionBlock(512, 512, 256)
        self.att2 = AttentionBlock(256, 256, 128)
        self.att3 = AttentionBlock(128, 128, 64)
        self.att4 = AttentionBlock(64,64,32)
        
        self.up_rec_res_layer1 = RecurrentResidualBlock(1024, 512, t, dropout=0.3)
        self.up_rec_res_layer2 = RecurrentResidualBlock(512, 256, t)
        self.up_rec_res_layer3 = RecurrentResidualBlock(256, 128, t)
        self.up_rec_res_layer4 = RecurrentResidualBlock(128,64, t)
        
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)
        
        
    def forward(self,x):
        x1 = self.rec_res_layer1(x)

        x2 = self.max(x1)
        x2 = self.rec_res_layer2(x2)

        x3 = self.max(x2)
        x3 =  self.rec_res_layer3(x3)

        x4 = self.max(x3)
        x4 = self.rec_res_layer4(x4)
        
        x5 = self.max(x4)
        x5 = self.rec_res_layer5(x5)

        d1 = self.up1(x5)
        x4 = self.att1(g=d1, x=x4)
        d1 = torch.cat((x4,d1), dim=1)
        d1 = self.up_rec_res_layer1(d1)
    
        d2 = self.up2(d1)
        x3 = self.att2(g=d2, x=x3)
        d2 = torch.cat((x3,d2), dim=1)
        d2 = self.up_rec_res_layer2(d2)
        
        d3 = self.up3(d2)
        x2 = self.att3(g=d3, x=x2)
        d3 = torch.cat((x2,d3), dim=1)
        d3 = self.up_rec_res_layer3(d3)
        
        d4 = self.up4(d3)
        x1 = self.att4(g=d4, x=x1)
        d4 = torch.cat((x1,d4), dim=1)
        d4 = self.up_rec_res_layer4(d4)
        
        out = self.out(d4)
        return out
    def center_crop(self, layers):
        
        layer_heights = []
        layer_widths = []
        rtn = []
        for layer in layers:
            layer_heights.append(layer.size(2))
            layer_widths.append(layer.size(3))
        target_height = min(layer_heights)
        target_width = min(layer_widths)
        for layer, layer_height, layer_width in zip(layers, layer_heights, layer_widths):
            diff_y = (layer_height - target_height) // 2
            diff_x = (layer_width - target_width) // 2
            rtn.append(layer[:, :, diff_y:(diff_y + target_height), diff_x:(diff_x + target_width)])
        torch.cuda.empty_cache()
        gc.collect()
        return torch.cat(rtn, dim = 1)
    
    def predict(self,x):
        x1 = self.rec_res_layer1(x)

        x2 = self.max(x1)
        x2 = self.rec_res_layer2(x2)

        x3 = self.max(x2)
        x3 =  self.rec_res_layer3(x3)

        x4 = self.max(x3)
        x4 = self.rec_res_layer4(x4)

        x5 = self.max(x4)
        x5 = self.rec_res_layer5(x5)

        d1 = self.up1(x5)
        x4 = self.att1(g=d1, x=x4)
        d1 = self.center_crop((x4,d1))
        d1 = self.up_rec_res_layer1(d1)
    
        d2 = self.up2(d1)
        x3 = self.att2(g=d2, x=x3)
        d2 = self.center_crop((x3,d2))
        d2 = self.up_rec_res_layer2(d2)
        
        d3 = self.up3(d2)
        x2 = self.att3(g=d3, x=x2)
        d3 = self.center_crop((x2,d3))
        d3 = self.up_rec_res_layer3(d3)
        
        d4 = self.up4(d3)
        x1 = self.att4(g=d4, x=x1)
        d4 = self.center_crop((x1,d4))
        d4 = self.up_rec_res_layer4(d4)
        out = self.out(d4)
        out = torch.sigmoid(out)
        return out
    


## 3. Training the model

Some loss functions designed to handle the class imbalance that is present in the task. The function `torch.nn.BCEWithLogitsLoss` with `pos_weight` can also be used for this purpose.

In [None]:

class FocalLoss(nn.Module):
    def __init__(self, alpha, gamma):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    def forward(self, predicition, target):
        p = torch.sigmoid(predicition)
        ce_loss = F.binary_cross_entropy_with_logits(predicition, target, reduction="none")
        p_t = p * target + (1 - p) * (1 - target)
        loss = ce_loss * ((1 - p_t) ** self.gamma)
        alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)
        loss = alpha_t * loss 
        return loss.mean()
class BCEDiceLoss(nn.Module):
    def __init__(self, gamma, pos_weight):
        super().__init__()
        self.gamma = gamma
        self.pos_weight = pos_weight
        self.epsilon = 1e-8
        
    def forward(self, inputs, labels):
        inputs = torch.flatten(inputs)

        labels = torch.flatten(labels)
        bce_loss = F.binary_cross_entropy_with_logits(input=inputs, target=labels, pos_weight=self.pos_weight)
        inputs = torch.sigmoid(inputs)
        intersect = (inputs*labels).sum()
        dice_loss = 1-(2.*intersect + self.epsilon) / (inputs.sum() + labels.sum() + self.epsilon)
        
        return self.gamma*dice_loss + (1-self.gamma)*bce_loss
    

The `train` function trains the neural network. It uses the DataLoaders defined earlier to iterate through the images. It needs an `optimizer` and `scheduler`, a `loss_function` that is suited for the task, the number of epochs (`n_epochs`) for training. `delta` can be used together with `patience` for early stopping, and to only save the models which have shown improvement. If the batch size would be to small due to GPU memory limitations, `accumulation_steps` can make it so that we only step with the optimizer after `accumulation_Steps` number of batches have been processed, practically increasing the batch size. `deep_vision` and `pretrained` are for UNET++ models.

In [None]:
def train(model, device,train_dataloader, val_dataloader, optimizer, scheduler,loss_function ,n_epochs=50, delta=0.01, patience=10, accumulation_steps = 1, deep_vision = 0, pretrained = 0):
    model.to(device)
    train_losses = []
    val_losses = []
    best_loss = np.inf
    no_improvement_count = 0
    print("training began")
    try:
        for epoch in range(n_epochs):
            model.train()
            train_epoch_losses = []
            for i, (x, y) in enumerate(train_dataloader):
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad() if i % accumulation_steps == 0 else None
                if not deep_vision:
                    y_hat = model(x)
                    y_hat = y_hat.squeeze()
                    y = y.squeeze()
                    loss = loss_function(y_hat, y)
                else:
                    y_hat = model(x)
                    loss = 0
                    for y_h in y_hat:
                        y = y.squeeze()
                        y_h = y_h.squeeze()
                        loss += loss_function(y_h,y)
                    loss /= 4

                train_epoch_losses.append(loss)

                loss.backward()

                if (i + 1) % accumulation_steps == 0:
                    optimizer.step() 
                    optimizer.zero_grad()  

            train_epoch_losses = torch.tensor(train_epoch_losses)
            avg_epoch_loss = train_epoch_losses.mean()
            train_losses.append(avg_epoch_loss)

            model.eval()
            val_epoch_losses = []
            for x, y in val_dataloader:
                x, y = x.to(device), y.to(device)
                with torch.no_grad():
                    y_hat = model(x)
                if not deep_vision:
                    y_hat = y_hat.squeeze()
                    y = y.squeeze()
                    loss = loss_function(y_hat, y)
                else:    
                    loss = 0
                    for y_h in y_hat:
                        y = y.squeeze()
                        y_h = y_h.squeeze()
                        loss += loss_function(y_h,y)
                    loss /= 4

                loss = loss.mean()
                val_epoch_losses.append(loss)

            val_epoch_losses = torch.tensor(val_epoch_losses)
            avg_epoch_loss = val_epoch_losses.mean()
            val_losses.append(avg_epoch_loss)
            scheduler.step()


            print(f'Tranining {epoch+1}/{n_epochs} done, training loss: {train_losses[-1]}, validation loss: {val_losses[-1]}')


            if avg_epoch_loss + delta < best_loss:
                best_loss = avg_epoch_loss

                path = MODEL_SAVE_PATH + (epoch+1).__str__() + '.sav'

                torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'deep_vision': deep_vision,
            'pretrained': pretrained

                }, path)
                no_improvement_count = 0

            if avg_epoch_loss < delta:
                delta /= 10
            if patience < no_improvement_count:
                print("early stopping")
                return train_losses, val_losses
            else:
                no_improvement_count += 1

        return train_losses, val_losses
    except KeyboardInterrupt as e:
        print("Training interrupted by user")
        return None, None
    except Exception as e:
        print("Error occured during training:")
        traceback.print_exc()
        return None, None


Here we define the learning rate, optimizer, and scheduler, and loss_function for the training process, and start the training.

In [None]:
print(gc.collect())
torch.cuda.empty_cache()

u_model = UNET()

lr = 0.0005
n_epoch = 50
patience = 10
delta = 0.01
accumulation_steps = 1
is_pretrained = 0
is_deep_supervision = 0

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, u_model.parameters()), lr, weight_decay=1e-5)
scheduler = OneCycleLR(optimizer,max_lr=0.001, steps_per_epoch=len(train_dataloader), epochs=50)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
weight = torch.tensor([15.0]).to(device)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=weight)

train_losses, val_losses  =  train(model=u_model, device = device, train_dataloader = train_dataloader,val_dataloader= val_dataloader,
                                   optimizer= optimizer, scheduler=scheduler, loss_function=loss_fn,
                                   n_epochs=n_epoch, delta=delta, patience=patience, 
                                   accumulation_steps=accumulation_steps, pretrained=is_pretrained,deep_vision=is_deep_supervision)
if train_losses is not None:
    plt.plot(train_losses, label='Training')
    plt.plot(val_losses, label='Validation')
    plt.xlabel('Epochs')
    plt.ylabel('CrossEntropy')
    plt.legend()
    plt.show()


## 4. Evaluation

Evaluates the model on the test set. It saves a figure which shows the original image, the continuous prediction, a binary classification where the treshold is 0.5 probability, and the ground truth. it also calculates some metrics for numerical evaluation.

In [None]:
checkpoint = torch.load(MODEL_LOAD_PATH)
u_model.load_state_dict(checkpoint['model_state_dict'])


i = 0
j = 0
u_model.eval()

precision = []
recall = []
focus = []
support = []
kappa = []
jaccard = []
accuracy = []


for x, y in test_dataloader:
    with torch.no_grad():   
        j += 1
        pred = u_model.predict(x)
        curr_pred = pred
        curr_pred = curr_pred.squeeze()  
        print(pred.shape)  
        for i in range(curr_pred.shape[0]):         
            display = curr_pred[i].data.cpu().numpy()
            display2 = np.copy(display)
            display2[display > 0.5] = 1
            display2[display <= 0.5] = 0
            input_image = x[i].data.cpu().numpy()
            input_image = np.transpose(input_image, (1,2,0))
            input_image = np.ascontiguousarray(input_image)
            f, ax = plt.subplots(2,2)
            ax[0,0].imshow(input_image[:,:,:3])
            ax[1,1].imshow(display)
            ax[1,0].imshow(display2)
            ax[0,1].imshow(y[i][:,:].data.cpu().numpy())
            f.savefig(TEST_IMAGE_SAVE_PATH  +j.__str__() + "_" + i.__str__())
            plt.close(f)
            treshold = 0.5
            cpu_y = y[i].data.cpu().numpy()
            prediction = np.zeros(cpu_y.shape)
            prediction[display > treshold] = 1
            prediction = prediction.astype(int)
            cpu_y = cpu_y.astype(int)
            prediction = prediction.flatten()
            cpu_y = cpu_y.flatten()
            p,r,fs,s = precision_recall_fscore_support(cpu_y, prediction, zero_division=0)
            epoch_jaccard = jaccard_score(cpu_y, prediction, zero_division=0)
            epoch_kappa = cohen_kappa_score(cpu_y, prediction)
            epoch_accuracy = accuracy_score(cpu_y, prediction)
            if not np.isnan(epoch_kappa.item()):
                kappa.append(epoch_kappa)
            precision.append(p.mean())
            recall.append(r.mean())
            focus.append(fs.mean())
            support.append(s)
            jaccard.append(epoch_jaccard)
            accuracy.append(epoch_accuracy)  
precision = np.asarray(precision)
recall = np.asarray(recall)
focus = np.asarray(focus)

kappa = np.asarray(kappa)
jaccard = np.asarray(jaccard)
accuracy = np.asarray(accuracy)


Writes the calculated metrics to a txt file.

In [None]:

f = open("metrics.txt", "a")
f.write("precision: " + np.mean(precision).__str__())
f.write("\n")
f.write("recall: " + np.mean(recall).__str__())        
f.write("\n")    
f.write("focus: " + np.mean(focus).__str__())
f.write("\n")
f.write("kappa: " + np.mean(kappa).__str__())
f.write("\n")
f.write("jaccard: " +  np.mean(jaccard).__str__())
f.write("\n")
f.write("accuracy: " + np.mean(accuracy).__str__())
f.close() 

Reads the images from Drina, Raho, and Kiskore.

In [None]:
def drina_images(root_directory):
    tif_images = []
    paths = []
    for root, dirs, files in os.walk(root_directory):
        for file in files:
            if file.lower().endswith('.tif') and not file.lower().endswith('classified.tif') and 'udm2' not in file.lower():
                file_path = os.path.join(root, file)
                try:
                    img = rasterio.open(file_path)
                    
                    red = img.read(1)
                    green = img.read(2)
                    blue = img.read(3)
                    NIR = img.read(4)

                    red_max = np.percentile(red, 99.9)
                    green_max = np.percentile(green, 99.9)
                    blue_max = np.percentile(blue, 99.9)
                    nir_max = np.percentile(NIR, 99.9)
                    
                    
                    
                    rgb = [red.astype('f4')/red_max, green.astype('f4')/green_max, blue.astype('f4')/blue_max, NIR.astype('f4')/nir_max]
                    rgb = np.asarray(rgb)


                    tif_images.append(rgb)
                    paths.append(file_path[:-4])
                    
                except Exception as e:
                    print(e)
                    ""
                    
    return tif_images, paths


Creates a list of model names, their paths, and a save directory for the processed images.

In [None]:
model_paths = []
model_names = []
save_dir = []

MODEL_FOLDER = ""
SERVER_IMAGE_PATH = ""
for root, dirs, files in os.walk(MODEL_FOLDER):
        for file in files:
            if file.lower().endswith('.sav'):
                file_path = os.path.join(root, file)
                save_dir.append(root + "\\test_output\\")
                model_names.append(file[:-4])
                model_paths.append(file_path)
print(save_dir)
print(model_paths)

Evaluates the models on the images of Drina, Raho, and Kiskore, and saves a `.tif` image next to the original image with the model name, and also saves a figure containing the original image, the continuous prediction, and the classification image in the `save_dir` directory.

In [None]:
images, paths = drina_images(SERVER_IMAGE_PATH)
import tifffile

for j in range(len(model_names)):
    u_model = UNETPP().cuda()
    checkpoint = torch.load(model_paths[j])
    u_model.load_state_dict(checkpoint['model_state_dict'])
    u_model.eval()
    print(len(paths))
    for i in range(len(images)):
        img = np.asarray(images[i])
        img = torch.tensor(img, dtype=torch.float32, device=torch.device("cuda:0"))
        img = img.unsqueeze(0)
        with torch.no_grad():
            pred = u_model.predict(img)

        display = pred.data.cpu().numpy()
        display = display.squeeze()
        treshold = 0.9
        display2 = np.copy(display)
        display2[display > treshold] = 1
        display2[display <= treshold] = 0
        img_cpu = img.data.cpu().numpy()
        img_cpu = np.squeeze(img_cpu)
        img_cpu = np.transpose(img_cpu, (1,2,0))
        f, ax = plt.subplots(1,3)

        ax[0].imshow(img_cpu[:,:,:3])
        ax[1].imshow(display)
        ax[2].imshow(display2)
        tifffile.imwrite(paths[i] + "_" + model_names[j]  + "_" + "_classified.tif", display)
        f.savefig(save_dir[j] +  model_names[j] + "_" +i.__str__() + ".png")
        plt.close(f)
    