# Model Training pipeling

First load all necessary libraries and the data

In [None]:
import random
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import optuna
import optuna.visualization as vis
import torch.optim as optim
import numpy as np 
import joblib
import gc
from torch.utils.data import DataLoader
from preprocessing import *
from utils import *
from datasets import *
from CNN_AE_helper import *
from CNN3d import *
from torchvision.transforms import v2
from scipy.ndimage import binary_erosion
import plotly
import plotly.io as pio

In [None]:
# FX10 camera
#IMG_DIR = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_hdf5/'
#IMG_DIR = '/home/u0158953/data/Strawberries/PotsprocessedData/cropped_hdf5'
IMG_DIR = '/home/r0979317/Documents/Thesis_Strawberries/Data/cropped_hdf5'
CAMERA = 'FX10'

# Healthy leaves
DATES = ['07SEPT2023', '08SEPT2023', '09SEPT2023', '10SEPT2023', '11SEPT2023', '12SEPT2023',
         '13SEPT2023', '14SEPT2023', '15SEPT2023', '18SEPT2023', '19SEPT2023']
TRAYS = ['3D', '4C', '4D', '2D']    # Some files from the FX17 camera are mistakenly named in 2D instead of 4D
healthy_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Early diseased leaves
DATES = ['07SEPT2023']
TRAYS = ['3C']
early_diseased_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Mid diseased leaves
DATES = ['08SEPT2023', '09SEPT2023']
TRAYS = ['3C']
mid_diseased_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Late diseased leaves
DATES = ['10SEPT2023', '11SEPT2023', '12SEPT2023', '13SEPT2023', '14SEPT2023', '15SEPT2023']
TRAYS = ['3C']
late_diseased_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Number of samples in each category
print(f'Healthy: {len(healthy_FX10)}')
print(f'Early diseased: {len(early_diseased_FX10)}')
print(f'Mid diseased: {len(mid_diseased_FX10)}')
print(f'Late diseased: {len(late_diseased_FX10)}')

Split data into train, validation, and test set

In [None]:
train, test = train_test_split(healthy_FX10, test_size=0.20, random_state=10)
train, validation = train_test_split(train, test_size=0.185, random_state=10)

In [None]:
#INPUT_DATA = healthy_FX10[0:10]    # [0:40] just to speed up the process for now
#MASK_FOLDER = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_masks'
#MASK_FOLDER = "/home/u0158953/data/Strawberries/PotsprocessedData/cropped_masks"
MASK_FOLDER = '/home/r0979317/Documents/Thesis_Strawberries/Data/cropped_masks'
BATCH_SIZE = 2
MASK_METHOD = 1    # 0 for only leaf, 1 for leaf+stem
BAND_SELECTION = [489.3, 505.1, 542.21, 550.2, 558.21, 582.31, 625.4, 660.62, 674.2, 679.64,
                  701.44, 717.81, 736.94, 745.15, 783.52, 866.08, 951.83]    # Important wavelengths obtained from pca_bandselect.ipynb # extra if needed: 819.25
POLYORDER = 2
WINDOW_LENGTH = 4 
PREPROCESS_METHOD = "normal" # Can be: "normal", "savgol", "snv", "stacked"
PATCH_PROBABILITY = 0.0 # amount of data the dataloader randomly zooms into
SCALER = joblib.load('/home/r0979317/Documents/Thesis_Strawberries/Thesis_code/master_thesis/models/pca/scaler_healthy.joblib')    
PCA = joblib.load('/home/r0979317/Documents/Thesis_Strawberries/Thesis_code/master_thesis/models/pca/pca_model_healthy.joblib')  
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
print(f'Total number of GPUs: {torch.cuda.device_count()}')

Data augmentation

In [None]:
# Define data augmentations
train_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize((256, 256)),    # By default this uses bilinear interpolation which is good.
    v2.ToDtype(torch.float32, scale=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomRotation(degrees=(0, 180), interpolation=v2.InterpolationMode.BILINEAR),
])

test_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize((256, 256)),
    v2.ToDtype(torch.float32, scale=True),
]) 

In [None]:
# Create Dataset and DataLoader
#####################
### TRAINING DATA ###
#####################
dataset_train_hsi = HsiDataset(train, MASK_FOLDER, transform=train_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               pca=None, scaler=None, polyorder=POLYORDER, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD, patch_probability = PATCH_PROBABILITY)
dataloader_train_hsi = DataLoader(dataset_train_hsi, batch_size=BATCH_SIZE, shuffle=True, collate_fn=None)

#######################
### VALIDATION DATA ###
#######################
dataset_validation_hsi = HsiDataset(validation, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               pca=None, scaler=None, polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD, patch_probability = PATCH_PROBABILITY)
dataloader_validation_hsi = DataLoader(dataset_validation_hsi, batch_size=BATCH_SIZE, shuffle=False, collate_fn=None)

#################
### TEST DATA ###
#################
dataset_test_hsi = HsiDataset(test, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               pca=None, scaler=None, polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD, patch_probability = PATCH_PROBABILITY)
dataloader_test_hsi = DataLoader(dataset_test_hsi, batch_size=BATCH_SIZE, shuffle=False, collate_fn=None)

###########################
### EARLY DISEASED DATA ###
###########################
dataset_early_diseased_hsi = HsiDataset(early_diseased_FX10, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               pca=None, scaler=None, polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD, patch_probability = PATCH_PROBABILITY)
dataloader_early_diseased_hsi = DataLoader(dataset_early_diseased_hsi, batch_size=BATCH_SIZE, shuffle=False)

#########################
### MID DISEASED DATA ###
#########################
dataset_mid_diseased_hsi = HsiDataset(mid_diseased_FX10, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               pca=None, scaler=None, polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD, patch_probability = PATCH_PROBABILITY)
dataloader_mid_diseased_hsi = DataLoader(dataset_mid_diseased_hsi, batch_size=BATCH_SIZE, shuffle=False)

##########################
### LATE DISEASED DATA ###
##########################
dataset_late_diseased_hsi = HsiDataset(late_diseased_FX10, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               pca=None, scaler=None, polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD, patch_probability = PATCH_PROBABILITY)
dataloader_late_diseased_hsi = DataLoader(dataset_late_diseased_hsi, batch_size=BATCH_SIZE, shuffle=False)

# Training process

- Training of a full model
- Different hyperparameter tuning pipelines

In [None]:
#############################################
##### Train the model
#############################################

torch.manual_seed(10)
torch.cuda.manual_seed_all(10)
device = 'cuda' if torch.cuda.is_available() else 'cpu'


# Instantiate the chosen model with the chosen parameters. These change depending on the architecture
# Models can be found in CNN3d.py
model = CNN3DAE_Exact_Configurable(
        base_channels=[72, 128, 128],
        dropout_p=0.06405922911448286,
        pool_type="avg",
        use_batchnorm=[False, False, False]
    ).to(device)

# Define loss function. Here it is MAE loss
criterion = torch.nn.L1Loss()

# Define Optimizer with the according learning rate and weight decay
optimizer = torch.optim.Adam(model.parameters(), lr=5.3302310878295036e-05, weight_decay=8.389079019851564e-05)


# Train the model using the function from CNN_AE_helper.py
train_losses, val_losses = train_autoencoder(
    75,
    model,
    dataloader_train_hsi,
    device,
    dataloader_validation_hsi,
    criterion=criterion,
    optimizer=optimizer,
    save_model=True,
    save_path='/home/r0979317/Documents/Thesis_Strawberries/models/third_finetuning_model.pth', noise = False
)

### Hyperparameter tuning

In [None]:
#############################################
#### First hyperparameter tuning set up
#############################################

def objective(trial, train = dataloader_train_hsi, validation = dataloader_validation_hsi):
    
    trial_start = time.time()
    
    # === Number of Layers ===
    n_layers = trial.suggest_int("n_layers", 4, 6)
    
    layers_list = []
    current_min = 8
    for i in range(n_layers):
        ch = trial.suggest_int(f"enc_ch_{i}", current_min, 256, step=8)
        layers_list.append(ch)
        if ch < 248:
            current_min = ch + 8  # ensure next layer has strictly larger channels
        else:
            current_min = ch
        
    
    # Optional: prune if model is too large since I run out of Memory often
    if sum(layers_list) > 900:  
        raise optuna.exceptions.TrialPruned()

    # === Dropout, LR, etc. ===
    #bdropout_p = trial.suggest_float("dropout", 0.2, 0.4)
    lr = trial.suggest_float("lr", 1e-5, 5e-3, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 7e-3, log=True)
    #optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop"]) # Initial fine tuning found RMSprop to perform worse
    # Suggest criterion type
    criterion_name = trial.suggest_categorical("criterion", ["MSE", "MAE"])
    
    
    print(f'Model {trial.number} with Layers: {layers_list}, LR: {lr:.1e}, WD: {weight_decay:.1e}, Loss criterion: {criterion_name}')

    # === Model ===
    model = CNN3DAEMAX_try(
        layers_list=layers_list,
        input_dim=1,
        kernel_sizes=3,
        strides=(1, 2, 2),
        paddings=1).to(device)

    # === Optimizer ===
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    #else:
    #    optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay)

    if criterion_name == "MSE":
        criterion = nn.MSELoss()
    elif criterion_name == "MAE":
        criterion = nn.L1Loss()

    train_losses, val_losses = train_autoencoder(
        num_epochs=20,
        model=model,
        dataloader_train=train,
        device=device,
        dataloader_valid=validation,
        criterion=criterion,
        optimizer=optimizer,
        save_model=False
    )

    best_val_loss = min(val_losses)
    
    # === Logging ===
    elapsed = (time.time() - trial_start) / 60
    print(f'Trial {trial.number} finished in {elapsed:2f} min')
    print(f'Minimal val loss: {best_val_loss:.6f}')
    
    # === Cleanup ===
    del model, optimizer, criterion, train_losses, val_losses
    torch.cuda.empty_cache()
    gc.collect()
    time.sleep(1)

    return best_val_loss

In [None]:
#############################################
#### Second hyperparameter tuning set up
#############################################
def objective_finetuning(trial, train = dataloader_train_hsi, validation = dataloader_validation_hsi):

    trial_start = time.time()
    # === Number of Layers ===
    n_layers = trial.suggest_int("n_layers", 2, 6)

    # === Dynamically define channels per layer ===
    layers_list = []
    current_min = 8
    for i in range(n_layers):
        ch = trial.suggest_int(f"enc_ch_{i}", current_min, 256, step=8)
        layers_list.append(ch)
        if ch < 248:
            current_min = ch + 8  # ensure next layer has strictly larger channels
        else:
            current_min = ch
        
    
    # Optional: prune if model is too large since I run out of Memory often
    if sum(layers_list) > 900:  
        raise optuna.exceptions.TrialPruned()

    # === Dropout, LR, etc. ===
    dropout_p = trial.suggest_float("dropout", 0.0, 0.5)
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True)

    # ===== Activation choice =====
    act_name = trial.suggest_categorical("activation", ["relu", "gelu", "elu"])

    # ——— 1) Loss function choice ———
    loss_name = trial.suggest_categorical("loss_fn", ["MAE", "MSE"])
    if loss_name == "MAE":
        loss_fn = torch.nn.L1Loss()
    else:
        loss_fn = torch.nn.MSELoss()
    

    # ——— 2) BatchNorm / Pooling choices ———
    use_bn_flag = trial.suggest_categorical("use_bn", [False, True])
    use_batchnorm = [use_bn_flag] * n_layers
    
    # ===== Global pooling choice =====
    # 1) you always sample from the same three options
    pool_type = trial.suggest_categorical("pool_type", ["none","max","avg"])
    
    # 2) but if you have too many layers, that combo is invalid—
    #    prune it right away so it never skews your results
    if len(layers_list) > 4 and pool_type != "none":
        raise optuna.TrialPruned()

    print(f'Model {trial.number} with Layers: {layers_list}, LR: {lr:.1e}, WD: {weight_decay:.1e}, Loss criterion: {loss_name}, Pooling: {pool_type}, BatchNorm: {use_bn_flag}, activation: {act_name}')

    strides = [(1,2,2)] * len(layers_list)

    # ===== instantiate model =====
    model = CNN3DAE_finetuning(
        layers_list=layers_list,
        strides=strides,
        dropout_p=dropout_p,
        pool_type=pool_type,
        use_batchnorm=use_batchnorm,
        activation=act_name,         # new argument
    ).to(device)


    # === Optimizer ===
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

        # FIXED validation metric
    train_losses, val_losses = train_autoencoder(
        num_epochs=20,
        model=model,
        dataloader_train=train,
        device=device,
        dataloader_valid=validation,
        criterion=loss_fn,
        optimizer=optimizer,
        save_model=False
    )

    best_val_loss = min(val_losses)
    
    # === Logging ===
    elapsed = (time.time() - trial_start) / 60
    print(f'Trial {trial.number} finished in {elapsed:2f} min')
    print(f'Minimal val loss: {best_val_loss:.6f}')
    
    # === Cleanup ===
    try:
        del model, optimizer, loss_fn, train_losses, val_losses
    except NameError:
        # in case we pruned early and some of these don't exist
        pass
    
    torch.cuda.empty_cache()
    gc.collect()
    time.sleep(1)

    return best_val_loss

Check exact upsampling of the model without interpolation

In [None]:
model = CNN3DAE_Exact_Configurable(
    input_dim=1,
    base_channels=[8,16,32, 64, 128],
    dropout_p=0.2,
    pool_type="avg",              # or "avg", or "none"
    use_batchnorm=[False, True, True, True, True]
)

x = torch.randn(1, 1, 18, 256, 256)
out = model(x)
print(out.shape) 

In [None]:
#############################################
#### Third hyperparameter tuning set up
#############################################
def objective_checkerboard(trial,
                           train=dataloader_train_hsi,
                           validation=dataloader_validation_hsi):

    trial_start = time.time()
    # Number of Layers
    n_layers = trial.suggest_int("n_layers", 3, 5)

    #  Dynamically define channels per layer, force them to increase size
    layers_list = []
    current_min = 8
    for i in range(n_layers):
        ch = trial.suggest_int(f"enc_ch_{i}", current_min, 128, step=8)
        layers_list.append(ch)
        current_min = ch + 8 if ch < 120 else ch

    # Prune overly large models so we don't run out of CUDA memory
    if sum(layers_list) > 600:
        raise optuna.exceptions.TrialPruned()

    # Dropout, LR, weight decay
    dropout_p    = trial.suggest_float("dropout",      0.0, 0.5)
    lr           = trial.suggest_float("lr",           1e-5, 1e-3, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True)

    # Pooling choice 
    pool_type = trial.suggest_categorical("pool_type", ["none", "max", "avg"])
    
    # Per-layer Batch normalization
    use_batchnorm = []
    for i in range(n_layers):
        flag = trial.suggest_categorical(f"use_bn_layer_{i}", [False, True])
        use_batchnorm.append(flag)


    print(
        f"Trial {trial.number}: layers={layers_list}, "
        f"lr={lr:.1e}, wd={weight_decay:.1e}, "
        f"pool={pool_type}, bn={use_batchnorm}"
    )

    # Build model
    model = CNN3DAE_Exact_Configurable(
        input_dim=1,
        base_channels=layers_list,
        dropout_p=dropout_p,
        pool_type=pool_type,
        use_batchnorm=use_batchnorm,
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
    optimizer.zero_grad(set_to_none=True)
    loss_fn = torch.nn.L1Loss()

    train_losses, val_losses = train_autoencoder(
        num_epochs=20,
        model=model,
        dataloader_train=train,
        device=device,
        dataloader_valid=validation,
        criterion=loss_fn,
        optimizer=optimizer,
        save_model=False,
        trial=trial
    )

    best_val_loss = min(val_losses)
    elapsed = (time.time() - trial_start) / 60
    print(f"Trial {trial.number} finished in {elapsed:.2f} min — best val loss {best_val_loss:.6f}")

    # cleanup
    del model, optimizer, loss_fn, train_losses, val_losses
    torch.cuda.empty_cache()
    gc.collect()
    time.sleep(1)
    return best_val_loss

In [None]:
# Define study_path as filen_name.db
study_path = "/home/r0979317/Documents/Thesis_Strawberries/..."

In [None]:
pruner = optuna.pruners.PercentilePruner(
    percentile=50.0,         # Only keep top 50% of trials
    n_startup_trials=6,      # Don't prune the first 6 trials
    n_warmup_steps=2,        # Start checking after the second epoch
    interval_steps=1,        # Check every epoch
    n_min_trials=1           # Minimum # of trials to start pruning comparisons
)

# Or use other pruner for run 1 and 2

pruner = optuna.pruners.MedianPruner()

study = optuna.create_study(
    direction="minimize",
    pruner=pruner,
    study_name="...",       # Here the name of the study has to be defined
    storage=f"sqlite:///{study_path}",
    load_if_exists=True
)


study.optimize(objective_checkerboard, n_trials=100, timeout=40000, gc_after_trial=True)

In [None]:
# Load the study to visualize it
study_test = optuna.load_study(
    study_name= "...",#"zoomed_finetuning_study",
    storage=f"sqlite:///{study_path}"
)

In [None]:
# Show the best trial and its parameters
best_trial = study.best_trial

print("Best trial:")
print(f"  Value (objective score): {best_trial.value}")
print("  Params:")
for key, value in best_trial.params.items():
    print(f"    {key}: {value}")

Visualize hyperparameter tuning process

In [None]:

pio.renderers.default = "notebook"  # or "iframe" if this fails

In [None]:
# Sort all completed trials by objective value (lowest = best)
completed_trials = [t for t in study_test.trials if t.state == optuna.trial.TrialState.COMPLETE]
top5 = sorted(completed_trials, key=lambda t: t.value)[:9]

# Print them
for i, trial in enumerate(top5):
    print(f"\nRank {i+1}:")
    print(f"  Trial number: {trial.number}")
    print(f"  Objective value: {trial.value:.6f}")
    print(f"  Params: {trial.params}")

In [None]:
vis.plot_optimization_history(study).show()
vis.plot_param_importances(study).show()

In [None]:
vis.plot_parallel_coordinate(study_test).show()

### Loading the trained model

In [None]:
# Define the model architecture and optimizer again
model = CNN3DAE_TightDropout(
        layers_list=[18, 32, 64, 64],
        input_dim=1,
        strides=(1, 2, 2),
        paddings=(0, 1, 1)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Load the checkpoint
model, optimizer, train_losses, val_losses, last_epoch = load_checkpoint(model, '/home/r0979317/Documents/Thesis_Strawberries/models/first_pc_dropout_model.pth', device, optimizer)