# Unified Hyperspectral Denoising: Training and Evaluation

This notebook consolidates the training and evaluation workflows of the provided scripts into a single, reproducible document. It inlines only local utilities (datasets, transforms, metrics) while importing model architectures from the `models` package by name.

In [1]:
pip install tensorboardX lmdb

Collecting tensorboardX
  Downloading tensorboardx-2.6.4-py3-none-any.whl.metadata (6.2 kB)
Collecting lmdb
  Downloading lmdb-1.7.5-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (1.4 kB)
Downloading tensorboardx-2.6.4-py3-none-any.whl (87 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lmdb-1.7.5-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (295 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.1/295.1 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lmdb, tensorboardX
Successfully installed lmdb-1.7.5 tensorboardX-2.6.4
Note: you may need to restart the kernel to use updated packages.


In [2]:
# Environment setup and imports
import os, math, time, json, random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose

# Ensure offline behavior in network-restricted environments
if os.path.exists('/kaggle') or 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
    os.environ.setdefault('WANDB_MODE', 'offline')

# Reproducibility seed and device selection
seed = 2018
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)

Using device: cuda


In [3]:
candidate_roots = [
    os.path.abspath('.'),
    '/kaggle/working',
    '/kaggle/input',
    # User-provided dataset hints
    '/kaggle/input/hsi_denoising_all/other/default/5',
]
for root in candidate_roots:
    if root and os.path.isdir(root):
        if root not in sys.path:
            sys.path.insert(0, root)
        # Also scan one level deep under /kaggle/input for unpacked dataset dirs
        try:
            for name in os.listdir(root):
                p = os.path.join(root, name)
                if os.path.isdir(p) and (p not in sys.path):
                    # If it looks like a repo folder with models/
                    if os.path.exists(os.path.join(p, 'models', '__init__.py')):
                        sys.path.insert(0, p)
        except Exception:
            pass

## Utility functions

In [4]:
# Learning-rate utilities and parameter initialization
import time, sys

def adjust_learning_rate(optimizer, lr):
    print(f'Adjust Learning Rate => {lr:.4e}')
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def display_learning_rate(optimizer):
    lrs = []
    for i, param_group in enumerate(optimizer.param_groups):
        lr = param_group['lr']
        print('learning rate of group %d: %.4e' % (i, lr))
        lrs.append(lr)
    return lrs

TOTAL_BAR_LENGTH = 50.0
_last_time = time.time()
_begin_time = _last_time

def _format_time(seconds):
    days = int(seconds / 3600/24)
    seconds -= days*3600*24
    hours = int(seconds / 3600)
    seconds -= hours*3600
    minutes = int(seconds / 60)
    seconds -= minutes*60
    secondsf = int(seconds)
    millis = int((seconds - secondsf)*1000)
    f, i = '', 1
    if days > 0 and i <= 2: f += str(days) + 'D'; i += 1
    if hours > 0 and i <= 2: f += str(hours) + 'h'; i += 1
    if minutes > 0 and i <= 2: f += str(minutes) + 'm'; i += 1
    if secondsf > 0 and i <= 2: f += str(secondsf) + 's'; i += 1
    if millis > 0 and i <= 2: f += str(millis) + 'ms'; i += 1
    return f or '0ms'

def progress_bar(current, total, msg=None):
    global _last_time, _begin_time
    if current == 0:
        _begin_time = time.time()
    cur_len = int(TOTAL_BAR_LENGTH * current / total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
    bar = '[' + '=' * cur_len + '>' + '.' * rest_len + ']'
    cur_time = time.time()
    step_time = cur_time - _last_time
    _last_time = cur_time
    tot_time = cur_time - _begin_time
    progress_msg = f"\r{bar} {current+1}/{total} | Step: {_format_time(step_time)} | Tot: {_format_time(tot_time)}"
    if msg:
        progress_msg += f" | {msg}"
    sys.stdout.write(progress_msg)
    if current >= total - 1:
        sys.stdout.write('\n')
    sys.stdout.flush()

def init_params(net, init_type='kn'):
    print('use init scheme:', init_type)
    if init_type != 'edsr':
        for m in net.modules():
            if isinstance(m, (nn.Conv2d, nn.Conv3d)):
                if init_type == 'kn':
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if init_type == 'ku':
                    nn.init.kaiming_uniform_(m.weight, mode='fan_out')
                if init_type == 'xn':
                    nn.init.xavier_normal_(m.weight)
                if init_type == 'xu':
                    nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=1e-3)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

## Dataset utilities

In [5]:
# Data transforms and dataset wrappers
import cv2
from PIL import Image
from scipy.io import loadmat

class AddNoise(object):
    def __init__(self, sigma):
        self.sigma_ratio = sigma / 255.0
    def __call__(self, img):
        noise = np.random.randn(*img.shape) * self.sigma_ratio
        return img + noise

class AddNoiseBlindv1(object):
    def __init__(self, min_sigma, max_sigma):
        self.min_sigma = min_sigma
        self.max_sigma = max_sigma
    def __call__(self, img):
        sigma = np.random.uniform(self.min_sigma, self.max_sigma) / 255
        noise = np.random.randn(*img.shape) * sigma
        return img + noise

class HSI2Tensor(object):
    def __init__(self, use_2dconv):
        self.use_2dconv = use_2dconv
    def __call__(self, hsi):
        if self.use_2dconv:
            img = torch.from_numpy(hsi)
        else:
            img = torch.from_numpy(hsi[None])
        return img.float()

class LoadMatHSI(object):
    def __init__(self, input_key, gt_key, needsigma=False, transform=None):
        self.gt_key = gt_key
        self.input_key = input_key
        self.transform = transform
        self.needsigma = needsigma
    def __call__(self, mat):
        if self.transform:
            _input = self.transform(mat[self.input_key][:].transpose((2,0,1)))
            _gt = self.transform(mat[self.gt_key][:].transpose((2,0,1)))
        else:
            _input = mat[self.input_key][:].transpose((2,0,1))
            _gt = mat[self.gt_key][:].transpose((2,0,1))
        input_t = torch.from_numpy(_input).float()
        gt_t = torch.from_numpy(_gt).float()
        if self.needsigma:
            sigma = torch.from_numpy(mat['sigma']).float()
            return input_t, gt_t, sigma
        return input_t, gt_t

class MatDataFromFolder(Dataset):
    def __init__(self, data_dir, suffix='.mat', fns=None, size=None):
        super().__init__()
        if fns is not None:
            self.filenames = [os.path.join(data_dir, fn) for fn in fns]
        else:
            self.filenames = [os.path.join(data_dir, fn) for fn in os.listdir(data_dir) if fn.endswith(suffix)]
        if size is not None and size <= len(self.filenames):
            self.filenames = self.filenames[:size]
    def __getitem__(self, index):
        return loadmat(self.filenames[index])
    def __len__(self):
        return len(self.filenames)

class ImageTransformDataset(Dataset):
    def __init__(self, dataset, transform, target_transform=None):
        super().__init__()
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform
        self.length = len(dataset)
    def __len__(self):
        return self.length
    def __getitem__(self, idx):
        img = self.dataset[idx]
        target = img.copy()
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

class TransformDatasetWrapper(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        return self.transform(self.dataset[idx])

class LMDBDataset(Dataset):
    def __init__(self, db_path, repeat=1):
        import lmdb
        self.db_path = db_path
        self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
        with self.env.begin(write=False) as txn:
            self.length = int(txn.stat()['entries'])
        with open(os.path.join(db_path, 'meta_info.txt')) as fin:
            line = fin.readlines()[0]
            size = line.split('(')[1].split(')')[0]
            h, w, c = [int(s) for s in size.split(',')]
        self.channels = c; self.width = h; self.height = w
        self.repeat = repeat
    def __getitem__(self, index):
        import numpy as _np
        index = index % self.length
        with self.env.begin(write=False) as txn:
            data = txn.get(f'{index:08d}'.encode('ascii'))
        flat_x = _np.fromstring(data, dtype=_np.float32)
        x = flat_x.reshape(self.channels, self.height, self.width)
        return x
    def __len__(self):
        return self.length * self.repeat

def worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

## Metrics and loss functions

In [6]:
# Band-wise metrics and structural similarity
class Bandwise(object):
    def __init__(self, index_fn):
        self.index_fn = index_fn
    def __call__(self, X, Y):
        C = X.shape[-3]
        bwindex = []
        for ch in range(C):
            x = torch.squeeze(X[..., ch, :, :].data).cpu().numpy()
            y = torch.squeeze(Y[..., ch, :, :].data).cpu().numpy()
            index = self.index_fn(x, y)
            bwindex.append(index)
        return bwindex

# Prefer skimage.metrics when available
try:
    from skimage.metrics import structural_similarity as compare_ssim, peak_signal_noise_ratio as compare_psnr
except ImportError:
    from skimage.measure import compare_ssim, compare_psnr
from functools import partial
cal_bwpsnr = Bandwise(partial(compare_psnr, data_range=1))
cal_bwssim = Bandwise(compare_ssim)

def cal_sam(X, Y, eps=1e-8):
    Xn = torch.squeeze(X.data).cpu().numpy()
    Yn = torch.squeeze(Y.data).cpu().numpy()
    tmp = (np.sum(Xn*Yn, axis=0) + eps) / (np.sqrt(np.sum(Xn**2, axis=0)) + eps) / (np.sqrt(np.sum(Yn**2, axis=0)) + eps)
    return np.mean(np.real(np.arccos(tmp)))

# SSIM loss and SAM loss
import torch.nn.functional as F

def _fspecial_gauss_1d(size, sigma):
    coords = torch.arange(size).to(dtype=torch.float)
    coords -= size//2
    g = torch.exp(-(coords**2) / (2*sigma**2))
    g /= g.sum()
    return g.unsqueeze(0).unsqueeze(0)

def gaussian_filter(input, win):
    N, C, H, W = input.shape
    out = F.conv2d(input, win, stride=1, padding=0, groups=C)
    out = out.transpose(2, 3).contiguous()
    out = F.conv2d(out, win, stride=1, padding=0, groups=C)
    return out.transpose(2, 3).contiguous()

def _ssim(X, Y, win, data_range=255, size_average=True, full=False):
    K1, K2 = 0.01, 0.03
    channel = X.shape[1]
    compensation = 1.0
    C1 = (K1 * data_range)**2
    C2 = (K2 * data_range)**2
    concat_input = torch.cat([X, Y, X*X, Y*Y, X*Y], dim=1)
    concat_win = win.repeat(5, 1, 1, 1).to(X.device, dtype=X.dtype)
    concat_out = gaussian_filter(concat_input, concat_win)
    mu1, mu2, sigma1_sq, sigma2_sq, sigma12 = (concat_out[:, idx*channel:(idx+1)*channel, :, :] for idx in range(5))
    mu1_sq = mu1.pow(2); mu2_sq = mu2.pow(2); mu1_mu2 = mu1 * mu2
    sigma1_sq = compensation * (sigma1_sq - mu1_sq)
    sigma2_sq = compensation * (sigma2_sq - mu2_sq)
    sigma12 = compensation * (sigma12 - mu1_mu2)
    cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
    ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
    if size_average:
        ssim_val = ssim_map.mean(); cs = cs_map.mean()
    else:
        ssim_val = ssim_map.mean(-1).mean(-1).mean(-1); cs = cs_map.mean(-1).mean(-1).mean(-1)
    return (ssim_val, cs) if full else ssim_val

def ssim(X, Y, win_size=11, win_sigma=1.5, win=None, data_range=255, size_average=True, full=False):
    if len(X.shape) != 4: raise ValueError('Input images must 4-d tensor.')
    if not X.type() == Y.type(): raise ValueError('Input images must have the same dtype.')
    if not X.shape == Y.shape: raise ValueError('Input images must have the same dimensions.')
    if not (win_size % 2 == 1): raise ValueError('Window size must be odd.')
    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma).repeat(X.shape[1], 1, 1, 1)
    ssim_val, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, full=True)
    if size_average:
        ssim_val = ssim_val.mean(); cs = cs.mean()
    return (ssim_val, cs) if full else ssim_val

class SSIMLoss(torch.nn.Module):
    def __init__(self, win_size=11, win_sigma=1.5, data_range=None, size_average=True, channel=3):
        super().__init__()
        self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat(channel, 1, 1, 1)
        self.size_average = size_average
        self.data_range = data_range
    def forward(self, X, Y):
        if X.ndimension() == 5: X = X[:,0,...]
        if Y.ndimension() == 5: Y = Y[:,0,...]
        return 1 - ssim(X, Y, win=self.win, data_range=self.data_range, size_average=self.size_average)

class SAMLoss(torch.nn.Module):
    def __init__(self, size_average=False):
        super().__init__()
    def forward(self, img_base, img_out):
        if img_base.ndimension() == 5: img_base = img_base[:,0,...]
        if img_out.ndimension() == 5: img_out = img_out[:,0,...]
        sum1 = torch.sum(img_base * img_out, 1)
        sum2 = torch.sum(img_base * img_base, 1)
        sum3 = torch.sum(img_out * img_out, 1)
        t = (sum2 * sum3) ** 0.5
        numlocal = torch.gt(t, 0)
        num = torch.sum(numlocal)
        t = sum1 / t
        angle = torch.acos(t)
        sumangle = torch.where(torch.isnan(angle), torch.full_like(angle, 0), angle).sum()
        averangle = sumangle if num == 0 else sumangle / num
        return averangle * 180 / 3.14159256

## Loss composition

In [7]:
# Combined and consistency losses
class MultipleLoss(nn.Module):
    def __init__(self, losses, weight=None):
        super().__init__()
        self.losses = nn.ModuleList(losses)
        self.weight = weight or [1/len(self.losses)] * len(self.losses)
    def forward(self, predict, target):
        total_loss = 0
        for w, loss in zip(self.weight, self.losses):
            total_loss += loss(predict, target) * w
        return total_loss
    def extra_repr(self):
        return f'weight={self.weight}'

class L1Consist(nn.Module):
    def __init__(self, losses, weight=None):
        super().__init__()
        self.loss1 = losses[0]
        self.loss_cons = losses[1]
        self.weight = weight or [1/len(losses)] * len(losses)
    def forward(self, predict, target, inputs):
        total_loss = 0
        total_loss += self.loss1(predict, target) * self.weight[0]
        # Contrastive consistency loss is not available in this repository.
        raise RuntimeError('ContrastLoss is not available; use one of {"l1", "l2", "smooth_l1", "ssim", "l2_ssim", "l2_sam"}.')

def build_criterion(loss_name):
    if loss_name == 'l2': return nn.MSELoss()
    if loss_name == 'l1': return nn.L1Loss()
    if loss_name == 'smooth_l1': return nn.SmoothL1Loss()
    if loss_name == 'ssim': return SSIMLoss(data_range=1, channel=31)
    if loss_name == 'l2_ssim': return MultipleLoss([nn.MSELoss(), SSIMLoss(data_range=1, channel=31)], weight=[1, 2.5e-3])
    if loss_name == 'l2_sam': return MultipleLoss([nn.MSELoss(), SAMLoss()], weight=[1, 1e-3])
    if loss_name in ('cons', 'cons_l2'):
        raise RuntimeError('Loss mode %s requires ContrastLoss which is not present in this repository.' % loss_name)
    raise ValueError('Unknown loss_name: %s' % loss_name)

## Configuration

In [8]:
# Paths
dataroot = '/kaggle/input/icvl64-31-hs-v2/content/ICVL64_31.db'  # LMDB path for training patches
val_data_dir = '/kaggle/input/icvl-test-512'  # directory of .mat files for validation
test_dir = '/kaggle/input/icvl-test-512'  # directory of .mat files for testing
# Validation will use val_data_dir directly (no separate eval_dir)
save_dir = 'checkpoints'
checkpoint_dir = os.path.join(save_dir, 'model')
os.makedirs(save_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

# Training hyperparameters
batch_size = 16
learning_rate = 1e-4
weight_decay = 0.0
# When resuming, num_epochs indicates how many ADDITIONAL epochs to run beyond the checkpoint epoch count.
# Example: if the checkpoint is at epoch 10 and num_epochs=1, training will run for epoch 11 only.
num_epochs = 70
clip = 1e6
num_workers = 0  # DataLoader workers
no_cuda = False
no_log = True

# Control flags
chop = False
resume = True
resumePath = '/kaggle/input/sert-icvl-temp/model_latest.pth'  # Path to a checkpoint with keys {'net', 'optimizer', 'epoch'}; 'epoch' is treated as completed epochs

# Validation cadence
val_every = 5  # run validate_epoch every N epochs; saves best+latest when validation runs

# Model and loss
model_class_or_name = 'sert_base'  # Name of model in models package (callable, e.g., 'sert_base')
init_scheme = 'kn'  # ['kn','ku','xn','xu','edsr']
loss_name = 'l2'  # ['l1','l2','smooth_l1','ssim','l2_ssim','l2_sam']
use_2dconv = True  # will be auto-detected from model if attribute exists

# GPU configuration (mirrors hsi_setup --gpu-ids)
# Provide a comma-separated string like '0' or '0,1' to use multiple GPUs
_gpu_ids_str = '0'
def _parse_gpu_ids(args_str):
    parts = [s.strip() for s in args_str.split(',') if s.strip() != '']
    parsed = []
    for p in parts:
        try:
            v = int(p)
            if v >= 0:
                parsed.append(v)
        except ValueError:
            pass
    return parsed

gpu_ids = _parse_gpu_ids(_gpu_ids_str)
print('GPU IDs:', gpu_ids)

# Noise levels for simulated testing
sigma = None
sigma_test = 10

# Scheduler (original script decays LR at epoch==50)
step_decay_epoch = 250
gamma = 0.1

# Training-time mode toggle
# In the original scripts, the forward pass during training runs with net.eval() set
# because Engine.__step() unconditionally calls self.net.eval().
# Set this to True to match script behavior; set to False for conventional train-mode forwards.
eval_during_train = False
print('eval_during_train =', eval_during_train)

# Meta-learning params (defined in options; not used in these loops)
update_lr = 0.5e-4
meta_lr = 0.5e-4
n_way = 1
k_spt = 2
k_qry = 5
task_num = 16
update_step = 5
update_step_test = 10

# Reproducibility
seed = 2018
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

# Device selection
device = 'cuda' if (not no_cuda and torch.cuda.is_available()) else 'cpu'
print('Using device:', device)

GPU IDs: [0]
eval_during_train = False
Using device: cuda


## Model initialization and optimization

In [9]:
import models  # import from local models/ package

print(f"=> creating model '{model_class_or_name}'")
net = getattr(models, model_class_or_name)()
# Initialize params as in helper
init_params(net, init_type=init_scheme)
# Auto-detect 2D/3D conv usage if attribute exists
if hasattr(net, 'use_2dconv'):
    use_2dconv = bool(getattr(net, 'use_2dconv'))
print('use_2dconv =', use_2dconv)

# Device and optional DataParallel based on gpu_ids
if device == 'cuda' and isinstance(gpu_ids, list) and len(gpu_ids) > 1:
    net = nn.DataParallel(net.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0])
else:
    net = net.to(device)

criterion = build_criterion(loss_name).to(device)
optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay, amsgrad=False)
base_lr = learning_rate
print('Criterion:', criterion)
print('Number of parameters:', sum(p.numel() for p in net.parameters())/1e6, 'M')

# Resume support: load checkpoint and lift epoch/optimizer state if requested
start_epoch_completed = 0
if resume and resumePath and os.path.exists(resumePath):
    print('==> Resuming from checkpoint:', resumePath)
    ckpt = torch.load(resumePath, map_location=device)
    state = ckpt.get('net', ckpt)
    try:
        net.load_state_dict(state, strict=False)
    except Exception as e:
        print('Warning: non-strict load of state_dict due to mismatch:', e)
    if 'optimizer' in ckpt:
        try:
            optimizer.load_state_dict(ckpt['optimizer'])
        except Exception as e:
            print('Warning: could not load optimizer state:', e)
    start_epoch_completed = int(ckpt.get('epoch', 0))
    print(f'Resumed from completed epoch: {start_epoch_completed}')



=> creating model 'sert_base'
3
use init scheme: kn
use_2dconv = True
Criterion: MSELoss()
Number of parameters: 1.905319 M
==> Resuming from checkpoint: /kaggle/input/sert-icvl-temp/model_latest.pth
Resumed from completed epoch: 200


## Dataset preparation

In [10]:
# Training dataset: LMDB patches with blind Gaussian noise
try:
    icvl_dataset = LMDBDataset(dataroot)
    train_transform = Compose([
        AddNoiseBlindv1(10, 70),
        HSI2Tensor(use_2dconv)
    ])
    target_transform = HSI2Tensor(use_2dconv)
    train_dataset = ImageTransformDataset(icvl_dataset, train_transform, target_transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device=='cuda'), worker_init_fn=worker_init_fn)
    print('Training samples:', len(train_dataset))
except Exception as e:
    print('Warning: LMDBDataset unavailable; training may be skipped. Error:', e)
    train_loader = None

# Validation dataset: .mat files with simulated Gaussian noise on inputs only
val_mat_dataset = MatDataFromFolder(val_data_dir)



if not use_2dconv:
    val_mat_transform = Compose([
        LoadMatHSI(input_key='input', gt_key='gt', transform=lambda x: x[...][None], needsigma=False),
        # Add noise only to the input, keep gt clean
        (lambda x: (
            AddNoise(sigma_test)(x[0].numpy()) if isinstance(x[0], torch.Tensor) else AddNoise(sigma_val)(x[0]),
            x[1]
        ))
    ])
else:
    val_mat_transform = Compose([
        LoadMatHSI(input_key='input', gt_key='gt', needsigma=False),
        (lambda x: (
            AddNoise(sigma_test)(x[0].numpy()) if isinstance(x[0], torch.Tensor) else AddNoise(sigma_val)(x[0]),
            x[1]
        ))
    ])

def _apply_val_transform(sample):
    inp, gt = val_mat_transform(sample)
    # Ensure both are torch tensors
    if isinstance(inp, np.ndarray): inp = torch.from_numpy(inp).float()
    if isinstance(gt, np.ndarray): gt = torch.from_numpy(gt).float()
    return inp, gt

val_dataset_wrapped = TransformDatasetWrapper(val_mat_dataset, _apply_val_transform)
val_loader = DataLoader(val_dataset_wrapped, batch_size=1, shuffle=False, num_workers=1, pin_memory=(device == 'cuda'))
print('Validation samples:', len(val_dataset_wrapped))


# Test dataset: .mat files with simulated noise on inputs only (uses test_dir)
mat_dataset = MatDataFromFolder(test_dir)
if not use_2dconv:
    mat_transform = Compose([
        LoadMatHSI(input_key='input', gt_key='gt', transform=lambda x: x[...][None], needsigma=False),
        # Add noise only to input, keep gt clean
        (lambda x: (AddNoise(sigma_test)(x[0].numpy()) if isinstance(x[0], torch.Tensor) else AddNoise(sigma_test)(x[0]), x[1]))
    ])
else:
    mat_transform = Compose([
        LoadMatHSI(input_key='input', gt_key='gt', needsigma=False),
        (lambda x: (AddNoise(sigma_test)(x[0].numpy()) if isinstance(x[0], torch.Tensor) else AddNoise(sigma_test)(x[0]), x[1]))
    ])

def _apply_mat_transform(sample):
    inp, gt = mat_transform(sample)
    # mat_transform returns numpy arrays when noise added; ensure tensors
    if isinstance(inp, np.ndarray): inp = torch.from_numpy(inp).float()
    if isinstance(gt, np.ndarray): gt = torch.from_numpy(gt).float()
    return inp, gt

mat_dataset_wrapped = TransformDatasetWrapper(mat_dataset, _apply_mat_transform)
mat_loader = DataLoader(mat_dataset_wrapped, batch_size=1, shuffle=False, num_workers=1, pin_memory=(device=='cuda'))
print('Test samples:', len(mat_dataset_wrapped))

Training samples: 3200
Validation samples: 50
Test samples: 50


## Training and validation

In [11]:
# Validation and testing routines
@torch.no_grad()
def _forward_step(net, inputs, targets):
    return net(inputs.float())


def validate_epoch(net, loader, name='val'):
    # Evaluate in inference mode for metric stability
    was_training = net.training
    net.eval()
    validate_loss = 0.0
    total_psnr = 0.0
    total_sam = 0.0
    RMSE, SSIM, SAM, ERGAS, PSNR = [], [], [], [], []
    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = _forward_step(net, inputs, targets)
        loss_data = criterion(outputs, targets).item()
        psnr = np.mean(cal_bwpsnr(outputs, targets))
        sam = cal_sam(outputs, targets)
        validate_loss += loss_data
        total_psnr += psnr
        total_sam += sam
        avg_loss = validate_loss / (batch_idx+1)
        avg_psnr = total_psnr / (batch_idx+1)
        progress_bar(batch_idx, len(loader), 'Loss: %.4e | PSNR: %.4f | AVGPSNR: %.4f' % (avg_loss, psnr, avg_psnr))
        # Scalar metrics
        psnr_bands = []
        h, w = inputs.shape[-2:]
        band = inputs.shape[-3]
        result = outputs.squeeze().detach().cpu().numpy()
        img = targets.squeeze().detach().cpu().numpy()
        for k in range(band):
            psnr_bands.append(10*np.log10((h*w)/np.sum((result[k]-img[k])**2)))
        PSNR.append(np.mean(psnr_bands))
        mse = np.sum((result-img)**2) / (band*h*w) * 255*255
        RMSE.append(np.sqrt(mse))
        ssim_vals = []
        k1, k2 = 0.01, 0.03
        for k in range(band):
            cov = np.cov(result[k].reshape(h*w), img[k].reshape(h*w))[0,1]
            ssim_vals.append((2*np.mean(result[k])*np.mean(img[k])+k1**2) * (2*cov + k2**2) / (np.mean(result[k])**2+np.mean(img[k])**2+k1**2) / (np.var(result[k])+np.var(img[k])+k2**2))
        SSIM.append(np.mean(ssim_vals))
        temp = (np.sum(result*img, 0) + np.spacing(1)) /(np.sqrt(np.sum(result**2, 0) + np.spacing(1))) /(np.sqrt(np.sum(img**2, 0) + np.spacing(1)))
        SAM.append(np.mean(np.arccos(temp))*180/np.pi)
        ergas = 0.0
        for k in range(band):
            ergas += np.mean((img[k]-result[k])**2)/np.mean(img[k])**2
        ERGAS.append(100*np.sqrt(ergas/band))
    final = {
        'psnr': float(np.mean(PSNR)) if len(PSNR) else None,
        'rmse': float(np.mean(RMSE)) if len(RMSE) else None,
        'ssim': float(np.mean(SSIM)) if len(SSIM) else None,
        'sam': float(np.mean(SAM)) if len(SAM) else None,
        'ergas': float(np.mean(ERGAS)) if len(ERGAS) else None,
        'loss': float(validate_loss/len(loader)) if len(loader) else None
    }
    print('\n' + '='*60)
    print(f' {name.upper()} SUMMARY')
    print('='*60)
    if final['psnr'] is not None:
        print(f"PSNR: {final['psnr']:.4f} dB | SSIM: {final['ssim']:.4f} | SAM: {final['sam']:.4f}° | RMSE: {final['rmse']:.4f} | ERGAS: {final['ergas']:.4f}")
    print('='*60)
    if was_training and not eval_during_train:
        net.train()
    return final


def test_eval(net, loader):
    return validate_epoch(net, loader, name='test')

In [12]:
# # Training loop with resume-aware epoch counting
# metrics_log = []
# start_time = time.time()
# # start_epoch_completed is set in model setup (0 if not resuming)
# current_epoch_completed = start_epoch_completed
# best_psnr = -float('inf')
# best_ckpt_path = os.path.join(checkpoint_dir, 'model_best.pth')

# if 'train_loader' not in globals():
#     train_loader = None

# if train_loader is None:
#     print('Training loader is unavailable. Skipping training.')
# else:
#     # Run exactly num_epochs additional epochs beyond the checkpoint's completed epoch count
#     for add_epoch_idx in range(num_epochs):
#         # Epoch index to display and save as completed after this loop
#         epoch_to_run = current_epoch_completed + 1
#         # Match script behavior if eval_during_train is true
#         if eval_during_train:
#             net.eval()
#         else:
#             net.train()
#         train_loss_sum = 0.0
#         train_psnr_sum = 0.0
#         for batch_idx, (inputs, targets) in enumerate(train_loader):
#             inputs, targets = inputs.to(device), targets.to(device)
#             optimizer.zero_grad()
#             outputs = net(inputs.float())
#             loss = criterion(outputs, targets)
#             loss.backward()
#             torch.nn.utils.clip_grad_norm_(net.parameters(), clip)
#             optimizer.step()
#             train_loss_sum += loss.item()
#             psnr = np.mean(cal_bwpsnr(outputs, targets))
#             train_psnr_sum += psnr
#             avg_loss = train_loss_sum / (batch_idx+1)
#             avg_psnr = train_psnr_sum / (batch_idx+1)
#             progress_bar(batch_idx, len(train_loader), 'Epoch: %d | AvgLoss: %.4e | Loss: %.4e | PSNR: %4e' % (epoch_to_run, avg_loss, loss.item(), psnr))
#         # Learning-rate step decay based on absolute epoch number
#         if epoch_to_run == step_decay_epoch:
#             adjust_learning_rate(optimizer, base_lr * gamma)
#         # Save latest checkpoint after the epoch is completed
#         model_latest_path = os.path.join(checkpoint_dir, 'model_latest.pth')
#         torch.save({'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch_to_run}, model_latest_path)
#         # Periodic validation
#         if (epoch_to_run) % max(1, val_every) == 0 and 'val_loader' in globals() and val_loader is not None:
#             val_stats = validate_epoch(net, val_loader, name='val')
#             # Save per-validation checkpoint tagged with absolute epoch number
#             model_val_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch_to_run}.pth')
#             torch.save({'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch_to_run, 'val': val_stats}, model_val_path)
#             # Track best by PSNR
#             current_psnr = val_stats.get('psnr') or -float('inf')
#             if current_psnr > best_psnr:
#                 best_psnr = current_psnr
#                 torch.save({'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch_to_run, 'val': val_stats}, best_ckpt_path)
#                 print(f'New best PSNR {best_psnr:.4f} dB at epoch {epoch_to_run}. Saved to {best_ckpt_path}')
#             metrics_log.append({
#                 'epoch': epoch_to_run,
#                 'train_loss': float(avg_loss),
#                 'train_psnr': float(avg_psnr),
#                 'val_loss': val_stats.get('loss'),
#                 'psnr': val_stats.get('psnr'),
#                 'ssim': val_stats.get('ssim'),
#                 'sam': val_stats.get('sam'),
#                 'rmse': val_stats.get('rmse'),
#                 'ergas': val_stats.get('ergas')
#             })
#         else:
#             metrics_log.append({
#                 'epoch': epoch_to_run,
#                 'train_loss': float(avg_loss),
#                 'train_psnr': float(avg_psnr)
#             })
#         # Increment completed epoch count
#         current_epoch_completed = epoch_to_run

# elapsed = time.time() - start_time
# print('Total training time (s):', elapsed)

# # Final evaluation on the test set (noisy inputs)
# test_stats = test_eval(net, mat_loader) if 'mat_loader' in globals() else {}
# print('Test stats:', test_stats)

## Logging and visualization

In [13]:
# from IPython.display import display
# import pandas as pd
# import matplotlib.pyplot as plt
# import os

# # Persist metrics to CSV
# df = pd.DataFrame(metrics_log) if len(metrics_log) else pd.DataFrame([test_stats])
# csv_path = os.path.join(save_dir, 'training_metrics.csv')
# df.to_csv(csv_path, index=False)
# print('Saved metrics to', csv_path)
# display(df.head())

# # Plot only Train Loss
# if 'epoch' in df.columns and 'train_loss' in df.columns:
#     plt.figure(figsize=(8,5))
#     plt.plot(df['epoch'], df['train_loss'], label='Train Loss', color='tab:blue')
#     plt.title('Training Loss over Epochs')
#     plt.xlabel('Epoch')
#     plt.ylabel('Loss')
#     plt.legend()
#     plt.grid(True, linestyle='--', alpha=0.6)
#     plt.tight_layout()
#     plt.show()

## Simulated test

constructing the model, loading a checkpoint, creating a .mat DataLoader with Gaussian noise on inputs (sigma_test), and reporting metrics.

In [14]:
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
import models as _models

# Configuration
_test_arch = 'sert_base'
_test_prefix = 'sert_base_gaussian_test'
_resume = True
_resume_path = '/kaggle/input/sert-icvl-checkpoint/icvl_gaussian.pth'
_sigma_tests = [10,30,50,70, 'blind']  # Add 'blind' as a special case
_test_dir_override = '/kaggle/input/icvl-test-512'

# Create model
print(f"Creating model {_test_arch}")
_test_net = getattr(_models, _test_arch)()

# Device setup
if device == 'cuda' and isinstance(gpu_ids, list) and len(gpu_ids) > 1:
    _test_net = nn.DataParallel(_test_net.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0])
else:
    _test_net = _test_net.to(device)

# Load checkpoint
if _resume and _resume_path and os.path.exists(_resume_path):
    print("Loading checkpoint from", _resume_path)
    ckpt = torch.load(_resume_path, map_location=device)
    state = ckpt.get('net', ckpt)
    _test_net.load_state_dict(state, strict=False)
else:
    print("Checkpoint not found or resume disabled, using random weights")

# Handle DataParallel for use_2dconv
if isinstance(_test_net, nn.DataParallel):
    use_2dconv = getattr(_test_net.module, 'use_2dconv', True)
else:
    use_2dconv = getattr(_test_net, 'use_2dconv', True)
print('use_2dconv =', use_2dconv)

# Base dataset
_mat_dataset = MatDataFromFolder(_test_dir_override)

# Wrapper to run evaluation for a given sigma
def run_test_for_sigma(sigma_value):
    if sigma_value == 'blind':
        noise_adder = AddNoiseBlindv1(10, 70)
    else:
        noise_adder = lambda img: AddNoise(sigma_value)(img)
    
    def _mat_transform(sample):
        inp, gt = LoadMatHSI(input_key='input', gt_key='gt', needsigma=False)(sample)
        inp = inp.numpy() if isinstance(inp, torch.Tensor) else inp
        inp_noisy = noise_adder(inp)
        return inp_noisy, gt

    def _apply_transform(sample):
        inp, gt = _mat_transform(sample)
        if isinstance(inp, np.ndarray):
            inp = torch.from_numpy(inp).float()
        if isinstance(gt, np.ndarray):
            gt = torch.from_numpy(gt).float()
        if not use_2dconv:
            inp = inp.unsqueeze(0)  # Add T dimension for 3D model
            gt = gt.unsqueeze(0)
        return inp, gt

    _mat_dataset_wrapped = TransformDatasetWrapper(_mat_dataset, _apply_transform)
    _mat_loader = DataLoader(_mat_dataset_wrapped, batch_size=1, shuffle=False, num_workers=1, pin_memory=(device=='cuda'))

    print(f"Testing with noise sigma = {sigma_value}")
    _test_net.eval()
    stats = validate_epoch(_test_net, _mat_loader, name=f"test_sigma{sigma_value}")
    return stats

# Run tests for all sigma values
_results_rows = []
for sigma in _sigma_tests:
    sigma_stats = run_test_for_sigma(sigma)
    if isinstance(sigma_stats, dict) and len(sigma_stats):
        row = {'run': f'sigma_{sigma}'}
        row.update({k: v for k, v in sigma_stats.items() if k in ['psnr','ssim','sam','rmse','ergas','loss']})
        _results_rows.append(row)

# Compile and save results
if len(_results_rows) > 0:
    results_df = pd.DataFrame(_results_rows)
    cols = [c for c in ['run','psnr','ssim','sam','rmse','ergas','loss'] if c in results_df.columns]
    results_df = results_df[cols]
    display(results_df.style.format({
        'psnr': '{:.4f}', 'ssim': '{:.4f}', 'sam': '{:.4f}',
        'rmse': '{:.4f}', 'ergas': '{:.4f}', 'loss': '{:.4e}'
    }))
    out_csv = os.path.join(save_dir, f"{_test_prefix}_results.csv")
    results_df.to_csv(out_csv, index=False)
    print("Saved test results table to:", out_csv)
else:
    print("No test results available")


Creating model sert_base
3
Loading checkpoint from /kaggle/input/sert-icvl-checkpoint/icvl_gaussian.pth
use_2dconv = True
Testing with noise sigma = 10

 TEST_SIGMA10 SUMMARY
PSNR: 47.7196 dB | SSIM: 0.9988 | SAM: 1.3612° | RMSE: 1.1227 | ERGAS: 3.8066
Testing with noise sigma = 30

 TEST_SIGMA30 SUMMARY
PSNR: 43.5608 dB | SSIM: 0.9969 | SAM: 1.7687° | RMSE: 1.8509 | ERGAS: 5.8625
Testing with noise sigma = 50

 TEST_SIGMA50 SUMMARY
PSNR: 41.3323 dB | SSIM: 0.9949 | SAM: 2.0575° | RMSE: 2.4169 | ERGAS: 7.3456
Testing with noise sigma = 70

 TEST_SIGMA70 SUMMARY
PSNR: 39.8217 dB | SSIM: 0.9928 | SAM: 2.2937° | RMSE: 2.8941 | ERGAS: 8.5476
Testing with noise sigma = blind

 TEST_SIGMABLIND SUMMARY
PSNR: 42.9495 dB | SSIM: 0.9961 | SAM: 1.8645° | RMSE: 2.1019 | ERGAS: 6.3266


Unnamed: 0,run,psnr,ssim,sam,rmse,ergas,loss
0,sigma_10,47.7196,0.9988,1.3612,1.1227,3.8066,1.98e-05
1,sigma_30,43.5608,0.9969,1.7687,1.8509,5.8625,5.3823e-05
2,sigma_50,41.3323,0.9949,2.0575,2.4169,7.3456,9.2037e-05
3,sigma_70,39.8217,0.9928,2.2937,2.8941,8.5476,0.00013219
4,sigma_blind,42.9495,0.9961,1.8645,2.1019,6.3266,7.6588e-05


Saved test results table to: checkpoints/sert_base_gaussian_test_results.csv
