# Calculate the Thresholds for each Image Channel

Author(s): Peer Schütt

In [None]:
import sys, os

parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(parent_dir)
import json
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from autoencoder import Autoencoder
from MultispectralImageDataset import MultispectralImageDataset
from utils import load_all_img_paths, loss_func, set_global_random_seed
import torch
import matplotlib.pyplot as plt
import numpy as np

DEVICE = "cpu"
from pathlib import Path

In [None]:
pretrained_model_folder = "trained_models/"

# pretrained_model_name = "2025-04-18_16-14-34_musero_autoencoder_NIR" # NIR 1024
# pretrained_model_name = "2025-04-19_12-45-55_musero_autoencoder_NIR" # NIR 2048
# pretrained_model_name = "2025-04-18_09-48-55_musero_autoencoder_VIS" # VIS 1024
pretrained_model_name = "2025-04-24_13-32-54_musero_autoencoder_VIS" # VIS 2048
pretrained_model_full_path = pretrained_model_folder + pretrained_model_name

pretrained_model_location_dict = pretrained_model_full_path+".json"
pretrained_model_location_weights = pretrained_model_full_path+".pth"

model_precomputes_save_folder_path = f"trained_models/{pretrained_model_name}/"
Path(model_precomputes_save_folder_path).mkdir(parents=True, exist_ok=True)

whole_dict = None
with open(pretrained_model_location_dict) as f:
    whole_dict = json.load(f)
config = whole_dict["config"]

set_global_random_seed(config)


## Load / calculate loss values for the train/val/test set

In [None]:
if os.path.isfile(f"{model_precomputes_save_folder_path}all_train_losses.npy") and os.path.isfile(f"{model_precomputes_save_folder_path}all_val_losses.npy") and os.path.isfile(f"{model_precomputes_save_folder_path}all_test_losses.npy"):
    print("Loading precomputed losses!")
    all_train_losses = np.load(f"{model_precomputes_save_folder_path}all_train_losses.npy", allow_pickle=True)
    print("Loaded train losses!")
    all_val_losses = np.load(f"{model_precomputes_save_folder_path}all_val_losses.npy", allow_pickle=True)
    print("Loaded val losses!")
    all_test_losses = np.load(f"{model_precomputes_save_folder_path}all_test_losses.npy", allow_pickle=True)
    print("Loaded test losses!")
    
else:
    print("Computing new losses!")

    transform = transforms.Compose([
            transforms.CenterCrop([config['img_size_x'], config['img_size_y']])
        ])
    model = Autoencoder(**config).to(DEVICE)

    model.load_state_dict(torch.load(pretrained_model_location_weights, weights_only=True))
    model.eval()

    dataset_params = {
        'transform': transform,
        'img_type': config["img_type"],
        'channels_to_use': config['channels_to_use'],
        'device': DEVICE,
        'overfit': False,
        'augment_data': False
    }   

    # Create data loaders with common parameters
    dataloader_params = {
        'batch_size': config['batch_size'],
        'num_workers': 4,
        'pin_memory': True
    }


    from utils import sample_data_paths # import all image paths
    import os.path

    # Combine training folders
    train_folder = [
        *sample_data_paths.values(),
    ]

    # Load and prepare datasets
    train_img_paths = load_all_img_paths(train_folder, config["img_type"], split="train")

    # Initialize datasets
    train_val_dataset = MultispectralImageDataset(train_img_paths, **dataset_params)

    # Calculate split sizes
    dataset_size = len(train_val_dataset)
    val_size = int(np.ceil(dataset_size * config['val_split']))
    train_size = dataset_size - val_size

    # Split dataset
    train_set, val_set = torch.utils.data.random_split(
        train_val_dataset,
        [train_size, val_size]
    )

    # Create data loaders with common parameters
    dataloader_params = {
        'batch_size': config['batch_size'],
        'num_workers': 4,
        'pin_memory': True
    }

    train_data_loader = DataLoader(train_set, shuffle=True, **dataloader_params)
    val_data_loader = DataLoader(val_set, shuffle=False, **dataloader_params)

    # Prepare test dataset
    # test_folder = [old_data_paths['fss_test']]
    test_folder = [
        *sample_data_paths.values(), 
    ]
    test_img_paths = load_all_img_paths(test_folder, config["img_type"], split="test")
    test_dataset = MultispectralImageDataset(test_img_paths, **dataset_params)
    test_data_loader = DataLoader(test_dataset, shuffle=False, **dataloader_params)

    def channel_loss(model, dataloader, device=DEVICE):
        model.eval()
        channel_losses = []
        
        with torch.no_grad():
            for _, raw_images, _ in tqdm(dataloader):
                raw_images = raw_images.to(device)
                outputs = model(raw_images)
                
                channel_loss = loss_func(raw_images, outputs, per_img_and_pixel=True)
                channel_losses.append(channel_loss.cpu())
                
        return channel_losses

    # train_losses = channel_loss(model, train_data_loader)
    # reshaped_list = []
    # for tensor in train_losses:
    #     batch_size = tensor.shape[0]
    #     reshaped = tensor.reshape(-1, 9, tensor.shape[2], tensor.shape[3])
    #     reshaped_list.append(reshaped)
    # all_train_losses = torch.cat(reshaped_list, dim=0).numpy()

    val_losses = channel_loss(model, val_data_loader)
    reshaped_list = []
    for tensor in val_losses:
        batch_size = tensor.shape[0]
        reshaped = tensor.reshape(-1, 9, tensor.shape[2], tensor.shape[3])
        reshaped_list.append(reshaped)
    all_val_losses = torch.cat(reshaped_list, dim=0).numpy()

    # test_losses = channel_loss(model, test_data_loader)
    # reshaped_list = []
    # for tensor in test_losses:
    #     batch_size = tensor.shape[0]
    #     reshaped = tensor.reshape(-1, 9, tensor.shape[2], tensor.shape[3])
    #     reshaped_list.append(reshaped)
    # all_test_losses = torch.cat(reshaped_list, dim=0).numpy()

    # np.save(f"{model_precomputes_save_folder_path}all_train_losses.npy",all_train_losses, allow_pickle=True)
    # np.save(f"{model_precomputes_save_folder_path}all_val_losses.npy",all_val_losses, allow_pickle=True)
    # np.save(f"{model_precomputes_save_folder_path}all_test_losses.npy",all_test_losses, allow_pickle=True)

## Plot loss distribution

In [None]:
# Create a 3×3 grid of plots
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
axes = axes.flatten()
# Plot histograms for each of the 9 components
for i in range(9):
    ax = axes[i]
    
    # Flatten the pixel values for this component
    # train_pixels = all_train_losses[::100, i, :, :].flatten()
    val_pixels = all_val_losses[::100, i, :, :].flatten()
    # test_pixels = all_test_losses[::100, i, :, :].flatten()
    
    # Create histograms with transparency
    # ax.hist(train_pixels, bins=50, alpha=0.5, color='blue', label='Train')
    # ax.hist(test_pixels, bins=50, alpha=0.5, color='red', label='Test')
    ax.hist(val_pixels, bins=50, alpha=0.5, color='green', label='Validation')
    
    ax.set_title(f'Channel {i} loss distribution')
    ax.set_xlabel('Pixel Value')
    ax.set_ylabel('Frequency')
    ax.set_ylim(0, 1000)
    ax.legend()
plt.tight_layout()
plt.show()

## Calculate percentiles of the loss values

In [None]:
# reshaped_train = np.moveaxis(all_train_losses, 0, 1).reshape(9,-1)
reshaped_val = np.moveaxis(all_val_losses, 0, 1).reshape(9,-1)
# reshaped_test = np.moveaxis(all_test_losses, 0, 1).reshape(9,-1)

In [None]:
# train_95 = np.percentile(reshaped_train[:,::100], 95, axis=1)
# train_99 = np.percentile(reshaped_train[:,::100], 99, axis=1)

In [None]:
val_95 = np.percentile(reshaped_val[:,::100], 95, axis=1)
val_99 = np.percentile(reshaped_val[:,::100], 99, axis=1)

In [None]:
# test_95 = np.percentile(reshaped_test[:,::100], 95, axis=1)
# test_99 = np.percentile(reshaped_test[:,::100], 99, axis=1)

In [None]:
print("95th Percentile")
# print("Train: ", train_95)
print("Val: ", val_95)
# print("Test: ", test_95)

In [None]:
print("99th Percentile")
# print("Train: ", train_99)
print("Val: ", val_99)
# print("Test: ", test_99)

In [None]:
# np.save(f"{model_precomputes_save_folder_path}thresholds_train_99", train_99)
np.save(f"{model_precomputes_save_folder_path}thresholds_val_99", val_99)
# np.save(f"{model_precomputes_save_folder_path}thresholds_test_99", test_99)


# np.save(f"{model_precomputes_save_folder_path}thresholds_train_95", train_95)
np.save(f"{model_precomputes_save_folder_path}thresholds_val_95", val_95)
# np.save(f"{model_precomputes_save_folder_path}thresholds_test_95", test_95)