In [None]:
import os
import torch
import numpy as np
import imageio
import matplotlib
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from time import time as time

from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import torchnet as tnt
import functools

import mock
from tqdm import notebook as tqdm

from distutils.dir_util import copy_tree


# GLOBAL SETTINGS
PlotSize = 12                                     # Size of plots
matplotlib.rcParams['figure.figsize'] = [PlotSize*2, PlotSize]  
CMAP = matplotlib.colors.ListedColormap(['black', 'white', 'orange'])               # Color mapping 
np.set_printoptions(precision=2, suppress=True)  # Array print precision

# CLASS AND FEATURE DESCRIPTION
class_names = ['BACKGRD','PINUS','PICEA']

# PATHS TO TRAIN/TEST DATA
data_path = '../data/jeseniky/blackandwhite/'
training_set_path = data_path + 'train/'         # Relative path to training patch root folder
test_set_path =     data_path + 'test/'          # Relative path to test patch root folder

num_of_training_tiles = len(os.listdir(training_set_path + 'GT/'))
num_of_test_tiles = len(os.listdir(test_set_path + 'GT/'))

# USE CIR OR RGB DATA
use_cir = False
use_rgb = False
use_pan = True

# MODEL NAME... USED AS FILENAME OF SAVED MODEL AND FOR APPROPRIATE RESULTS FOLDER
model_name = 'U_Net'

In [None]:
def read_patch(root_folder, cir, rgb, pan, gt=True):
    ##########################################################
    # READ IMAGES as FLOAT
    
    if cir:
        cir_file_list = os.listdir(root_folder + 'CIR/')
        cir_list = []
        
        for file in cir_file_list:
            cir_patch = imageio.imread(root_folder + 'CIR/' + file).astype(np.float32)
            cir_patch = cir_patch[:,:,:].transpose([2,0,1])
            cir_patch = cir_patch * 1/255
            
            cir_list.append(cir_patch)
            del cir_patch

        cir_features = np.stack(cir_list, axis=0)    
    
    if rgb:
        rgb_file_list = os.listdir(root_folder + 'RGB/')
        rgb_list = []
        
        for file in rgb_file_list:
            rgb_patch = imageio.imread(root_folder + 'RGB/' + file).astype(np.float32)
            rgb_patch = rgb_patch[:,:,:].transpose([2,0,1])
            rgb_patch = rgb_patch * 1/255
            
            rgb_list.append(rgb_patch)
            
            del rgb_patch
        
        rgb_features = np.stack(rgb_list, axis=0)

    if pan:
        pan_file_list = os.listdir(root_folder + 'PAN/')
        pan_list = []
        
        for file in pan_file_list:
            pan_patch = imageio.imread(root_folder + 'PAN/' + file).astype(np.float32)
            pan_patch = pan_patch * 1/255
            pan_patch = np.expand_dims(pan_patch, axis=0)
            
            pan_list.append(pan_patch)
            
            del pan_patch
        
        pan_features = np.stack(pan_list, axis=0)

        
    if cir and rgb:
        features = np.concatenate([cir_features, rgb_features], axis=1)
    elif cir:
        features = cir_features
    elif rgb:
        features = rgb_features
    elif pan:
        features = pan_features
    else:
        print('No valid data input.')
    features = torch.from_numpy(features)
    
    
    if gt:
        gt_file_list = os.listdir(root_folder + 'GT/')
        gt_list = []

        for file in gt_file_list:
            gt_patch = imageio.imread(root_folder + 'GT/' + file).astype(np.int64)
 
            gt_list.append(gt_patch[:,:])
            del gt_patch

        ground_truth = np.stack(gt_list, axis=0)
        ground_truth = torch.from_numpy(ground_truth)
    
    if gt:
        return features, ground_truth
    else:
        return features

In [None]:
### putting the dataset into the TensorDataset wrapper
X, y = read_patch(training_set_path, use_cir, use_rgb, use_pan)
X_t, y_t = read_patch(test_set_path, use_cir, use_rgb, use_pan)

print(X.shape)
print(X_t.shape)


train_set = tnt.dataset.TensorDataset(list([X, y]))
test_set  = tnt.dataset.TensorDataset(list([X_t, y_t]))
print(len(train_set))

In [None]:
class UNet(nn.Module):
    """
    U-Net network for semantic segmentation
    """
  
    def __init__(self, n_channels, encoder_conv_width, decoder_conv_width, n_class, cuda):
        """
        initialization function
        n_channels, int, number of input channel
        encoder_conv_width, int list, size of the feature maps of convs for the encoder
        decoder_conv_width, int list, size of the feature maps of convs for the decoder
        n_class = int,  the number of classes
        """
        super(UNet, self).__init__() #necessary for all classes extending the module class
    
        self.maxpool=nn.MaxPool2d(2,2,return_indices=False) #maxpooling layer
    
        #encoder
        self.c1 = nn.Sequential(nn.Conv2d(n_channels,encoder_conv_width[0],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c2 = nn.Sequential(nn.Conv2d(encoder_conv_width[0],encoder_conv_width[1],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c3 = nn.Sequential(nn.Conv2d(encoder_conv_width[1],encoder_conv_width[2],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c4 = nn.Sequential(nn.Conv2d(encoder_conv_width[2],encoder_conv_width[3],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c5 = nn.Sequential(nn.Conv2d(encoder_conv_width[3],encoder_conv_width[4],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c6 = nn.Sequential(nn.Conv2d(encoder_conv_width[4],encoder_conv_width[5],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c7 = nn.Sequential(nn.Conv2d(encoder_conv_width[5],encoder_conv_width[6],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c8 = nn.Sequential(nn.Conv2d(encoder_conv_width[6],encoder_conv_width[7],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c9 = nn.Sequential(nn.Conv2d(encoder_conv_width[7],encoder_conv_width[8],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c10 = nn.Sequential(nn.Conv2d(encoder_conv_width[8],encoder_conv_width[9],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        #decoder
        self.c11 = nn.ConvTranspose2d(encoder_conv_width[9], int(decoder_conv_width[0]/2),kernel_size=2, stride=2)
        self.c12 = nn.Sequential(nn.Conv2d(decoder_conv_width[0],decoder_conv_width[1],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c13 = nn.Sequential(nn.Conv2d(decoder_conv_width[1],decoder_conv_width[2],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c14 = nn.ConvTranspose2d(decoder_conv_width[2], int(decoder_conv_width[3]/2),kernel_size=2, stride=2)
        self.c15 = nn.Sequential(nn.Conv2d(decoder_conv_width[3],decoder_conv_width[4],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c16 = nn.Sequential(nn.Conv2d(decoder_conv_width[4],decoder_conv_width[5],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c17 = nn.ConvTranspose2d(decoder_conv_width[5], int(decoder_conv_width[6]/2),kernel_size=2, stride=2)
        self.c18 = nn.Sequential(nn.Conv2d(decoder_conv_width[6],decoder_conv_width[7],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c19 = nn.Sequential(nn.Conv2d(decoder_conv_width[7],decoder_conv_width[8],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c20 = nn.ConvTranspose2d(decoder_conv_width[8], int(decoder_conv_width[9]/2),kernel_size=2, stride=2)
        self.c21 = nn.Sequential(nn.Conv2d(decoder_conv_width[9],decoder_conv_width[10],3,padding=1, padding_mode='reflect'),nn.ReLU(True))
        self.c22 = nn.Sequential(nn.Conv2d(decoder_conv_width[10],decoder_conv_width[11],3,padding=1, padding_mode='reflect'),nn.ReLU(True)) 
        
        #final classifying layer
        self.classifier=nn.Conv2d(decoder_conv_width[11],n_class,1,padding=0)

        #weight initialization

        self.c1[0].apply(self.init_weights)
        self.c2[0].apply(self.init_weights)
        self.c3[0].apply(self.init_weights)
        self.c4[0].apply(self.init_weights)
        self.c5[0].apply(self.init_weights)
        self.c6[0].apply(self.init_weights)
        self.c7[0].apply(self.init_weights)
        self.c8[0].apply(self.init_weights)
        self.c9[0].apply(self.init_weights)
        self.c10[0].apply(self.init_weights)
        
        self.c12[0].apply(self.init_weights)
        self.c13[0].apply(self.init_weights)
        
        self.c15[0].apply(self.init_weights)
        self.c16[0].apply(self.init_weights)
        
        self.c18[0].apply(self.init_weights)
        self.c19[0].apply(self.init_weights)
        
        self.c21[0].apply(self.init_weights)
        self.c22[0].apply(self.init_weights)
        self.classifier.apply(self.init_weights)
    
        if cuda: #put the model on the GPU memory
            self.cuda()
    
    def init_weights(self,layer): #gaussian init for the conv layers
        nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
    
    def forward(self,input):
        """
        the function called to run inference
        """  
        #encoder
        #level 1
        x1 = self.c2(self.c1(input))
        x2 = self.maxpool(x1)
        #level 2
        x3 = self.c4(self.c3(x2))
        x4 = self.maxpool(x3)
        #level 3
        x5 = self.c6(self.c5(x4))
        x6 = self.maxpool(x5)
        #Level 4
        x7 = self.c8(self.c7(x6))
        x8 = self.maxpool(x7)
        #Level 5
        x9 = self.c10(self.c9(x8))
        #decoder
        #Level 4
        y8 = torch.cat((self.c11(x9),x7),1)
        y7 = self.c13(self.c12(y8))
        #Level 3
        y6 = torch.cat((self.c14(y7),x5),1)
        y5 = self.c16(self.c15(y6))
        #level 2
        y4 = torch.cat((self.c17(y5),x3),1)
        y3 = self.c19(self.c18(y4))
        #level 1       
        y2 = torch.cat((self.c20(y3),x1),1)
        y1 = self.c22(self.c21(y2))
        #output         
        out = self.classifier(y1)
    
        return out

In [None]:
def augment(obs, g_t):
    """the augmentation function
    do not change until you reach Q14
    """
    sigma, clip= 0.01, 0.03 
    #Hint: use np.clip to clip and np.random.randn to generate gaussian noise
    obs = obs + np.clip(sigma*np.random.randn(), -clip, clip).astype(np.float32).copy()

    #random rotation 0 90 180 270 degree
    n_turn = np.random.randint(4) #number of 90 degree truens, random int between 0 and 3
    obs = np.rot90(obs, n_turn, axes=(2,3)).copy()
    g_t = np.rot90(g_t, n_turn, axes=(1,2)).copy()

    obs = torch.from_numpy(obs)
    g_t = torch.from_numpy(g_t)
    
    return obs, g_t

In [None]:
def train(model, optimizer, args):
    """train for one epoch"""
    model.train() #switch the model in training mode
  
    #the loader function will take care of the batching
    loader = torch.utils.data.DataLoader(train_set, \
         batch_size=args.batch_size, shuffle=True, drop_last=True)
    loader = tqdm.tqdm(loader, ncols=500)
  
    #will keep track of the loss
    loss_meter = tnt.meter.AverageValueMeter()

    for index, (tiles, gt) in enumerate(loader):
    
        optimizer.zero_grad() #put gradient to zero
        
        tiles, gt = augment(tiles, gt)
    
        pred = model(tiles.cuda()) #compute the prediction

        loss = nn.functional.cross_entropy(pred.cpu(),gt, weight=torch.tensor(args.class_weights))

        loss.backward() #compute gradients

        for p in model.parameters(): #we clip the gradient at norm 1
            p.grad.data.clamp_(-1, 1) #this helps learning faster
    
        optimizer.step() #one SGD step
    
        loss_meter.add(loss.item())
        
    return loss_meter.value()[0]

def eval(model, args):
    """eval on test/validation set"""
  
    model.eval() #switch in eval mode
  
    loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, drop_last=False)
    loader = tqdm.tqdm(loader, ncols=500)
  
    loss_meter = tnt.meter.AverageValueMeter()

    with torch.no_grad():
        for index, (tiles, gt) in enumerate(loader):
            pred = model(tiles.cuda())
            loss = nn.functional.cross_entropy(pred.cpu(),gt)
            loss_meter.add(loss.item())

    return loss_meter.value()[0]


def train_full(args):
    """The full training loop"""

    #initialize the model
    model = UNet(args.n_channel, args.conv_width, args.dconv_width, args.n_class, args.cuda)

    print('Total number of parameters: {}'.format(sum([p.numel() for p in model.parameters()])))
  
    #define the optimizer
    #adam optimizer is always a good guess for classification
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[40,60], gamma=0.1)
  
    TESTCOLOR = '\033[104m'
    NORMALCOLOR = '\033[0m'
  
    train_loss = np.empty(args.n_epoch)
    test_loss = np.empty(args.n_epoch//args.n_epoch_test)
    test_i = 0

    for i_epoch in range(args.n_epoch):
        #train one epoch
        print('Epoch ' + str(i_epoch))
        loss_train = train(model, optimizer, args)
        scheduler.step()
        train_loss[i_epoch] = loss_train

        if (i_epoch == args.n_epoch - 1) or (args.n_epoch_test != 0 and i_epoch % args.n_epoch_test == 0 and i_epoch > 0):
            #periodic testing
            print(TESTCOLOR)
            print('Evaluation')
            loss_test = eval(model, args)
            test_loss[test_i] = loss_test
            test_i += 1

    plt.figure(figsize=(10, 10))
    plt.subplot(1,1,1,ylim=(0,2), xlabel='Epoch #', ylabel='Loss')
    plt.plot(range(args.n_epoch), train_loss)
    plt.plot(range(args.n_epoch_test-1, args.n_epoch, args.n_epoch_test), test_loss)
    plt.show()
    print(train_loss)
    print(test_loss)
    
    return model

In [None]:
args = mock.Mock() #stores the parameters
args.n_epoch = 100
args.n_epoch_test = int(5) #periodicity of evaluation on test set
args.batch_size = 2
args.n_class = len(class_names)
args.n_channel = 1 # 6 if use_cir and use_rgb else 3
args.conv_width = [64,64,128,128,256,256,512,512,1024,1024]
args.dconv_width = [1024,512,512,512,256,256,256,128,128,128,64,64]
args.class_weights = [0.2, 0.2, 0.6]
args.cuda = True
args.lr = 1e-3

In [None]:
a = time()
trained_model = train_full(args)
b = time()

print('Training finished in ' + str(b-a) + 's')

In [None]:
# training multiple models
learning_rates = [1e-4, 1e-3]
trained_models = []
for i in learning_rates:
    args.lr = i
    print('Learning rate for this run is ' + str(i))
    a = time()
    trained_models.append(train_full(args))
    b = time()
    print('Training finished in ' + str(b-a) + 's')

## Result visualisation

In [None]:
def plot_rgb_cir_gt_pred(tile_index, data, gt, model, cir, rgb):
    # Function to plot prediction vs ground truth
    
    # Plotting
    plt.figure(facecolor='white')

    data = data[tile_index,:,:,:]
    pred = model(data[None,:,:,:].cuda()).cpu().detach().numpy()
    pred = pred[0,:,:,:].argmax(0).squeeze()
    
    unique, counts = np.unique(pred, return_counts=True)
    print(dict(zip(unique, counts)))
    
    data = data.cpu().numpy()
    
    if cir and rgb:
        plt.subplot(1, 4, 1)
        plt.imshow(data[:3].transpose([1,2,0]))
        plt.title('NIR Red Green composite')
        plt.axis('off')

        plt.subplot(1, 4, 2)
        plt.imshow(data[-3:].transpose([1,2,0]))
        plt.title('Red Green Blue composite')
        plt.axis('off')
        
        plt.subplot(1, 4, 3)
        plt.imshow(gt[tile_index,:,:], CMAP)
        plt.title('GT Labels')
        plt.axis('off')

        plt.subplot(1, 4, 4)
        plt.imshow(pred, CMAP)
        plt.title('Predicted Labels')
        plt.axis('off')
    
    elif cir or rgb:
        plt.subplot(1, 3, 1)
        plt.imshow(data.transpose([1,2,0]))
        if cir:
            plt.title('NIR Red Green composite')
        else:
            plt.title('Red Green Blue composite')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(gt[tile_index,:,:], CMAP)
        plt.title('GT Labels')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(pred, CMAP)
        plt.title('Predicted Labels')
        plt.axis('off')

In [None]:
plot_rgb_cir_gt_pred(750, X_t, y_t, trained_model, use_cir, use_rgb)

## Computing accuracy metrics

In [None]:
def classify(model, args):
    """eval on test/validation set"""
  
    model.eval() #switch in eval mode
    loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, drop_last=False)
    loader = tqdm.tqdm(loader, ncols=500)
    
    classified = np.empty_like(y_t.detach().numpy())
    
    with torch.no_grad():
        for index, (tiles, gt) in enumerate(loader):
            pred = model(tiles.cuda()).cpu().detach().numpy()
            classified[index, :, :] = pred.squeeze().argmax(0)

    return classified

In [None]:
a = time()
Y_t = classify(trained_model, args)
b = time()
print('Inferrence finished in ' + str(b-a) + ' s')

Y_t_flat = Y_t.flatten()

unique, counts = np.unique(Y_t_flat, return_counts=True)
print(unique)
print(counts)

In [None]:
# write the results for all test tiles into one file, used later for visualisation
imageio.imwrite(data_path + 'results/U_Net.tif', Y_t_flat.reshape(2,int(Y_t_flat.shape[0]/2)).astype(np.uint8), bigtiff=True)

In [None]:
y_t_flat = y_t.detach().numpy().flatten()

unique, counts = np.unique(y_t_flat, return_counts=True)
print(unique)
print(counts)

In [None]:
# Compute accuracy metrics
precisions, recalls, f1_scores, supports = precision_recall_fscore_support(y_t_flat, Y_t_flat)
overall_accuracy = accuracy_score(y_t_flat, Y_t_flat)
mean_f1_score = sum(f1_scores)/len(f1_scores)

print('precisions [%]:      ', precisions*100)
print('recalls    [%]:      ', recalls*100)
print('f1_scores  [%]:      ', f1_scores*100)
print('')
print('overall accuracy: {:.2%}'.format(overall_accuracy))
print('mean f1 score:    {:.2%}'.format(mean_f1_score))

## Saving and reusing a trained model

In [None]:
# Path to the state_dictionary
state_dict_path = 'trained_models/U_Net_1989_1e-4.pt'

Save a model to state_dict_path:

In [None]:
# Save a trained model state_dictionary
torch.save(trained_models[0].state_dict(), state_dict_path)

Reuse a model at state_dict_path:

In [None]:
# Parameters for model definition
args = mock.Mock() #stores the parameters

args.n_class = len(class_names)
args.n_channel = 1 # 6 if use_cir and use_rgb else 3
args.conv_width = [64,64,128,128,256,256,512,512,1024,1024]
args.dconv_width = [1024,512,512,512,256,256,256,128,128,128,64,64]
args.cuda = True

In [None]:
# Load a trained model state_dictionary
model = UNet(args.n_channel, args.conv_width, args.dconv_width, args.n_class, args.cuda)
model.load_state_dict(torch.load(state_dict_path))
model.eval()

In [None]:
plot_rgb_cir_gt_pred(8, X_t, y_t, model, use_cir, use_rgb)

## Export results
Results are not georeferenced – use ArcPy_georeference_results.py for georeferencing and combining into a single raster

In [None]:
source_path =  '../data/jeseniky/blackandwhite/'
results_path = source_path + 'results/'

In [None]:
in_features = read_patch(source_path, use_cir, use_rgb, use_pan, gt=False)

In [None]:
print(in_features.shape)

In [None]:
if use_rgb:
    copy_tree(source_path + 'RGB/', results_path, update=1)
elif use_cir:
    copy_tree(source_path + 'CIR/', results_path, update=1)
elif use_pan:
    copy_tree(source_path + 'PAN/', results_path, update=1)
else:
    print('no input files')

In [None]:
def classify_and_export(model_b, in_features_b, results_path_b):
    i = 0
    for patch in os.listdir(results_path_b):
        in_patch = in_features_b[i,:,:,:]
        pred = model_b(in_patch[None,:,:,:].cuda()).cpu().detach().numpy()
        pred = pred[0,:,:,:].argmax(0).squeeze()

        imageio.imwrite(results_path_b + patch, pred.astype(np.uint8))
        i+=1

In [None]:
a = time()
classify_and_export(model, in_features, results_path)
b = time()

print('Classification finished in ' + str(b-a) + 's')

### Export in bulk

In [None]:
source_path = '../data/2012/eastern/overlap/'
source_list = os.listdir(source_path)
results_path = '../data/2012/eastern/results/overlap/'

In [None]:
print(os.listdir('../data/2012/western/overlap/'))

In [None]:
a = time()
for source_name in source_list:
    in_features = read_patch(source_path + source_name + '/', use_cir, use_rgb, gt=False)
    
    if use_rgb:
        copy_tree(source_path + source_name + '/RGB/', results_path + source_name, update=1)
    elif use_cir:
        copy_tree(source_path + source_name + '/CIR/', results_path + source_name, update=1)
    else:
        print('no input files')
    
    classify_and_export(model, in_features, results_path + source_name + '/')
    del in_features
    print('Finished classifing of ' + results_path + source_name)
b = time()
print('This took ' + str(b-a) + ' s')