# Search

## Random Search

In [None]:
import sys

sys.path.append("..")  # Add parent directory to path

from src.dataset.dataset import build_hsi_testloader, get_wavelengths_from_metadata
from src.util.segmentation_util import (
    build_segmentation_model,
    evaluate_model,
    load_model,
)
from src.util.constants import MODELS_DIR
import numpy as np
import torch
import random
import gc


def get_interval_from_wavelenths(start, end):
    wavelength_array = get_wavelengths_from_metadata()
    indices = np.where((wavelength_array >= start) & (wavelength_array <= end))[0]
    return indices[0], indices[-1]


# Define the intervals:
red_interval = get_interval_from_wavelenths(600, 1000)
green_interval = get_interval_from_wavelenths(500, 600)
blue_interval = get_interval_from_wavelenths(400, 500)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

segmentation_model = build_segmentation_model(
    encoder="timm-regnetx_320", architecture="Linknet", device=device, in_channels=3
)
model_path = MODELS_DIR / "serene-sweep-9.pth"
segmentation_model = load_model(segmentation_model, model_path, device=device)
segmentation_model.eval()

num_random_samples = 50
best_score = -1.0
best_bands = None
wavelengths = get_wavelengths_from_metadata()
for _ in range(num_random_samples):
    # Randomly pick one band from each interval
    red_band = random.randint(red_interval[0], red_interval[1])
    green_band = random.randint(green_interval[0], green_interval[1])
    blue_band = random.randint(blue_interval[0], blue_interval[1])

    print(
        f"Trying bands: red={wavelengths[red_band]}, green={wavelengths[green_band]}, blue={wavelengths[blue_band]}"
    )
    testloader_target = build_hsi_testloader(
        batch_size=1,
        rgb=True,
        rgb_channels=(red_band, green_band, blue_band),
    )

    # Evaluate the model on these chosen channels
    with torch.no_grad():
        _, _, _, _, dice_score = evaluate_model(
            segmentation_model, testloader_target, device, with_wandb=False
        )

    # Update best found so far
    if dice_score > best_score:
        best_score = dice_score
        best_bands = (red_band, green_band, blue_band)
        print(f"New best score: {best_score:.4f}, bands: {best_bands}")

    del testloader_target
    torch.cuda.empty_cache()
    gc.collect()

print(f"Best band combination: {best_bands} with Dice = {best_score:.4f}")

Trying bands: red=768.635, green=564.914, blue=468.147
Precision: nan, Recall: 0.1095, F1 Score: 0.1112, Dice Score: 0.1112, Accuracy: 0.7894
New best score: 0.1112, bands: (506, 226, 93)
Trying bands: red=773.001, green=527.808, blue=466.692


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from src.util.segmentation_util import build_criterion, build_segmentation_model
from src.dataset.dataset import build_hsi_testloader

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")


class ChannelSelectionNet(nn.Module):
    """
    A simple gating network that outputs gating logits for 826 channels.
    For a static best-3 selection (dataset-wide), we do not need per-image gating input. 
    If you want per-image gating, pass some features into forward().
    """
    def __init__(self, num_channels=826):
        super().__init__()
        # Each channel has a trainable logit parameter
        self.logits = nn.Parameter(torch.zeros(num_channels))
    
    def forward(self):
        # We apply softmax or a sparse approach. This is not strictly top-3 differentiable,
        # but we can pick top-3 from alpha each iteration.
        alpha = F.softmax(self.logits, dim=0)  # shape [826]
        return alpha

testloader_target = build_hsi_testloader(
    batch_size=1,
)
# Instantiate gating network
gating_net = ChannelSelectionNet(num_channels=826).to(device)

# Build / load your frozen segmentation model (SMP Linknet, etc.)
segmentation_model = build_segmentation_model(
    encoder='timm-regnetx_320',
    architecture='Linknet',
    device=device,
    in_channels=3
)
model_path = MODELS_DIR / "serene-sweep-9.pth"
segmentation_model = load_model(segmentation_model, model_path, device=device)
segmentation_model.eval()
for param in segmentation_model.parameters():
    param.requires_grad = False  # freeze the segmentation model
    
dice_loss = build_criterion('Dice')
num_epochs = 5
l1_lambda = 1e-3 
optimizer = optim.Adam(gating_net.parameters(), lr=1e-2)

for epoch in range(num_epochs):
    gating_net.train()
    
    for batch_idx, (hsi_image, mask) in enumerate(testloader_target):
        hsi_image = hsi_image.to(device)   # shape [B, 826, H, W]
        mask = mask.to(device)            # shape [B, H, W] or [B, 1, H, W] depending on your code

        optimizer.zero_grad()

        # Forward gating net
        alpha = gating_net()  # shape [826]
        
        # Pick top-3 channels (hard selection). 
        # This operation is not fully differentiable, but PyTorch can still route gradients
        # to those top-3 logits. Over training, 3 channels will dominate.
        topk_vals, topk_idx = torch.topk(alpha, 3)  # shape [3]
        
        # Extract the selected channels: shape [B, 3, H, W]
        selected_channels = hsi_image[:, topk_idx, :, :]
        
        # Forward pass in the frozen segmentation model 
        # (we do NOT use torch.no_grad(), because we want gating_net to receive gradient 
        # from the DiceLoss. The seg model won't update anyway because it's frozen.)
        preds = segmentation_model(selected_channels)
        
        # SMP Dice loss expects (preds, target) 
        # If your seg model outputs raw logits, "from_logits=True" is correct
        # Make sure mask is shape [B, 1, H, W] or [B, H, W], consistent with your segmentation model output
        loss_dice = dice_loss(preds, mask)
        
        # Add an L1 penalty to encourage alpha to be sparse
        l1_penalty = torch.sum(torch.abs(alpha))
        loss = loss_dice + l1_lambda * l1_penalty
        
        loss.backward()
        optimizer.step()
        
        print(f"Epoch [{epoch}/{num_epochs}] Batch [{batch_idx}] "
                f"DiceLoss: {loss_dice.item():.4f} | Loss + L1: {loss.item():.4f}")

Epoch [0/5] Batch [0] DiceLoss: 0.7992 | Loss + L1: 0.8002
Epoch [0/5] Batch [1] DiceLoss: 0.9589 | Loss + L1: 0.9599
Epoch [0/5] Batch [2] DiceLoss: 0.7920 | Loss + L1: 0.7930
Epoch [0/5] Batch [3] DiceLoss: 0.8386 | Loss + L1: 0.8396
Epoch [0/5] Batch [4] DiceLoss: 0.7204 | Loss + L1: 0.7214
Epoch [1/5] Batch [0] DiceLoss: 0.7992 | Loss + L1: 0.8002
Epoch [1/5] Batch [1] DiceLoss: 0.9589 | Loss + L1: 0.9599
Epoch [1/5] Batch [2] DiceLoss: 0.7920 | Loss + L1: 0.7930
Epoch [1/5] Batch [3] DiceLoss: 0.8386 | Loss + L1: 0.8396
Epoch [1/5] Batch [4] DiceLoss: 0.7204 | Loss + L1: 0.7214
Epoch [2/5] Batch [0] DiceLoss: 0.7992 | Loss + L1: 0.8002
Epoch [2/5] Batch [1] DiceLoss: 0.9589 | Loss + L1: 0.9599
Epoch [2/5] Batch [2] DiceLoss: 0.7920 | Loss + L1: 0.7930
Epoch [2/5] Batch [3] DiceLoss: 0.8386 | Loss + L1: 0.8396
Epoch [2/5] Batch [4] DiceLoss: 0.7204 | Loss + L1: 0.7214
Epoch [3/5] Batch [0] DiceLoss: 0.7992 | Loss + L1: 0.8002
Epoch [3/5] Batch [1] DiceLoss: 0.9589 | Loss + L1: 0.95

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from src.util.segmentation_util import build_criterion, build_segmentation_model
from src.dataset.dataset import build_hsi_testloader


class ChannelSelectionNet(nn.Module):
    def __init__(self, num_channels=826):
        super().__init__()
        # We'll have 3 rows, each row is a set of logits for how to combine the 826 channels
        self.alpha = nn.Parameter(torch.zeros(3, num_channels))

    def forward(self):
        # Softmax over dimension=1 for each of the 3 rows
        # shape: [3, 826]
        alpha_soft = F.softmax(self.alpha, dim=1)
        return alpha_soft


device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# Build gating net
gating_net = ChannelSelectionNet(num_channels=826).to(device)

# Build and freeze segmentation model
segmentation_model = build_segmentation_model(
    encoder="timm-regnetx_320", architecture="Linknet", device=device, in_channels=3
)
model_path = MODELS_DIR / "serene-sweep-9.pth"
segmentation_model = load_model(segmentation_model, model_path, device=device)
segmentation_model.eval()
for param in segmentation_model.parameters():
    param.requires_grad = False

dice_loss_fn = build_criterion()
optimizer = optim.Adam(gating_net.parameters(), lr=1e-2)

trainloader = build_hsi_testloader(batch_size=1)

num_epochs = 10
l1_lambda = 1e-3  # encourages a few weights in each row to dominate

for epoch in range(num_epochs):
    gating_net.train()
    for batch_idx, (hsi_image, mask) in enumerate(trainloader):
        hsi_image = hsi_image.to(device)  # [B, 826, H, W]
        mask = mask.to(device)  # [B, H, W] or [B,1,H,W]

        optimizer.zero_grad()
        alpha_soft = gating_net()  # shape [3, 826]

        # Weighted sum of the 826 channels -> 3 channels
        # out_3channels has shape [B, 3, H, W]
        out_3channels = torch.einsum("bchw,rc->brhw", hsi_image, alpha_soft)

        preds = segmentation_model(out_3channels)
        loss_dice = dice_loss_fn(preds, mask)

        # L1 penalty on alpha
        l1_penalty = torch.sum(torch.abs(alpha_soft))
        loss = loss_dice + l1_lambda * l1_penalty

        loss.backward()
        optimizer.step()

        print(
            f"Epoch: {epoch}, Batch: {batch_idx}, DiceLoss: {loss_dice.item():.4f}, TotalLoss: {loss.item():.4f}"
        )


# After training, gating_net.alpha_soft likely has 3 rows each nearly "one-hot" for a channel
gating_net.eval()
with torch.no_grad():
    alpha_soft = gating_net()  # shape [3, 826]

# For each of the 3 rows, find the channel with the largest weight
for i in range(3):
    best_idx = torch.argmax(alpha_soft[i]).item()
    print(f"Channel for row {i}: {best_idx}")

Epoch: 0, Batch: 0, DiceLoss: 0.7105, TotalLoss: 0.7135
Epoch: 0, Batch: 1, DiceLoss: 0.8017, TotalLoss: 0.8047
Epoch: 0, Batch: 2, DiceLoss: 0.6638, TotalLoss: 0.6668
Epoch: 0, Batch: 3, DiceLoss: 0.4322, TotalLoss: 0.4352
Epoch: 0, Batch: 4, DiceLoss: 0.5988, TotalLoss: 0.6018
Epoch: 1, Batch: 0, DiceLoss: 0.7084, TotalLoss: 0.7114
Epoch: 1, Batch: 1, DiceLoss: 0.8007, TotalLoss: 0.8037
Epoch: 1, Batch: 2, DiceLoss: 0.6623, TotalLoss: 0.6653
Epoch: 1, Batch: 3, DiceLoss: 0.4312, TotalLoss: 0.4342
Epoch: 1, Batch: 4, DiceLoss: 0.5954, TotalLoss: 0.5984
Epoch: 2, Batch: 0, DiceLoss: 0.7066, TotalLoss: 0.7096
Epoch: 2, Batch: 1, DiceLoss: 0.7999, TotalLoss: 0.8029
Epoch: 2, Batch: 2, DiceLoss: 0.6609, TotalLoss: 0.6639
Epoch: 2, Batch: 3, DiceLoss: 0.4302, TotalLoss: 0.4332
Epoch: 2, Batch: 4, DiceLoss: 0.5919, TotalLoss: 0.5949
Epoch: 3, Batch: 0, DiceLoss: 0.7048, TotalLoss: 0.7078
Epoch: 3, Batch: 1, DiceLoss: 0.7991, TotalLoss: 0.8021
Epoch: 3, Batch: 2, DiceLoss: 0.6596, TotalLoss: