# Inference into model performance

The following code is used as a pipeline to analyse the models trained based on different metrics and visualizations.

Load the necessary packages, data and the model

In [None]:
import random
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np 
import joblib
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from scipy.ndimage import binary_erosion, binary_dilation
from sklearn.manifold import TSNE
from sklearn.metrics import (
    confusion_matrix, roc_auc_score, average_precision_score,
    precision_recall_fscore_support, balanced_accuracy_score, matthews_corrcoef,
    classification_report
)
from sklearn.preprocessing import minmax_scale
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import umap.umap_ as UMAP
import os
from sklearn.metrics import PrecisionRecallDisplay
from tqdm import tqdm 
from scipy.stats import mannwhitneyu
from PIL import Image

# Custom functions
from preprocessing import *
from utils import *
from datasets import *
from CNN_AE_helper import *
from CNN3d import *
from vein_detection import *
from visualization_helper import *
from reconstruction_error import *
from roc_percision_recall import *


In [3]:
# FX10 camera
IMG_DIR = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_hdf5'
# IMG_DIR = '/home/u0158953/data/Strawberries/PotsprocessedData/cropped_hdf5'
CAMERA = 'FX10'

# Healthy leaves
DATES = ['07SEPT2023', '08SEPT2023', '09SEPT2023', '10SEPT2023', '11SEPT2023', '12SEPT2023',
         '13SEPT2023', '14SEPT2023', '15SEPT2023', '18SEPT2023', '19SEPT2023']
TRAYS = ['3D', '4C', '4D', '2D']    # Some files from the FX17 camera are mistakenly named in 2D instead of 4D
healthy_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Early diseased leaves
DATES = ['07SEPT2023']
TRAYS = ['3C']
early_diseased_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Mid diseased leaves
DATES = ['08SEPT2023', '09SEPT2023']
TRAYS = ['3C']
mid_diseased_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Late diseased leaves
DATES = ['10SEPT2023', '11SEPT2023', '12SEPT2023', '13SEPT2023', '14SEPT2023', '15SEPT2023']
TRAYS = ['3C']
late_diseased_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Number of samples in each category
print(f'Healthy: {len(healthy_FX10)}')
print(f'Early diseased: {len(early_diseased_FX10)}')
print(f'Mid diseased: {len(mid_diseased_FX10)}')
print(f'Late diseased: {len(late_diseased_FX10)}')

Healthy: 321
Early diseased: 12
Mid diseased: 25
Late diseased: 30


**Split data into train, validation and test set**

In [None]:
train, test = train_test_split(healthy_FX10, test_size=0.20, random_state=10)
train, validation = train_test_split(train, test_size=0.185, random_state=10)

256


In [None]:
MASK_FOLDER = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_masks/'
#MASK_FOLDER = "/home/u0158953/data/Strawberries/PotsprocessedData/cropped_masks"
#MASK_FOLDER = '/home/r0979317/Documents/Thesis_Strawberries/Data/cropped_masks'
BATCH_SIZE = 8
MASK_METHOD = 1    # 0 for only leaf, 1 for leaf+stem
BAND_SELECTION = [489.3, 505.1, 542.21, 550.2, 558.21, 582.31, 625.4, 660.62, 674.2, 679.64,
                  701.44, 717.81, 736.94, 745.15, 783.52, 866.08, 951.83]    # Important wavelengths obtained from pca_bandselect.ipynb 819.25,
POLYORDER = 2
WINDOW_LENGTH = 4 
PREPROCESS_METHOD = "normal"
PATCH_PROBABILITY = 0.0 # amount of data the dataloader randomly zooms into
#SCALER = joblib.load('/home/r0979317/Documents/Thesis_Strawberries/Thesis_code/master_thesis/models/pca/scaler_healthy.joblib')    
#PCA = joblib.load('/home/r0979317/Documents/Thesis_Strawberries/Thesis_code/master_thesis/models/pca/pca_model_healthy.joblib')    
SCALER = joblib.load('C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/GitLab/master_thesis/models/pca/scaler_healthy.joblib')
PCA = joblib.load('C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/GitLab/master_thesis/models/pca/pca_model_healthy.joblib')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
print(f'Total number of GPUs: {torch.cuda.device_count()}')

Using cpu device
Total number of GPUs: 0


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


**Data augmentations**

In [12]:
# Define data augmentations
train_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize((256, 256)),    # By default this uses bilinear interpolation which is good.
    v2.ToDtype(torch.float32, scale=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomRotation(degrees=(0, 180), interpolation=v2.InterpolationMode.BILINEAR),
])

test_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize((256, 256)),
    v2.ToDtype(torch.float32, scale=True),
]) 

In [13]:
# Create Dataset and DataLoader
#####################
### TRAINING DATA ###
#####################
dataset_train_hsi = HsiDataset(train, MASK_FOLDER, transform=train_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               pca=None, scaler=None, polyorder=POLYORDER, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD, patch_probability = PATCH_PROBABILITY)
dataloader_train_hsi = DataLoader(dataset_train_hsi, batch_size=BATCH_SIZE, shuffle=True, collate_fn=None)

#######################
### VALIDATION DATA ###
#######################
dataset_validation_hsi = HsiDataset(validation, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               pca=None, scaler=None, polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD, patch_probability = PATCH_PROBABILITY)
dataloader_validation_hsi = DataLoader(dataset_validation_hsi, batch_size=BATCH_SIZE, shuffle=False, collate_fn=None)

#################
### TEST DATA ###
#################
dataset_test_hsi = HsiDataset(test, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               pca=None, scaler=None, polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD, patch_probability = PATCH_PROBABILITY)
dataloader_test_hsi = DataLoader(dataset_test_hsi, batch_size=BATCH_SIZE, shuffle=False, collate_fn=None)

###########################
### EARLY DISEASED DATA ###
###########################
dataset_early_diseased_hsi = HsiDataset(early_diseased_FX10, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               pca=None, scaler=None, polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD, patch_probability = PATCH_PROBABILITY)
dataloader_early_diseased_hsi = DataLoader(dataset_early_diseased_hsi, batch_size=BATCH_SIZE, shuffle=False)

#########################
### MID DISEASED DATA ###
#########################
dataset_mid_diseased_hsi = HsiDataset(mid_diseased_FX10, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION,
                               pca=None, scaler=None, polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD, patch_probability = PATCH_PROBABILITY)
dataloader_mid_diseased_hsi = DataLoader(dataset_mid_diseased_hsi, batch_size=BATCH_SIZE, shuffle=False)

##########################
### LATE DISEASED DATA ###
##########################
dataset_late_diseased_hsi = HsiDataset(late_diseased_FX10, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION,
                               pca=None, scaler=None, polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD, patch_probability = PATCH_PROBABILITY)
dataloader_late_diseased_hsi = DataLoader(dataset_late_diseased_hsi, batch_size=BATCH_SIZE, shuffle=False)

# Load Model that needs to be investigated

In [15]:
# --- Instantiate model ---
model = CNN3DAEMAX_try(layers_list=[18, 18, 32, 32], input_dim=1, kernel_sizes=3, strides=(1, 2, 2), paddings=1).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Load the checkpoint
model, optimizer, train_losses, val_losses, last_epoch = load_checkpoint(model, 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/GitLab/master_thesis/models/3D_autoencoder/first_normal_max_model.pth', device, optimizer)

Checkpoint loaded from C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/GitLab/master_thesis/models/3D_autoencoder/first_normal_max_model.pth, Last Epoch: 163


  checkpoint = torch.load(path, map_location=device)  # Ensures compatibility across devices


In [16]:
# Save all plots in this directory
model_name = "max_normal_test"
save_path = f'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Images/{model_name}' 
if not os.path.exists(save_path):
    os.makedirs(save_path)

# Inference

As inference techniques we will conduct:

Metrics without threshold
- AUC
- AUPRC
- FPR@95%TPR

Threshold based metrics
- Max-value
- 95-quantile
- mean + 2*standard deviation
- Max-Area?

- The overall reconstruction error value
- NO of 99% extreme reconstruction errors
- NOP of 99% extreme reconstruction error per band
- Vein area




**!!!!!!! IF STACKED MODEL CHOSEN GO TO SECTION STACKED MODEL FOR INFERENCE!!!!!**

## ROC and Percision-Recall for binary evaluation

### Overall reconstruction error

In [None]:
ERROR = 'mae'
ERROR_METRIC = "image_mean"
MASK_AFTER = True
REMOVE_EDGES = True
SIZE = 256
#compare_auc_across_stages(model, healthy_dataloader, diseased_dataloaders, stages_labels, device, thresholds = 0, error_metric = "image_mean", error_type='mse',
                             #mask_after=False, remove_edges=False, mask_resize=256, MASK_FOLDER = "No_folder")
auc_results = compare_auc_across_stages(
    model,
    dataloader_test_hsi,  # Using validation set as healthy reference
    [dataloader_early_diseased_hsi, dataloader_mid_diseased_hsi, dataloader_late_diseased_hsi],
    ["Early", "Mid", "Late"],
    device, 
    error_metric = ERROR_METRIC,
    error_type=ERROR, 
    mask_after=MASK_AFTER, 
    remove_edges=REMOVE_EDGES, 
    mask_resize=SIZE,
    MASK_FOLDER=MASK_FOLDER, 
    save_path = f'{save_path}/{ERROR_METRIC}_AUC_ROC.png')

pr_auc_results = compare_pr_auc_across_stages(
    model,
    dataloader_test_hsi, 
    [dataloader_early_diseased_hsi, dataloader_mid_diseased_hsi, dataloader_late_diseased_hsi],
    ["Early", "Mid", "Late"], 
    device, 
    error_metric = ERROR_METRIC, 
    error_type=ERROR,
    mask_after=MASK_AFTER, 
    remove_edges=REMOVE_EDGES, 
    mask_resize=SIZE,
    MASK_FOLDER=MASK_FOLDER, 
    save_path = f'{save_path}/{ERROR_METRIC}_AUC_PR.png')

**Reconstruction Error Visualization**

First visualize errors per band and then overall

In [None]:
visu_error_per_band(model, dataloader_train_hsi, device, save_path = f'{save_path}/reconstruction_error_per_band_train.png',
                                        error_type='mae', vmax=0.2, minmax = False)

In [None]:
# Visualization with edges removed: visu_reconsruction_error would add edges
visu_reconstruction_error_edge_removed(model, dataloader_validation_hsi, device, save_path = f'{save_path}/reconstruction_error_validation',
                                       error_type='mae', vmax=0.2, minmax = False)
visu_reconstruction_error_edge_removed(model, dataloader_test_hsi, device, save_path = f'{save_path}/reconstruction_error_test'
                                       , error_type='mae', vmax=0.2, minmax = False)
visu_reconstruction_error_edge_removed(model, dataloader_early_diseased_hsi, device, save_path = f'{save_path}/reconstruction_error_early', 
                                        error_type='mae', vmax=0.2, minmax = False)
visu_reconstruction_error_edge_removed(model, dataloader_mid_diseased_hsi, device, save_path = f'{save_path}/reconstruction_error_mid', 
                                       error_type='mae', vmax=0.2, minmax = False)
visu_reconstruction_error_edge_removed(model, dataloader_late_diseased_hsi, device, save_path = f'{save_path}/reconstruction_error_late', 
                                        error_type='mae', vmax=0.2, minmax = False)

## Alternative Errors

1. Extreme pixel error rate per band
2. Latent space error
3. Disease specific error

To get Extreme pixel error per band:

-> calculate 99% quantile of validation dataset to get thresholds

In [None]:
pix_threshold_band_noedge = get_pixel_threshold_per_band_edge_removal(model, dataloader_validation_hsi, device, MASK_FOLDER, quantile=0.99, error_metric="MAE")
# print the thresholds
print(pix_threshold_band_noedge)

In [None]:
# Calculate simple statistics about the thresholds for inside
minimum       = np.min(pix_threshold_band_noedge)
maximum       = np.max(pix_threshold_band_noedge)
median        = np.percentile(pix_threshold_band_noedge, 50)       # 50th percentile
p25, p75      = np.percentile(pix_threshold_band_noedge, [25, 75]) # 25th and 75th percentiles
other_quants  = np.percentile(pix_threshold_band_noedge, [1, 5, 10, 90, 95, 99])

print(f'minimum: {minimum}, maximum: {maximum}, median: {median}, p25: {p25}, p75: {p75}')
print(other_quants)

minimum: 0.06094030287116785, maximum: 0.17580827474594152, median: 0.08709860444068918, p25: 0.07988026991486558, p75: 0.15381613880395906
[0.06137524 0.06311501 0.0639855  0.1699026  0.17530691 0.175708  ]


### Extreme pixel error rate per band 

In [None]:
ERROR = 'mae'
ERROR_METRIC = "extreme_bands_99"
MASK_AFTER = True
REMOVE_EDGES = True
SIZE = 256
#compare_auc_across_stages(model, healthy_dataloader, diseased_dataloaders, stages_labels, device, thresholds = 0, error_metric = "image_mean", error_type='mse',
                             #mask_after=False, remove_edges=False, mask_resize=256, MASK_FOLDER = "No_folder")
auc_results = compare_auc_across_stages(
    model,
    dataloader_test_hsi,  # Using validation set as healthy reference
    [dataloader_early_diseased_hsi, dataloader_mid_diseased_hsi, dataloader_late_diseased_hsi],
    ["Early", "Mid", "Late"],
    device, 
    threshold_band=pix_threshold_band_noedge,
    error_metric = ERROR_METRIC,
    error_type=ERROR, 
    mask_after=MASK_AFTER, 
    remove_edges=REMOVE_EDGES, 
    mask_resize=SIZE,
    MASK_FOLDER=MASK_FOLDER, 
    save_path = f'{save_path}/{ERROR_METRIC}_AUC_ROC.png')

pr_auc_results = compare_pr_auc_across_stages(
    model,
    dataloader_test_hsi, 
    [dataloader_early_diseased_hsi, dataloader_mid_diseased_hsi, dataloader_late_diseased_hsi],
    ["Early", "Mid", "Late"], 
    device,  
    threshold_band=pix_threshold_band_noedge,
    error_metric = ERROR_METRIC, 
    error_type=ERROR,
    mask_after=MASK_AFTER, 
    remove_edges=REMOVE_EDGES, 
    mask_resize=SIZE,
    MASK_FOLDER=MASK_FOLDER,
    save_path = f'{save_path}/{ERROR_METRIC}_AUC_PR.png')

In [None]:
# Function to conduct the statistical analysis on a band level
def stats_threshold_exceedance_dataset(
    model,
    dataloader,
    device,
    thresholds,
    MASK_FOLDER,
    bands_to_show=None
):
    """
    For each band, compute for *each* image in dataloader the % of valid (non-edge)
    pixels whose abs-reconstruction-error exceeds thresholds[b].
    Returns a dict mapping band -> {
        'mean': float,
        'min': float,
        'max': float,
        'all_percs': np.ndarray of shape (N_images,)
    }
    """
    model.eval()
    C = None

    # initialize a list of lists to collect percentages, one inner list per band
    per_band_percs = []

    with torch.no_grad():
        for imgs, idxs in tqdm(dataloader, desc="Computing stats"):
            # imgs: [B, C, H, W]
            imgs = imgs.to(device)
            recons = model(imgs.unsqueeze(1)).squeeze(1)  # [B, C, H, W]

            B, C, H, W = imgs.shape
            if per_band_percs == []:
                if bands_to_show is None:
                    bands_to_show = list(range(C))
                per_band_percs = [[] for _ in bands_to_show]

            for bi in range(B):
                data_np  = imgs[bi].cpu().numpy()      # [C, H, W]
                recon_np = recons[bi].cpu().numpy()    # [C, H, W]
                idx_img  = idxs[bi].item()

                # load + resize mask & compute valid_mask
                mask_path = os.path.join(
                    MASK_FOLDER,
                    os.path.basename(dataloader.dataset.img_paths[idx_img])
                        .replace('.hdf5','.png')
                )
                mask = Image.open(mask_path).convert('L')
                mask = mask.resize((W, H), resample=Image.NEAREST)
                mask = np.array(mask)
                edge = extract_full_edge(mask)
                valid = (1 - edge).astype(bool)        

                # for each band, compute % exceedance
                for ib, b in enumerate(bands_to_show):
                    err = np.abs(data_np[b] - recon_np[b])
                    valid_err = err[valid]
                    if valid_err.size == 0:
                        perc = 0.0
                    else:
                        exceed = (valid_err > thresholds[b]).sum()
                        perc = 100.0 * exceed / valid_err.size

                    per_band_percs[ib].append(perc)

    results = {}
    for ib, b in enumerate(bands_to_show):
        arr = np.array(per_band_percs[ib])
        results[b] = {
            'mean': float(arr.mean()),
            'min':  float(arr.min()),
            'max':  float(arr.max()),
            'all_percs': arr
        }

    return results

**Compute Mann-Whitney U test**

Shows in which band the distribution of the extreme error rate is signficantly different

In [None]:
bands = list(range(17)) # for SNV model choose 210 here
print(bands)

# 
stats_healthy = stats_threshold_exceedance_dataset(
    model=model,
    dataloader=dataloader_test_hsi,
    device=device,
    thresholds=pix_threshold_band_noedge,           
    MASK_FOLDER=MASK_FOLDER,
    bands_to_show=bands     # or None for all bands
)

stats_early = stats_threshold_exceedance_dataset(
    model=model,
    dataloader=dataloader_early_diseased_hsi,
    device=device,
    thresholds=pix_threshold_band_noedge,          
    MASK_FOLDER=MASK_FOLDER,
    bands_to_show=bands     # or None for all bands
)

stats_mid = stats_threshold_exceedance_dataset(
    model=model,
    dataloader=dataloader_mid_diseased_hsi,
    device=device,
    thresholds=pix_threshold_band_noedge,           
    MASK_FOLDER=MASK_FOLDER,
    bands_to_show=bands   # or None for all bands
)

stats_late = stats_threshold_exceedance_dataset(
    model=model,
    dataloader=dataloader_late_diseased_hsi,
    device=device,
    thresholds=pix_threshold_band_noedge,           
    MASK_FOLDER=MASK_FOLDER,
    bands_to_show=bands     # or None for all bands
)


# apply the Mann-Whitney U test to see if the error rates are significantly different distributed
records = []
labels = ["healthy", "early", "mid", "late"]
dicts = [stats_healthy, stats_early, stats_mid, stats_late]

for band in range(17):
    grp_healthy = stats_healthy[band]['all_percs']
    for label, stats_dict in zip(labels[1:], dicts[1:]):
        grp_sub = stats_dict[band]['all_percs']
        stat, p = mannwhitneyu(grp_healthy, grp_sub, alternative='two-sided')
        p_corr = min(p * 3, 1.0)  # Bonferroni correction for 3 comparisons
        records.append({
            'Band': band,
            'Comparison': f"healthy vs {label}",
            'U': stat,
            'p-value (Bonf)': p_corr
        })

df = pd.DataFrame(records).set_index(['Band', 'Comparison'])


# Get only the bands that are significant
significant = df[df["p-value (Bonf)"]<0.05]
print(significant)

**Visualize Errors above the 99% quantile**

Pixels exceeding the 99% quantiel of a band are shown in red

In [None]:
for j in range(3):
    for i in range(17):
        visualize_threshold_exceedance_edge_removal(
            model=model,
            dataloader=dataloader_test_hsi,
            device=device,
            thresholds=pix_threshold_band_noedge,
            save_path=f'{save_path}/pixel_exceeding_threshold_healthy_image{j}_band{i}.png',
            MASK_FOLDER=MASK_FOLDER,
            # optional:
            band_colors={i: [1,0,0] for i in range(17)},  # all red overlays
            select_img=j,         # first image in the batch
            bands_to_show=[i] # only visualize bands 0,1,2
        )

for j in range(3):
    for i in range(17):
        visualize_threshold_exceedance_edge_removal(
            model=model,
            dataloader=dataloader_early_diseased_hsi,
            device=device,
            thresholds=pix_threshold_band_noedge,
            save_path=f'{save_path}/pixel_exceeding_threshold_early_image{j}_band{i}.png',
            MASK_FOLDER=MASK_FOLDER,
            # optional:
            band_colors={i: [1,0,0] for i in range(17)},  # all red overlays
            select_img=j,         # first image in the batch
            bands_to_show=[i] # only visualize bands 0,1,2
        )

for j in range(3):
    for i in range(17):
        visualize_threshold_exceedance_edge_removal(
            model=model,
            dataloader=dataloader_mid_diseased_hsi,
            device=device,
            thresholds=pix_threshold_band_noedge,
            save_path=f'{save_path}/pixel_exceeding_threshold_mid_image{j}_band{i}.png',
            MASK_FOLDER=MASK_FOLDER,
            # optional:
            band_colors={i: [1,0,0] for i in range(17)},  # all red overlays
            select_img=j,         # first image in the batch
            bands_to_show=[i] # only visualize bands 0,1,2
        )

for j in range(3):
    for i in range(17):
        visualize_threshold_exceedance_edge_removal(
            model=model,
            dataloader=dataloader_late_diseased_hsi,
            device=device,
            thresholds=pix_threshold_band_noedge,
            save_path=f'{save_path}/pixel_exceeding_threshold_late_image{j}_band{i}.png',
            MASK_FOLDER=MASK_FOLDER,
            # optional:
            band_colors={i: [1,0,0] for i in range(17)},  # all red overlays
            select_img=j,         # first image in the batch
            bands_to_show=[i] # only visualize bands 0,1,2
        )

### Latent space error

In [None]:
ERROR = 'mae'
ERROR_METRIC = "latent"
MASK_AFTER = True
REMOVE_EDGES = True
SIZE = 256
#compare_auc_across_stages(model, healthy_dataloader, diseased_dataloaders, stages_labels, device, thresholds = 0, error_metric = "image_mean", error_type='mse',
                             #mask_after=False, remove_edges=False, mask_resize=256, MASK_FOLDER = "No_folder")
auc_results = compare_auc_across_stages(
    model,
    dataloader_test_hsi,  # Using validation set as healthy reference
    [dataloader_early_diseased_hsi, dataloader_mid_diseased_hsi, dataloader_late_diseased_hsi],
    ["Early", "Mid", "Late"],
    device, 
    error_metric = ERROR_METRIC,
    error_type=ERROR, 
    mask_after=MASK_AFTER, 
    remove_edges=REMOVE_EDGES, 
    mask_resize=SIZE,
    MASK_FOLDER=MASK_FOLDER, 
    save_path = f'{save_path}/{ERROR_METRIC}_benfenati_AUC_ROC.png')

pr_auc_results = compare_pr_auc_across_stages(
    model,
    dataloader_test_hsi, 
    [dataloader_early_diseased_hsi, dataloader_mid_diseased_hsi, dataloader_late_diseased_hsi],
    ["Early", "Mid", "Late"], 
    device, 
    error_metric = ERROR_METRIC, 
    error_type=ERROR,
    mask_after=MASK_AFTER, 
    remove_edges=REMOVE_EDGES, 
    mask_resize=SIZE,
    MASK_FOLDER=MASK_FOLDER, 
    save_path = f'{save_path}/{ERROR_METRIC}_benfenati_AUC_PR.png')

### Disease-specific error metric (Vein errors)

In [None]:
ERROR = 'mae'
ERROR_METRIC = "veins"
MASK_AFTER = True
REMOVE_EDGES = True
SIZE = 256
#compare_auc_across_stages(model, healthy_dataloader, diseased_dataloaders, stages_labels, device, thresholds = 0, error_metric = "image_mean", error_type='mse',
                             #mask_after=False, remove_edges=False, mask_resize=256, MASK_FOLDER = "No_folder")
auc_results = compare_auc_across_stages(
    model,
    dataloader_test_hsi,  # Using validation set as healthy reference
    [dataloader_early_diseased_hsi, dataloader_mid_diseased_hsi, dataloader_late_diseased_hsi],
    ["Early", "Mid", "Late"],
    device, 
    error_metric = ERROR_METRIC,
    error_type=ERROR, 
    mask_after=MASK_AFTER, 
    remove_edges=REMOVE_EDGES, 
    mask_resize=SIZE,
    MASK_FOLDER=MASK_FOLDER, 
    save_path = f'{save_path}/{ERROR_METRIC}_AUC_ROC.png')

pr_auc_results = compare_pr_auc_across_stages(
    model,
    dataloader_test_hsi, 
    [dataloader_early_diseased_hsi, dataloader_mid_diseased_hsi, dataloader_late_diseased_hsi],
    ["Early", "Mid", "Late"], 
    device, 
    error_metric = ERROR_METRIC, 
    error_type=ERROR,
    mask_after=MASK_AFTER, 
    remove_edges=REMOVE_EDGES, 
    mask_resize=SIZE,
    MASK_FOLDER=MASK_FOLDER, 
    save_path = f'{save_path}/{ERROR_METRIC}_AUC_PR.png')

**Visualizations**

First visualize vein detection, then the errors in the veins.

Taken from vein_detection.py

In [None]:
#Vein detection
vein_detection(dataloader_validation_hsi, device, MASK_FOLDER, save_path = f'{save_path}/veins_validation_' , plot=True)

In [None]:
label_test = 0
label_disease = 1
visu_vein_error(model, dataloader_validation_hsi, label_test, device, MASK_FOLDER, save_path = f'{save_path}/vein_errors_validation_' , plot=True)
visu_vein_error(model, dataloader_test_hsi, label_test, device, MASK_FOLDER, save_path = f'{save_path}/vein_errors_healthy_' , plot=True)
visu_vein_error(model, dataloader_early_diseased_hsi, label_disease, device, MASK_FOLDER, save_path = f'{save_path}/vein_errors_early_' , plot=True)
visu_vein_error(model, dataloader_mid_diseased_hsi, label_disease, device, MASK_FOLDER, save_path = f'{save_path}/vein_errors_mid_' , plot=True)
visu_vein_error(model, dataloader_late_diseased_hsi, label_disease, device, MASK_FOLDER, save_path = f'{save_path}/vein_errors_late_' , plot=True)

**Shows training and validation loss of model**

In [None]:
plot_losses(train_losses, val_losses, 166, save_path)

## Stacked Model analysis

In [None]:
# Ensure required imports
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score

# Define parameters
ERROR = 'mae'                # 'mae' or 'mse'
ERROR_METRIC = 'image_mean'  # 'image_mean', 'extreme_bands_99', 'latent', or 'veins'
MASK_FOLDER = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_masks/'
SAVE_DIR = save_path         # ensure this path exists

    # Compare ROC AUC (stacked inputs)
auc_results = compare_stacked_auc_across_stages(
        model=model,
        healthy_dataloader=dataloader_test_hsi,
        diseased_dataloaders=[
            dataloader_early_diseased_hsi,
            dataloader_mid_diseased_hsi,
            dataloader_late_diseased_hsi
        ],
        stages_labels=["Early", "Mid", "Late"],
        device=device,
        thresholds=None,
        error_metric=ERROR_METRIC,
        error_type=ERROR,
        MASK_FOLDER=MASK_FOLDER,
        save_path=f'{SAVE_DIR}/{ERROR_METRIC}_AUC_ROC.png'
    )
print("ROC AUC results:", auc_results)

    # Compare PR AUC (stacked inputs)
pr_auc_results = compare_stacked_pr_auc_across_stages(
        model=model,
        healthy_dataloader=dataloader_test_hsi,
        diseased_dataloaders=[
            dataloader_early_diseased_hsi,
            dataloader_mid_diseased_hsi,
            dataloader_late_diseased_hsi
        ],
        stages_labels=["Early", "Mid", "Late"],
        device=device,
        thresholds=None,
        error_metric=ERROR_METRIC,
        error_type=ERROR,
        MASK_FOLDER=MASK_FOLDER,
        save_path=f'{SAVE_DIR}/{ERROR_METRIC}_AUC_PR.png'
    )
print("PR AUC results:", pr_auc_results)


**!!!! EXTREME BANDS RUNS OUT OF MEMORY EASILY**

In [None]:

# Define parameters
ERROR = 'mae'                # 'mae' or 'mse'
ERROR_METRIC = 'extreme_bands_99'  # 'image_mean', 'extreme_bands_99', 'latent', or 'veins'
MASK_FOLDER = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_masks/'
SAVE_DIR = save_path         # ensure this path exists


thresholds = get_stacked_error_thresholds(model, dataloader_validation_hsi, device, MASK_FOLDER)

    # Compare ROC AUC (stacked inputs)
auc_results = compare_stacked_auc_across_stages(
        model=model,
        healthy_dataloader=dataloader_test_hsi,
        diseased_dataloaders=[
            dataloader_early_diseased_hsi,
            dataloader_mid_diseased_hsi,
            dataloader_late_diseased_hsi
        ],
        stages_labels=["Early", "Mid", "Late"],
        device=device,
        thresholds=thresholds,
        error_metric=ERROR_METRIC,
        error_type=ERROR,
        MASK_FOLDER=MASK_FOLDER,
        save_path=f'{SAVE_DIR}/{ERROR_METRIC}_AUC_ROC.png'
    )
print("ROC AUC results:", auc_results)

    # Compare PR AUC (stacked inputs)
pr_auc_results = compare_stacked_pr_auc_across_stages(
        model=model,
        healthy_dataloader=dataloader_test_hsi,
        diseased_dataloaders=[
            dataloader_early_diseased_hsi,
            dataloader_mid_diseased_hsi,
            dataloader_late_diseased_hsi
        ],
        stages_labels=["Early", "Mid", "Late"],
        device=device,
        thresholds=thresholds,
        error_metric=ERROR_METRIC,
        error_type=ERROR,
        MASK_FOLDER=MASK_FOLDER,
        save_path=f'{SAVE_DIR}/{ERROR_METRIC}_AUC_PR.png'
    )
print("PR AUC results:", pr_auc_results)


In [None]:
# Define parameters
ERROR = 'mae'                # 'mae' or 'mse'
ERROR_METRIC = 'latent'  # 'image_mean', 'extreme_bands_99', 'latent', or 'veins'
MASK_FOLDER = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_masks/'
SAVE_DIR = save_path         # ensure this path exists

    # Compare ROC AUC (stacked inputs)
auc_results = compare_stacked_auc_across_stages(
        model=model,
        healthy_dataloader=dataloader_test_hsi,
        diseased_dataloaders=[
            dataloader_early_diseased_hsi,
            dataloader_mid_diseased_hsi,
            dataloader_late_diseased_hsi
        ],
        stages_labels=["Early", "Mid", "Late"],
        device=device,
        thresholds=None,
        error_metric=ERROR_METRIC,
        error_type=ERROR,
        MASK_FOLDER=MASK_FOLDER,
        save_path=f'{SAVE_DIR}/{ERROR_METRIC}_AUC_ROC.png'
    )
print("ROC AUC results:", auc_results)

    # Compare PR AUC (stacked inputs)
pr_auc_results = compare_stacked_pr_auc_across_stages(
        model=model,
        healthy_dataloader=dataloader_test_hsi,
        diseased_dataloaders=[
            dataloader_early_diseased_hsi,
            dataloader_mid_diseased_hsi,
            dataloader_late_diseased_hsi
        ],
        stages_labels=["Early", "Mid", "Late"],
        device=device,
        thresholds=None,
        error_metric=ERROR_METRIC,
        error_type=ERROR,
        MASK_FOLDER=MASK_FOLDER,
        save_path=f'{SAVE_DIR}/{ERROR_METRIC}_AUC_PR.png'
    )
print("PR AUC results:", pr_auc_results)


In [None]:
# Define parameters
ERROR = 'mae'                # 'mae' or 'mse'
ERROR_METRIC = 'veins'  # 'image_mean', 'extreme_bands_99', 'latent', or 'veins'
MASK_FOLDER = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_masks/'
SAVE_DIR = save_path         # ensure this path exists

    # Compare ROC AUC (stacked inputs)
auc_results = compare_stacked_auc_across_stages(
        model=model,
        healthy_dataloader=dataloader_test_hsi,
        diseased_dataloaders=[
            dataloader_early_diseased_hsi,
            dataloader_mid_diseased_hsi,
            dataloader_late_diseased_hsi
        ],
        stages_labels=["Early", "Mid", "Late"],
        device=device,
        thresholds=None,
        error_metric=ERROR_METRIC,
        error_type=ERROR,
        MASK_FOLDER=MASK_FOLDER,
        save_path=f'{SAVE_DIR}/{ERROR_METRIC}_AUC_ROC.png'
    )
print("ROC AUC results:", auc_results)

    # Compare PR AUC (stacked inputs)
pr_auc_results = compare_stacked_pr_auc_across_stages(
        model=model,
        healthy_dataloader=dataloader_test_hsi,
        diseased_dataloaders=[
            dataloader_early_diseased_hsi,
            dataloader_mid_diseased_hsi,
            dataloader_late_diseased_hsi
        ],
        stages_labels=["Early", "Mid", "Late"],
        device=device,
        thresholds=None,
        error_metric=ERROR_METRIC,
        error_type=ERROR,
        MASK_FOLDER=MASK_FOLDER,
        save_path=f'{SAVE_DIR}/{ERROR_METRIC}_AUC_PR.png'
    )
print("PR AUC results:", pr_auc_results)


# MAX threshold

The following shows seperation results if the max validation value is chosen as threshold

In [None]:
# Image wise aggregated reconstruction error
torch.manual_seed(10)
torch.cuda.manual_seed_all(10)
threshold_max = get_recon_error_threshold(model, dataloader_test_hsi, dataloader_early=dataloader_early_diseased_hsi,
                                      dataloader_mid=dataloader_mid_diseased_hsi, dataloader_late=dataloader_late_diseased_hsi, device=device, file_path = save_path)
print(threshold_max)

In [92]:
ERROR = 'mae'
MASK_AFTER = False
REMOVE_EDGES = True   
SIZE = 256

## Outdated Visualizations

**FURTHER VISUALIZATIONS NOT SHOWN IN THE THESIS FOR PERSONAL DATA EXPLORATION**

In [None]:
# model.load_state_dict(torch.load("/home/r0979317/Documents/Thesis_Strawberries/models/third_fc_model.pth", map_location="cuda" if torch.cuda.is_available() else "cpu"))
device = next(model.parameters()).device
print("Model loaded and on device:", device)

datasets = {
    "Healthy": dataloader_validation_hsi,
    "Early Diseased": dataloader_early_diseased_hsi,
    "Mid Diseased": dataloader_mid_diseased_hsi,
    "Late Diseased": dataloader_late_diseased_hsi,
}

# Plot
for name, loader in datasets.items():
    print(f"Processing {name}...")
    errors = get_pixel_reconstruction_errors(model, loader, device)

    # Optional: clip 99th percentile for visualization
    upper = np.percentile(errors, 99)

    # histogram
    plt.figure(figsize=(8, 5))
    plt.hist(errors, bins=100, range=(0, upper), color='skyblue', edgecolor='black', density=True)
    plt.title(f"Reconstruction Error Distribution - {name}")
    plt.xlabel("Pixelwise Reconstruction Error")
    plt.ylabel("Frequency (normalized)")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'{save_path}pixelwise_bar_{name}.png')
    plt.show()

In [None]:
datasets = {
    "Healthy": dataloader_validation_hsi,
    "Early Diseased": dataloader_early_diseased_hsi,
    "Mid Diseased": dataloader_mid_diseased_hsi,
    "Late Diseased": dataloader_late_diseased_hsi,
}
# Step 1: Collect all errors
all_errors = {}
for name, loader in datasets.items():
    print(f"Processing {name}...")
    errors = get_pixel_reconstruction_errors(model, loader, device)
    all_errors[name] = errors

lower_percentile = 5
upper_percentile = 95

# First, determine global lower and upper clipping bounds
clip_lower = min(np.percentile(errors, lower_percentile) for errors in all_errors.values())
clip_upper = max(np.percentile(errors, upper_percentile) for errors in all_errors.values())

error_data = []
labels = []

for name, errors in all_errors.items():
    # Keep only values between clip_lower and clip_upper
    errors_clipped = errors[(errors >= clip_lower) & (errors <= clip_upper)]
    error_data.append(errors_clipped)
    labels.append(name)

# Step 3: Plot violin plot
plt.figure(figsize=(10, 6))
plt.violinplot(error_data, showmeans=True, showmedians=True)
plt.xticks(ticks=np.arange(1, len(labels) + 1), labels=labels, rotation=30)
plt.title("Pixelwise Reconstruction Error Distribution")
plt.ylabel("Pixelwise Reconstruction Error")
plt.grid(True)
plt.tight_layout()
plt.savefig(f'{save_path}pixelwise_violin_quantile_{lower_percentile}_{upper_percentile}.png')
plt.show()

In [None]:
datasets = {
    "Healthy": dataloader_validation_hsi,
    "Early Diseased": dataloader_early_diseased_hsi,
    "Mid Diseased": dataloader_mid_diseased_hsi,
    "Late Diseased": dataloader_late_diseased_hsi,
}
# Step 1: Collect all errors
all_errors = {}
for name, loader in datasets.items():
    print(f"Processing {name}...")
    errors = get_pixel_reconstruction_errors(model, loader, device)
    all_errors[name] = errors

lower_percentile = 95
upper_percentile = 100

# First, determine global lower and upper clipping bounds
clip_lower = min(np.percentile(errors, lower_percentile) for errors in all_errors.values())
clip_upper = max(np.percentile(errors, upper_percentile) for errors in all_errors.values())

error_data = []
labels = []

for name, errors in all_errors.items():
    # Keep only values between clip_lower and clip_upper
    errors_clipped = errors[(errors >= clip_lower) & (errors <= clip_upper)]
    error_data.append(errors_clipped)
    labels.append(name)

# Step 3: Plot violin plot
plt.figure(figsize=(10, 6))
plt.violinplot(error_data, showmeans=True, showmedians=True)
plt.xticks(ticks=np.arange(1, len(labels) + 1), labels=labels, rotation=30)
plt.title("Pixelwise Reconstruction Error Distribution")
plt.ylabel("Pixelwise Reconstruction Error")
plt.grid(True)
plt.tight_layout()
plt.savefig(f'{save_path}pixelwise_violin_quantile_{lower_percentile}_{upper_percentile}.png')
plt.show()

In [None]:
thresholds = get_pixel_threshold_per_band(model, dataloader_early_diseased_hsi, device, quantile=0.75)
scores = classify_leaves_per_band(model, dataloader_early_diseased_hsi, device, thresholds)

# Transpose for plotting (bands × images)
scores_by_band = list(zip(*scores))  # list of 17 lists
# Classify and gather scores per group
all_scores_per_band = []  # will be list of [n_groups x n_images x bands]
for dataloader in dataloader_list:
    scores = classify_leaves_per_band(model, dataloader, device, thresholds)
    all_scores_per_band.append(scores)

# Transpose: bands x groups x images
n_bands = len(thresholds)
for band_idx in range(n_bands):
    plt.figure(figsize=(8, 4))
    
    for group_idx, group_scores in enumerate(all_scores_per_band):
        # Extract scores for this band across all images in the group
        band_scores = [img_scores[band_idx] for img_scores in group_scores]
        plt.scatter(
            [group_labels[group_idx]] * len(band_scores),
            band_scores,
            label=group_labels[group_idx],
            color=colors[group_idx],
            alpha=0.7
        )
        

    #plt.axhline(y=thresholds[band_idx], color='red', linestyle='--', label=f'Threshold ({thresholds[band_idx]:.2f})')
    plt.title(f"Band {band_idx+1} - Per-image Error Scores by Group")
    plt.ylabel("Aggregated Error (Above Threshold)")
    plt.xlabel("Group")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'{save_path}pixelwise_perimage_band{band_idx+1}.png')
    plt.show()


### Pixelwise reconstruction error

#### Plot the pixel-wise error per data group

Plot the pixelwise error and its frequency in  a certain group

In [None]:
# model.load_state_dict(torch.load("/home/r0979317/Documents/Thesis_Strawberries/models/third_fc_model.pth", map_location="cuda" if torch.cuda.is_available() else "cpu"))
device = next(model.parameters()).device
print("Model loaded and on device:", device)

datasets = {
    "Healthy": dataloader_validation_hsi,
    "Early Diseased": dataloader_early_diseased_hsi,
    "Mid Diseased": dataloader_mid_diseased_hsi,
    "Late Diseased": dataloader_late_diseased_hsi,
}

# Plot
for name, loader in datasets.items():
    print(f"Processing {name}...")
    errors = get_pixel_reconstruction_errors(model, loader, device)

    # Optional: clip 99th percentile for visualization
    upper = np.percentile(errors, 99)

    # histogram
    plt.figure(figsize=(8, 5))
    plt.hist(errors, bins=100, range=(0, upper), color='skyblue', edgecolor='black', density=True)
    plt.title(f"Reconstruction Error Distribution - {name}")
    plt.xlabel("Pixelwise Reconstruction Error")
    plt.ylabel("Frequency (normalized)")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'{save_path}pixelwise_bar_{name}.png')
    plt.show()

Plot the same data as a Violin plot

In [None]:
datasets = {
    "Healthy": dataloader_validation_hsi,
    "Early Diseased": dataloader_early_diseased_hsi,
    "Mid Diseased": dataloader_mid_diseased_hsi,
    "Late Diseased": dataloader_late_diseased_hsi,
}
# Step 1: Collect all errors
all_errors = {}
for name, loader in datasets.items():
    print(f"Processing {name}...")
    errors = get_pixel_reconstruction_errors(model, loader, device)
    all_errors[name] = errors

lower_percentile = 5
upper_percentile = 95

# First, determine global lower and upper clipping bounds
clip_lower = min(np.percentile(errors, lower_percentile) for errors in all_errors.values())
clip_upper = max(np.percentile(errors, upper_percentile) for errors in all_errors.values())

error_data = []
labels = []

for name, errors in all_errors.items():
    # Keep only values between clip_lower and clip_upper
    errors_clipped = errors[(errors >= clip_lower) & (errors <= clip_upper)]
    error_data.append(errors_clipped)
    labels.append(name)

# Step 3: Plot violin plot
plt.figure(figsize=(10, 6))
plt.violinplot(error_data, showmeans=True, showmedians=True)
plt.xticks(ticks=np.arange(1, len(labels) + 1), labels=labels, rotation=30)
plt.title("Pixelwise Reconstruction Error Distribution")
plt.ylabel("Pixelwise Reconstruction Error")
plt.grid(True)
plt.tight_layout()
plt.savefig(f'{save_path}pixelwise_violin_quantile_{lower_percentile}_{upper_percentile}.png')
plt.show()

In [None]:
datasets = {
    "Healthy": dataloader_validation_hsi,
    "Early Diseased": dataloader_early_diseased_hsi,
    "Mid Diseased": dataloader_mid_diseased_hsi,
    "Late Diseased": dataloader_late_diseased_hsi,
}
# Step 1: Collect all errors
all_errors = {}
for name, loader in datasets.items():
    print(f"Processing {name}...")
    errors = get_pixel_reconstruction_errors(model, loader, device)
    all_errors[name] = errors

lower_percentile = 95
upper_percentile = 100

# First, determine global lower and upper clipping bounds
clip_lower = min(np.percentile(errors, lower_percentile) for errors in all_errors.values())
clip_upper = max(np.percentile(errors, upper_percentile) for errors in all_errors.values())

error_data = []
labels = []

for name, errors in all_errors.items():
    # Keep only values between clip_lower and clip_upper
    errors_clipped = errors[(errors >= clip_lower) & (errors <= clip_upper)]
    error_data.append(errors_clipped)
    labels.append(name)

# Step 3: Plot violin plot
plt.figure(figsize=(10, 6))
plt.violinplot(error_data, showmeans=True, showmedians=True)
plt.xticks(ticks=np.arange(1, len(labels) + 1), labels=labels, rotation=30)
plt.title("Pixelwise Reconstruction Error Distribution")
plt.ylabel("Pixelwise Reconstruction Error")
plt.grid(True)
plt.tight_layout()
plt.savefig(f'{save_path}pixelwise_violin_quantile_{lower_percentile}_{upper_percentile}.png')
plt.show()

#### Pixel error classification and visualization per image

In [None]:
pix_error = get_pixel_error_threshold(model, dataloader_test_hsi, device, quantile = 0.99)

In [None]:
score_early = classify_leaves_pixel_error_aggregate(model, dataloader_early_diseased_hsi, device, pix_error)
# classify_leaves_pixel_error_mean classifies based on the mean error per image

In [None]:
# All data groups to be considered. If change, change colors and labels as well
dataloader_list = [dataloader_test_hsi, dataloader_early_diseased_hsi, dataloader_mid_diseased_hsi, dataloader_late_diseased_hsi]

In [None]:
errors = [] 
for i in dataloader_list:
    errors.append(classify_leaves_pixel_error_mean(model, i, device, pix_error))

In [None]:
# Colors and labels of the groups
colors = ['green', 'yellow', 'orange', 'red']
group_labels = ['Healthy', 'Early Diseased','Mid Diseased', 'Severely Diseased']

In [None]:
# Overall aggregated error  
plt.figure(figsize=(10, 5))
for i, (scores, color, label) in enumerate(zip(errors, colors, group_labels)):
    x = [i] * len(scores)
    plt.scatter(x, scores, color=color, label=label, alpha=0.7)

plt.xticks(ticks=range(len(group_labels)), labels=group_labels)
plt.axhline(y=pix_error, color='red', linestyle='--', label=f'Threshold ({pix_error:.2f})')
plt.ylabel("Aggregated Error (Above Threshold)")
plt.title("Per-image Error Scores by Group")
plt.legend()
plt.grid(True)
plt.savefig(f'{save_path}pixelwise_image.png')
plt.show()

#### Pixel error classification and visualization per image and band

In [None]:
thresholds = get_pixel_threshold_per_band(model, dataloader_early_diseased_hsi, device, quantile=0.75)
scores = classify_leaves_per_band(model, dataloader_early_diseased_hsi, device, thresholds)

# Transpose for plotting (bands × images)
scores_by_band = list(zip(*scores))  # list of 17 lists

In [None]:
thresholds = get_pixel_threshold_per_band(model, dataloader_early_diseased_hsi, device, quantile=0.75)
scores = classify_leaves_per_band(model, dataloader_early_diseased_hsi, device, thresholds)

# Transpose for plotting (bands × images)
scores_by_band = list(zip(*scores))  # list of 17 lists
# Classify and gather scores per group
all_scores_per_band = []  # will be list of [n_groups x n_images x bands]
for dataloader in dataloader_list:
    scores = classify_leaves_per_band(model, dataloader, device, thresholds)
    all_scores_per_band.append(scores)

# Transpose: bands x groups x images
n_bands = len(thresholds)
for band_idx in range(n_bands):
    plt.figure(figsize=(8, 4))
    
    for group_idx, group_scores in enumerate(all_scores_per_band):
        # Extract scores for this band across all images in the group
        band_scores = [img_scores[band_idx] for img_scores in group_scores]
        plt.scatter(
            [group_labels[group_idx]] * len(band_scores),
            band_scores,
            label=group_labels[group_idx],
            color=colors[group_idx],
            alpha=0.7
        )
        

    #plt.axhline(y=thresholds[band_idx], color='red', linestyle='--', label=f'Threshold ({thresholds[band_idx]:.2f})')
    plt.title(f"Band {band_idx+1} - Per-image Error Scores by Group")
    plt.ylabel("Aggregated Error (Above Threshold)")
    plt.xlabel("Group")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'{save_path}pixelwise_perimage_band{band_idx+1}.png')
    plt.show()


**COMPUTES VIOLIN PLOT PER BAND**

In [None]:
def get_pixel_errors_per_band(model, dataloader, device):
    model.eval()
    all_errors_per_band = []  # list of [C, N_pixels] tensors

    with torch.no_grad():
        for batch, _ in dataloader:
            batch = batch.unsqueeze(1).to(device)  # [B, 1, C, H, W]
            recon = model(batch)                   # [B, 1, C, H, W]
            error = (recon - batch).pow(2).squeeze(1)  # [B, C, H, W]

            # Reshape to [C, B*H*W]
            error_per_band = error.permute(1, 0, 2, 3).reshape(error.shape[1], -1)
            all_errors_per_band.append(error_per_band.cpu())

    # Concatenate along pixel dimension
    all_errors_per_band = torch.cat(all_errors_per_band, dim=1)  # [C, total_pixels]
    return all_errors_per_band  # tensor [C, N_pixels_total]

In [None]:
datasets = {
    "Validation": dataloader_validation_hsi,
    "Early Diseased": dataloader_early_diseased_hsi,
    "Mid Diseased": dataloader_mid_diseased_hsi,
    "Late Diseased": dataloader_late_diseased_hsi,
}

all_pixel_errors_per_band = []
for name, loader in datasets.items():
    print(f"Processing {name}...")
    errors_per_band = get_pixel_errors_per_band(model, loader, device)
    all_pixel_errors_per_band.append(errors_per_band)  # list of [C, N_pixels] tensors


In [None]:
n_bands = all_pixel_errors_per_band[0].shape[0]

for band_idx in range(n_bands):
    plt.figure(figsize=(10, 6))
    
    band_group_errors = []  # for violin plot
    for group_idx, errors_tensor in enumerate(all_pixel_errors_per_band):
        # Get errors for this band
        band_errors = errors_tensor[band_idx].numpy()
        band_group_errors.append(band_errors)

    # Optional: Clip extreme values globally for better visualization
    lower_bound = 5
    upper_bound = 95
    lower_clip = np.percentile(np.concatenate(band_group_errors), lower_bound)
    upper_clip = np.percentile(np.concatenate(band_group_errors), upper_bound)
    
    band_group_errors_clipped = [np.clip(errors, lower_clip, upper_clip) for errors in band_group_errors]

    # Create violin plot
    parts = plt.violinplot(
        band_group_errors_clipped,
        showmeans=True,
        showmedians=True,
        showextrema=True
    )

    # Customize colors
    for idx, pc in enumerate(parts['bodies']):
        pc.set_facecolor(colors[idx])
        pc.set_alpha(0.6)

    plt.xticks(
        ticks=np.arange(1, len(group_labels) + 1),
        labels=group_labels,
        rotation=30
    )
    # Draw threshold once per band
    # plt.axhline(y=thresholds[band_idx], color='red', linestyle='--', label=f'Threshold ({thresholds[band_idx]:.2f})')
    plt.title(f"Band {band_idx+1} - Pixelwise Reconstruction Errors by Group")
    plt.ylabel("Pixelwise Reconstruction Error")
    plt.xlabel("Group")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'{save_path}pixelwise_perimage_violin_band{band_idx+1}_{lower_bound}_{upper_bound}.png')
    plt.show()

In [None]:
error, x_coords, y_coords = visualize_pixel_spectra_combined_3d(
    model, 
    dataloader_early_diseased_hsi, 
    device, 
    n_pixels=100,
    error_type='mae',
    mask_after=True,
    remove_edges=True,
    mask_resize=256
)

### Inference into the latent space (UMAP)

In [34]:
latent_healthy, labels_healthy, images_healthy = get_lat_representations(model, dataloader_train_hsi, device, assigned_label=0)
latent_early, labels_early, images_early = get_lat_representations(model, dataloader_early_diseased_hsi, device, assigned_label=1)
#latent_mid, labels_mid, images_mid = get_lat_representations(model, dataloader_mid_diseased_hsi, device, assigned_label=2)
latent_late, labels_late, images_late = get_lat_representations(model, dataloader_late_diseased_hsi, device, assigned_label=3)

In [None]:
latent_all = np.concatenate([latent_healthy, latent_early, latent_mid, latent_late], axis=0)
labels_all = np.concatenate([labels_healthy, labels_early, labels_mid, labels_late], axis=0)
images_all = torch.cat([images_healthy, images_early, images_mid, images_late], axis=0)

In [None]:
plot_umap_interactive(latent_all, images_all, labels=labels_all, n_neighbors=3, min_dist=0.05) #latent_array, images, labels=None, RGB_bands=[9, 3, 5], image_scale=0.2, click_threshold=1.0, n_neighbors=15, min_dist=0.1