# CAVE / Generic Hyperspectral-Multispectral Fusion Notebook

This notebook adapts the BASFE model (model cell left untouched) to work on a dataset laid out like:

```
root/
  test/
    X/  (LR-HSI *.mat)
    Y/  (HR-MSI *.mat)
  Z/
    train/
      X/       (HR-HSI ground-truth)
      X_blur/  (LR-HSI versions matched to X)
      Y/       (HR-MSI)
```

You may change the directory names in the configuration cell. The code will:
- Load training scenes (HR-HSI, LR-HSI, HR-MSI)
- Extract aligned patches
- Train the BASFE model
- Run inference on test scenes (LR-HSI + HR-MSI) and (optionally) compute metrics if HR-HSI GT is available

Assumptions (can be adjusted in config cell):
- Each `.mat` file contains at least one 3D array (H, W, Bands). The first suitable array is selected automatically.
- Value ranges are normalized to [0,1].
- Scale factor is inferred from HR vs LR spatial dimensions in training set (integer ratio).

Modify `CONFIG` dict below to fit your filesystem.


# Configuration & Environment Setup

In [None]:
# If running in Colab uncomment the following two lines.
# from google.colab import drive
# drive.mount('/content/drive')

import os, sys, math, random, json, glob, time, socket, platform, shutil
import numpy as np
import cv2 as cv
import scipy.io as sio
from scipy.io import savemat
import matplotlib.pyplot as plt
from matplotlib import pyplot
import tensorflow as tf
from tensorflow import keras
from keras import layers
from tqdm.auto import tqdm

print('TensorFlow version:', tf.__version__)

# ---------------- Configuration ----------------
# Added performance tuning knobs for large datasets.
CONFIG = {
    'ROOT_DIR': '/kaggle/input/cave-dataset',   # CHANGE IF DIFFERENT
    'TRAIN_HR_HSI_DIR_CAND': ['Z/train/X', 'train/X', 'Train/X'],
    'TRAIN_LR_HSI_DIR_CAND': ['Z/train/X_blur', 'train/X_blur', 'Train/X_blur'],
    'TRAIN_HR_MSI_DIR_CAND': ['Z/train/Y', 'train/Y', 'Train/Y'],
    'TEST_LR_HSI_DIR_CAND':  ['Z/test/X', 'test/X', 'Test/X'],
    'TEST_HR_MSI_DIR_CAND':  ['Z/test/Y', 'test/Y', 'Test/Y'],
    'TEST_GT_HR_HSI_DIR_CAND': ['Z/test/GT_HR', 'test/GT_HR', 'Test/GT_HR'],
    'USE_GT': False,
    'MSI_BANDS_SELECT': None,          # list[int] or None
    'PATCH_HR_SIZE': 20,
    'PATCH_STRIDE': 7,
    'EDGE_OVERLAP': 2,
    'MAX_TRAIN_SCENES': None,
    'MAX_TEST_SCENES': None,
    'EPOCHS': 50,
    'BATCH_SIZE': 64,
    'LEARNING_RATE': 1e-4,
    'SAVE_MODEL_PATH': '/kaggle/working/BASFE_CAVE_init.keras',
    'SAVE_MODEL_TRAINED_PATH': '/kaggle/working/BASFE_CAVE_trained.keras',
    'LOG_DIR': '/kaggle/working/logs',
    'RESULTS_DIR': '/kaggle/working/results',
    'CHECKPOINT_DIR': '/kaggle/working/checkpoints',
    'CSV_LOG': '/kaggle/working/training_log.csv',
    'EARLY_STOP_PATIENCE': 10,
    'SEED': 42,
    # Memory / Patch Controls
    'PATCH_MEMORY_CAP_MB': 4000,
    'SUBSAMPLE_PATCH_RATE': 1,
    'TRUNCATE_TO_CAP': True,
    'MIXED_PRECISION': 'fp16',  # 'fp16','bf16',None
    # Progress / Verbosity
    'SHOW_CONFIG_SUMMARY': True,
    'TRAIN_PROGRESS_EVERY_BATCHES': 25,
    'PRED_PATCH_PROGRESS': True,
    'SKIP_PLOT': False,          # Skip sample visualization for speed
    # New Performance Knobs
    'FAST_DEBUG_MODE': False,    # If True overrides some heavy params for quick smoke run
    'RANDOM_PATCHES_PER_SCENE': None,  # If set (int), sample that many random patch positions per scene
    'TOTAL_PATCHES_LIMIT': None, # Hard stop once this many patches collected (after subsampling)
    'BUILD_PATCHES_MODE': 'grid',# 'grid' or 'random' (random = uniform random positions limited by RANDOM_PATCHES_PER_SCENE)
    'ESTIMATE_ONLY': False,      # If True, only estimate patch counts & memory then stop (no training)
}

# FAST_DEBUG_MODE overrides
if CONFIG['FAST_DEBUG_MODE']:
    CONFIG['EPOCHS'] = min(CONFIG['EPOCHS'], 2)
    CONFIG['MAX_TRAIN_SCENES'] = 1 if CONFIG['MAX_TRAIN_SCENES'] is None else min(CONFIG['MAX_TRAIN_SCENES'],1)
    CONFIG['MAX_TEST_SCENES'] = 1 if CONFIG['MAX_TEST_SCENES'] is None else min(CONFIG['MAX_TEST_SCENES'],1)
    CONFIG['SUBSAMPLE_PATCH_RATE'] = max(CONFIG['SUBSAMPLE_PATCH_RATE'], 4)
    CONFIG['RANDOM_PATCHES_PER_SCENE'] = 200 if CONFIG['RANDOM_PATCHES_PER_SCENE'] is None else min(CONFIG['RANDOM_PATCHES_PER_SCENE'],200)
    print('FAST_DEBUG_MODE active: parameters reduced for speed.')

for d in ['LOG_DIR','RESULTS_DIR','CHECKPOINT_DIR']:
    os.makedirs(CONFIG[d], exist_ok=True)

# Optional: enable mixed precision
if CONFIG['MIXED_PRECISION']:
    try:
        from tensorflow.keras import mixed_precision
        policy = mixed_precision.Policy(CONFIG['MIXED_PRECISION'])
        mixed_precision.set_global_policy(policy)
        print('Mixed precision enabled with policy:', policy)
    except Exception as e:
        print('Mixed precision request failed, proceeding without:', e)

random.seed(CONFIG['SEED'])
np.random.seed(CONFIG['SEED'])

def resolve_dir(candidates, required=True):
    for rel in candidates:
        full = os.path.join(CONFIG['ROOT_DIR'], rel)
        if os.path.isdir(full):
            return rel
    if required:
        raise FileNotFoundError(f"None of these candidate paths exist under ROOT_DIR={CONFIG['ROOT_DIR']}: {candidates}")
    return None

TRAIN_HR_HSI_DIR = resolve_dir(CONFIG['TRAIN_HR_HSI_DIR_CAND'])
TRAIN_LR_HSI_DIR = resolve_dir(CONFIG['TRAIN_LR_HSI_DIR_CAND'])
TRAIN_HR_MSI_DIR = resolve_dir(CONFIG['TRAIN_HR_MSI_DIR_CAND'])
TEST_LR_HSI_DIR  = resolve_dir(CONFIG['TEST_LR_HSI_DIR_CAND'])
TEST_HR_MSI_DIR  = resolve_dir(CONFIG['TEST_HR_MSI_DIR_CAND'])
TEST_GT_HR_HSI_DIR = resolve_dir(CONFIG['TEST_GT_HR_HSI_DIR_CAND'], required=False) if CONFIG['USE_GT'] else None

print('Resolved directories:')
print('  TRAIN_HR_HSI_DIR ->', TRAIN_HR_HSI_DIR)
print('  TRAIN_LR_HSI_DIR ->', TRAIN_LR_HSI_DIR)
print('  TRAIN_HR_MSI_DIR ->', TRAIN_HR_MSI_DIR)
print('  TEST_LR_HSI_DIR  ->', TEST_LR_HSI_DIR)
print('  TEST_HR_MSI_DIR  ->', TEST_HR_MSI_DIR)
print('  TEST_GT_HR_HSI_DIR ->', TEST_GT_HR_HSI_DIR)

print('Host:', socket.gethostname())
print('Python:', sys.version.split()[0], 'Platform:', platform.platform())
print('GPU Available:', tf.config.list_physical_devices('GPU'))

IGNORE_KEYS = {'__globals__', '__header__', '__version__'}

def load_first_cube(mat_path):
    mat = sio.loadmat(mat_path)
    for k,v in mat.items():
        if k in IGNORE_KEYS: continue
        if isinstance(v, np.ndarray) and v.ndim == 3 and v.shape[2] >= 3:
            arr = v.astype(np.float32)
            vmin, vmax = arr.min(), arr.max()
            if vmax > vmin:
                arr = (arr - vmin)/(vmax - vmin)
            return arr, k
    raise ValueError(f'No 3D cube found in {mat_path}')

def list_mats(rel_dir):
    full = os.path.join(CONFIG['ROOT_DIR'], rel_dir)
    if not os.path.isdir(full):
        raise FileNotFoundError(f'Directory not found: {full}')
    files = sorted([f for f in os.listdir(full) if f.lower().endswith('.mat')])
    return [os.path.join(full, f) for f in files]

train_hr_hsi_files = list_mats(TRAIN_HR_HSI_DIR)
train_lr_hsi_files = list_mats(TRAIN_LR_HSI_DIR)
train_hr_msi_files = list_mats(TRAIN_HR_MSI_DIR)

if CONFIG['MAX_TRAIN_SCENES']:
    train_hr_hsi_files = train_hr_hsi_files[:CONFIG['MAX_TRAIN_SCENES']]
    train_lr_hsi_files = train_lr_hsi_files[:CONFIG['MAX_TRAIN_SCENES']]
    train_hr_msi_files = train_hr_msi_files[:CONFIG['MAX_TRAIN_SCENES']]

print('Train HR-HSI:', len(train_hr_hsi_files), 'Train LR-HSI:', len(train_lr_hsi_files), 'Train HR-MSI:', len(train_hr_msi_files))
assert len(train_hr_hsi_files)==len(train_lr_hsi_files)==len(train_hr_msi_files), 'Mismatch in training file counts'

hr_sample,_ = load_first_cube(train_hr_hsi_files[0])
lr_sample,_ = load_first_cube(train_lr_hsi_files[0])
scale_y = hr_sample.shape[0] / lr_sample.shape[0]
scale_x = hr_sample.shape[1] / lr_sample.shape[1]
assert abs(scale_x - round(scale_x))<1e-3 and abs(scale_y - round(scale_y))<1e-3, 'Scale must be integer'
SCALE = int(round(scale_x))
print('Inferred scale:', SCALE)

msi_sample,_ = load_first_cube(train_hr_msi_files[0])
H_BANDS = hr_sample.shape[2]
M_BANDS = msi_sample.shape[2] if CONFIG['MSI_BANDS_SELECT'] is None else len(CONFIG['MSI_BANDS_SELECT'])
print('Bands HR-HSI:', H_BANDS, 'Bands HR-MSI (used):', M_BANDS)

hrsize = CONFIG['PATCH_HR_SIZE']
stride = CONFIG['PATCH_STRIDE']
num_filter = 32
msi_bands = M_BANDS
hsi_bands = H_BANDS
scale = SCALE

RESOLVED_PATHS = {
    'TRAIN_HR_HSI_DIR': TRAIN_HR_HSI_DIR,
    'TRAIN_LR_HSI_DIR': TRAIN_LR_HSI_DIR,
    'TRAIN_HR_MSI_DIR': TRAIN_HR_MSI_DIR,
    'TEST_LR_HSI_DIR': TEST_LR_HSI_DIR,
    'TEST_HR_MSI_DIR': TEST_HR_MSI_DIR,
    'TEST_GT_HR_HSI_DIR': TEST_GT_HR_HSI_DIR,
}

if CONFIG['SHOW_CONFIG_SUMMARY']:
    print('\n--- CONFIG SUMMARY ---')
    show_keys = [
        'ROOT_DIR','PATCH_HR_SIZE','PATCH_STRIDE','EDGE_OVERLAP','EPOCHS','BATCH_SIZE',
        'LEARNING_RATE','PATCH_MEMORY_CAP_MB','SUBSAMPLE_PATCH_RATE','MIXED_PRECISION',
        'RANDOM_PATCHES_PER_SCENE','TOTAL_PATCHES_LIMIT','BUILD_PATCHES_MODE','FAST_DEBUG_MODE'
    ]
    for k in show_keys:
        print(f'{k}:', CONFIG[k])
    print('----------------------\n')

# Build Training Patch Dataset

In [None]:
# Construct aligned patch tensors with performance modes (grid or random) and limits
start_time = time.time()
patches_hr = []
patches_lr = []
patches_mr = []
scene_patch_counts = {}
cap_bytes = CONFIG['PATCH_MEMORY_CAP_MB'] * 1_000_000
subsample = max(1, CONFIG['SUBSAMPLE_PATCH_RATE'])

est_patch_bytes = None
MODE = CONFIG['BUILD_PATCHES_MODE']
rand_per_scene = CONFIG['RANDOM_PATCHES_PER_SCENE']
limit_total = CONFIG['TOTAL_PATCHES_LIMIT']

scene_iter = tqdm(list(zip(train_hr_hsi_files, train_lr_hsi_files, train_hr_msi_files)), desc=f'Scenes (train) mode={MODE}', leave=True)
for scene_index, (hr_path, lr_path, msi_path) in enumerate(scene_iter):
    hr_cube,_ = load_first_cube(hr_path)
    lr_cube,_ = load_first_cube(lr_path)
    msi_cube,_ = load_first_cube(msi_path)
    if CONFIG['MSI_BANDS_SELECT']:
        msi_cube = msi_cube[:,:,CONFIG['MSI_BANDS_SELECT']]
    up_lr = cv.resize(lr_cube, (hr_cube.shape[1], hr_cube.shape[0]), interpolation=cv.INTER_CUBIC)
    h, w, _ = hr_cube.shape
    c_before = len(patches_hr)

    if MODE == 'grid':
        row_starts = np.arange(0, h - hrsize, stride)
        col_starts = np.arange(0, w - hrsize, stride)
        total_grid = len(row_starts) * len(col_starts)
        patch_iter = tqdm(total=total_grid, desc=f'Patches[{os.path.basename(hr_path)}]', leave=False)
        local_added = 0
        stop_scene = False
        for i in row_starts:
            for j in col_starts:
                if (local_added % subsample) == 0:
                    patches_hr.append(hr_cube[i:i+hrsize, j:j+hrsize, :])
                    patches_lr.append(up_lr[i:i+hrsize, j:j+hrsize, :])
                    patches_mr.append(msi_cube[i:i+hrsize, j:j+hrsize, :])
                    if est_patch_bytes is None:
                        est_patch_bytes = (patches_hr[-1].nbytes + patches_lr[-1].nbytes + patches_mr[-1].nbytes)
                    total_bytes = len(patches_hr) * est_patch_bytes
                    if cap_bytes and total_bytes > cap_bytes:
                        print(f"Memory cap reached (~{total_bytes/1e6:.1f} MB). Stopping patch collection.")
                        stop_scene = True
                        break
                local_added += 1
                patch_iter.update(1)
                if limit_total and len(patches_hr) >= limit_total:
                    stop_scene = True
                    break
            if stop_scene:
                break
        patch_iter.close()
    else:  # random mode
        # Determine number of valid upper-left coordinates
        max_i = h - hrsize
        max_j = w - hrsize
        n_samples = rand_per_scene if rand_per_scene else 0
        if n_samples == 0:
            # estimate comparable to grid count but reduced by stride
            approx_rows = max(1, (h - hrsize)//stride)
            approx_cols = max(1, (w - hrsize)//stride)
            n_samples = int(approx_rows * approx_cols)
        patch_iter = tqdm(range(n_samples), desc=f'RandPatches[{os.path.basename(hr_path)}]', leave=False)
        for k in patch_iter:
            i = random.randint(0, max_i)
            j = random.randint(0, max_j)
            if (k % subsample) == 0:
                patches_hr.append(hr_cube[i:i+hrsize, j:j+hrsize, :])
                patches_lr.append(up_lr[i:i+hrsize, j:j+hrsize, :])
                patches_mr.append(msi_cube[i:i+hrsize, j:j+hrsize, :])
                if est_patch_bytes is None:
                    est_patch_bytes = (patches_hr[-1].nbytes + patches_lr[-1].nbytes + patches_mr[-1].nbytes)
                total_bytes = len(patches_hr) * est_patch_bytes
                if cap_bytes and total_bytes > cap_bytes:
                    print(f"Memory cap reached (~{total_bytes/1e6:.1f} MB). Stopping patch collection.")
                    break
                if limit_total and len(patches_hr) >= limit_total:
                    break
        patch_iter.close()
    scene_patch_counts[os.path.basename(hr_path)] = len(patches_hr) - c_before
    if (cap_bytes and est_patch_bytes and (len(patches_hr) * est_patch_bytes) > cap_bytes) or (limit_total and len(patches_hr) >= limit_total):
        print('Global stop condition reached (memory or total patches).')
        break

if len(patches_hr) == 0:
    raise RuntimeError('No patches collected; adjust configuration (patch size/stride/mode).')

if CONFIG['ESTIMATE_ONLY']:
    est_total_mb = (len(patches_hr) * est_patch_bytes)/1e6 if est_patch_bytes else 0
    print(f'[ESTIMATE_ONLY] Collected {len(patches_hr)} patches (~{est_total_mb:.1f} MB for all three tensors). Stopping before stacking.')
else:
    hrdata = np.stack(patches_hr, axis=0)
    lrdata = np.stack(patches_lr, axis=0)
    mrdata = np.stack(patches_mr, axis=0)

    if CONFIG['TRUNCATE_TO_CAP'] and est_patch_bytes and (hrdata.nbytes + lrdata.nbytes + mrdata.nbytes) > cap_bytes:
        triplet_bytes = (hrdata[0].nbytes + lrdata[0].nbytes + mrdata[0].nbytes)
        max_samples = int(cap_bytes // triplet_bytes)
        if max_samples < hrdata.shape[0]:
            idx = np.random.permutation(hrdata.shape[0])[:max_samples]
            hrdata = hrdata[idx]; lrdata = lrdata[idx]; mrdata = mrdata[idx]
            print(f'Truncated to {max_samples} samples to meet memory cap.')

    elapsed = time.time()-start_time
    print(f'Patch extraction finished in {elapsed:.1f}s')
    print('Training patches shapes -> hrdata:', hrdata.shape, 'lrdata:', lrdata.shape, 'mrdata:', mrdata.shape)
    print('Approx total patch tensors memory MB:', (hrdata.nbytes+lrdata.nbytes+mrdata.nbytes)/1e6)
    print('Patches per scene (first 5):', list(scene_patch_counts.items())[:5])

    if not CONFIG['SKIP_PLOT']:
        if hrdata.shape[0] > 10:
            n = min(10, hrdata.shape[0]-1)
        else:
            n = 0
        i = hrdata.shape[3]//2
        p = np.hstack((hrdata[n,:,:,i], lrdata[n,:,:,i], mrdata[n,:,:,0]))
        fig, ax = plt.subplots(figsize=(6,6))
        pyplot.imshow(p, vmin=0, vmax=1, cmap='gray')
        pyplot.title('HR | UpLR | MSI(first band)')
        pyplot.axis('off')
        plt.show()
    print('Memory usage (approx) HR patches MB:', hrdata.nbytes/1e6 if not CONFIG['ESTIMATE_ONLY'] else 'N/A (estimate only)')

In [None]:
# Placeholder cell retained (was sample visualization) -- already visualized above.
print('Sample mid training patch stats:')
print('HR mean', hrdata.mean(), 'LR mean', lrdata.mean(), 'MR mean', mrdata.mean())

# Model (Unmodified BASFE Architecture)

In [None]:
# Model definition using explicit PReLU layers (fix for activation='PReLU' ValueError)
# Architectural logic (number of convolutions, filter sizes, skip connections, concat order) kept identical.
num_filter = 32

# Helper blocks replicating original spec & spat residual patterns
# spec: two 1x1 convs with PReLU then residual sum
# spat: two 3x3 convs with PReLU then residual sum
def spec(inputs, nf):
    x = layers.Conv2D(nf, 1, padding="same", use_bias=True)(inputs)
    x = layers.PReLU(shared_axes=[1,2])(x)
    x = layers.Conv2D(nf, 1, padding="same", use_bias=True)(x)
    x = layers.PReLU(shared_axes=[1,2])(x)
    return layers.Add()([inputs, x])

def spat(inputs, nf):
    x = layers.Conv2D(nf, 3, padding="same", use_bias=True)(inputs)
    x = layers.PReLU(shared_axes=[1,2])(x)
    x = layers.Conv2D(nf, 3, padding="same", use_bias=True)(x)
    x = layers.PReLU(shared_axes=[1,2])(x)
    return layers.Add()([inputs, x])

# ****************************** MSI Encoder ******************************
msi_input = keras.Input(shape=(hrsize, hrsize, msi_bands), name="msi_input")
x01 = layers.Conv2D(num_filter, 3, padding="same", use_bias=True)(msi_input)
x01 = layers.PReLU(shared_axes=[1,2])(x01)
x02 = spec(x01,num_filter)
x02 = spat(x02,num_filter)
x03 = layers.Concatenate()([x01, x02])

x04 = layers.Conv2D(num_filter, 3, padding="same", use_bias=True)(x03)
x04 = layers.PReLU(shared_axes=[1,2])(x04)
x05 = spec(x04,num_filter)
x05 = spat(x05,num_filter)
x06 = layers.Concatenate()([x01, x04, x05])

x07 = layers.Conv2D(num_filter, 5, padding="same", use_bias=True)(x06)
x07 = layers.PReLU(shared_axes=[1,2])(x07)
x07 = spec(x07,num_filter)
x08 = spat(x07,num_filter)

# ****************************** LR_HSI Encoder ******************************
lr_input = keras.Input(shape=(hrsize, hrsize,hsi_bands), name="lr_input")
x11 = layers.Conv2D(num_filter, 3, padding="same", use_bias=True)(lr_input)
x11 = layers.PReLU(shared_axes=[1,2])(x11)
x12 = spec(x11,num_filter)
x12 = spat(x12,num_filter)
x13 = layers.Concatenate()([x11, x12])

x14 = layers.Conv2D(num_filter, 3, padding="same", use_bias=True)(x13)
x14 = layers.PReLU(shared_axes=[1,2])(x14)
x15 = spec(x14,num_filter)
x15 = spat(x15,num_filter)
x16 = layers.Concatenate()([x11, x14, x15])

x17 = layers.Conv2D(num_filter, 5, padding="same", use_bias=True)(x16)
x17 = layers.PReLU(shared_axes=[1,2])(x17)
x17 = spec(x17,num_filter)
x18 = spat(x17,num_filter)

# ********************************* concat ***********************************
x21 = layers.Concatenate()([x01, x04, x07, x08, x11, x14, x17, x18])

x22 = layers.Conv2D(hsi_bands, 3, padding="same", use_bias=True)(x21)
x22 = layers.PReLU(shared_axes=[1,2])(x22)
fuse_output = layers.Conv2D(hsi_bands, 3, padding="same", use_bias=True, name="fuse_output")(x22)
fuse_output = layers.PReLU(shared_axes=[1,2], name='fuse_output_prelu')(fuse_output)

# IMPORTANT: provide a single tensor (not list) for outputs to avoid target structure ambiguity in Keras 3 / optree.
model = keras.Model(
    inputs=[msi_input, lr_input],
    outputs=fuse_output,
    name='BASFE_Fusion'
)
model.summary()
# plot_model may fail in headless or restricted envs; wrap in try/except
try:
    keras.utils.plot_model(model, "multi_input_and_output_model.png", show_shapes=True)
except Exception as e:
    print('plot_model skipped:', e)

# Training BASFE on CAVE-style Dataset

In [None]:
# Save initial (untrained) model with defensive wrapper
try:
    model.save(CONFIG['SAVE_MODEL_PATH'])
    print('Initial model saved to', CONFIG['SAVE_MODEL_PATH'])
except Exception as e:
    print('WARNING: initial model save failed:', e)

# Custom progress callback
class BatchProgress(keras.callbacks.Callback):
    def __init__(self, every_n=50, total_batches=None):
        super().__init__()
        self.every_n = every_n
        self.total_batches = total_batches
        self.start_time = None
    def on_train_begin(self, logs=None):
        self.start_time = time.time()
    def on_batch_end(self, batch, logs=None):
        if (batch+1) % self.every_n == 0:
            elapsed = time.time() - self.start_time
            if self.total_batches:
                rate = (batch+1)/elapsed
                remaining_batches = self.total_batches - (batch+1)
                eta = remaining_batches / max(rate,1e-6)
                print(f"Batch {batch+1}/{self.total_batches} - loss {logs.get('loss'):.5f} - elapsed {elapsed:.1f}s - ETA {eta:.1f}s")
            else:
                print(f"Batch {batch+1} - loss {logs.get('loss'):.5f} - elapsed {elapsed:.1f}s")

# Estimate total batches for progress messages
steps_per_epoch = math.ceil(hrdata.shape[0] / CONFIG['BATCH_SIZE'])

# Callbacks for richer logging
callbacks = []
callbacks.append(keras.callbacks.CSVLogger(CONFIG['CSV_LOG'], append=False))
callbacks.append(keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(CONFIG['CHECKPOINT_DIR'], 'epoch_{epoch:03d}_loss_{loss:.5f}.keras'),
    monitor='loss', save_best_only=True, save_weights_only=False, verbose=1))
callbacks.append(keras.callbacks.EarlyStopping(monitor='loss', patience=CONFIG['EARLY_STOP_PATIENCE'], restore_best_weights=True, verbose=1))
callbacks.append(keras.callbacks.TensorBoard(log_dir=CONFIG['LOG_DIR'], write_graph=False, update_freq='epoch'))
callbacks.append(BatchProgress(every_n=CONFIG['TRAIN_PROGRESS_EVERY_BATCHES'], total_batches=steps_per_epoch))

model.compile(
    optimizer=tf.optimizers.Adam(learning_rate=CONFIG['LEARNING_RATE']),
    loss=keras.losses.MeanSquaredError(),
)

print('Starting training: epochs', CONFIG['EPOCHS'], 'batch size', CONFIG['BATCH_SIZE'], 'steps/epoch', steps_per_epoch)
start_train = time.time()

history = model.fit(
    {"msi_input": mrdata, "lr_input": lrdata},
    hrdata,
    epochs=CONFIG['EPOCHS'],
    batch_size=CONFIG['BATCH_SIZE'],
    shuffle=True,
    verbose=1,
    callbacks=callbacks,
)
train_time = time.time()-start_train
print(f'Training completed in {train_time/60:.2f} min')

try:
    model.save(CONFIG['SAVE_MODEL_TRAINED_PATH'])
    print('Trained model saved to', CONFIG['SAVE_MODEL_TRAINED_PATH'])
except Exception as e:
    print('WARNING: trained model save failed:', e)

plt.figure(figsize=(6,4))
plt.plot(history.history['loss'], label='loss')
plt.xlabel('Epoch'); plt.ylabel('MSE Loss'); plt.title('Training Loss Curve'); plt.grid(True); plt.legend();
plt.tight_layout(); plt.show()

# Summary JSON
summary_info = {
    'epochs_ran': len(history.history['loss']),
    'final_loss': float(history.history['loss'][-1]),
    'train_time_sec': train_time,
    'num_patches': int(hrdata.shape[0]),
    'scale': scale,
    'hr_patch_size': hrsize,
    'bands_hsi': hsi_bands,
    'bands_msi': msi_bands,
    'mixed_precision': CONFIG['MIXED_PRECISION'],
    'patch_memory_cap_mb': CONFIG['PATCH_MEMORY_CAP_MB'],
    'subsample_rate': CONFIG['SUBSAMPLE_PATCH_RATE'],
}
with open(os.path.join(CONFIG['RESULTS_DIR'], 'training_summary.json'), 'w') as f:
    json.dump(summary_info, f, indent=2)
print('Saved training summary to training_summary.json')

# Reconstruction & Assessment on Test Scenes

In [None]:
# Prepare test inputs and run model prediction over each scene via tiling reconstruction with progress indicators.

def prepare_test_scene(lr_path, msi_path):
    lr_cube,_ = load_first_cube(lr_path)
    msi_cube,_ = load_first_cube(msi_path)
    if CONFIG['MSI_BANDS_SELECT']:
        msi_cube = msi_cube[:,:,CONFIG['MSI_BANDS_SELECT']]
    up_lr = cv.resize(lr_cube, (msi_cube.shape[1], msi_cube.shape[0]), interpolation=cv.INTER_CUBIC)
    return up_lr, msi_cube

TEST_LR = list_mats(RESOLVED_PATHS['TEST_LR_HSI_DIR'])
TEST_MSI = list_mats(RESOLVED_PATHS['TEST_HR_MSI_DIR'])
if CONFIG['MAX_TEST_SCENES']:
    TEST_LR = TEST_LR[:CONFIG['MAX_TEST_SCENES']]
    TEST_MSI = TEST_MSI[:CONFIG['MAX_TEST_SCENES']]
assert len(TEST_LR)==len(TEST_MSI), 'Mismatch in test file counts'

EDGE = CONFIG['EDGE_OVERLAP']
strider = hrsize - 2*EDGE

reconstructed = {}

# Batch size for prediction to limit memory in case of many patches
PRED_BATCH = max(1, 2048 // (hrsize*hrsize))  # heuristic; adjust if needed
print('Prediction batch heuristic:', PRED_BATCH)

for lr_path, msi_path in tqdm(list(zip(TEST_LR, TEST_MSI)), desc='Scenes (test)', leave=True):
    scene_name = os.path.splitext(os.path.basename(lr_path))[0]
    up_lr, hr_msi = prepare_test_scene(lr_path, msi_path)
    H, W, _ = hr_msi.shape
    ii = np.arange(0, H, strider)
    jj = np.arange(0, W, strider)
    if ii[-1] + hrsize > H: ii[-1] = H - hrsize
    if len(ii) > 1 and ii[-2] >= ii[-1]: ii = ii[:-1]
    if jj[-1] + hrsize > W: jj[-1] = W - hrsize
    if len(jj) > 1 and jj[-2] >= jj[-1]: jj = jj[:-1]

    mr_patches = []
    lr_patches = []
    grid_total = len(ii)*len(jj)
    for i in ii:
        for j in jj:
            mr_patches.append(hr_msi[i:i+hrsize, j:j+hrsize, :])
            lr_patches.append(up_lr[i:i+hrsize, j:j+hrsize, :])
    mrdatainput = np.stack(mr_patches, axis=0)
    lrdatainput = np.stack(lr_patches, axis=0)

    preds_list = []
    batch_iter = tqdm(range(0, mrdatainput.shape[0], PRED_BATCH), desc=f'Predict[{scene_name}]', leave=False)
    for start in batch_iter:
        end = start + PRED_BATCH
        batch_preds = model.predict((mrdatainput[start:end], lrdatainput[start:end]), verbose=0)
        preds_list.append(batch_preds)
    preds = np.concatenate(preds_list, axis=0)

    out_cube = np.zeros((H, W, hsi_bands), dtype=np.float32)
    count = 0
    for i in ii:
        for j in jj:
            out_cube[i:i+hrsize, j:j+hrsize, :] = preds[count]
            count += 1
    count = 0
    for i in ii:
        for j in jj:
            out_cube[i+EDGE:i+hrsize-EDGE, j+EDGE:j+hrsize-EDGE, :] = preds[count, EDGE:-EDGE, EDGE:-EDGE, :]
            count += 1

    reconstructed[scene_name] = out_cube
    savemat(os.path.join(CONFIG['RESULTS_DIR'], f'{scene_name}_reconst.mat'), {f'reconst_{scene_name}': out_cube})
    print(f'Reconstructed {scene_name}: shape {out_cube.shape}')

print('Total test scenes reconstructed:', len(reconstructed))

In [None]:
# Metrics (per-scene if GT available) + consolidated logging
from math import log10

metrics_results = {}

if not reconstructed:
    print('No reconstructed scenes found; skipping metrics.')
else:
    if CONFIG['TEST_GT_HR_HSI_DIR']:
        gt_files = list_mats(CONFIG['TEST_GT_HR_HSI_DIR'])
        gt_map = {os.path.splitext(os.path.basename(f))[0]: f for f in gt_files}

        def compute_metrics(pred, gt, scale):
            assert pred.shape == gt.shape
            z = pred.shape
            n = z[0]*z[1]
            L = z[2]
            temp = np.sum(np.sum((pred-gt)*(pred-gt),axis=0),axis=0)/n
            rmse_per_band = np.sqrt(temp)
            rmse_total = np.sqrt(np.sum(temp)/L)
            psnr = 10*log10(1.0/(rmse_total**2 + 1e-12))
            num = np.sum(pred*gt,axis=2)
            den = np.sqrt(np.sum(pred*pred,axis=2)*np.sum(gt*gt,axis=2))+1e-12
            sam = np.mean(np.arccos(np.clip(num/den, -1, 1))) * 180/np.pi
            mean_gt = np.sum(np.sum(gt,axis=0),axis=0)/n
            ergas = 100/scale*np.sqrt(np.sum((rmse_per_band / (mean_gt+1e-12))**2)/L)
            c1=.0001; c2=.0001
            ssim = []
            mean_p = np.sum(np.sum(pred,axis=0),axis=0)/n
            for i in range(L):
                sigma2_p = np.mean(pred[:,:,i]**2)-mean_p[i]**2
                sigma2_g = np.mean(gt[:,:,i]**2)-mean_gt[i]**2
                sigma_pg = np.mean(pred[:,:,i]*gt[:,:,i]) - mean_p[i]*mean_gt[i]
                ssim_i = ((2*mean_p[i]*mean_gt[i]+c1)*(2*sigma_pg+c2))/((mean_p[i]**2+mean_gt[i]**2+c1)*(sigma2_p+sigma2_g+c2))
                ssim.append(ssim_i)
            mssim = float(np.mean(ssim))
            cc = []
            for i in range(L):
                cc_num = np.sum((pred[:,:,i]-mean_p[i])*(gt[:,:,i]-mean_gt[i]))
                cc_den = np.sqrt(np.sum((pred[:,:,i]-mean_p[i])**2)*np.sum((gt[:,:,i]-mean_gt[i])**2))+1e-12
                cc.append(cc_num/cc_den)
            return {
                'RMSE': float(rmse_total),
                'PSNR': float(psnr),
                'SAM_deg': float(sam),
                'ERGAS': float(ergas),
                'MSSIM': float(np.mean(ssim)),
                'CC': float(np.mean(cc)),
            }

        for scene in tqdm(reconstructed.keys(), desc='Metrics scenes', leave=True):
            if scene in gt_map:
                try:
                    gt_cube,_ = load_first_cube(gt_map[scene])
                    H = min(gt_cube.shape[0], reconstructed[scene].shape[0])
                    W = min(gt_cube.shape[1], reconstructed[scene].shape[1])
                    gt_crop = gt_cube[:H,:W,:reconstructed[scene].shape[2]]
                    pred_crop = reconstructed[scene][:H,:W,:reconstructed[scene].shape[2]]
                    metrics_results[scene] = compute_metrics(pred_crop, gt_crop, scale)
                except Exception as e:
                    print(f'Failed metrics for {scene}:', e)
        for s, m in metrics_results.items():
            print('Metrics', s, m)
    else:
        print('No GT directory configured for test metrics; skipping metric computation.')

# Save metrics JSON if any
os.makedirs(CONFIG['RESULTS_DIR'], exist_ok=True)
if metrics_results:
    with open(os.path.join(CONFIG['RESULTS_DIR'], 'metrics.json'), 'w') as f:
        json.dump(metrics_results, f, indent=2)
    print('Saved metrics.json')
print('Done.')