In [1]:
!pip install gputil h5py timm PyWavelets -q

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m96.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m77.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m40.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01

In [2]:
!pip install -q timm --upgrade

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.8/60.8 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m36.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from scipy.io import loadmat
import numpy as np
import cv2
import os
import timm
from timm.models.swin_transformer import SwinTransformerBlock
import pywt
import warnings
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, cohen_kappa_score, classification_report
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import random
import h5py
import GPUtil
from torch.amp import GradScaler
from torch import amp
torch.set_float32_matmul_precision('medium')   # or 'high' if your GPU supports it
torch.backends.cudnn.benchmark = True
# Suppress the UserWarning from PyWavelets for small patches
warnings.filterwarnings("ignore", message="Level value of .* is too high: all coefficients will experience boundary effects.")

print("✅ Setup complete. All packages installed and libraries imported.")

✅ Setup complete. All packages installed and libraries imported.


In [4]:
def create_lr_hr_pairs(hsi_cube, scale_factor=4):
    """Creates Low-Res (downsampled) and High-Res pairs for training."""
    hr_image = hsi_cube
    blurred_hr = cv2.GaussianBlur(hr_image, (5, 5), 0)
    lr_h, lr_w = hr_image.shape[0] // scale_factor, hr_image.shape[1] // scale_factor
    lr_downsampled = cv2.resize(blurred_hr, (lr_w, lr_h), interpolation=cv2.INTER_CUBIC)
    return lr_downsampled, hr_image

class WaveletDenoise3D(nn.Module):
    """Implements the 3D wavelet denoising method from the paper."""
    def __init__(self, wavelet='db4', level=1):
        super().__init__()
        self.wavelet = wavelet
        self.level = level
        self.coeffs_to_remove = ['aad', 'ada', 'daa', 'ddd']

    def forward(self, x):
        input_numpy = x.detach().cpu().numpy()
        output_numpy = np.empty_like(input_numpy)
        for i in range(input_numpy.shape[0]):
            volume = input_numpy[i]
            coeffs = pywt.wavedecn(volume, self.wavelet, level=self.level)
            for detail_dict in coeffs[1:]:
                for key in self.coeffs_to_remove:
                    if key in detail_dict:
                        detail_dict[key] = np.zeros_like(detail_dict[key])
            denoised_volume = pywt.waverecn(coeffs, self.wavelet)
            c, h, w = volume.shape
            output_numpy[i] = denoised_volume[:c, :h, :w]
        return torch.from_numpy(output_numpy).to(x.device)

def denoise_cube(lr_cube):
    """Applies the 3D wavelet denoising to an entire HSI cube one time before training."""
    print("Starting one-time denoising of the LR cube...")
    temp_tensor = torch.from_numpy(lr_cube).permute(2, 0, 1).unsqueeze(0)
    denoiser = WaveletDenoise3D()
    denoised_tensor = denoiser(temp_tensor)
    denoised_cube = denoised_tensor.squeeze(0).permute(1, 2, 0).numpy()
    print("Denoising complete.")
    return denoised_cube

print("✅ Data preparation functions defined.")

✅ Data preparation functions defined.


In [5]:
import os
import random
import numpy as np
import cv2

def load_scene_from_pngs(scene_path):
    """Load a 31-band HSI scene from sorted PNGs in a directory."""
    # ⚠️ Only include files that match the *_ms_XX.png format
    png_files = [
        f for f in os.listdir(scene_path)
        if f.endswith('.png') and '_ms_' in f
    ]
    
    # ✅ Sort by band number extracted from filename
    def band_number(filename):
        parts = filename.split('_')
        band_str = parts[-1].replace('.png', '')
        return int(band_str)

    png_files = sorted(png_files, key=band_number)

    if len(png_files) != 31:
        print(f"    ❌ Skipping: Found {len(png_files)} bands, expected 31.")
        return None

    first_img_path = os.path.join(scene_path, png_files[0])
    sample_img = cv2.imread(first_img_path, cv2.IMREAD_GRAYSCALE)
    if sample_img is None:
        print(f"    ❌ Corrupt image: {png_files[0]}")
        return None
    
    h, w = sample_img.shape
    hsi_cube = np.zeros((h, w, 31), dtype=np.float32)

    for i, fname in enumerate(png_files):
        file_path = os.path.join(scene_path, fname)
        band_img = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
        if band_img is None:
            print(f"    ⚠️ Skipping corrupt image: {file_path}")
            return None
        # ✅ Normalize 8-bit image
        hsi_cube[:, :, i] = band_img.astype(np.float32) / 255.0

    return hsi_cube

# Dummy placeholders – replace with your actual functions
def create_lr_hr_pairs(hsi_cube, scale_factor):
    hr = hsi_cube[np.newaxis, ...]  # Add batch dimension
    lr = cv2.resize(hsi_cube, (hsi_cube.shape[1] // scale_factor, hsi_cube.shape[0] // scale_factor))
    lr = cv2.resize(lr, (hsi_cube.shape[1], hsi_cube.shape[0]))  # Upsample back
    lr = lr[np.newaxis, ...]
    return lr, hr

def denoise_cube(cube):
    return cube  # Placeholder (no denoising)

def prepare_data_cubes_from_folders(scene_folder_list, scale_factor=4):
    hr_cubes, lr_denoised_cubes = [], []

    for scene_path in scene_folder_list:
        print(f"  Processing: {os.path.basename(scene_path)}")
        hsi_cube = load_scene_from_pngs(scene_path)

        if hsi_cube is None:
            continue

        # Crop to multiple of scale
        h, w, c = hsi_cube.shape
        hsi_cube = hsi_cube[:h - (h % scale_factor), :w - (w % scale_factor), :]

        # Normalize each band (again, optional — already normalized)
        for i in range(c):
            band = hsi_cube[:, :, i]
            min_val, max_val = np.min(band), np.max(band)
            if max_val > min_val:
                hsi_cube[:, :, i] = (band - min_val) / (max_val - min_val)

        lr_down, hr = create_lr_hr_pairs(hsi_cube, scale_factor=scale_factor)
        lr_denoised = denoise_cube(lr_down)

        hr_cubes.append(hr)
        lr_denoised_cubes.append(lr_denoised)

    if not hr_cubes:
        raise ValueError("❌ No valid 31-band scenes found.")

    return np.concatenate(hr_cubes, axis=0), np.concatenate(lr_denoised_cubes, axis=0)

# --- Traverse the double-nested dataset structure ---
dataset_path = '/kaggle/input/cave-hsi/'

scene_folders = []

for scene_name in os.listdir(dataset_path):
    outer_path = os.path.join(dataset_path, scene_name)
    inner_path = os.path.join(outer_path, scene_name)
    
    if os.path.isdir(inner_path):
        png_files = [f for f in os.listdir(inner_path) if f.endswith('.png') and '_ms_' in f]
        if len(png_files) == 31:
            scene_folders.append(inner_path)
        else:
            print(f"⚠️ Skipping {scene_name}: found {len(png_files)} PNGs, expected 31.")

print(f"✅ Ready to process {len(scene_folders)} valid scenes.")
hr_cubes, lr_denoised_cubes = prepare_data_cubes_from_folders(scene_folders)
print(f"✅ HR shape: {hr_cubes.shape}, LR shape: {lr_denoised_cubes.shape}")

✅ Ready to process 32 valid scenes.
  Processing: oil_painting_ms
  Processing: superballs_ms
  Processing: egyptian_statue_ms
  Processing: fake_and_real_tomatoes_ms
  Processing: photo_and_face_ms
  Processing: glass_tiles_ms
  Processing: beads_ms
  Processing: fake_and_real_lemon_slices_ms
  Processing: hairs_ms
  Processing: chart_and_stuffed_toy_ms
  Processing: watercolors_ms
  Processing: clay_ms
  Processing: stuffed_toys_ms
  Processing: fake_and_real_peppers_ms
  Processing: fake_and_real_strawberries_ms
  Processing: sponges_ms
  Processing: face_ms
  Processing: cd_ms
  Processing: fake_and_real_beers_ms
  Processing: real_and_fake_apples_ms
  Processing: feathers_ms
  Processing: fake_and_real_food_ms
  Processing: jelly_beans_ms
  Processing: balloons_ms
  Processing: thread_spools_ms
  Processing: flowers_ms
  Processing: paints_ms
  Processing: pompoms_ms
  Processing: fake_and_real_sushi_ms
  Processing: cloth_ms
  Processing: real_and_fake_peppers_ms
  Processing: fa

In [6]:
class HSISuperResolutionDataset(Dataset):
    def __init__(self, hr_cube, lr_downsampled_cube, hr_patch_size=63, scale_factor=4):
        self.hr_cube = hr_cube
        self.lr_cube = lr_downsampled_cube
        self.hr_patch_size = hr_patch_size
        self.lr_patch_size = hr_patch_size // scale_factor
        self.scale_factor = scale_factor

        self.n_scenes, self.h, self.w, self.bands = self.hr_cube.shape
        self.patch_step = self.hr_patch_size // 2
        self.band_groups = self.bands // 5
        self.patch_coords = [
            (scene_idx, r, c)
            for scene_idx in range(self.n_scenes)
            for r in range(0, self.h - self.hr_patch_size + 1, self.patch_step)
            for c in range(0, self.w - self.hr_patch_size + 1, self.patch_step)
        ]
        self.total_samples = len(self.patch_coords) * self.band_groups

    def __len__(self):
        return self.total_samples

    def __getitem__(self, index):
        patch_index = index % len(self.patch_coords)
        band_group_index = index // len(self.patch_coords)
        scene_idx, r, c = self.patch_coords[patch_index]
        hr_patch = self.hr_cube[scene_idx, r:r + self.hr_patch_size, c:c + self.hr_patch_size, :]
        lr_patch = self.lr_cube[scene_idx,
                                r // self.scale_factor:r // self.scale_factor + self.lr_patch_size,
                                c // self.scale_factor:c // self.scale_factor + self.lr_patch_size, :]
        start_band = band_group_index * 5
        end_band = start_band + 5
        hr_patch = hr_patch[:, :, start_band:end_band]
        lr_patch = lr_patch[:, :, start_band:end_band]

        if np.random.rand() > 0.5:
            lr_patch = np.ascontiguousarray(np.flip(lr_patch, axis=1))
            hr_patch = np.ascontiguousarray(np.flip(hr_patch, axis=1))
        if np.random.rand() > 0.5:
            lr_patch = np.ascontiguousarray(np.flip(lr_patch, axis=0))
            hr_patch = np.ascontiguousarray(np.flip(hr_patch, axis=0))

        lr_tensor = torch.from_numpy(lr_patch.copy()).float().permute(2, 0, 1)
        center_band_idx = hr_patch.shape[2] // 2
        hr_tensor = torch.from_numpy(hr_patch[:, :, center_band_idx:center_band_idx+1].copy()).float().permute(2, 0, 1)
        return lr_tensor, hr_tensor

In [7]:
class ShallowFeatureExtractor(nn.Module):
    def __init__(self, in_channels=5, base_channels=64):
        super().__init__()
        self.conv2d_initial = nn.Conv2d(in_channels, base_channels, 3, padding=1)
        self.res_block = nn.Sequential(
            nn.BatchNorm2d(base_channels),
            nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1)
        )
        self.spatial_conv3d = nn.Conv3d(base_channels, base_channels, kernel_size=(3, 3, 3), padding=1)
        self.spectral_conv3d = nn.Conv3d(base_channels, base_channels, kernel_size=(3, 3, 3), padding=1)
        self.fusion_conv3d = nn.Conv3d(base_channels, base_channels, kernel_size=(1, 1, 1))

    def forward(self, x):
        t = self.conv2d_initial(x)
        t = t + self.res_block(t)
        t_3d = t.unsqueeze(2)
        feat_spa = self.spatial_conv3d(t_3d)
        feat_spec = self.spectral_conv3d(t_3d)
        fused_feat = feat_spa + feat_spec
        return self.fusion_conv3d(fused_feat).squeeze(2)

class SwinTransformerBlockWrapper(nn.Module):
    def __init__(self, dim, num_heads, window_size=7):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size

    def forward(self, x):
        B, N, C = x.shape
        H = W = int(np.sqrt(N))
        swin_block = SwinTransformerBlock(
            dim=self.dim,
            input_resolution=(H, W),
            num_heads=self.num_heads,
            window_size=7,
            shift_size=3
        ).to(x.device)
        x = x.view(B, H, W, C)
        x = swin_block(x)
        return x.view(B, -1, C)

class SpectralAttention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.sigma = nn.Parameter(torch.ones(1))  # Learnable

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (self.sigma * q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)
class DSAL(nn.Module):
    def __init__(self, dim, num_heads, window_size=7):
        super().__init__()
        self.norm1, self.norm2 = nn.LayerNorm(dim), nn.LayerNorm(dim)
        self.norm3, self.norm4 = nn.LayerNorm(dim), nn.LayerNorm(dim)
        self.spatial_attention = SwinTransformerBlockWrapper(dim, num_heads, window_size=7)
        self.spectral_attention = SpectralAttention(dim, num_heads)
        self.mlp1 = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
        self.mlp2 = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))

    def forward(self, x):
        x = x + self.spatial_attention(self.norm1(x))
        x = x + self.mlp1(self.norm2(x))
        x = x + self.spectral_attention(self.norm3(x))
        x = x + self.mlp2(self.norm4(x))
        return x

class DSST_Module(nn.Module):
    def __init__(self, dim, num_heads, num_dsal=6, window_size=7):
        super().__init__()
        self.dsal_blocks = nn.ModuleList([DSAL(dim, num_heads, window_size=7) for _ in range(num_dsal)])
        self.conv_final = nn.Conv2d(dim, dim, 3, padding=1)

    def forward(self, x):
        B, C, H, W = x.shape
        x_res = x
        x = x.flatten(2).transpose(1, 2)
        for block in self.dsal_blocks:
            x = block(x)
        x = x.transpose(1, 2).reshape(B, C, H, W)
        return self.conv_final(x) + x_res
        
class DSSTSR(nn.Module):
    def __init__(self, in_channels=5, base_channels=64, out_channels=1, scale_factor=4, num_dsst_modules=4, window_size=7):
        super().__init__()
        self.scale_factor = scale_factor
        self.shallow_extractor = ShallowFeatureExtractor(in_channels=in_channels, base_channels=base_channels)
        self.deep_extractor = nn.Sequential(*[DSST_Module(base_channels, num_heads=8, window_size=window_size)
                                              for _ in range(num_dsst_modules)])
        self.upsampler = nn.Sequential(
            nn.Conv2d(base_channels, out_channels * (scale_factor ** 2), 3, padding=1),
            nn.PixelShuffle(scale_factor)
        )

    def forward(self, x):
        center_band_idx = x.shape[1] // 2
        bicubic_upsampled_input = F.interpolate(
            x[:, center_band_idx:center_band_idx+1, :, :], 
            scale_factor=self.scale_factor, 
            mode='bicubic', 
            align_corners=False
        )
        shallow_features = self.shallow_extractor(x)
        deep_features = self.deep_extractor(shallow_features)
        residual = self.upsampler(deep_features)
        return bicubic_upsampled_input + residual

print("✅ Model architecture classes defined.")

✅ Model architecture classes defined.


In [8]:
class HSI_Classifier(nn.Module):
    def __init__(self, num_bands, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(num_bands, 100)
        self.fc2 = nn.Linear(100, num_classes)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class PixelDataset(Dataset):
    def __init__(self, X, y):
        self.X, self.y = torch.from_numpy(X).float(), torch.from_numpy(y).long()
    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.y[idx]

def run_classification_evaluation(sr_image_cube, gt_path):
    print("\n--- Starting Classification Evaluation ---")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load ground truth
    gt_data = loadmat(gt_path)
    gt_map = gt_data['indian_pines_gt']
    
    # --- ADD THIS CODE TO CROP THE GROUND TRUTH MAP ---
    # Ensure gt_map has the same spatial dimensions as the sr_image_cube
    h, w, _ = sr_image_cube.shape
    gt_map = gt_map[:h, :w]
    # --- End of New Code ---
    
    # Prepare pixel data (this line will now work correctly)
    X_pixels, y_pixels = sr_image_cube[gt_map > 0], gt_map[gt_map > 0] - 1 
    num_classes = len(np.unique(y_pixels))
    
    # Paper's "30 samples per class" strategy
    train_size = 30 * num_classes
    if train_size >= len(y_pixels): # Make sure train_size is not too large
        train_size = int(len(y_pixels) * 0.5)

    X_train, X_test, y_train, y_test = train_test_split(
        X_pixels, y_pixels, train_size=train_size, stratify=y_pixels, random_state=42
    )
    
    train_loader = DataLoader(PixelDataset(X_train, y_train), batch_size=128, shuffle=True)
    test_loader = DataLoader(PixelDataset(X_test, y_test), batch_size=128, shuffle=False)
    
    classifier = HSI_Classifier(num_bands=sr_image_cube.shape[2], num_classes=num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=0.001)
    
    print("Training classifier...")
    for epoch in range(100): # Train for a reasonable number of epochs
        for data, labels in train_loader:
            data, labels = data.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = classifier(data)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
    
    print("Evaluating classifier...")
    classifier.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for data, labels in test_loader:
            data = data.to(device)
            outputs = classifier(data)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    oa = accuracy_score(all_labels, all_preds)
    kappa = cohen_kappa_score(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0)
    aa = np.mean([report[str(i)]['recall'] for i in range(num_classes)])
    
    return {'OA': oa, 'AA': aa, 'Kappa': kappa}
print("✅ Classification experiment functions defined.")

✅ Classification experiment functions defined.


In [9]:
import torch
import numpy as np
import time
import gc

def t4_optimized_reconstruction(model, lr_scene, hr_shape, scale_factor, device):
    h, w, bands = hr_shape
    reconstructed = np.zeros((h, w, bands), dtype=np.float32)
    weight_map = np.zeros_like(reconstructed)

    print(f" T4-optimized reconstruction: {h}×{w}×{bands}")
    start_time = time.time()

    patch_size = 63  # as per paper
    step_size = 48   # as per paper

    lr_h, lr_w = lr_scene.shape[:2]

    torch.cuda.empty_cache()
    total_memory = torch.cuda.get_device_properties(device).total_memory
    allocated_memory = torch.cuda.memory_allocated(device)
    free_memory = total_memory - allocated_memory

    if free_memory > 8e9:
        initial_batch_size = 12
    elif free_memory > 4e9:
        initial_batch_size = 8
    else:
        initial_batch_size = 4

    print(f" Available memory: {free_memory/1e9:.1f}GB, using batch size: {initial_batch_size}")

    # Cover entire image including edges
    patch_coords = []
    for r in range(0, lr_h - patch_size + 1, step_size):
        for c in range(0, lr_w - patch_size + 1, step_size):
            patch_coords.append((r, c))
    if (lr_h - patch_size) % step_size != 0:
        for c in range(0, lr_w - patch_size + 1, step_size):
            patch_coords.append((lr_h - patch_size, c))
    if (lr_w - patch_size) % step_size != 0:
        for r in range(0, lr_h - patch_size + 1, step_size):
            patch_coords.append((r, lr_w - patch_size))
    if (lr_h - patch_size) % step_size != 0 and (lr_w - patch_size) % step_size != 0:
        patch_coords.append((lr_h - patch_size, lr_w - patch_size))

    print(f" Processing {len(patch_coords)} patches across {bands} bands")

    model.eval()
    with torch.no_grad():
        band_to_input_bands = {}
        for band_idx in range(bands):
            if band_idx < 2:
                input_bands = [0, 1, 2, 3, 4]
            elif band_idx > bands - 3:
                input_bands = [bands-5, bands-4, bands-3, bands-2, bands-1]
            else:
                input_bands = [band_idx-2, band_idx-1, band_idx, band_idx+1, band_idx+2]
            band_to_input_bands[band_idx] = tuple(input_bands)

        input_to_output_bands = {}
        for output_band, input_bands in band_to_input_bands.items():
            input_to_output_bands.setdefault(input_bands, []).append(output_band)

        print(f" Optimization: {bands} bands grouped into {len(input_to_output_bands)} unique inputs")

        for input_idx, (input_bands, output_bands) in enumerate(input_to_output_bands.items()):
            print(f" Processing input {input_idx+1}/{len(input_to_output_bands)}: bands {input_bands} → output bands {output_bands}")
            lr_input_5bands = lr_scene[:, :, list(input_bands)]

            all_patches = []
            valid_coords = []

            for r, c in patch_coords:
                patch = lr_input_5bands[r:r+patch_size, c:c+patch_size, :]
                if patch.shape[:2] == (patch_size, patch_size):
                    all_patches.append(patch)
                    valid_coords.append((r, c))

            if not all_patches:
                continue

            current_batch_size = initial_batch_size

            for batch_start in range(0, len(all_patches), current_batch_size):
                batch_end = min(batch_start + current_batch_size, len(all_patches))
                batch_patches = all_patches[batch_start:batch_end]
                batch_coords = valid_coords[batch_start:batch_end]

                success = False
                retry_count = 0

                while not success and retry_count < 3:
                    try:
                        batch_tensor = torch.stack([
                            torch.from_numpy(patch).float().permute(2, 0, 1)
                            for patch in batch_patches
                        ]).to(device, non_blocking=True)

                        with torch.amp.autocast('cuda'):
                            hr_batch = model(batch_tensor)

                        hr_batch_cpu = hr_batch.cpu().numpy()

                        for i, (r, c) in enumerate(batch_coords):
                            hr_patch = hr_batch_cpu[i, 0]

                            hr_r = r * scale_factor
                            hr_c = c * scale_factor
                            hr_r_end = min(hr_r + hr_patch.shape[0], h)
                            hr_c_end = min(hr_c + hr_patch.shape[1], w)

                            patch_h = hr_r_end - hr_r
                            patch_w = hr_c_end - hr_c

                            if patch_h > 0 and patch_w > 0:
                                for output_band in output_bands:
                                    reconstructed[hr_r:hr_r_end, hr_c:hr_c_end, output_band] += hr_patch[:patch_h, :patch_w]
                                    weight_map[hr_r:hr_r_end, hr_c:hr_c_end, output_band] += 1

                        success = True

                    except RuntimeError as e:
                        if "out of memory" in str(e):
                            retry_count += 1
                            current_batch_size = max(1, current_batch_size // 2)
                            print(f" OOM detected, reducing batch size to {current_batch_size}, retry {retry_count}/3")
                            torch.cuda.empty_cache()
                            gc.collect()
                        else:
                            raise e

                if batch_start % (current_batch_size * 4) == 0:
                    torch.cuda.empty_cache()

            print(f" Completed input {input_idx+1}/{len(input_to_output_bands)}")

    # Normalize overlapping contributions
    weight_map[weight_map == 0] = 1  # avoid division by zero
    reconstructed /= weight_map

    total_time = time.time() - start_time
    print(f" T4-optimized reconstruction completed in {total_time:.1f} seconds")

    return reconstructed


In [10]:
# Debug: Check what model files exist
import os
print("\n🔍 Checking for saved model files...")
potential_files = [
    'best_dsstsr_weights.pth',
    'best_dsstsr_model_fixed.pth', 
    'final_dsstsr_checkpoint.pth',
    'early_stopped_model.pth'
]

available_files = []
for file in potential_files:
    if os.path.exists(file):
        file_size = os.path.getsize(file) / (1024*1024)  # MB
        print(f"  ✅ Found: {file} ({file_size:.1f} MB)")
        available_files.append(file)
    else:
        print(f"  ❌ Missing: {file}")

if available_files:
    model_path = available_files[0]  # Use the first available file
    print(f"\n📁 Using model file: {model_path}")
else:
    print("\n⚠️ No model files found!")



🔍 Checking for saved model files...
  ❌ Missing: best_dsstsr_weights.pth
  ❌ Missing: best_dsstsr_model_fixed.pth
  ❌ Missing: final_dsstsr_checkpoint.pth
  ❌ Missing: early_stopped_model.pth

⚠️ No model files found!


In [11]:
import concurrent.futures
import threading

def process_scene_on_gpu(gpu_id, model_path, lr_scene, hr_shape, scale_factor, scene_name):
    """Process one scene on specific GPU with robust model loading"""
    device = torch.device(f'cuda:{gpu_id}')
    torch.cuda.set_device(gpu_id)
    torch.cuda.empty_cache()
    
    # Load model on this GPU
    model = DSSTSR(in_channels=5, out_channels=1, base_channels=64, scale_factor=scale_factor)
    
    try:
        # Load the model file
        checkpoint = torch.load(model_path, map_location=device)
        
        # Handle both checkpoint and weights-only files
        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            # This is a full checkpoint
            state_dict = checkpoint['model_state_dict']
            epoch_info = checkpoint.get('epoch', 'unknown')
            val_loss = checkpoint.get('val_loss', 'unknown')
            print(f"    GPU {gpu_id}: Loaded checkpoint from epoch {epoch_info} (val_loss: {val_loss})")
        else:
            # This is weights-only
            state_dict = checkpoint
            print(f"    GPU {gpu_id}: Loaded weights-only file")
        
        # Handle DataParallel state dict (remove 'module.' prefix)
        if any(k.startswith('module.') for k in state_dict.keys()):
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
            print(f"    GPU {gpu_id}: Removed DataParallel prefix from state dict")
        
        # Load the state dict into model
        model.load_state_dict(state_dict)
        print(f"    GPU {gpu_id}: ✅ Model loaded successfully for {scene_name}")
        
    except Exception as e:
        print(f"    GPU {gpu_id}: ❌ Error loading model: {e}")
        print(f"    GPU {gpu_id}: ⚠️ Using randomly initialized model for {scene_name}")
    
    model.to(device)
    model.eval()
    
    print(f"    GPU {gpu_id}: 🔄 Processing {scene_name}...")
    
    # Run reconstruction
    try:
        result = t4_optimized_reconstruction(model, lr_scene, hr_shape, scale_factor, device)
        print(f"    GPU {gpu_id}: ✅ Completed {scene_name}")
        return scene_name, result
    
    except Exception as e:
        print(f"    GPU {gpu_id}: ❌ Error during reconstruction of {scene_name}: {e}")
        # Return dummy result to prevent the entire process from failing
        dummy_result = torch.zeros(hr_shape, dtype=torch.float32)
        return scene_name, dummy_result.numpy()

def parallel_scene_processing(scenes_data, model_path, scale_factor=4):
    """Process scenes in parallel across GPUs with better error handling"""
    num_gpus = torch.cuda.device_count()
    print(f"🚀 Starting parallel processing on {num_gpus} GPUs")
    print(f"📁 Using model file: {model_path}")
    
    # Verify model file exists
    if not os.path.exists(model_path):
        print(f"❌ Model file not found: {model_path}")
        available_models = [f for f in os.listdir('.') if f.endswith('.pth')]
        if available_models:
            print(f"Available .pth files: {available_models}")
            model_path = available_models[0]
            print(f"🔄 Using fallback model: {model_path}")
        else:
            raise FileNotFoundError(f"No model file found: {model_path}")
    
    # Show model file info
    model_size_mb = os.path.getsize(model_path) / (1024 * 1024)
    print(f"📊 Model file size: {model_size_mb:.1f} MB")
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_gpus) as executor:
        print(f"⚡ Submitting {len(scenes_data)} scenes to {num_gpus} GPUs...")
        futures = []
        
        for i, (lr_scene, hr_shape, scene_name) in enumerate(scenes_data):
            gpu_id = i % num_gpus  # Round-robin GPU assignment
            print(f"  📤 Submitting {scene_name} to GPU {gpu_id}")
            
            future = executor.submit(
                process_scene_on_gpu, 
                gpu_id, model_path, lr_scene, hr_shape, scale_factor, scene_name
            )
            futures.append(future)
        
        # Collect results with progress tracking
        results = {}
        completed = 0
        total = len(futures)
        
        print(f"\n🔄 Processing {total} scenes...")
        for future in concurrent.futures.as_completed(futures):
            try:
                scene_name, result = future.result()
                results[scene_name] = result
                completed += 1
                print(f"  ✅ [{completed}/{total}] Completed: {scene_name}")
                
            except Exception as e:
                completed += 1
                print(f"  ❌ [{completed}/{total}] Failed: {e}")
                continue
    
    print(f"\n🎉 Parallel processing complete! Processed {len(results)}/{total} scenes successfully")
    return results

# Helper function to find the best model file
def find_best_model_file():
    """Find the best available model file in priority order"""
    model_candidates = [
        ('best_dsstsr_weights.pth', 'Best model weights only'),
        ('best_dsstsr_model_fixed.pth', 'Best model checkpoint'), 
        ('final_dsstsr_checkpoint.pth', 'Final training checkpoint'),
        ('early_stopped_model.pth', 'Early stopped model')
    ]
    
    print("\n🔍 Searching for model files...")
    for model_path, description in model_candidates:
        if os.path.exists(model_path):
            size_mb = os.path.getsize(model_path) / (1024 * 1024)
            print(f"  ✅ Found: {model_path} ({size_mb:.1f} MB) - {description}")
            return model_path
        else:
            print(f"  ❌ Missing: {model_path}")
    
    # If no standard files found, look for any .pth file
    pth_files = [f for f in os.listdir('.') if f.endswith('.pth')]
    if pth_files:
        fallback = pth_files[0]
        print(f"  🔄 Using fallback: {fallback}")
        return fallback
    
    raise FileNotFoundError("No model files found!")


In [12]:
import time
import torch
import numpy as np
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader

def init_weights(m):
    """Applies Kaiming initialization to conv and linear layers."""
    if isinstance(m, (nn.Conv2d, nn.Conv3d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
            
def compute_metrics(y_true, y_pred):
    y_true, y_pred = np.clip(y_true, 0, 1), np.clip(y_pred, 0, 1)

    psnr_val = psnr(y_true, y_pred, data_range=1.0)

    ssim_val = ssim(
        y_true, y_pred, data_range=1.0, multichannel=True, channel_axis=-1, win_size=11
    )

    y_true_flat = y_true.reshape(-1, y_true.shape[-1])
    y_pred_flat = y_pred.reshape(-1, y_pred.shape[-1])

    dot_products = np.sum(y_true_flat * y_pred_flat, axis=1)
    norms_true = np.linalg.norm(y_true_flat, axis=1)
    norms_pred = np.linalg.norm(y_pred_flat, axis=1)

    valid_mask = (norms_true > 1e-10) & (norms_pred > 1e-10)
    sam_values = np.zeros(len(dot_products))

    if np.any(valid_mask):
        cos_angles = dot_products[valid_mask] / (norms_true[valid_mask] * norms_pred[valid_mask])
        cos_angles = np.clip(cos_angles, -1, 1)
        sam_values[valid_mask] = np.arccos(cos_angles)

    sam_mean = np.mean(sam_values) * 180 / np.pi
    return {'PSNR': psnr_val, 'SSIM': ssim_val, 'SAM': sam_mean}

class EarlyStopping:
    """Early stopping to stop training when validation loss doesn't improve."""
    def __init__(self, patience=20, min_delta=0.001, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = float('inf')
        self.counter = 0
        self.best_weights = None
        
    def __call__(self, val_loss, model):
        """Returns True if training should stop, False otherwise"""
        if val_loss < self.best_loss - self.min_delta:
            # Improvement found
            self.best_loss = val_loss
            self.counter = 0
            if self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
            return False
        else:
            # No improvement
            self.counter += 1
            if self.counter >= self.patience:
                if self.restore_best_weights and self.best_weights is not None:
                    print(f"Restoring best model weights from validation loss: {self.best_loss:.6f}")
                    model.load_state_dict(self.best_weights)
                return True
            return False

def find_best_model_file():
    """Find the best available model file in priority order"""
    model_candidates = [
        ('best_dsstsr_weights.pth', 'Best model weights only'),
        ('best_dsstsr_model_fixed.pth', 'Best model checkpoint'), 
        ('final_dsstsr_checkpoint.pth', 'Final training checkpoint'),
        ('early_stopped_model.pth', 'Early stopped model')
    ]
    
    print("\n🔍 Searching for model files...")
    for model_path, description in model_candidates:
        if os.path.exists(model_path):
            size_mb = os.path.getsize(model_path) / (1024 * 1024)
            print(f"  ✅ Found: {model_path} ({size_mb:.1f} MB) - {description}")
            return model_path
        else:
            print(f"  ❌ Missing: {model_path}")
    
    # If no standard files found, look for any .pth file
    pth_files = [f for f in os.listdir('.') if f.endswith('.pth')]
    if pth_files:
        fallback = pth_files[0]
        print(f"  🔄 Using fallback: {fallback}")
        return fallback
    
    raise FileNotFoundError("No model files found!")

def run_experiment_fixed(hr_train_cube, lr_train_denoised_cube, test_files, num_epochs=175, scale_factor=4, hr_patch_size=63, val_folders=None):
    import time, os, cv2
    import torch
    import numpy as np
    from torch.amp import autocast, GradScaler
    from torch.utils.data import DataLoader
    
    print("\n--- Starting Super-Resolution Experiment (Optimized Version) ---")
    lr_patch_size = hr_patch_size // scale_factor
    
    # Prepare validation data
    if val_folders is not None:
        print("Using provided validation folders...")
        hr_val_scene, lr_val_denoised_scene = prepare_data_cubes_from_folders(val_folders[:1])  # Use first val folder
    else:
        print("Preparing validation data from the first test scene...")
        hr_val_scene, lr_val_denoised_scene = prepare_data_cubes_from_folders([test_files[0]], scale_factor=scale_factor)

    train_dataset = HSISuperResolutionDataset(hr_train_cube, lr_train_denoised_cube, hr_patch_size, scale_factor)
    val_dataset = HSISuperResolutionDataset(hr_val_scene, lr_val_denoised_scene, hr_patch_size, scale_factor)

    # Optimized data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=96,              # Increased batch size
        shuffle=True, 
        num_workers=4,               # More workers
        pin_memory=True, 
        prefetch_factor=8,           # Higher prefetch
        persistent_workers=True,     # Keep workers alive
        drop_last=True              # Consistent batch sizes
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=96, 
        shuffle=False, 
        num_workers=4, 
        pin_memory=True
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    def paper_faithful_init(m):
        if isinstance(m, (nn.Conv2d, nn.Conv3d)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
    
    model = DSSTSR(in_channels=5, out_channels=1, base_channels=64, scale_factor=scale_factor, window_size=7)
    model.apply(paper_faithful_init)
    model.to(device)

    if torch.cuda.device_count() > 1:
        print(f"Activating data parallelism on {torch.cuda.device_count()} GPUs!")
        model = torch.nn.DataParallel(model)

    criterion = torch.nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
    
    # Optimized learning rate schedule with warmup
    def lr_lambda(epoch):
        if epoch < 10:  # Warmup
            return (epoch + 1) / 10
        else:
            return 0.5 ** ((epoch - 10) // 25)  # Faster decay

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    scaler = GradScaler('cuda')

    # Initialize early stopping
    early_stopping = EarlyStopping(
        patience=15,           # Reduced patience
        min_delta=0.001,      # Minimum improvement threshold
        restore_best_weights=True
    )

    best_val_loss = float('inf')
    best_epoch = 0
    validation_frequency = 5  # Validate every 5 epochs
    
    print("Starting training with early stopping...")
    print(f"Early stopping: patience={early_stopping.patience}, validation every {validation_frequency} epochs")

    for epoch in range(1, num_epochs + 1):
        model.train()
        train_loss = 0.0
        epoch_start_time = time.time()
        batch_losses = []

        for i, (lr_batch, hr_batch) in enumerate(train_loader):
            # Convert to channels_last for better performance
            lr_batch = lr_batch.to(device, memory_format=torch.channels_last)
            hr_batch = hr_batch.to(device, memory_format=torch.channels_last)
            
            optimizer.zero_grad()
            
            with autocast(device_type='cuda'):
                preds = model(lr_batch)
                min_h = min(preds.shape[2], hr_batch.shape[2])
                min_w = min(preds.shape[3], hr_batch.shape[3])
                preds = preds[:, :, :min_h, :min_w]
                hr_batch = hr_batch[:, :, :min_h, :min_w]
                loss = criterion(preds, hr_batch)
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            batch_losses.append(loss.item())
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        # Validation phase - only every N epochs to save time
        if epoch % validation_frequency == 0:
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for lr_batch, hr_batch in val_loader:
                    lr_batch = lr_batch.to(device, memory_format=torch.channels_last)
                    hr_batch = hr_batch.to(device, memory_format=torch.channels_last)
                    
                    with autocast(device_type='cuda'):
                        preds = model(lr_batch)
                        min_h = min(preds.shape[2], hr_batch.shape[2])
                        min_w = min(preds.shape[3], hr_batch.shape[3])
                        preds = preds[:, :, :min_h, :min_w]
                        hr_batch = hr_batch[:, :, :min_h, :min_w]
                        loss = criterion(preds, hr_batch)
                    val_loss += loss.item()

            avg_val_loss = val_loss / len(val_loader)
            
            # Save best model
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_epoch = epoch
                
                # Save checkpoint
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'best_val_loss': best_val_loss
                }
                torch.save(checkpoint, 'best_dsstsr_model_fixed.pth')
                torch.save(model.state_dict(), 'best_dsstsr_weights.pth')  # Weights only

            # Check early stopping
            if early_stopping(avg_val_loss, model):
                print(f"\n🛑 Early stopping triggered at epoch {epoch}")
                print(f"   Best validation loss: {early_stopping.best_loss:.6f}")
                print(f"   No improvement for {early_stopping.patience * validation_frequency} epochs")
                print(f"   Training stopped {num_epochs - epoch} epochs early")
                break

        scheduler.step()

        # Clean logging every 5 epochs
        if epoch % 5 == 0:
            epoch_time = time.time() - epoch_start_time
            recent_losses = batch_losses[-10:] if len(batch_losses) >= 10 else batch_losses
            avg_recent_loss = np.mean(recent_losses)
            
            print(f"\n[Epoch {epoch:3d}/{num_epochs}] Train Loss: {avg_train_loss:.4f} | Recent Batch Loss: {avg_recent_loss:.4f}")
            print(f"                     Time: {epoch_time:.1f}s | LR: {scheduler.get_last_lr()[0]:.2e}")
            
            if epoch % validation_frequency == 0:
                print(f"                     Val Loss: {avg_val_loss:.4f} | Best: {best_val_loss:.4f} (Epoch {best_epoch})")
                print(f"                     Early Stop Counter: {early_stopping.counter}/{early_stopping.patience}")
            
            # GPU memory info
            if torch.cuda.is_available():
                memory_used = torch.cuda.memory_allocated(device) / 1024**3
                memory_total = torch.cuda.get_device_properties(device).total_memory / 1024**3
                print(f"                     GPU Memory: {memory_used:.1f}/{memory_total:.1f} GB ({memory_used/memory_total*100:.1f}%)")

    else:
        print(f"\n✅ Training completed all {num_epochs} epochs without early stopping")
        print(f"   Best validation loss: {best_val_loss:.6f} at epoch {best_epoch}")

    # Save final checkpoint
    final_checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': avg_train_loss,
        'best_val_loss': best_val_loss,
        'final': True
    }
    torch.save(final_checkpoint, 'final_dsstsr_checkpoint.pth')

    print("\n--- Training Complete. Starting evaluation... ---")
    
    # Find the best available model file
    try:
        model_path = find_best_model_file()
    except FileNotFoundError as e:
        print(f"❌ {e}")
        return None, None

    print(f"\n--- Starting evaluation on {len(test_files)} test scenes ---")
    print("Preparing scene data for parallel processing...")
    scenes_data = []
    hr_test_scenes = []
    
    for test_file_path in test_files:
        hr_test_scene, lr_test_denoised_scene = prepare_data_cubes_from_folders([test_file_path], scale_factor=scale_factor)
        if hr_test_scene.ndim == 4: hr_test_scene = hr_test_scene[0]
        if lr_test_denoised_scene.ndim == 4: lr_test_denoised_scene = lr_test_denoised_scene[0]
        
        scene_name = os.path.basename(test_file_path)
        scenes_data.append((lr_test_denoised_scene, hr_test_scene.shape, scene_name))
        hr_test_scenes.append((scene_name, hr_test_scene))
    
    print(f"Prepared {len(scenes_data)} scenes for parallel processing")
    
    # Run parallel processing across GPUs
    results = parallel_scene_processing(scenes_data, model_path, scale_factor)
    
    # Compute metrics
    print("Computing evaluation metrics...")
    all_psnr, all_ssim, all_sam = [], [], []
    hr_scenes_dict = {name: scene for name, scene in hr_test_scenes}
    
    for scene_name, reconstructed_hr_cube in results.items():
        hr_original = hr_scenes_dict[scene_name]
        print(f"    Computing metrics for {scene_name}...")
        scene_metrics = compute_metrics(hr_original, reconstructed_hr_cube)
        all_psnr.append(scene_metrics['PSNR'])
        all_ssim.append(scene_metrics['SSIM'])
        all_sam.append(scene_metrics['SAM'])
        print(f"    ✅ {scene_name} - PSNR: {scene_metrics['PSNR']:.2f}, SSIM: {scene_metrics['SSIM']:.4f}, SAM: {scene_metrics['SAM']:.2f}")

    final_metrics = {
        'PSNR': np.mean(all_psnr),
        'SSIM': np.mean(all_ssim),
        'SAM': np.mean(all_sam)
    }

    print(f"\n🎉 Evaluation Complete!")
    return final_metrics, reconstructed_hr_cube


In [13]:
import os
dataset_path = '/kaggle/input/cave-hsi/'

print(f"--- Attempting to list contents of: {dataset_path} ---")
try:
    top_level_contents = os.listdir(dataset_path)
    print("Top-level contents are:")
    print(top_level_contents)

    if top_level_contents:
        # Check the first item in the list
        first_item_name = top_level_contents[0]
        first_item_path = os.path.join(dataset_path, first_item_name)
        
        if os.path.isdir(first_item_path):
            print(f"\n--- Contents of the first folder '{first_item_name}' ---")
            print(os.listdir(first_item_path))
        else:
            print(f"\n'{first_item_name}' is a file, not a folder.")

except FileNotFoundError:
    print(f"❌ CRITICAL ERROR: The path '{dataset_path}' does not exist. Please check your dataset name.")

--- Attempting to list contents of: /kaggle/input/cave-hsi/ ---
Top-level contents are:
['oil_painting_ms', 'superballs_ms', 'egyptian_statue_ms', 'fake_and_real_tomatoes_ms', 'photo_and_face_ms', 'glass_tiles_ms', 'beads_ms', 'fake_and_real_lemon_slices_ms', 'hairs_ms', 'chart_and_stuffed_toy_ms', 'watercolors_ms', 'clay_ms', 'stuffed_toys_ms', 'fake_and_real_peppers_ms', 'fake_and_real_strawberries_ms', 'sponges_ms', 'face_ms', 'cd_ms', 'fake_and_real_beers_ms', 'real_and_fake_apples_ms', 'feathers_ms', 'fake_and_real_food_ms', 'jelly_beans_ms', 'balloons_ms', 'thread_spools_ms', 'flowers_ms', 'paints_ms', 'pompoms_ms', 'fake_and_real_sushi_ms', 'cloth_ms', 'real_and_fake_peppers_ms', 'fake_and_real_lemons_ms']

--- Contents of the first folder 'oil_painting_ms' ---
['oil_painting_ms']


In [None]:
import os
import random
import numpy as np

# --- 1. Define File Path and Get List of All Valid Scene Folders ---
dataset_path = '/kaggle/input/cave-hsi'
all_scene_folders = []

for scene_name in os.listdir(dataset_path):
    outer_path = os.path.join(dataset_path, scene_name)
    inner_path = os.path.join(outer_path, scene_name)

    if os.path.isdir(inner_path):
        png_files = [f for f in os.listdir(inner_path) if f.endswith('.png') and '_ms_' in f]
        if len(png_files) == 31:
            all_scene_folders.append(inner_path)

# Set random seed for reproducible splits
random.seed(42)
random.shuffle(all_scene_folders)

# Split into 20 train, 6 validation, 6 test
train_folders = all_scene_folders[:20]
val_folders = all_scene_folders[20:26]  # 6 validation scenes
test_folders = all_scene_folders[26:32]  # 6 test scenes

print(f"📁 Dataset split: {len(train_folders)} training, {len(val_folders)} validation, {len(test_folders)} test scenes")
print(f"   Total scenes used: {len(train_folders) + len(val_folders) + len(test_folders)} out of {len(all_scene_folders)} available")

# --- 2. Prepare Data Cubes ---
print("\n🔧 Preparing Training Data")
hr_train_cube, lr_train_denoised_cube = prepare_data_cubes_from_folders(train_folders)
print(f"✅ Final training HR shape: {hr_train_cube.shape}, LR shape: {lr_train_denoised_cube.shape}")
import os
for file in ['best_dsstsr_weights.pth', 'best_dsstsr_model_fixed.pth', 'final_dsstsr_checkpoint.pth']:
    if os.path.exists(file):
        os.remove(file)
# --- 3. Run Super-Resolution Experiment ---
print("\n🚀 Running Super-Resolution (Optimized DSSTSR Model)")
sr_metrics, reconstructed_hr_cube = run_experiment_fixed(
    hr_train_cube=hr_train_cube,
    lr_train_denoised_cube=lr_train_denoised_cube,
    test_files=test_folders,# Combine validation and test folders
    val_folders=val_folders,
    num_epochs=175,
    hr_patch_size=63
)

# --- 4. Print SR Results ---
print("\n📊 ✅ Final Super-Resolution Results:")
for metric, value in sr_metrics.items():
    print(f"{metric}: {value:.4f}")

print(f"\n🔍 Training completed on {len(train_folders)} scenes")
print(f"🔍 Validation used first scene from validation set during training")
print(f"🔍 Final evaluation on {len(val_folders) + len(test_folders)} scenes total")


📁 Dataset split: 20 training, 6 validation, 6 test scenes
   Total scenes used: 32 out of 32 available

🔧 Preparing Training Data
  Processing: paints_ms
  Processing: glass_tiles_ms
  Processing: watercolors_ms
  Processing: sponges_ms
  Processing: flowers_ms
  Processing: clay_ms
  Processing: jelly_beans_ms
  Processing: beads_ms
  Processing: real_and_fake_apples_ms
  Processing: stuffed_toys_ms
  Processing: face_ms
  Processing: chart_and_stuffed_toy_ms
  Processing: fake_and_real_sushi_ms
  Processing: fake_and_real_strawberries_ms
  Processing: thread_spools_ms
  Processing: feathers_ms
  Processing: real_and_fake_peppers_ms
  Processing: superballs_ms
  Processing: fake_and_real_peppers_ms
  Processing: fake_and_real_beers_ms
✅ Final training HR shape: (20, 512, 512, 31), LR shape: (20, 512, 512, 31)

🚀 Running Super-Resolution (Optimized DSSTSR Model)

--- Starting Super-Resolution Experiment (Optimized Version) ---
Using provided validation folders...
  Processing: egyptian