In [1]:
import gc
import math
import os
from typing import List, Tuple

import matplotlib.pyplot as plt
import pandas as pd
import rasterio
import rasterio.plot
import seaborn as sns
import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import yaml
from torch import nn, optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.models import (EfficientNet_V2_S_Weights, Swin_V2_S_Weights,
                                efficientnet_v2_s, swin_v2_s)

from ml_commons import *

cudnn.benchmark = True
sns.set_theme()

Load configuration file that specifies training routine

In [2]:
config = yaml.safe_load(open('ml_config.yml'))

In [3]:
prefix_dir = config['paths']['prefix_dir']
dataset_dir = os.path.join(prefix_dir, config['paths']['dataset_dir'])
output_dir = os.path.join(config['paths']['machine_learning_dir'], 'output')

Release memory to ensure enough capacity on the selected device

In [4]:
gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {device} for inference')

Using cuda for inference


Create model using transfer learning from either EfficientNetV2 or SwinTransformerV2

In [6]:
def print_model(model:nn.Module, name:str):
    spacing = '  '
    model_str = spacing + f'\n{spacing}'.join(str(model).splitlines())
    print(f'--- {name} ---\n{model_str}\n{"-" * (8 + len(name))}')

In [8]:
print(f'Using model: ' + config['model']['name'])
model_name = str(config['model']['name']).lower()
use_transfer_learning = config['model']['use_transfer_learning']
if model_name == 'swintransformer':
    model = swin_v2_s(weights = Swin_V2_S_Weights.DEFAULT if use_transfer_learning else None)
elif model_name == 'efficientnet':
    model = efficientnet_v2_s(weights = EfficientNet_V2_S_Weights.DEFAULT if use_transfer_learning else None)
else:
    raise RuntimeError(f'Model "' + config['model']['name'] + '" is unknown')

if config['model']['freeze_parameters']:
    for param in model.parameters(): #freeze model
        param.requires_grad = False

print_model(get_model_head(model), 'Initial Model Head')
num_features = get_num_features(model)
print(f'Classification layer has {num_features} input features')
new_model_head = nn.Sequential()
# Before linear layer
if config['processing']['use_dropout']:
    new_model_head.append(nn.Dropout(p=config['processing']['dropout_p'], inplace=True))
# Linear layer
linear_layer = nn.Linear(num_features, 1 if config['processing']['use_one_neuron_regression'] else len(classes))
new_model_head.append(linear_layer)
# After linear layer
if (config['processing']['use_ordinal_regression'] or config['processing']['use_one_neuron_regression']) \
        and config['processing']['activation_function'] != False:
    activation_function = str(config['processing']['activation_function']).lower()
    if activation_function == 'sigmoid':
        activation_function = nn.Sigmoid()
    elif activation_function == 'relu':
        activation_function = nn.ReLU()
    elif activation_function == 'tanh':
        activation_function = nn.Tanh()
    else:
        raise RuntimeError(f'Unkown activation function: {activation_function}')
    new_model_head = new_model_head.append(activation_function)

set_model_head(model, new_model_head)
print_model(get_model_head(model), 'Modified Model Head')
# model = nn.DataParallel(model)
model = model.to(device)

Using model: EfficientNet
--- Initial Model Head ---
  Sequential(
    (0): Dropout(p=0.2, inplace=True)
    (1): Linear(in_features=1280, out_features=1000, bias=True)
  )
--------------------------
Classification layer has 1280 input features
--- Modified Model Head ---
  Sequential(
    (0): Linear(in_features=1280, out_features=5, bias=True)
    (1): Sigmoid()
  )
---------------------------


Load weights and define prerequisite functions for model training

In [9]:
training_weights = pd.read_csv(os.path.join(dataset_dir, 'training_weights.csv'), index_col='label')
training_weights.T

label,CLR,FEW,SCT,BKN,OVC
weight,2.505284,13.22179,9.509328,6.538807,3.743665


In [10]:
def prediction_to_label_ordinal_regression(pred: torch.Tensor) -> torch.Tensor:
    return (pred > 0.5).cumprod(axis=1).sum(axis=1) - 1

In [11]:
class OrdinalRegression:
    def __init__(self, weights:Optional[torch.Tensor]) -> None:
        if weights is None:
            self.weights = torch.Tensor([1] * len(classes)).to(device)
        else:
            self.weights = weights
    def __call__(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        modified_targets = torch.zeros_like(predictions)
        for i, target in enumerate(targets):
            modified_targets[i, 0 : target + 1] = 1
        loss_function_name = str(config['processing']['ordinal_regression_loss']).lower()
        if loss_function_name == 'mse':
            loss = torch.mean((nn.MSELoss(reduction='none')(predictions, modified_targets) * self.weights).sum(axis=1))
        elif loss_function_name == 'l1':
            loss = torch.mean((nn.L1Loss(reduction='none')(predictions, modified_targets) * self.weights).sum(axis=1))
        else:
            raise RuntimeError(f'Unknown loss function: {loss_function_name}')
        return loss

In [12]:
def prediction_to_label_one_neuron_regression(pred: torch.Tensor) -> torch.Tensor:
    return ((pred * max(classes)).round()).int().flatten()

In [13]:
class OneNeuronRegression:
    def __call__(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        modified_predictions = predictions.flatten()
        classes_max = max(classes)
        loss_function_name = str(config['processing']['ordinal_regression_loss']).lower()
        if loss_function_name == 'mse':
            loss = torch.mean(nn.MSELoss(reduction='none')(modified_predictions * classes_max, targets.float()))
        elif loss_function_name == 'l1':
            loss = torch.mean(nn.L1Loss(reduction='none')(modified_predictions * classes_max, targets.float()))
        else:
            raise RuntimeError(f'Unknown loss function: {loss_function_name}')
        return loss

Load datasets and apply techniques like data augmentation or use a weighted sampler

In [14]:
dataset_transforms = [
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
]
composed_transforms = transforms.Compose(dataset_transforms)

In [15]:
use_manual_labels = config['processing']['use_manual_labels']

In [16]:
train_dataset       = AimlsseImageDataset(DatasetType.TRAINING,     dataset_dir, transfrom=composed_transforms, use_manual_labels=use_manual_labels)
validation_dataset  = AimlsseImageDataset(DatasetType.VALIDATION,   dataset_dir, transfrom=None, use_manual_labels=use_manual_labels)

In [17]:
def mean_std(loader:DataLoader):
  sum, squared_sum, num_batches = 0,0,0
  for data, _, _ in loader:
    sum += torch.mean(data,dim=[0,1,2])
    squared_sum += torch.mean(data**2,dim=[0,1,2])
    num_batches += 1
  mean = sum/num_batches
  std = (squared_sum/num_batches - mean**2)**0.5
  return mean, std

In [18]:
def batch_normalization(dataset:AimlsseImageDataset, dataset_type:DatasetType, dataset_transforms):
    mean, std = mean_std(dataset)
    print(f'{dataset_type.name} - mean {mean:.3f}, std {std:.3f}')
    if dataset_transforms is None:
        dataset_transforms = []
    return AimlsseImageDataset(dataset_type, dataset_dir,
                               transfrom = transforms.Compose(dataset_transforms + [transforms.Normalize(mean, std)]),
                               use_manual_labels=use_manual_labels)

In [19]:
if config['processing']['batch_normalization']:
    train_dataset       = batch_normalization(train_dataset, DatasetType.TRAINING, dataset_transforms)
    validation_dataset  = batch_normalization(validation_dataset, DatasetType.VALIDATION, None)

In [20]:
if config['processing']['use_weighted_sampler']:
    num_samples = len(train_dataset)
    weights = [0] * num_samples
    for i in range(num_samples):
        label = train_dataset.get_label(i)
        weights[i] = training_weights.loc[class_names]['weight'].iloc[label]
    training_sampler = WeightedRandomSampler(weights, num_samples)
    train_dataloader =  DataLoader(train_dataset,       batch_size=config['processing']['batch_size'], sampler=training_sampler)
else:
    train_dataloader =  DataLoader(train_dataset,       batch_size=config['processing']['batch_size'], shuffle=True)
validation_dataloader = DataLoader(validation_dataset,  batch_size=config['processing']['batch_size'], shuffle=True)

If enabled in the config file, show example images of the dataset with corresponding labels

In [21]:
sample_batch_index = 0

In [22]:
if config['output']['show_samples']:
    plot_samples(train_dataset, config['processing']['batch_size'], sample_batch_index)
    sample_batch_index += 1

Define all prerequisites that are necessary for model training, depending on the settings in the configuration file

In [23]:
if config['processing']['use_weighted_loss_function']:
    loss_function_weights = torch.tensor(training_weights['weight'].to_list())
    loss_function_weights = loss_function_weights.to(device)
else:
    loss_function_weights = None
print(f'Loss weights: {loss_function_weights}')

if config['processing']['use_ordinal_regression']:
    criterion = OrdinalRegression(loss_function_weights)
    outputs_to_predictions = prediction_to_label_ordinal_regression
elif config['processing']['use_one_neuron_regression']:
    criterion = OneNeuronRegression()
    if loss_function_weights is not None:
        raise Warning('Unable to use one neuron regression with loss function weights')
    outputs_to_predictions = prediction_to_label_one_neuron_regression
else:
    criterion = nn.CrossEntropyLoss(loss_function_weights)
    outputs_to_predictions = lambda outputs: torch.max(outputs, 1)[1]

learning_rate = math.pow(10, -config['processing']['learning_rate_exp'])
weight_decay = math.pow(10, -config['processing']['weight_decay_exp']) if config['processing']['use_weight_decay'] else 0.0
optimizer_name = str(config['processing']['optimizer']).lower()
if optimizer_name == 'adam':
    optimizer = optim.Adam(get_model_head(model).parameters(), lr=learning_rate, weight_decay=weight_decay)
elif optimizer_name == 'sgd':
    optimizer = optim.SGD(get_model_head(model).parameters(), lr=learning_rate, momentum=config['processing']['momentum'])
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

Loss weights: None


Train the machine learning model

In [24]:
model_data = ModelData(train_dataset, validation_dataset, train_dataloader, validation_dataloader)
checkpoint_filepath = os.path.join(config['paths']['machine_learning_dir'], 'checkpoints', 'chk.pt')
print(f'Model Checkpoints will be stored in: {checkpoint_filepath}')
output_filepath = os.path.join(config['paths']['machine_learning_dir'], 'output', config['output']['output_name'])
print(f'The results will be stored in: {output_filepath}')
model_trained = train_model(model, device, model_data, criterion, outputs_to_predictions, optimizer, scheduler,
                            checkpoint_filepath, num_epochs=config['processing']['num_epochs'],
                            batch_accumulation=config['processing']['batch_accumulation'], config=config)
print('Copying data from checkpoint to results..')
state = load_state(checkpoint_filepath)
save_state(output_filepath, state)
print(f'Results stored in: {output_filepath}')
print('Done!')

Model Checkpoints will be stored in: ML\checkpoints\chk.pt
The results will be stored in: ML\output\efficientnet_v2_s_preset_3_16km_300_weighted_sampler_ord_regr_batchacc_all.pt
Epoch 0/31
----------


  0%|          | 0/638 [00:00<?, ?it/s]

Training took 319.3 [s]
	Loss: 0.0020 Acc: 0.0753
	CLR - Precision: 0.181 Recall: 0.246 F1-Score: 0.208
	FEW - Precision: 0.218 Recall: 0.072 F1-Score: 0.109
	SCT - Precision: 0.234 Recall: 0.019 F1-Score: 0.035
	BKN - Precision: 0.217 Recall: 0.027 F1-Score: 0.048
	OVC - Precision: 0.177 Recall: 0.018 F1-Score: 0.032
Total -> Precision: 0.205 Recall: 0.075 F1-Score: 0.086


  0%|          | 0/80 [00:00<?, ?it/s]

Validation took 8.1 [s]
	Loss: 0.0019 Acc: 0.1292
	CLR - Precision: 0.323 Recall: 0.163 F1-Score: 0.217
	FEW - Precision: 0.037 Recall: 0.093 F1-Score: 0.053
	SCT - Precision: 0.060 Recall: 0.151 F1-Score: 0.086
	BKN - Precision: 0.123 Recall: 0.108 F1-Score: 0.115
	OVC - Precision: 0.367 Recall: 0.072 F1-Score: 0.121
Total -> Precision: 0.272 Recall: 0.129 F1-Score: 0.162
Epoch 1/31
----------


  0%|          | 0/638 [00:00<?, ?it/s]

Training took 235.1 [s]
	Loss: 0.0018 Acc: 0.1545
	CLR - Precision: 0.165 Recall: 0.159 F1-Score: 0.162
	FEW - Precision: 0.224 Recall: 0.151 F1-Score: 0.180
	SCT - Precision: 0.196 Recall: 0.289 F1-Score: 0.233
	BKN - Precision: 0.202 Recall: 0.145 F1-Score: 0.169
	OVC - Precision: 0.280 Recall: 0.029 F1-Score: 0.053
Total -> Precision: 0.213 Recall: 0.155 F1-Score: 0.160


  0%|          | 0/80 [00:00<?, ?it/s]

Validation took 8.1 [s]
	Loss: 0.0018 Acc: 0.1174
	CLR - Precision: 0.361 Recall: 0.088 F1-Score: 0.141
	FEW - Precision: 0.031 Recall: 0.107 F1-Score: 0.048
	SCT - Precision: 0.081 Recall: 0.462 F1-Score: 0.137
	BKN - Precision: 0.179 Recall: 0.165 F1-Score: 0.172
	OVC - Precision: 0.500 Recall: 0.052 F1-Score: 0.095
Total -> Precision: 0.332 Recall: 0.117 F1-Score: 0.128
Epoch 2/31
----------


  0%|          | 0/638 [00:00<?, ?it/s]

KeyboardInterrupt: 