# Search

In [1]:
from dataset import build_hsi_testloader, get_wavelengths_from_metadata
from segmentation_util import build_segmentation_model, evaluate_model
import numpy as np
import torch
import random
import gc


def get_interval_from_wavelenths(start, end):
    wavelength_array = get_wavelengths_from_metadata()  # <--- your function
    indices = np.where((wavelength_array >= start) & (wavelength_array <= end))[0]
    return indices[0], indices[-1]


# Define the intervals:
red_interval = get_interval_from_wavelenths(600, 1000)
green_interval = get_interval_from_wavelenths(500, 600)
blue_interval = get_interval_from_wavelenths(400, 500)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

segmentation_model = build_segmentation_model(
    encoder="timm-regnetx_320", architecture="Linknet", device=device, in_channels=3
)
segmentation_model.load_state_dict(torch.load("models/serene-sweep-9.pth"))
segmentation_model.eval()  # put model in eval mode

num_random_samples = 50  # how many random picks to try
best_score = -1.0
best_bands = None
wavelengths = get_wavelengths_from_metadata()
for _ in range(num_random_samples):
    # Randomly pick one band from each interval
    red_band = random.randint(red_interval[0], red_interval[1])
    green_band = random.randint(green_interval[0], green_interval[1])
    blue_band = random.randint(blue_interval[0], blue_interval[1])

    print(
        f"Trying bands: red={wavelengths[red_band]}, green={wavelengths[green_band]}, blue={wavelengths[blue_band]}"
    )
    testloader_target = build_hsi_testloader(
        batch_size=1,
        rgb=True,
        rgb_channels=(red_band, green_band, blue_band),
    )

    # Evaluate the model on these chosen channels
    with torch.no_grad():
        _, _, _, _, dice_score = evaluate_model(
            segmentation_model, testloader_target, device, with_wandb=False
        )

    # Update best found so far
    if dice_score > best_score:
        best_score = dice_score
        best_bands = (red_band, green_band, blue_band)
        print(f"New best score: {best_score:.4f}, bands: {best_bands}")

    del testloader_target
    torch.cuda.empty_cache()
    gc.collect()

print(f"Best band combination: {best_bands} with Dice = {best_score:.4f}")

Trying bands: red=844.303, green=540.177, blue=425.947
Precision: nan, Recall: 0.0699, F1 Score: 0.0846, Dice Score: 0.0846, Accuracy: 0.7891
New best score: 0.0846, bands: (610, 192, 35)
Trying bands: red=823.931, green=563.459, blue=475.423
Precision: nan, Recall: 0.0683, F1 Score: 0.0803, Dice Score: 0.0803, Accuracy: 0.7849
Trying bands: red=681.326, green=527.808, blue=421.582
Precision: nan, Recall: 0.1319, F1 Score: 0.1217, Dice Score: 0.1217, Accuracy: 0.7927
New best score: 0.1217, bands: (386, 175, 29)
Trying bands: red=839.937, green=571.462, blue=438.316
Precision: nan, Recall: 0.0588, F1 Score: 0.0736, Dice Score: 0.0736, Accuracy: 0.7858
Trying bands: red=896.688, green=529.263, blue=482.698
Precision: nan, Recall: 0.0358, F1 Score: 0.0510, Dice Score: 0.0510, Accuracy: 0.7837
Trying bands: red=710.429, green=590.379, blue=425.22
Precision: nan, Recall: 0.1690, F1 Score: 0.1430, Dice Score: 0.1430, Accuracy: 0.7911
New best score: 0.1430, bands: (426, 261, 34)
Trying band