In [1]:
import random
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import numpy as np 
import joblib
from torch.utils.data import DataLoader
from preprocessing import *
from utils import *
from datasets import *
from CNN_AE_helper import *
from CNN3d import *
from torchvision.transforms import v2
from scipy.ndimage import binary_erosion

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# FX10 camera
#IMG_DIR = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_hdf5/'
IMG_DIR = '/home/u0158953/data/Strawberries/PotsprocessedData/cropped_hdf5'
CAMERA = 'FX10'

# Healthy leaves
DATES = ['07SEPT2023', '08SEPT2023', '09SEPT2023', '10SEPT2023', '11SEPT2023', '12SEPT2023',
         '13SEPT2023', '14SEPT2023', '15SEPT2023', '18SEPT2023', '19SEPT2023']
TRAYS = ['4C', '4D', '2D']    # Some files from the FX17 camera are mistakenly named in 2D instead of 4D
healthy_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Early diseased leaves
DATES = ['07SEPT2023']
TRAYS = ['3C']
early_diseased_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Mid diseased leaves
DATES = ['08SEPT2023', '09SEPT2023']
TRAYS = ['3C']
mid_diseased_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Late diseased leaves
DATES = ['10SEPT2023', '11SEPT2023', '12SEPT2023', '13SEPT2023', '14SEPT2023', '15SEPT2023']
TRAYS = ['3C']
late_diseased_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Number of samples in each category
print(f'Healthy: {len(healthy_FX10)}')
print(f'Early diseased: {len(early_diseased_FX10)}')
print(f'Mid diseased: {len(mid_diseased_FX10)}')
print(f'Late diseased: {len(late_diseased_FX10)}')

Healthy: 221
Early diseased: 12
Mid diseased: 25
Late diseased: 30


In [4]:
train, test = train_test_split(healthy_FX10, test_size=0.20)
train, validation = train_test_split(train, test_size=0.2)

In [18]:
 # For now these only apply for the HSI data, but eventually we will only need that
INPUT_DATA = healthy_FX10[0:10]    # [0:40] just to speed up the process for now
MASK_FOLDER = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_masks'
# MASK_FOLDER = "/home/u0158953/data/Strawberries/PotsprocessedData/cropped_masks"
BATCH_SIZE = 8
MASK_METHOD = 1    # 0 for only leaf, 1 for leaf+stem
BAND_SELECTION = [489.3, 505.1, 542.21, 550.2, 558.21, 582.31, 625.4, 660.62, 674.2, 679.64,
                  701.44, 717.81, 736.94, 745.15, 783.52, 849.54, 951.83]    # Important wavelengths obtained from pca_bandselect.ipynb
POLYORDER = 2
WINDOW_LENGTH = 4 
PREPROCESS_METHOD = "normal"
#SCALER = joblib.load('models/pca/scaler_healthy.joblib')    # Only needed if we want to transform data to PCA space
#PCA = joblib.load('models/pca/pca_model_healthy.joblib')    # Only needed if we want to transform data to PCA space
device = 'cpu'
print(f'Using {device} device')
print(f'Total number of GPUs: {torch.cuda.device_count()}')

Using cpu device
Total number of GPUs: 0


In [19]:
# Define data augmentations
train_transforms = v2.Compose([
    v2.ToImage(),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomRotation(degrees=(0, 180), interpolation=v2.InterpolationMode.BILINEAR),
    v2.Resize((256, 256)),    # By default this uses bilinear interpolation which is good.
    v2.ToDtype(torch.float32, scale=True),
])

test_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize((256, 256)),
    v2.ToDtype(torch.float32, scale=True),
]) 

In [20]:
# Create Dataset and DataLoader
dataset_train_hsi = HsiDataset(train, MASK_FOLDER, transform=train_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               polyorder=POLYORDER, window_length=WINDOW_LENGTH, preprocess_method = "normal")
dataloader_train_hsi = DataLoader(dataset_train_hsi, batch_size=BATCH_SIZE, shuffle=True, collate_fn=padded_collate)

OSError: [Errno 22] Unable to synchronously open file (file read failed: time = Tue Apr 22 12:41:55 2025
, filename = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_hdf5/FX10_13SEPT2023_4D5_3.hdf5', file descriptor = 3, errno = 22, error message = 'Invalid argument', buf = 00000046631EC180, total read size = 8, bytes this sub-read = 8, bytes actually read = 18446744073709551615, offset = 0)

In [21]:
# Create Dataset and DataLoader
dataset_train_hsi_normal = HsiDataset(train, MASK_FOLDER, transform=train_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=False, selected_bands=None, 
                               polyorder=POLYORDER, window_length=WINDOW_LENGTH, preprocess_method = "normal")
dataloader_train_hsi_normal = DataLoader(dataset_train_hsi_normal, batch_size=BATCH_SIZE, shuffle=True, collate_fn=padded_collate)

OSError: [Errno 22] Unable to synchronously open file (file read failed: time = Tue Apr 22 12:42:03 2025
, filename = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_hdf5/FX10_13SEPT2023_4D5_3.hdf5', file descriptor = 3, errno = 22, error message = 'Invalid argument', buf = 00000046631EC180, total read size = 8, bytes this sub-read = 8, bytes actually read = 18446744073709551615, offset = 0)

In [22]:
# Create Dataset and DataLoader
dataset_train_hsi_savgol = HsiDataset(train, MASK_FOLDER, transform=train_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD)
dataloader_train_hsi_savgol = DataLoader(dataset_train_hsi_savgol, batch_size=BATCH_SIZE, shuffle=True, collate_fn=padded_collate)

OSError: [Errno 22] Unable to synchronously open file (file read failed: time = Tue Apr 22 12:42:04 2025
, filename = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_hdf5/FX10_13SEPT2023_4D5_3.hdf5', file descriptor = 3, errno = 22, error message = 'Invalid argument', buf = 00000046631EC180, total read size = 8, bytes this sub-read = 8, bytes actually read = 18446744073709551615, offset = 0)

In [23]:
for data, _ in dataloader_train_hsi_normal:  # assuming (images, labels)
    # data shape: (B, C, H, W)
    batch_size = data.shape[0]
    num_to_plot = min(10, batch_size)  # in case batch size < 10

    plt.figure(figsize=(12, 6))

    for i in range(num_to_plot):
        image = data[i]  # shape: (C, H, W)

        # Convert to numpy
        image = image.detach().cpu().numpy()

        # Mean spectrum across spatial dims
        mean_spectrum = image.mean(axis=(1, 2))

        # Plot
        plt.plot(mean_spectrum, label=f"Image {i}")

    plt.xlabel("Wavelength Index")
    plt.ylabel("Mean Intensity")
    plt.title("Mean Spectra of 10 Images")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    break  # only do the first batch

NameError: name 'dataloader_train_hsi_normal' is not defined

In [24]:
# Create Dataset and DataLoader
dataset_validation_hsi = HsiDataset(validation, MASK_FOLDER, transform=train_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD)
dataloader_validation_hsi = DataLoader(dataset_validation_hsi, batch_size=BATCH_SIZE, shuffle=True, collate_fn=padded_collate)

OSError: [Errno 22] Unable to synchronously open file (file read failed: time = Tue Apr 22 12:42:06 2025
, filename = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_hdf5/FX10_09SEPT2023_4D2_0.hdf5', file descriptor = 3, errno = 22, error message = 'Invalid argument', buf = 00000046631EC180, total read size = 8, bytes this sub-read = 8, bytes actually read = 18446744073709551615, offset = 0)

In [25]:
# Create Dataset and DataLoader
dataset_test_hsi = HsiDataset(test, MASK_FOLDER, transform=train_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD)
dataloader_test_hsi = DataLoader(dataset_test_hsi, batch_size=BATCH_SIZE, shuffle=True, collate_fn=padded_collate)

OSError: [Errno 22] Unable to synchronously open file (file read failed: time = Tue Apr 22 12:42:06 2025
, filename = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_hdf5/FX10_15SEPT2023_4D4_5.hdf5', file descriptor = 3, errno = 22, error message = 'Invalid argument', buf = 00000046631EC180, total read size = 8, bytes this sub-read = 8, bytes actually read = 18446744073709551615, offset = 0)

In [26]:
# Early diseased
dataset_early_diseased_hsi = HsiDataset(early_diseased_FX10, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD)
dataloader_early_diseased_hsi = DataLoader(dataset_early_diseased_hsi, batch_size=BATCH_SIZE, shuffle=False)

# Mid diseased
dataset_mid_diseased_hsi = HsiDataset(mid_diseased_FX10, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD)
dataloader_mid_diseased_hsi = DataLoader(dataset_mid_diseased_hsi, batch_size=BATCH_SIZE, shuffle=False)

# Late diseased
dataset_late_diseased_hsi = HsiDataset(late_diseased_FX10, MASK_FOLDER, transform=test_transforms, 
                               apply_mask=True, mask_method=MASK_METHOD, min_wavelength=430, normalize=True, selected_bands=BAND_SELECTION, 
                               polyorder=POLYORDER, deriv = 2, window_length=WINDOW_LENGTH, preprocess_method = PREPROCESS_METHOD)
dataloader_late_diseased_hsi = DataLoader(dataset_late_diseased_hsi, batch_size=BATCH_SIZE, shuffle=False)

OSError: [Errno 22] Unable to synchronously open file (file read failed: time = Tue Apr 22 12:42:07 2025
, filename = 'C:/Users/leonw/OneDrive - KU Leuven/Master Thesis/Data_cropped/cropped_hdf5/FX10_07SEPT2023_3C1_2.hdf5', file descriptor = 3, errno = 22, error message = 'Invalid argument', buf = 00000046631EC180, total read size = 8, bytes this sub-read = 8, bytes actually read = 18446744073709551615, offset = 0)

In [None]:
import cv2
attn_masks =[]
for img in batch:
    img = img.numpy() # shape (B, H, W)
    # compute spectral std and edges as before
    spectral_std = img.std(axis=0)
    gray = img.mean(axis=0)
    gray_uint8 = (gray/gray.max()*255).astype(np.uint8)
    edges = cv2.Canny(gray_uint8, 50, 150).astype(np.float32)
    # combine
    spectral_std_norm = spectral_std / (spectral_std.max() + 1e-8)
    edges_norm = edges / 255.0
    attn = spectral_std_norm + edges_norm
    attn = attn / (attn.max() + 1e-8)
    attn_masks.append(torch.tensor(attn, dtype=torch.float32))

In [None]:
import numpy as np
from scipy.ndimage import gaussian_filter, gaussian_filter1d

def compute_hessian_eigenvalues(image, sigma=1.0):
    # Gaussian second derivatives
    Dxx = gaussian_filter(image, sigma=sigma, order=(2, 0))
    Dyy = gaussian_filter(image, sigma=sigma, order=(0, 2))
    Dxy = gaussian_filter(image, sigma=sigma, order=(1, 1))

    # Eigenvalue computation
    trace = Dxx + Dyy
    det = Dxx * Dyy - Dxy ** 2
    temp = np.sqrt((Dxx - Dyy)**2 + 4 * Dxy**2)
    
    lambda1 = 0.5 * (Dxx + Dyy + temp)
    lambda2 = 0.5 * (Dxx + Dyy - temp)

    # Major eigenvalue at each pixel
    major = np.maximum(lambda1, lambda2)
    return major, lambda1, lambda2

# Example usage:
# If your hyperspectral image is `hsi` with shape (248, 256, 256):
def apply_hessian_to_hsi(hsi_cube, sigma=1.0):
    eigen_cube = []
    for i in range(hsi_cube.shape[0]):
        major, _, _ = compute_hessian_eigenvalues(hsi_cube[i], sigma)
        eigen_cube.append(major)
    return np.stack(eigen_cube)

import numpy as np
from scipy.ndimage import gaussian_filter, gaussian_filter1d

def steger_line_detection(image, sigma=1.0):
    # 1. Compute first and second derivatives
    Ix = gaussian_filter(image, sigma=sigma, order=(1, 0))
    Iy = gaussian_filter(image, sigma=sigma, order=(0, 1))

    Ixx = gaussian_filter(image, sigma=sigma, order=(2, 0))
    Iyy = gaussian_filter(image, sigma=sigma, order=(0, 2))
    Ixy = gaussian_filter(image, sigma=sigma, order=(1, 1))

    # 2. Compute eigenvalues and eigenvectors of Hessian matrix
    tmp = np.sqrt((Ixx - Iyy) ** 2 + 4 * Ixy ** 2)
    lambda1 = 0.5 * (Ixx + Iyy + tmp)
    lambda2 = 0.5 * (Ixx + Iyy - tmp)

    # 3. Determine the direction of the eigenvector associated with lambda2
    vx = 2 * Ixy
    vy = Iyy - Ixx + tmp
    norm = np.sqrt(vx ** 2 + vy ** 2)
    vx /= norm
    vy /= norm

    # 4. Interpolate to subpixel accuracy (zero-crossing along gradient direction)
    t = -(Ix * vx + Iy * vy) / (lambda2 * (vx ** 2 + vy ** 2))
    x_subpixel = np.clip(np.arange(image.shape[1])[None, :] + t * vx, 0, image.shape[1] - 1)
    y_subpixel = np.clip(np.arange(image.shape[0])[:, None] + t * vy, 0, image.shape[0] - 1)

    # 5. Output line strength (absolute lambda2) and subpixel position
    line_strength = np.abs(lambda2)
    return line_strength, x_subpixel, y_subpixel

# Apply across all hyperspectral bands
def apply_steger_to_hsi(hsi_cube, sigma=1.0):
    strength_stack = []
    for i in range(hsi_cube.shape[0]):
        strength, _, _ = steger_line_detection(hsi_cube[i], sigma)
        strength_stack.append(strength)
    return np.stack(strength_stack)

In [None]:
# Suppose your image is named hsi_cube
import matplotlib.pyplot as plt
for img in batch:
    img = img.numpy() # shape (B, H, W)
    hsi_np = img.astype(np.float32)
    # Apply Hessian-based major eigenvalue method
    hessian_response = apply_hessian_to_hsi(hsi_np, sigma=1.0)  # shape: (248, 256, 256)

    # Apply Steger's method
    steger_response = apply_steger_to_hsi(hsi_np, sigma=1.0)  # shape: (248, 256, 256)
    ridge_map = np.mean(hessian_response, axis=0)
    line_map = np.mean(steger_response, axis=0)
    # same plot code as above
    norm_ridge_map = minmax_scale(ridge_map.ravel()).reshape(ridge_map.shape)
    norm_line_map = minmax_scale(line_map.ravel()).reshape(line_map.shape)
    threshold = 0.6  # tune this!
    binary_mask = norm_line_map > threshold

    rgb = img[[9, 3, 5]]  # Pick bands for RGB
    rgb -= rgb.min()
    rgb /= rgb.max()
    rgb = rgb.transpose(1, 2, 0)  # [H, W, 3]

    # Plot
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(rgb)
    plt.title("Input RGB Composite")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(binary_mask, cmap='hot')
    plt.colorbar()
    plt.title("Thresholded Line Map")
    plt.axis('off')

    plt.tight_layout()
    plt.show()    

In [None]:
# same plot code as above
ridge_map = np.mean(hessian_response, axis=0)
line_map = np.mean(steger_response, axis=0)
# same plot code as above

In [None]:
from sklearn.preprocessing import minmax_scale

norm_ridge_map = minmax_scale(ridge_map.ravel()).reshape(ridge_map.shape)
norm_line_map = minmax_scale(line_map.ravel()).reshape(line_map.shape)

In [None]:
threshold = 0.09  # tune this!
binary_mask = norm_line_map > threshold

plt.imshow(binary_mask, cmap='gray')
plt.title("Thresholded Line Map")
plt.axis('off')
plt.show()

In [None]:
#mask_path = os.path.join(MASK_FOLDER, os.path.basename(image_path).replace('.hdf5', '.png'))
mask = read_mask('/home/u0158953/data/Strawberries/PotsprocessedData/cropped_masks/FX10_10SEPT2023_4D2_1.png')

In [None]:
import numpy as np
from scipy.ndimage import binary_erosion

def extract_edge(mask, edge_width=2):
    # Ensure the mask is boolean
    mask = mask.astype(bool)
    eroded = mask.copy()
    
    for _ in range(edge_width):
        eroded = binary_erosion(eroded)
    
    # XOR between original and eroded to get edge
    edge = np.logical_xor(mask, eroded)
    
    return edge.astype(np.uint8)  # Optional: return as uint8 for visualization

In [None]:
edge_line = extract_edge(mask)

In [None]:
hsi_np = LoadHSI('/home/u0158953/data/Strawberries/PotsprocessedData/cropped_hdf5/FX10_10SEPT2023_4D2_1.hdf5', return_wlens=False)
hsi_np = hsi_np * np.where(mask == 2, 1, mask)

In [None]:
rgb = hsi_np[[9, 3, 5]]  # Pick bands for RGB
rgb -= rgb.min()
rgb /= rgb.max()
rgb = rgb.transpose(1, 2, 0)  # [H, W, 3]

plt.imshow(edge_line, cmap='gray')
plt.title("Thresholded Line Map")
plt.axis('off')
plt.show()






plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(rgb)
plt.title("Input RGB Composite")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(edge_line, cmap='hot')
plt.colorbar()
plt.title("Thresholded Line Map")
plt.axis('off')

plt.tight_layout()
plt.show()    

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

# --- Function to compute pixelwise error for an entire dataloader ---
def get_reconstruction_errors(model, dataloader, device):
    model.eval()
    all_errors = []

    with torch.no_grad():
        for batch,_ in dataloader:
            batch = batch.unsqueeze(1).to(device)  # shape: [B, 1, C, H, W]
            recon = model(batch)                  # same shape
            error = (recon - batch).pow(2).sum(dim=2)  # sum over spectral bands (C)
            # result: [B, 1, H, W]
            error = error.squeeze(1)  # now [B, H, W]
            all_errors.append(error.cpu().flatten())

    return torch.cat(all_errors).numpy()


# --- Load trained model ---
model = CNN3DAEFC(layers_list=[18,32,32,64,64,128], input_dim=1, kernel_sizes=3, strides=(1, 2, 2), paddings=1, z_dim=128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Load the checkpoint
model, optimizer, train_losses, val_losses, last_epoch = load_checkpoint(model, '/home/r0979317/Documents/Thesis_Strawberries/models/third_fc_model.pth', device, optimizer)
# model.load_state_dict(torch.load("/home/r0979317/Documents/Thesis_Strawberries/models/third_fc_model.pth", map_location="cuda" if torch.cuda.is_available() else "cpu"))
device = next(model.parameters()).device
print("✅ Model loaded and moved to:", device)

print("✅ Model loaded and on device:", device)

# --- Datasets to process ---
datasets = {
    "Validation": dataloader_validation_hsi,
    "Early Diseased": dataloader_early_diseased_hsi,
    "Mid Diseased": dataloader_mid_diseased_hsi,
    "Late Diseased": dataloader_late_diseased_hsi,
}

# --- Compute and plot separately ---
for name, loader in datasets.items():
    print(f"Processing {name}...")
    errors = get_reconstruction_errors(model, loader, device)

    # Optional: clip 99th percentile for visualization
    upper = np.percentile(errors, 99)

    # --- Plot individual histogram ---
    plt.figure(figsize=(8, 5))
    plt.hist(errors, bins=100, range=(0, upper), color='skyblue', edgecolor='black', density=True)
    plt.title(f"Reconstruction Error Distribution - {name}")
    plt.xlabel("Pixelwise Reconstruction Error")
    plt.ylabel("Frequency (normalized)")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
class CNN3DAEMAX_try(nn.Module):
    def __init__(
        self,
        layers_list=[32, 64, 64, 128],
        input_dim=1,
        kernel_sizes=3,
        strides=(1, 2, 2),
        paddings=1
    ):
        super(CNN3DAEMAX_try, self).__init__()
        self.layers_list = layers_list

        # Encoder
        encoder = []
        in_channels = input_dim
        for out_channels in layers_list:
            encoder.append(nn.Conv3d(in_channels, out_channels, kernel_size=kernel_sizes, stride=1, padding=paddings))
            encoder.append(nn.MaxPool3d(kernel_sizes, stride=strides))
            encoder.append(nn.ReLU(inplace=True))
            in_channels = out_channels
        self.encoder = nn.Sequential(*encoder)

        # Get output shape of encoder
        with torch.no_grad():
            dummy_input = torch.zeros(1, input_dim, 17, 256, 256)
            dummy_output = self.encoder(dummy_input)
            self.feature_shape = dummy_output.shape[1:]  # (C, D, H, W)

        # Decoder (reversed)
        decoder = []
        rev_layers = layers_list[::-1]
        in_channels = rev_layers[0]

        for out_channels in rev_layers[1:]:
            decoder.append(nn.ConvTranspose3d(
                in_channels,
                out_channels,
                kernel_size=kernel_sizes,
                stride=strides,
                padding=paddings,
                output_padding=(0, 1, 1)  # helps recover 256 from downsampling
            ))
            decoder.append(nn.ReLU(inplace=True))
            in_channels = out_channels

        # Final layer to bring channels back to input_dim (e.g., 1)
        decoder.append(nn.ConvTranspose3d(
            in_channels,
            input_dim,
            kernel_size=kernel_sizes,
            stride=strides,
            padding=paddings,
            output_padding=(0, 1, 1)
        ))
        self.decoder = nn.Sequential(*decoder)

    def forward(self, x):
        out = self.encoder(x)
        out = self.decoder(out)
        out = F.interpolate(out, size=(17, 256, 256), mode="trilinear", align_corners=False)
        return out

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN3DAE_VarKernels(nn.Module):
    def __init__(
        self,
        layers_list=[18, 32, 32, 64],
        input_dim=1,
        strides=(1, 2, 2),
        paddings=(0, 1, 1)
    ):
        super(CNN3DAE_VarKernels, self).__init__()
        self.layers_list = layers_list
        self.strides = strides
        self.paddings = paddings

        # Define varying kernel sizes for encoder
        self.encoder_kernel_sizes = [(1, 3, 3), (1, 3, 3), (3, 3, 3), (3, 3, 3)]
        self.decoder_kernel_sizes = self.encoder_kernel_sizes[::-1]

        # --- Encoder ---
        self.enc_blocks = nn.ModuleList()
        in_channels = input_dim
        for out_channels, ks in zip(layers_list, self.encoder_kernel_sizes):
            block = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=ks, stride=strides, padding=self.paddings),
                nn.ReLU(inplace=True)
            )
            self.enc_blocks.append(block)
            in_channels = out_channels

        # --- Decoder ---
        rev_layers = layers_list[::-1]
        self.dec_blocks = nn.ModuleList()
        for idx, (in_channels, out_channels, ks) in enumerate(zip(rev_layers, rev_layers[1:], self.decoder_kernel_sizes[1:])):
            block = nn.Sequential(
                nn.ConvTranspose3d(
                    in_channels, out_channels,
                    kernel_size=ks, stride=strides,
                    padding=self.paddings, output_padding=(0, 1, 1)
                ),
                nn.ReLU(inplace=True)
            )
            self.dec_blocks.append(block)

        # Final output layer
        self.final_layer = nn.ConvTranspose3d(
            rev_layers[-1], input_dim,
            kernel_size=self.decoder_kernel_sizes[-1], stride=strides,
            padding=self.paddings, output_padding=(0, 1, 1)
        )

    def _encode(self, x):
        features = []
        for block in self.enc_blocks:
            x = block(x)
            features.append(x)
        return features

    def forward(self, x):
        enc_feats = self._encode(x)
        x = enc_feats[-1]

        for idx, block in enumerate(self.dec_blocks):
            x = block(x)
            if idx == len(self.dec_blocks) - 1:  # last decoder block before final layer
                x = x + F.interpolate(enc_feats[0], size=x.shape[2:], mode="trilinear", align_corners=False)

        x = self.final_layer(x)

        # Resize to original input shape
        x = F.interpolate(x, size=(17, 256, 256), mode='trilinear', align_corners=False)
        return x

In [None]:
class CNN3DAE_TightDropout(nn.Module):
    def __init__(
        self,
        layers_list=[18, 32, 32, 64],
        input_dim=1,
        strides=(1, 2, 2),
        paddings=(0, 1, 1),
        dropout_p=0.2
    ):
        super(CNN3DAE_TightDropout, self).__init__()

        self.strides = strides
        self.paddings = paddings
        self.dropout_p = dropout_p

        # Kernel sizes per encoder layer
        self.encoder_kernel_sizes = [(1, 3, 3), (1, 3, 3), (3, 3, 3), (3, 3, 3)]
        self.decoder_kernel_sizes = self.encoder_kernel_sizes[::-1]

        # === Encoder ===
        self.enc_blocks = nn.ModuleList()
        in_channels = input_dim
        for out_channels, ks in zip(layers_list, self.encoder_kernel_sizes):
            self.enc_blocks.append(
                nn.Sequential(
                    nn.Conv3d(in_channels, out_channels, kernel_size=ks, stride=strides, padding=paddings),
                    nn.ReLU(inplace=True)
                )
            )
            in_channels = out_channels

        # Extra encoder bottleneck block for tighter latent space
        self.bottleneck_encoder = nn.Sequential(
            nn.Conv3d(in_channels, in_channels, kernel_size=(3, 3, 3), stride=strides, padding=paddings),
            nn.ReLU(inplace=True),
            nn.Dropout3d(p=dropout_p)
        )

        # === Decoder ===
        rev_layers = layers_list[::-1]
        self.dec_blocks = nn.ModuleList()
        for in_ch, out_ch, ks in zip(rev_layers, rev_layers[1:], self.decoder_kernel_sizes[1:]):
            self.dec_blocks.append(
                nn.Sequential(
                    nn.ConvTranspose3d(
                        in_ch, out_ch,
                        kernel_size=ks, stride=strides,
                        padding=paddings, output_padding=(0, 1, 1)
                    ),
                    nn.ReLU(inplace=True)
                )
            )

        # Extra decoder block to match bottleneck layer
        self.bottleneck_decoder = nn.Sequential(
            nn.ConvTranspose3d(
                rev_layers[0], rev_layers[0],
                kernel_size=(3, 3, 3), stride=strides,
                padding=paddings, output_padding=(0, 1, 1)
            ),
            nn.ReLU(inplace=True)
        )

        # Final output layer
        self.final_layer = nn.ConvTranspose3d(
            rev_layers[-1], input_dim,
            kernel_size=self.decoder_kernel_sizes[-1], stride=strides,
            padding=paddings, output_padding=(0, 1, 1)
        )

    def _encode(self, x):
        features = []
        for block in self.enc_blocks:
            x = block(x)
            features.append(x)
        x = self.bottleneck_encoder(x)
        return features, x

    def forward(self, x):
        enc_feats, x = self._encode(x)
        x = self.bottleneck_decoder(x)

        for idx, block in enumerate(self.dec_blocks):
            x = block(x)
            if idx == len(self.dec_blocks) - 1:
                # Final skip connection (after full spatial upsampling)
                skip = F.interpolate(enc_feats[0], size=x.shape[2:], mode="trilinear", align_corners=False)
                x = x + skip

        x = self.final_layer(x)
        x = F.interpolate(x, size=(17, 256, 256), mode='trilinear', align_corners=False)
        return x

In [None]:
class CNN3DAE_Strict(nn.Module):
    def __init__(
        self,
        layers_list=[18, 32, 32, 64],
        input_dim=1,
        strides=(1, 2, 2),
        paddings=(0, 1, 1),
        dropout_p=0.2
    ):
        super(CNN3DAE_Strict, self).__init__()

        self.strides = strides
        self.paddings = paddings
        self.dropout_p = dropout_p

        # Kernel sizes per encoder layer
        self.encoder_kernel_sizes = [(1, 3, 3), (1, 3, 3), (3, 3, 3), (3, 3, 3)]
        self.decoder_kernel_sizes = self.encoder_kernel_sizes[::-1]

        # === Encoder ===
        self.enc_blocks = nn.ModuleList()
        in_channels = input_dim
        for out_channels, ks in zip(layers_list, self.encoder_kernel_sizes):
            self.enc_blocks.append(
                nn.Sequential(
                    nn.Conv3d(in_channels, out_channels, kernel_size=ks, stride=strides, padding=paddings),
                    nn.ReLU(inplace=True)
                )
            )
            in_channels = out_channels

        # Extra encoder bottleneck block for tighter latent space
        self.bottleneck_encoder = nn.Sequential(
            nn.Conv3d(in_channels, in_channels, kernel_size=(3, 3, 3), stride=strides, padding=paddings),
            nn.ReLU(inplace=True),
            nn.Dropout3d(p=dropout_p),
            nn.Conv3d(in_channels, in_channels, kernel_size=(3, 3, 3), stride=strides, padding=paddings),
            nn.ReLU(inplace=True),
            nn.Dropout3d(p=dropout_p)
        )

        # === Decoder ===
        rev_layers = layers_list[::-1]
        self.dec_blocks = nn.ModuleList()
        for in_ch, out_ch, ks in zip(rev_layers, rev_layers[1:], self.decoder_kernel_sizes[1:]):
            self.dec_blocks.append(
                nn.Sequential(
                    nn.ConvTranspose3d(
                        in_ch, out_ch,
                        kernel_size=ks, stride=strides,
                        padding=paddings, output_padding=(0, 1, 1)
                    ),
                    nn.ReLU(inplace=True)
                )
            )

        # Extra decoder block to match bottleneck layer
        self.bottleneck_decoder = nn.Sequential(
            nn.ConvTranspose3d(
                rev_layers[0], rev_layers[0],
                kernel_size=(3, 3, 3), stride=strides,
                padding=paddings, output_padding=(0, 1, 1)
            ),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(
                rev_layers[0], rev_layers[0],
                kernel_size=(3, 3, 3), stride=strides,
                padding=paddings, output_padding=(0, 1, 1)
            ),
            nn.ReLU(inplace=True)
        )

        # Final output layer
        self.final_layer = nn.ConvTranspose3d(
            rev_layers[-1], input_dim,
            kernel_size=self.decoder_kernel_sizes[-1], stride=strides,
            padding=paddings, output_padding=(0, 1, 1)
        )

    def _encode(self, x):
        features = []
        for block in self.enc_blocks:
            x = block(x)
            features.append(x)
        x = self.bottleneck_encoder(x)
        return features, x

    def forward(self, x):
        enc_feats, x = self._encode(x)
        x = self.bottleneck_decoder(x)

        for idx, block in enumerate(self.dec_blocks):
            x = block(x)
            if idx == len(self.dec_blocks) - 1:
                # Final skip connection (after full spatial upsampling)
                skip = F.interpolate(enc_feats[0], size=x.shape[2:], mode="trilinear", align_corners=False)
                x = x + skip

        x = self.final_layer(x)
        x = F.interpolate(x, size=(17, 256, 256), mode='trilinear', align_corners=False)
        return x

In [None]:
from pytorch_msssim import ssim

def hybrid_loss(output, target, alpha=0.8, beta=0.01):
    mse = F.mse_loss(output, target)
    ssim_loss = 1 - ssim(output, target, data_range=1.0, size_average=True)

    # Optional spectral smoothness regularization
    spectral_diff = (output[:, :, 1:, :, :] - output[:, :, :-1, :, :]) ** 2
    spectral_smooth = spectral_diff.mean()

    return alpha * mse + (1 - alpha) * ssim_loss + beta * spectral_smooth

In [None]:
# Now instantiate the model and the trainer
torch.manual_seed(10)
torch.cuda.manual_seed_all(10)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Loss function ---
def hybrid_loss(output, target, alpha=0.8, beta=0.01):
    mse = nn.functional.mse_loss(output, target)
    ssim_loss = 1 - ssim(output, target, data_range=1.0, size_average=True)

    # Optional spectral smoothness penalty
    spectral_diff = (output[:, :, 1:, :, :] - output[:, :, :-1, :, :]) ** 2
    spectral_smooth = spectral_diff.mean()

    return alpha * mse + (1 - alpha) * ssim_loss + beta * spectral_smooth

# --- Instantiate model ---
model = CNN3DAE_Strict(
    layers_list=[18, 32, 64, 64],
    input_dim=1,
    strides=(1, 2, 2),
    paddings=(0, 1, 1)
).to(device)

# --- Optimizer ---
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)


# --- Train ---
train_losses, val_losses = train_autoencoder(
    200,
    model,
    dataloader_train_hsi,
    device,
    dataloader_validation_hsi,
    criterion=hybrid_loss,
    optimizer=optimizer,
    save_model=True,
    save_path='/home/r0979317/Documents/Thesis_Strawberries/models/first_strict_model.pth'
)

In [None]:
# Define the model architecture and optimizer again
model = CNN3DAE_TightDropout(
        layers_list=[18, 32, 64, 64],
        input_dim=1,
        strides=(1, 2, 2),
        paddings=(0, 1, 1)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Load the checkpoint
model, optimizer, train_losses, val_losses, last_epoch = load_checkpoint(model, '/home/r0979317/Documents/Thesis_Strawberries/models/first_dropout_model.pth', device, optimizer)

In [None]:
from scipy.ndimage import binary_erosion

In [None]:
model.eval()
with torch.no_grad():
    for images, indices in dataloader_validation_hsi:
        images = images.unsqueeze(1)
        outputs = model(images)
        recon_error = torch.abs(images - outputs)
        print(recon_error.shape)

        for i, idx in enumerate(indices):
            print(idx)
            edge_mask = dataset_validation_hsi.get_edge_mask(idx).float()  # (H, W)
            print(i)
            print(images[idx].shape)
            print(edge_mask.shape)
            break
        break

In [None]:
model.eval()
with torch.no_grad():
    for images, indices in dataloader_validation_hsi:
        images = images.unsqueeze(1)
        
        outputs = model(images)
        recon_error = torch.abs(images - outputs)

        for i, idx in enumerate(indices):
            print(idx)
            edge_mask = dataset_validation_hsi.get_edge_mask(idx).float()  # (H, W)
            image = images[i]
            image = image.squeeze(0)
            print(image.shape)
            image_slice = image * np.where(edge_mask == 2, 1, edge_mask)

            rgb = image[[9, 3, 5]].numpy()  # Pick bands for RGB
            rgb -= rgb.min()
            rgb /= rgb.max()
            rgb = rgb.transpose(1, 2, 0)  # [H, W, 3]

            plt.figure(figsize=(10, 4))
            plt.subplot(1, 2, 1)
            plt.imshow(rgb)
            plt.title("Input RGB Composite")
            plt.axis('off')

            plt.subplot(1, 2, 2)
            plt.imshow(edge_mask, cmap='hot')
            plt.colorbar()
            plt.title("Thresholded Line Map")
            plt.axis('off')

            plt.tight_layout()
            plt.show()   

            break
        break 
            
            #H, W = recon_error.shape[2], recon_error.shape[3]

'''
            # Pad to match shape
            pad_height = H - edge_mask.shape[0]
            pad_width = W - edge_mask.shape[1]
            edge_mask = transforms.functional.pad(edge_mask, (0, pad_width, 0, pad_height))  # (W, H)

            # Expand to (C, H, W)
            edge_mask = edge_mask.unsqueeze(0).expand(recon_error.shape[1], -1, -1)

            # Masked reconstruction error
            edge_error = recon_error[i] * edge_mask  # (C, H, W)
            '''

## Inference into the model

In [None]:
plot_losses(train_losses, val_losses, 199)

In [None]:
torch.manual_seed(10)
torch.cuda.manual_seed_all(10)
threshold = get_recon_error_threshold(model, dataloader_test_hsi, dataloader_early=dataloader_early_diseased_hsi,
                                      dataloader_mid=dataloader_mid_diseased_hsi, dataloader_late=dataloader_late_diseased_hsi, device=device)
print(threshold)

In [None]:
# --- Function to compute pixelwise error for an entire dataloader ---
def get_reconstruction_errors(model, dataloader, device):
    model.eval()
    all_errors = []

    with torch.no_grad():
        for batch,_ in dataloader:
            batch = batch.unsqueeze(1).to(device)  # shape: [B, 1, C, H, W]
            recon = model(batch)                  # same shape
            error = (recon - batch).pow(2).sum(dim=2)  # sum over spectral bands (C)
            # result: [B, 1, H, W]
            error = error.squeeze(1)  # now [B, H, W]
            all_errors.append(error.cpu().flatten())

    return torch.cat(all_errors).numpy()


# model.load_state_dict(torch.load("/home/r0979317/Documents/Thesis_Strawberries/models/third_fc_model.pth", map_location="cuda" if torch.cuda.is_available() else "cpu"))
device = next(model.parameters()).device
print("✅ Model loaded and moved to:", device)

print("✅ Model loaded and on device:", device)

# --- Datasets to process ---
datasets = {
    "Validation": dataloader_validation_hsi,
    "Early Diseased": dataloader_early_diseased_hsi,
    "Mid Diseased": dataloader_mid_diseased_hsi,
    "Late Diseased": dataloader_late_diseased_hsi,
}

# --- Compute and plot separately ---
for name, loader in datasets.items():
    print(f"Processing {name}...")
    errors = get_reconstruction_errors(model, loader, device)

    # Optional: clip 99th percentile for visualization
    upper = np.percentile(errors, 99)

    # --- Plot individual histogram ---
    plt.figure(figsize=(8, 5))
    plt.hist(errors, bins=100, range=(0, upper), color='skyblue', edgecolor='black', density=True)
    plt.title(f"Reconstruction Error Distribution - {name}")
    plt.xlabel("Pixelwise Reconstruction Error")
    plt.ylabel("Frequency (normalized)")
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    

In [None]:
# --- Function to compute pixelwise error for an entire dataloader ---
def get_pixel_reconstruction_errors(model, dataloader, device, quantile = 0.75):
    model.eval()
    all_errors = []

    with torch.no_grad():
        for batch,_ in dataloader:
            batch = batch.unsqueeze(1).to(device)  # shape: [B, 1, C, H, W]
            recon = model(batch)                  # same shape
            error = (recon - batch).pow(2).sum(dim=2)  # sum over spectral bands (C)
            
            all_errors.append(error.cpu().flatten())
    all_pixel_errors = torch.cat(all_errors).numpy()
    threshold = np.quantile(all_pixel_errors, quantile)

    return threshold

In [None]:
def classify_leaves(model, dataloader, device, threshold):
    model.eval()
    image_scores =[]

    with torch.no_grad():
        for batch,_ in dataloader:
            batch = batch.unsqueeze(1).to(device)  # shape: [B, 1, C, H, W]
            recon = model(batch) 
            pixel_errors = (recon-batch).pow(2).squeeze(1).sum(dim=1)

            for err_map in pixel_errors:
               high_error_pixels = err_map[err_map>threshold]
               score = high_error_pixels.sum().mean()
               image_scores.append(score)

    return image_scores



In [None]:
def classify_leaves_mean(model, dataloader, device, threshold):
    model.eval()
    image_scores =[]
    with torch.no_grad():
            for batch, _ in dataloader:
                batch = batch.unsqueeze(1).to(device)  # [B, 1, D, H, W]
                recon = model(batch)
                pixel_errors = (recon - batch).pow(2).squeeze(1).sum(dim=1)  # [B, H, W]

                for err_map in pixel_errors:
                    mean_error = err_map.mean().item()
                    image_scores.append(mean_error)
    return image_scores

In [None]:
pix_error = get_pixel_reconstruction_errors(model, dataloader_test_hsi, device)

In [None]:
score_early = classify_leaves(model, dataloader_early_diseased_hsi, device, pix_error)

In [None]:
dataloader_list = [dataloader_test_hsi, dataloader_early_diseased_hsi, dataloader_mid_diseased_hsi, dataloader_late_diseased_hsi]
errors = [] 
for i in dataloader_list:
    errors.append(classify_leaves_mean(model, i, device, pix_error))

In [None]:
import pandas as pd

In [None]:
pd.Series(errors[0]).describe()

In [None]:
import matplotlib.pyplot as plt

colors = ['green', 'yellow', 'orange', 'red']
labels = ['Healthy', 'Early Diseased','Mid Diseased', 'Severely Diseased']

plt.figure(figsize=(10, 5))
for i, (scores, color, label) in enumerate(zip(errors, colors, labels)):
    x = [i] * len(scores)
    plt.scatter(x, scores, color=color, label=label, alpha=0.7)

plt.xticks(ticks=range(len(labels)), labels=labels)
plt.ylabel("Aggregated Error (Above Threshold)")
plt.title("Per-image Error Scores by Group")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
for batch,_ in dataloader_early_diseased_hsi:
    print(batch.shape)
    for img in batch:
        print(img.shape)
        img = img.unsqueeze(0).unsqueeze(0).to(device)
        recon = model(img)
        error = (recon-img).pow(2).sum(dim=2)
        print(error.shape)
        compar = error[error>threshold_pixel].sum()/len(error[error>threshold_pixel])

        break
    break

In [None]:
def visualize_reconstruction_error(model, dataloader, select_img = 0, device=device):
    model.eval()  # Set to evaluation mode

    # Get a single batch from validation set
    with torch.no_grad():
        for data,_ in dataloader:
            data = data.unsqueeze(1).to(device)
            recon = model(data)  # Forward pass
            break  # Take only one batch

    # Select one sample from the batch (the first image)
    original = data[select_img,0].cpu().numpy()  # Shape: (C, H, W)
    reconstructed = recon[select_img,0].cpu().numpy()  # Shape: (C, H, W)

    # Compute reconstruction error per pixel (absolute difference)
    error_map = np.abs(original - reconstructed)  # Shape: (C, H, W) TODO: At other times we use squared error+mean, which should we use?

    # Aggregate error across spectral bands (e.g., sum error across channels)
    error_map_aggregated = np.average(error_map, axis=0)  # Shape: (H, W)

    RGB_bands = [9, 3, 5]

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(original[RGB_bands,:,].transpose(1, 2, 0))
    plt.title("Original")
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(reconstructed[RGB_bands,:,].transpose(1, 2, 0))
    plt.title("Reconstructed")
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(error_map_aggregated, cmap='hot')
    plt.title("Reconstruction Error Map")
    plt.axis('off')
    plt.colorbar()

In [None]:
visualize_reconstruction_error(model, dataloader_validation_hsi, 2, device=device)

In [None]:
from sklearn.manifold import TSNE
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import umap.umap_ as UMAP

In [None]:
def get_pseudo_rgb_image(img, RGB_bands=[9, 3, 5]):
    """
    Extracts pseudo-RGB image from a hyperspectral image tensor.
    
    Parameters:
        img (torch.Tensor): Hyperspectral image tensor with shape (C, H, W).
        RGB_bands (list): Band indices to use for R, G, and B.
        
    Returns:
        numpy.ndarray: Pseudo-RGB image (H, W, 3) with values normalized between 0 and 1.
    """
    # Extract the bands and convert to numpy array
    img_rgb = img[RGB_bands].cpu().numpy()  # shape: (3, H, W)
    img_rgb = img_rgb.transpose(1, 2, 0)       # shape: (H, W, 3)
    # Normalize for display
    img_rgb = (img_rgb - img_rgb.min()) / (img_rgb.max() - img_rgb.min() + 1e-6)
    return img_rgb

def plot_umap_interactive(latent_array, images, labels=None, RGB_bands=[9, 3, 5], image_scale=0.2, click_threshold=1.0, n_neighbors=15, min_dist=0.1):
    """
    Applies UMAP on latent representations and creates an interactive plot.
    Clicking on a point toggles display of its corresponding pseudo-RGB image.

    Parameters:
        latent_array (np.ndarray): Latent vectors (num_samples x latent_dim)
        images (torch.Tensor): Hyperspectral images (num_samples x C x H x W)
        labels (np.ndarray or None): Optional labels for coloring
        RGB_bands (list): Indices of hyperspectral bands used as R, G, B
        image_scale (float): Zoom level of the image thumbnails
        click_threshold (float): Max distance to register a click on a point
        n_neighbors (int): UMAP parameter for local structure
        min_dist (float): UMAP parameter for cluster spread
    """
    reducer = UMAP.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, random_state=42)
    umap_result = reducer.fit_transform(latent_array)

    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(umap_result[:, 0], umap_result[:, 1], c=labels if labels is not None else 'blue', cmap='viridis', s=50)
    if labels is not None:
        plt.colorbar(scatter, label="Label")

    ax.set_title("Interactive UMAP: Click a point to toggle image")
    ax.set_xlabel("UMAP Dimension 1")
    ax.set_ylabel("UMAP Dimension 2")

    # Dictionary to hold the currently displayed annotation boxes.
    displayed_annotations = {}

    def on_click(event):
	# Only consider clicks inside the axes
        if event.inaxes != ax:
            return

	# Get click coordinates in data space
        x_click, y_click = event.xdata, event.ydata
	
	# Compute distances from the click to each point
        distances = np.sqrt((umap_result[:, 0] - x_click)**2 + (umap_result[:, 1] - y_click)**2)
        closest_idx = np.argmin(distances)

	# Only consider the click if it's close enough to a point
        if distances[closest_idx] > click_threshold:
            return

        # Toggle the image annotation for the closest point
        if closest_idx in displayed_annotations:
            # Remove the annotation if already displayed
            displayed_annotations[closest_idx].remove()
            del displayed_annotations[closest_idx]
        else:
            # Create an annotation box with the pseudo-RGB image
            img_rgb = get_pseudo_rgb_image(images[closest_idx], RGB_bands=RGB_bands)
            imagebox = OffsetImage(img_rgb, zoom=image_scale)
            ab = AnnotationBbox(imagebox, (umap_result[closest_idx, 0], umap_result[closest_idx, 1]),
                                frameon=False)
            displayed_annotations[closest_idx] = ab
            ax.add_artist(ab)

        fig.canvas.draw_idle()

    # Connect the click event handler
    fig.canvas.mpl_connect('button_press_event', on_click)
    plt.show()

In [None]:
def get_latent_representations(model, dataloader, device, assigned_label=None):
    """
    Passes images through the encoder and collects latent representations, labels, and original images.

    Parameters:
        model (torch.nn.Module): Trained autoencoder model.
        dataloader (torch.utils.data.DataLoader): Dataloader for the dataset.
        device (torch.device): Device to run the model on.
        assigned_label (int, optional): Label to assign to all samples in this dataloader.

    Returns:
        latent_array (np.ndarray): Array of latent representations (num_samples x latent_dim).
        labels_array (np.ndarray): Array of labels for each sample.
        all_images (torch.Tensor): Original image tensors.
    """
    model.to(device).eval()
    latent_list = []
    labels_list = []
    image_list = []

    with torch.no_grad():
        for batch in dataloader:
            images = batch.to(device)
            images = images.unsqueeze(1)  # [B, 1, C, H, W]
            latent = model.encoder(images)
            latent = latent.view(latent.size(0), -1)  # Flatten latent representation

            latent_list.append(latent.cpu().numpy())
            image_list.append(images.cpu())

            if assigned_label is not None:
                labels_list.extend([assigned_label] * images.size(0))

    latent_array = np.concatenate(latent_list, axis=0)
    labels_array = np.array(labels_list) if labels_list else None
    all_images = torch.cat(image_list, dim=0)

    return latent_array, labels_array, all_images

In [None]:
latent_healthy, labels_healthy, images_healthy = get_latent_representations(model, dataloader_train_hsi, device, assigned_label=0)
latent_early, labels_early, images_early = get_latent_representations(model, dataloader_early_diseased_hsi, device, assigned_label=1)
latent_mid, labels_mid, images_mid = get_latent_representations(model, dataloader_mid_diseased_hsi, device, assigned_label=2)
latent_late, labels_late, images_late = get_latent_representations(model, dataloader_late_diseased_hsi, device, assigned_label=3)

In [None]:
latent_all = np.concatenate([latent_healthy, latent_early, latent_mid, latent_late], axis=0)
labels_all = np.concatenate([labels_healthy, labels_early, labels_mid, labels_late], axis=0)
images_all = torch.cat([images_healthy, images_early, images_mid, images_late], axis=0)

In [None]:
plot_umap_interactive(latent_all, images_all, labels=labels_all, n_neighbors=3, min_dist=0.05)

In [None]:
# Define the model architecture and optimizer again
model = CNN3DAE(layers_list=[32,64,128], input_dim=1, kernel_sizes=3, strides=(1, 2, 2), paddings=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Load the checkpoint
model, optimizer, train_losses, val_losses, last_epoch = load_checkpoint(model, '/home/r0979317/Documents/Thesis_Strawberries/models/first_model.pth', device, optimizer)

In [None]:
visualize_reconstruction_error(model, dataloader_validation_hsi, 2, device=device)