In [None]:
import os
import torch
import numpy as np
import imageio
import matplotlib
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from time import time as time

from sklearn.metrics import precision_recall_fscore_support, accuracy_score, jaccard_score
import torchnet as tnt

from distutils.dir_util import copy_tree
from sklearn.model_selection import StratifiedKFold
from tqdm import notebook as tqdm

from model_definitions import KrakonosNet as Net
from imagery_handling import read_patch, classify_and_export

# GLOBAL SETTINGS
PlotSize = 12                                     # Size of plots
matplotlib.rcParams['figure.figsize'] = [PlotSize*2, PlotSize]  
CMAP = matplotlib.colors.ListedColormap(['black', 'white', 'orange'])               # Color mapping 
np.set_printoptions(precision=2, suppress=True)  # Array print precision

# PATHS TO TRAIN/TEST DATA
dataset_path = 'e:/datasets/test_unet/Krkonose2012/overlap'
num_of_tiles = len(os.listdir(os.path.join(dataset_path, 'GT')))
print(f'Number of tiles to be processed: \n{num_of_tiles}\n')

# USE CIR, RGB, PAN DATA
use_cir = False
use_rgb = False
use_pan = False

# USE multi/hyperspectral DATA (first value is a boolen similar to use_rgb etc. and second value is the number of bands)
use_mhs = (True, 6)
print(f'Total number of bands used: \n{use_cir*3 + use_rgb*3 + use_pan + use_mhs[0]*use_mhs[1]}')

# MODEL NAME... USED AS FILENAME OF SAVED MODEL AND FOR APPROPRIATE RESULTS FOLDER
model_name = 'U_Net'
model_save_folder = os.path.join(dataset_path, 'models')

In [None]:
#User defined variables

In [None]:
def read_imagery(img_dir):
    """Reads the individual imagery patches and prepares them for """
    img_file_list = os.listdir(img_dir)
    img_list = []
        
    for file in img_file_list:
        img_patch = imageio.imread(os.path.join(img_dir, file)).astype(np.float32)
        img_patch = img_patch[:,:,:].transpose([2,0,1])
        img_patch = img_patch * 1/255
            
        img_list.append(img_patch)
        del img_patch

    img_features = np.stack(img_list, axis=0)
    return img_features

def read_patch(root_folder, cir, rgb, pan, mhs, gt=True):
    """Reads data from images as floats"""
    
    if cir:
        cir_features = read_imagery(os.path.join(root_folder, 'CIR'))
    if rgb:
        rgb_features = read_imagery(os.path.join(root_folder, 'RGB'))
    if mhs[0]:
        mhs_features = read_imagery(os.path.join(root_folder, 'MHS'))

    if pan:
        pan_file_list = os.listdir(os.path.join(root_folder, 'PAN'))
        pan_list = []
        for file in pan_file_list:
            pan_patch = imageio.imread(os.path.join(root_folder, 'PAN', file)).astype(np.float32)
            pan_patch = pan_patch * 1/255
            pan_patch = np.expand_dims(pan_patch, axis=0)
            pan_list.append(pan_patch)
            del pan_patch
        pan_features = np.stack(pan_list, axis=0)


    if cir and rgb:
        features = np.concatenate([cir_features, rgb_features], axis=1)
    elif cir:
        features = cir_features
    elif rgb:
        features = rgb_features
    elif pan:
        features = pan_features
    elif mhs:
        features = mhs_features
    else:
        print('No valid data input.')
    features = torch.from_numpy(features)
    
    
    if gt:
        gt_file_list = os.listdir(os.path.join(root_folder, 'GT'))
        gt_list = []

        for file in gt_file_list:
            gt_patch = imageio.imread(os.path.join(root_folder, 'GT', file)).astype(np.int64)
            # assigns 0 to classes 3 and above
            # gt_patch[gt_patch > 2] = 0
            
            gt_list.append(gt_patch[:,:])
            del gt_patch

        ground_truth = np.stack(gt_list, axis=0)
        ground_truth = torch.from_numpy(ground_truth)
    
    if gt:
        return features, ground_truth
    else:
        return features

In [None]:
### putting the dataset into the TensorDataset wrapper
data_features, data_labels = read_patch(dataset_path, use_cir, use_rgb, use_pan, use_mhs)
class_count_patch = [len(np.unique(data_labels[i,:,:])) for i in range(data_labels.shape[0])]

print(f'Size of image data: \n{data_features.shape}\n')
print(f'Size of reference data: \n{data_labels.shape}\n')

dataset = tnt.dataset.TensorDataset([data_features, data_labels])
print(len(dataset))

In [None]:
unique, counts = np.unique(data_labels, return_counts=True)
print(f'Class labels: \n{unique}\n')
print(f'Number of pixels in a class: \n{counts}')

In [None]:
def augment(obs, g_t):
    """the data augmentation function, introduces random noise and rotation"""
    sigma, clip= 0.01, 0.03 
    #Hint: use np.clip to clip and np.random.randn to generate gaussian noise
    obs = obs + np.clip(sigma*np.random.randn(), -clip, clip).astype(np.float32).copy()

    #random rotation 0 90 180 270 degree
    n_turn = np.random.randint(4) #number of 90 degree turns, random int between 0 and 3
    obs = np.rot90(obs, n_turn, axes=(2,3)).copy()
    g_t = np.rot90(g_t, n_turn, axes=(1,2)).copy()

    obs = torch.from_numpy(obs)
    g_t = torch.from_numpy(g_t)
    
    return obs, g_t

In [None]:
def train(model, optimizer, args):
    """train for one epoch"""
    model.train() #switch the model in training mode
  
    #the loader function will take care of the batching
    loader = torch.utils.data.DataLoader(dataset, batch_size=args['batch_size'], sampler=args['train_subsampler'])
    loader = tqdm.tqdm(loader, ncols=500)
  
    #will keep track of the loss
    loss_meter = tnt.meter.AverageValueMeter()

    for index, (tiles, gt) in enumerate(loader):
    
        optimizer.zero_grad() #put gradient to zero
        tiles, gt = augment(tiles, gt)
    
        pred = model(tiles.cuda()) #compute the prediction

        loss = nn.functional.cross_entropy(pred.cpu(),gt, weight=torch.tensor(args['class_weights']))
        loss.backward() #compute gradients

        for p in model.parameters(): #we clip the gradient at norm 1
            p.grad.data.clamp_(-1, 1) #this helps learning faster
    
        optimizer.step() #one SGD step
        loss_meter.add(loss.item())
        
    return loss_meter.value()[0]

def eval(model, sampler):
    """eval on test/validation set"""
  
    model.eval() #switch in eval mode
  
    loader = torch.utils.data.DataLoader(dataset, batch_size=1, sampler=sampler)
    loader = tqdm.tqdm(loader, ncols=500)
  
    loss_meter = tnt.meter.AverageValueMeter()

    with torch.no_grad():
        for index, (tiles, gt) in enumerate(loader):
            pred = model(tiles.cuda())
            loss = nn.functional.cross_entropy(pred.cpu(),gt)
            loss_meter.add(loss.item())

    return loss_meter.value()[0]


def train_full(args):
    """The full training loop"""

    #initialize the model
    model = Net(args['n_channel'], args['conv_width'], args['dconv_width'], args['n_class'], args['cuda'])

    print(f'Total number of parameters: {sum([p.numel() for p in model.parameters()])}')
  
    #define the Adam optimizer
    optimizer = optim.Adam(model.parameters(), lr=args['lr'])
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args['scheduler_milestones'],
                                               gamma=args['scheduler_gamma'])
  
    train_loss = np.empty(args['n_epoch'])
    test_loss = np.empty(args['n_epoch']//args['n_epoch_test'] + 1)
    test_i = 0

    for i_epoch in range(args['n_epoch']):
        #train one epoch
        print(f'Epoch #{str(i_epoch+1)}')
        train_loss[i_epoch] = train(model, optimizer, args)
        scheduler.step()

        # Periodic testing on the validation set
        if (i_epoch == args['n_epoch'] - 1) or ((i_epoch + 1) % args['n_epoch_test'] == 0):
            print('Evaluation')
            loss_test = eval(model, args['test_subsampler'])
            test_loss[test_i] = loss_test
            test_i += 1
            
    plt.figure(figsize=(10, 10))
    plt.subplot(1,1,1,ylim=(0,2), xlabel='Epoch #', ylabel='Loss')
    plt.plot(range(args['n_epoch']), train_loss, label='Training loss')
    test_epochs = list(range(args['n_epoch_test'], args['n_epoch'], args['n_epoch_test']))
    test_epochs.append(args['n_epoch'])
    plt.plot(test_epochs, test_loss, label='Validation loss')
    plt.legend()
    plt.show()
    print(train_loss)
    print(test_loss)
    args['loss_test'] = test_loss[-1]
    
    return model

In [None]:
args = { #Dict to store all model parameters
    'n_channel': use_cir*3 + use_rgb*3 + use_pan + use_mhs[0]*use_mhs[1],
    'cuda': True,
    'n_class': len(unique),
    'conv_width': np.divide([64,64,128,128,256,256,512,512,1024,1024],        4).astype(np.int32),
    'dconv_width': np.divide([1024,512,512,512,256,256,256,128,128,128,64,64], 4).astype(np.int32),
    
    'crossval_nfolds': 3,
    'n_epoch_test': 2,          #periodicity of evaluation on test set
    'scheduler_milestones': [60,80,95],
    'scheduler_gamma': 0.3,
    'class_weights': [0.1, 0.1, 0.8],

    'n_epoch': 3,
    'lr': 5e-4,
    'batch_size': 1,
}

print(f'''Number of models to be trained:
    {args['crossval_nfolds']}
Number of spectral channels:
    {args['n_channel']}
Initial learning rate:
    {args['lr']}
Batch size:
    {args['batch_size']}
Number of training epochs:
    {args['n_epoch']}''')

In [None]:
kfold = StratifiedKFold(n_splits = args['crossval_nfolds'], shuffle=True)
trained_models = []
for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset, class_count_patch)):
    print(f'Training starts for model number {str(fold+1)}')
    
    a = time()
    args['train_subsampler'] = torch.utils.data.SubsetRandomSampler(train_ids)
    args['test_subsampler'] = torch.utils.data.SubsetRandomSampler(test_ids)
    
    trained_models.append((train_full(args), args['loss_test']))
    
    state_dict_path = os.path.join(model_save_folder, f'fold_{str(fold)}.pt')
    torch.save(trained_models[fold][0].state_dict(), state_dict_path)
    print(f'Model saved to: {state_dict_path}')
    print(f'Training finished in {str(time()-a)}s')
    print('\n\n')

print(f'Resulting loss for individual folds: \n{[i for _, i in trained_models]}')
print(f'Mean loss across all folds: \n{np.mean([i for _, i in trained_models])}')

## Result visualisation

In [None]:
def plot_rgb_cir_gt_pred(tile_index, data, gt, model, cir, rgb):
    # Function to plot prediction vs ground truth
    
    # Plotting
    plt.figure(facecolor='white')

    data = data[tile_index,:,:,:]
    pred = model(data[None,:,:,:].cuda()).cpu().detach().numpy()
    pred = pred[0,:,:,:].argmax(0).squeeze()
    
    unique, counts = np.unique(pred, return_counts=True)
    print(dict(zip(unique, counts)))
    
    data = data.cpu().numpy()
    
    if cir and rgb:
        plt.subplot(1, 4, 1)
        plt.imshow(data[:3].transpose([1,2,0]))
        plt.title('NIR Red Green composite')
        plt.axis('off')

        plt.subplot(1, 4, 2)
        plt.imshow(data[-3:].transpose([1,2,0]))
        plt.title('Red Green Blue composite')
        plt.axis('off')
        
        plt.subplot(1, 4, 3)
        plt.imshow(gt[tile_index,:,:], CMAP)
        plt.title('GT Labels')
        plt.axis('off')

        plt.subplot(1, 4, 4)
        plt.imshow(pred, CMAP)
        plt.title('Predicted Labels')
        plt.axis('off')
    
    elif cir or rgb:
        plt.subplot(1, 3, 1)
        plt.imshow(data.transpose([1,2,0]))
        if cir:
            plt.title('NIR Red Green composite')
        else:
            plt.title('Red Green Blue composite')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(gt[tile_index,:,:], CMAP)
        plt.title('GT Labels')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(pred, CMAP)
        plt.title('Predicted Labels')
        plt.axis('off')

In [None]:
def plot_data_ref_pred(tile_index, data, ref, model, vis_bands, vis_title='data'):
    # Plotting
    plt.figure(facecolor='white')

    data = data[tile_index,:,:,:]
    pred = model(data[None,:,:,:].cuda()).cpu().detach().numpy()
    pred = pred[0,:,:,:].argmax(0).squeeze()
    
    unique, counts = np.unique(pred, return_counts=True)
    print(dict(zip(unique, counts)))
    
    data = data.cpu().numpy()
    
    plt.subplot(1, 3, 1)
    plt.imshow(data[vis_bands].transpose([1,2,0]))
    plt.title(vis_title)
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(ref[tile_index,:,:], CMAP)
    plt.title('Reference labels')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(pred, CMAP)
    plt.title('Predicted Labels')
    plt.axis('off')

In [None]:
vis_tile = 7
model_fold = 0

try:
    model
except NameError:
    model = trained_models[model_fold][0]
    
if use_mhs:
    plot_data_ref_pred(vis_tile, data_features, data_labels, model, [3,4,5], 'RGB composite')
else:
    plot_rgb_cir_gt_pred(vis_tile, data_features, data_labels, model, use_cir, use_rgb)

## Computing accuracy metrics

In [None]:
def classify(model, data):
    """eval on test/validation set"""
  
    model.eval() #switch in eval mode
    loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False, drop_last=False)
    loader = tqdm.tqdm(loader, ncols=500)
    
    classified = np.empty_like(y_t.detach().numpy())
    
    with torch.no_grad():
        for index, (tiles, gt) in enumerate(loader):
            pred = model(tiles.cuda()).cpu().detach().numpy()
            classified[index, :, :] = pred.squeeze().argmax(0)

    return classified

In [None]:
y_t = data_labels

a = time()
Y_t = classify(model, dataset)
b = time()
print('Inferrence finished in ' + str(b-a) + ' s')

Y_t_flat = Y_t.flatten()

unique, counts = np.unique(Y_t_flat, return_counts=True)
print(unique)
print(counts)

In [None]:
background_class = 255 # class number representing the background class (used for validation)

#deletes zeroes from the reference set if there is a background class
Y_t_flat = Y_t_flat[y_t_flat != background_class]
y_t_flat = y_t_flat[y_t_flat != background_class]

In [None]:
y_t_flat = y_t.detach().numpy().flatten()

unique, counts = np.unique(y_t_flat, return_counts=True)
print(unique)
print(counts)

In [None]:
precisions, recalls, f1_scores, _ = precision_recall_fscore_support(y_t_flat, Y_t_flat, zero_division=0)
jaccard_index = jaccard_score(y_t_flat, Y_t_flat, average=None)
overall_accuracy = accuracy_score(y_t_flat, Y_t_flat)
mean_f1_score = sum(f1_scores)/len(f1_scores)
mean_iou_score = sum(jaccard_index)/len(jaccard_index)

print(f'precisions [%]:      {precisions*100}')
print(f'recalls    [%]:      {recalls*100}')
print(f'f1-scores  [%]:      {f1_scores*100}')
print(f'IoU scores [%]:      {jaccard_index*100}')
print('')
print(f'overall accuracy:    {overall_accuracy:.2%}')
print(f'mean f1 score:       {mean_f1_score:.2%}')
print(f'mean IoU score:      {mean_iou_score:.2%}')

## Saving and reusing a trained model

In [None]:
# Path to the state_dictionary
state_dict_path = 'E:\\datasets\\test_unet\\Krkonose2012\\overlap\\models\\fold_2.pt'

Save a model to state_dict_path:

In [None]:
# Save a trained model state_dictionary
torch.save(trained_model.state_dict(), state_dict_path)

Reuse a model at state_dict_path:

In [None]:
# Parameters for model definition
args = {} #stores the parameters
args['n_class'] = 3
args['n_channel'] = 6 # 6 if use_cir and use_rgb else 3
args['conv_width'] =  np.divide([64,64,128,128,256,256,512,512,1024,1024],        4).astype(np.int32)
args['dconv_width'] = np.divide([1024,512,512,512,256,256,256,128,128,128,64,64], 4).astype(np.int32)
args['cuda'] = True

# Load a trained model state_dictionary
model = Net(args['n_channel'], args['conv_width'], args['dconv_width'], args['n_class'], args['cuda'])
model.load_state_dict(torch.load(state_dict_path))
model.eval()

## Export results
Results are not georeferenced – use ArcPy_georeference_results.py for georeferencing and combining into a single raster

In [None]:
source_path = 'E:\\datasets\\test_unet\\Krkonose2012\\overlap'
results_path = os.path.join(source_path, 'results')

In [None]:
in_features = read_patch(source_path, use_cir, use_rgb, use_pan, use_mhs, gt=False)
print(in_features.shape)

In [None]:
if use_rgb:
    filenames = os.listdir(os.path.join(source_path, 'RGB'))
elif use_cir:
    filenames = os.listdir(os.path.join(source_path, 'CIR'))
elif use_pan:
    filenames = os.listdir(os.path.join(source_path, 'PAN'))
elif use_mhs[0]:
    filenames = os.listdir(os.path.join(source_path, 'MHS'))
else:
    print('no input files')

In [None]:
def classify_and_export(model_b, in_features_b, results_path_b, fnames):
    for i, patch in enumerate(fnames):
        in_patch = in_features_b[i,:,:,:]
        pred = model_b(in_patch[None,:,:,:].cuda()).cpu().detach().numpy()
        pred = pred[0,:,:,:].argmax(0).squeeze()

        imageio.imwrite(os.path.join(results_path_b, patch), pred.astype(np.uint8))

In [None]:
classify_and_export(model, in_features, results_path, filenames)