In [None]:
import os
import torch
import numpy as np
import imageio
import matplotlib
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from osgeo import gdal
from time import perf_counter

from sklearn.metrics import precision_recall_fscore_support, accuracy_score, jaccard_score
import torchnet as tnt

from distutils.dir_util import copy_tree
from sklearn.model_selection import KFold, StratifiedKFold
from tqdm import notebook as tqdm

from model_definitions import SpectroSpatialNet
import image_preprocessing
import inference_utils

# GLOBAL SETTINGS
PlotSize = 12                                     # Size of plots
matplotlib.rcParams['figure.figsize'] = [PlotSize*2, PlotSize]  
CMAP = matplotlib.colors.ListedColormap(['black', 'white', 'orange'])               # Color mapping 
np.set_printoptions(precision=2, suppress=True)  # Array print precision

# PATHS TO TRAINING DATA
trainingdata_path = r'C:\Users\dd\Documents\NATUR_CUNI\additional_projects\prace\ETRAINEE\data\LH_202008_54pasem_9cm.tif'
referencedata_path = r'C:\Users\dd\Documents\NATUR_CUNI\additional_projects\prace\ETRAINEE\data\LH_trenovaci_2019a2020_rasterized.tif'

In [None]:
loaded_raster = image_preprocessing.read_gdal(trainingdata_path, referencedata_path)
print(f'Tiled imagery shape {loaded_raster["imagery"].shape}')
print(f'Tiled reference shape {loaded_raster["reference"].shape}')

In [None]:
tile_shape = (256, 256)
overlap = 128
offset = (0, 0)

dataset_tiles = image_preprocessing.tile_training(loaded_raster, tile_shape, overlap, offset)
print(f'Tiled imagery shape {dataset_tiles["imagery"].shape}')
print(f'Tiled reference shape {dataset_tiles["reference"].shape}')

In [None]:
filtered_tiles = image_preprocessing.filter_useful_tiles(dataset_tiles, nodata_vals=[65535], is_training=True)
print(f'Filtered imagery shape {filtered_tiles["imagery"].shape}')
print(f'Filtered reference shape {filtered_tiles["reference"].shape}')

In [None]:
preprocessed_tiles, unique, counts = image_preprocessing.normalize_tiles_3d(filtered_tiles, nodata_vals=[65535], is_training=True)
print(f'Preprocessed imagery shape {preprocessed_tiles["imagery"].shape}')
print(f'Preprocessed reference shape {preprocessed_tiles["reference"].shape}')

dataset = tnt.dataset.TensorDataset([preprocessed_tiles['imagery'], preprocessed_tiles['reference']])
print(dataset)

print(f'Class labels: \n{unique}\n')
print(f'Number of pixels in a class: \n{counts}')

In [None]:
def augment(obs, g_t):
    """the data augmentation function, introduces random noise and rotation"""
    sigma, clip= 0.01, 0.03 
    #Hint: use np.clip to clip and np.random.randn to generate gaussian noise
    obs = obs + np.clip(sigma*np.random.randn(), -clip, clip).astype(np.float32).copy()

    #random rotation 0 90 180 270 degree
    n_turn = np.random.randint(4) #number of 90 degree turns, random int between 0 and 3
    obs = np.rot90(obs, n_turn, axes=(2,3)).copy()
    g_t = np.rot90(g_t, n_turn, axes=(1,2)).copy()

    obs = torch.from_numpy(obs)
    g_t = torch.from_numpy(g_t)
    
    return obs, g_t

In [None]:
def train(model, optimizer, args):
    """train for one epoch"""
    model.train() #switch the model in training mode
  
    #the loader function will take care of the batching
    loader = torch.utils.data.DataLoader(dataset, batch_size=args['batch_size'], sampler=args['train_subsampler'])
    loader = tqdm.tqdm(loader, ncols=500)
  
    #will keep track of the loss
    loss_meter = tnt.meter.AverageValueMeter()

    for index, (tiles, gt) in enumerate(loader):
    
        optimizer.zero_grad() #put gradient to zero

        #tiles, gt = augment(tiles, gt)

        pred = model(tiles.cuda()) #compute the prediction

        loss = nn.functional.cross_entropy(pred.cpu(),gt, weight=args['class_weights'])
        loss.backward() #compute gradients

        for p in model.parameters(): #we clip the gradient at norm 1
            p.grad.data.clamp_(-1, 1) #this helps learning faster

        optimizer.step() #one SGD step
        loss_meter.add(loss.item())
        
    return loss_meter.value()[0]

def eval(model, sampler):
    """eval on test/validation set"""
  
    model.eval() #switch in eval mode
  
    loader = torch.utils.data.DataLoader(dataset, batch_size=1, sampler=sampler)
    loader = tqdm.tqdm(loader, ncols=500)
  
    loss_meter = tnt.meter.AverageValueMeter()

    with torch.no_grad():
        for index, (tiles, gt) in enumerate(loader):
            pred = model(tiles.cuda())
            loss = nn.functional.cross_entropy(pred.cpu(), gt)
            loss_meter.add(loss.item())

    return loss_meter.value()[0]


def train_full(args):
    """The full training loop"""

    #initialize the model
    model = SpectroSpatialNet(args)

    print(f'Total number of parameters: {sum([p.numel() for p in model.parameters()])}')
  
    #define the Adam optimizer
    optimizer = optim.Adam(model.parameters(), lr=args['lr'])
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args['scheduler_milestones'],
                                               gamma=args['scheduler_gamma'])
  
    train_loss = np.empty(args['n_epoch'])
    test_epochs = []
    test_loss = []

    for i_epoch in range(args['n_epoch']):
        #train one epoch
        print(f'Epoch #{str(i_epoch+1)}')
        train_loss[i_epoch] = train(model, optimizer, args)
        scheduler.step()

        # Periodic testing on the validation set
        if (i_epoch == args['n_epoch'] - 1) or ((i_epoch + 1) % args['n_epoch_test'] == 0):
            print('Evaluation')
            loss_test = eval(model, args['test_subsampler'])
            test_epochs.append(i_epoch + 1)
            test_loss.append(loss_test)
            
    plt.figure(figsize=(10, 10))
    plt.subplot(1,1,1,ylim=(0,5), xlabel='Epoch #', ylabel='Loss')
    plt.plot([i+1 for i in range(args['n_epoch'])], train_loss, label='Training loss')
    plt.plot(test_epochs, test_loss, label='Validation loss')
    plt.legend()
    plt.show()
    print(train_loss)
    print(test_loss)
    args['loss_test'] = test_loss[-1]
    
    return model

## Training a 3D network

In [None]:
args = { #Dict to store all model parameters
    'n_channel': 1,
    'cuda': True,
    'n_class': len(unique),
    'size_e': np.divide([64,64,128,128,256,256,512,512,1024,1024],        32).astype(np.int32),
    'size_d': np.divide(np.multiply([1024,512,512,512,256,256,256,128,128,128,64,64], 3), 32).astype(np.int32),
    
    'crossval_nfolds': 3,
    'n_epoch_test': 2,          #periodicity of evaluation on test set
    'scheduler_milestones': [60,80,90],
    'scheduler_gamma': 0.3,
    'class_weights': torch.tensor([0.0, 0.1966495481867119, 0.3430579108649065, 0.03228185060941156, 0.16631530937067654, 0.14084359532291527, 0.02901643207368558, 0.012932577826270997, 0.021546657524297137, 0.057356118221124304]),

    'n_epoch': 5,
    'lr': 5e-6,
    'batch_size': 4,
}
model_save_folder = r'C:\Users\dd\Documents\NATUR_CUNI\additional_projects\prace\ETRAINEE\data\models'
print(f'''Number of models to be trained:
    {args['crossval_nfolds']}
Number of spectral channels:
    {args['n_channel']}
Initial learning rate:
    {args['lr']}
Batch size:
    {args['batch_size']}
Number of training epochs:
    {args['n_epoch']}''')

In [None]:
## Training a 3D network
kfold = KFold(n_splits = args['crossval_nfolds'], shuffle=True)
trained_models = []
for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
    print(f'Training starts for model number {str(fold+1)}')
    
    a = perf_counter()
    args['train_subsampler'] = torch.utils.data.SubsetRandomSampler(train_ids)
    args['test_subsampler'] = torch.utils.data.SubsetRandomSampler(test_ids)
    
    trained_models.append((train_full(args), args['loss_test']))
    
    state_dict_path = os.path.join(model_save_folder, f'fold_{str(fold)}.pt')
    torch.save(trained_models[fold][0].state_dict(), state_dict_path)
    print(f'Model saved to: {state_dict_path}')
    print(f'Training finished in {str(perf_counter()-a)}s')
    print('\n\n')

print(f'Resulting loss for individual folds: \n{[i for _, i in trained_models]}')
print(f'Mean loss across all folds: \n{np.mean([i for _, i in trained_models])}')

## Saving and reusing a trained model

In [None]:
# Path to the state_dictionary
state_dict_path = r'C:\Users\dd\Documents\NATUR_CUNI\additional_projects\prace\ETRAINEE\data\models\fold_2.pt'

Save a model to state_dict_path:

In [None]:
# Save a trained model state_dictionary
torch.save(trained_model.state_dict(), state_dict_path)

Reuse a model at state_dict_path:

In [None]:
# Parameters for model definition

args = {} #stores the parameters
args['n_class'] = 10
args['n_channel'] = 1
args['size_e'] =  np.divide([64,64,128,128,256,256,512,512,1024,1024],        32).astype(np.int32)
args['size_d'] = np.divide(np.multiply([1024,512,512,512,256,256,256,128,128,128,64,64], 3), 32).astype(np.int32)
args['cuda'] = True


# Load a trained model state_dictionary
#state_dict_path = 'd:/studenti/JD/ETRAINEE/models/fold_0.pt'
model = SpectroSpatialNet(args)
#state_dict_path = 'D:/Studenti/JD/Jeseniky_2020/models/KrakonosNet_jes_1e-3_200epochs.pt'
#model = KrakonosNet(args['n_channel'], args['size_e'], args['size_d'], args['n_class'], args['cuda'])
model.load_state_dict(torch.load(state_dict_path))
model.eval()

## Export results
Results are not georeferenced – use ArcPy_georeference_results.py for georeferencing and combining into a single raster

In [None]:
source_path = r'C:\Users\dd\Documents\NATUR_CUNI\additional_projects\prace\ETRAINEE\data\LH_202008_54pasem_9cm.tif'
tile_shape = (256, 256)
overlap = 128
offset_topleft = (0, 0)

start = perf_counter()
raster_orig = image_preprocessing.read_gdal_with_geoinfo(trainingdata_path, offset_topleft)
print(raster_orig['geoinfo'])

dataset_full_tiles = image_preprocessing.run_tiling_dims(raster_orig['imagery'], out_shape=tile_shape, 
                                                    out_overlap=overlap, offset=offset_topleft)
dataset_full = image_preprocessing.normalize_tiles_3d(dataset_full_tiles, nodata_vals=[0])
dataset = tnt.dataset.TensorDataset(dataset_full['imagery'])
end = perf_counter()
print(f'Loading the imagery took {end - start} seconds.')

print(dataset_full_tiles['imagery'].shape)
print(dataset_full_tiles['dimensions'])

In [None]:
output_path = r'C:\Users\dd\Documents\NATUR_CUNI\additional_projects\prace\ETRAINEE\data\test_result_3d.tif'

start = perf_counter()
arr_class = inference_utils.combine_tiles_2d(model, dataset, tile_shape, overlap, dataset_full_tiles['dimensions'])
print(np.unique(arr_class, return_counts=True))
inference_utils.export_result(output_path, arr_class, raster_orig['geoinfo'])
end = perf_counter()
print(f'The processing took {end - start} seconds.')
plt.imshow(arr_class)