# CAVE / Generic Hyperspectral-Multispectral Fusion Notebook

This notebook adapts the BASFE model (model cell left untouched) to work on a dataset laid out like:

```
root/
  test/
    X/  (LR-HSI *.mat)
    Y/  (HR-MSI *.mat)
  Z/
    train/
      X/       (HR-HSI ground-truth)
      X_blur/  (LR-HSI versions matched to X)
      Y/       (HR-MSI)
```

You may change the directory names in the configuration cell. The code will:
- Load training scenes (HR-HSI, LR-HSI, HR-MSI)
- Extract aligned patches
- Train the BASFE model
- Run inference on test scenes (LR-HSI + HR-MSI) and (optionally) compute metrics if HR-HSI GT is available

Assumptions (can be adjusted in config cell):
- Each `.mat` file contains at least one 3D array (H, W, Bands). The first suitable array is selected automatically.
- Value ranges are normalized to [0,1].
- Scale factor is inferred from HR vs LR spatial dimensions in training set (integer ratio).

Modify `CONFIG` dict below to fit your filesystem.


# Configuration & Environment Setup

In [None]:
# If running in Colab uncomment the following two lines.
# from google.colab import drive
# drive.mount('/content/drive')

import os, sys, math, random, json, glob, time, socket, platform
import numpy as np
import cv2 as cv
import scipy.io as sio
from scipy.io import savemat
import matplotlib.pyplot as plt
from matplotlib import pyplot
import tensorflow as tf
from tensorflow import keras
from keras import layers
from tqdm import tqdm

print('TensorFlow version:', tf.__version__)

# ---------------- Configuration ----------------
# For Kaggle, set ROOT_DIR to the dataset mount path, e.g. '/kaggle/input/cave-dataset'
# Ensure the folder structure under ROOT_DIR matches:
#   Z/train/X        (HR-HSI)
#   Z/train/X_blur   (LR-HSI)
#   Z/train/Y        (HR-MSI)
#   test/X           (LR-HSI test)
#   test/Y           (HR-MSI test)
# Optionally: a GT HR-HSI test folder you can point TEST_GT_HR_HSI_DIR to.
CONFIG = {
    'ROOT_DIR': '/kaggle/input/cave-dataset',   # CHANGE IF DIFFERENT
    'TRAIN_HR_HSI_DIR': 'Z/train/X',        # HR-HSI ground-truth
    'TRAIN_LR_HSI_DIR': 'Z/train/X_blur',   # LR-HSI (blurred / downsampled)
    'TRAIN_HR_MSI_DIR': 'Z/train/Y',        # HR-MSI
    'TEST_LR_HSI_DIR': 'test/X',            # LR-HSI test
    'TEST_HR_MSI_DIR': 'test/Y',            # HR-MSI test
    'TEST_GT_HR_HSI_DIR': None,             # e.g. 'test/GT_HR' if available
    'MSI_BANDS_SELECT': None,               # list[int] e.g. [10,15,20,25] or None
    'PATCH_HR_SIZE': 20,
    'PATCH_STRIDE': 7,
    'EDGE_OVERLAP': 2,
    'MAX_TRAIN_SCENES': None,               # reduce for fast debug
    'MAX_TEST_SCENES': None,
    'EPOCHS': 50,
    'BATCH_SIZE': 64,
    'LEARNING_RATE': 1e-4,
    'SAVE_MODEL_PATH': '/kaggle/working/BASFE_CAVE_init',
    'SAVE_MODEL_TRAINED_PATH': '/kaggle/working/BASFE_CAVE_trained',
    'LOG_DIR': '/kaggle/working/logs',
    'RESULTS_DIR': '/kaggle/working/results',
    'CHECKPOINT_DIR': '/kaggle/working/checkpoints',
    'CSV_LOG': '/kaggle/working/training_log.csv',
    'EARLY_STOP_PATIENCE': 10,
    'SEED': 42,
}

for d in ['LOG_DIR','RESULTS_DIR','CHECKPOINT_DIR']:
    os.makedirs(CONFIG[d], exist_ok=True)

random.seed(CONFIG['SEED'])
np.random.seed(CONFIG['SEED'])

# Quick environment summary
print('Host:', socket.gethostname())
print('Python:', sys.version.split()[0], 'Platform:', platform.platform())
print('GPU Available:', tf.config.list_physical_devices('GPU'))

# Utility: find first 3D array in .mat file
IGNORE_KEYS = {'__globals__', '__header__', '__version__'}

def load_first_cube(mat_path):
    mat = sio.loadmat(mat_path)
    for k,v in mat.items():
        if k in IGNORE_KEYS: continue
        if isinstance(v, np.ndarray) and v.ndim == 3 and v.shape[2] >= 3:
            arr = v.astype(np.float32)
            vmin, vmax = arr.min(), arr.max()
            if vmax > vmin:
                arr = (arr - vmin)/(vmax - vmin)
            return arr, k
    raise ValueError(f'No 3D cube found in {mat_path}')

# List .mat files utility

def list_mats(dir_path):
    full = os.path.join(CONFIG['ROOT_DIR'], dir_path)
    if not os.path.isdir(full):
        raise FileNotFoundError(f'Directory not found: {full}')
    files = sorted([f for f in os.listdir(full) if f.lower().endswith('.mat')])
    return [os.path.join(full, f) for f in files]

train_hr_hsi_files = list_mats(CONFIG['TRAIN_HR_HSI_DIR'])
train_lr_hsi_files = list_mats(CONFIG['TRAIN_LR_HSI_DIR'])
train_hr_msi_files = list_mats(CONFIG['TRAIN_HR_MSI_DIR'])

if CONFIG['MAX_TRAIN_SCENES']:
    train_hr_hsi_files = train_hr_hsi_files[:CONFIG['MAX_TRAIN_SCENES']]
    train_lr_hsi_files = train_lr_hsi_files[:CONFIG['MAX_TRAIN_SCENES']]
    train_hr_msi_files = train_hr_msi_files[:CONFIG['MAX_TRAIN_SCENES']]

print('Train HR-HSI:', len(train_hr_hsi_files), 'Train LR-HSI:', len(train_lr_hsi_files), 'Train HR-MSI:', len(train_hr_msi_files))
assert len(train_hr_hsi_files)==len(train_lr_hsi_files)==len(train_hr_msi_files), 'Mismatch in training file counts'

# Infer scale from first pair
hr_sample,_ = load_first_cube(train_hr_hsi_files[0])
lr_sample,_ = load_first_cube(train_lr_hsi_files[0])
scale_y = hr_sample.shape[0] / lr_sample.shape[0]
scale_x = hr_sample.shape[1] / lr_sample.shape[1]
assert abs(scale_x - round(scale_x))<1e-3 and abs(scale_y - round(scale_y))<1e-3, 'Scale must be integer'
SCALE = int(round(scale_x))
print('Inferred scale:', SCALE)

msi_sample,_ = load_first_cube(train_hr_msi_files[0])
H_BANDS = hr_sample.shape[2]
M_BANDS = msi_sample.shape[2] if CONFIG['MSI_BANDS_SELECT'] is None else len(CONFIG['MSI_BANDS_SELECT'])
print('Bands HR-HSI:', H_BANDS, 'Bands HR-MSI (used):', M_BANDS)

hrsize = CONFIG['PATCH_HR_SIZE']
stride = CONFIG['PATCH_STRIDE']
num_filter = 32  # keep consistent with model cell
msi_bands = M_BANDS
hsi_bands = H_BANDS
scale = SCALE


# Build Training Patch Dataset

In [None]:
# Construct aligned patch tensors (mrdata: HR-MSI, lrdata: upsampled LR-HSI, hrdata: HR-HSI GT) with progress bars
start_time = time.time()
patches_hr = []
patches_lr = []
patches_mr = []
scene_patch_counts = {}

for hr_path, lr_path, msi_path in tqdm(list(zip(train_hr_hsi_files, train_lr_hsi_files, train_hr_msi_files)), desc='Scenes (train)'):
    hr_cube,_ = load_first_cube(hr_path)
    lr_cube,_ = load_first_cube(lr_path)
    msi_cube,_ = load_first_cube(msi_path)
    if CONFIG['MSI_BANDS_SELECT']:
        msi_cube = msi_cube[:,:,CONFIG['MSI_BANDS_SELECT']]
    up_lr = cv.resize(lr_cube, (hr_cube.shape[1], hr_cube.shape[0]), interpolation=cv.INTER_CUBIC)
    h, w, _ = hr_cube.shape
    row_starts = np.arange(0, h - hrsize, stride)
    col_starts = np.arange(0, w - hrsize, stride)
    c_before = len(patches_hr)
    for i in row_starts:
        for j in col_starts:
            patches_hr.append(hr_cube[i:i+hrsize, j:j+hrsize, :])
            patches_lr.append(up_lr[i:i+hrsize, j:j+hrsize, :])
            patches_mr.append(msi_cube[i:i+hrsize, j:j+hrsize, :])
    scene_patch_counts[os.path.basename(hr_path)] = len(patches_hr) - c_before

hrdata = np.stack(patches_hr, axis=0)
lrdata = np.stack(patches_lr, axis=0)
mrdata = np.stack(patches_mr, axis=0)
elapsed = time.time()-start_time
print(f'Patch extraction finished in {elapsed:.1f}s')
print('Training patches shapes -> hrdata:', hrdata.shape, 'lrdata:', lrdata.shape, 'mrdata:', mrdata.shape)
print('Patches per scene (first 5):', list(scene_patch_counts.items())[:5])

# Quick visualization
n = min(10, hrdata.shape[0]-1)
i = hrdata.shape[3]//2
p = np.hstack((hrdata[n,:,:,i], lrdata[n,:,:,i], mrdata[n,:,:,0]))
fig, ax = plt.subplots(figsize=(6,6))
pyplot.imshow(p, vmin=0, vmax=1, cmap='gray')
pyplot.title('HR | UpLR | MSI(first band)')
pyplot.axis('off')
plt.show()

print('Memory usage (approx) HR patches MB:', hrdata.nbytes/1e6)


In [None]:
# Placeholder cell retained (was sample visualization) -- already visualized above.
print('Sample mid training patch stats:')
print('HR mean', hrdata.mean(), 'LR mean', lrdata.mean(), 'MR mean', mrdata.mean())

# Model (Unmodified BASFE Architecture)

In [None]:
num_filter = 32

def spec(inputs,nf):
    x = tf.keras.layers.Conv2D(nf, 1, activation="PReLU", padding="same", use_bias=True)(inputs)
    x = tf.keras.layers.Conv2D(nf, 1, activation="PReLU", padding="same", use_bias=True)(x)
    return inputs + x
def spat(inputs,nf):
    x = tf.keras.layers.Conv2D(nf, 3, activation="PReLU", padding="same", use_bias=True)(inputs)
    x = tf.keras.layers.Conv2D(nf, 3, activation="PReLU", padding="same", use_bias=True)(x)
    return inputs + x

# ****************************** MSI Encoder ******************************
msi_input = keras.Input(shape=(hrsize, hrsize, msi_bands), name="msi_input")
x01 = layers.Conv2D(num_filter, 3, activation="PReLU", padding="same", use_bias=True)(msi_input)
x02 = spec(x01,num_filter)
x02 = spat(x02,num_filter)
x03 = layers.concatenate([x01, x02])

x04 = layers.Conv2D(num_filter, 3, activation="PReLU", padding="same", use_bias=True)(x03)
x05 = spec(x04,num_filter)
x05 = spat(x05,num_filter)
x06 = layers.concatenate([x01, x04, x05])

x07 = layers.Conv2D(num_filter, 5, activation="PReLU", padding="same", use_bias=True)(x06)
x07 = spec(x07,num_filter)
x08 = spat(x07,num_filter)

# ****************************** LR_HSI Encoder ******************************
lr_input = keras.Input(shape=(hrsize, hrsize,hsi_bands), name="lr_input")
x11 = layers.Conv2D(num_filter, 3, activation="PReLU", padding="same", use_bias=True)(lr_input)
x12 = spec(x11,num_filter)
x12 = spat(x12,num_filter)
x13 = layers.concatenate([x11, x12])

x14 = layers.Conv2D(num_filter, 3, activation="PReLU", padding="same", use_bias=True)(x13)
x15 = spec(x14,num_filter)
x15 = spat(x15,num_filter)
x16 = layers.concatenate([x11, x14, x15])

x17 = layers.Conv2D(num_filter, 5, activation="PReLU", padding="same", use_bias=True)(x16)
x17 = spec(x17,num_filter)
x18 = spat(x17,num_filter)

# ********************************* concat ***********************************
x21 = layers.concatenate([x01, x04, x07, x08, x11, x14, x17, x18])

x22 = layers.Conv2D(hsi_bands, 3, activation="PReLU", padding="same", use_bias=True)(x21)
fuse_output = layers.Conv2D(hsi_bands, 3, activation="PReLU", padding="same", use_bias=True, name="fuse_output")(x22)

model = keras.Model(
    inputs=[msi_input, lr_input],
    outputs=[fuse_output],
)
model.summary()
keras.utils.plot_model(model, "multi_input_and_output_model.png", show_shapes=True)



# Training BASFE on CAVE-style Dataset

In [None]:
# Save initial (untrained) model
model.save(CONFIG['SAVE_MODEL_PATH'])

# Callbacks for richer logging
callbacks = []
callbacks.append(keras.callbacks.CSVLogger(CONFIG['CSV_LOG'], append=False))
callbacks.append(keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(CONFIG['CHECKPOINT_DIR'], 'epoch_{epoch:03d}_loss_{loss:.5f}.keras'),
    monitor='loss', save_best_only=True, save_weights_only=False, verbose=1))
callbacks.append(keras.callbacks.EarlyStopping(monitor='loss', patience=CONFIG['EARLY_STOP_PATIENCE'], restore_best_weights=True, verbose=1))
callbacks.append(keras.callbacks.TensorBoard(log_dir=CONFIG['LOG_DIR'], write_graph=False, update_freq='epoch'))

model.compile(
    optimizer=tf.optimizers.Adam(learning_rate=CONFIG['LEARNING_RATE']),
    loss=keras.losses.MeanSquaredError(),
)

print('Starting training: epochs', CONFIG['EPOCHS'], 'batch size', CONFIG['BATCH_SIZE'])
start_train = time.time()
history = model.fit(
    {"msi_input": mrdata, "lr_input": lrdata},
    {"fuse_output": hrdata},
    epochs=CONFIG['EPOCHS'],
    batch_size=CONFIG['BATCH_SIZE'],
    shuffle=True,
    verbose=1,
    callbacks=callbacks,
)
train_time = time.time()-start_train
print(f'Training completed in {train_time/60:.2f} min')

model.save(CONFIG['SAVE_MODEL_TRAINED_PATH'])

plt.figure(figsize=(6,4))
plt.plot(history.history['loss'], label='loss')
plt.xlabel('Epoch'); plt.ylabel('MSE Loss'); plt.title('Training Loss Curve'); plt.grid(True); plt.legend();
plt.tight_layout(); plt.show()

# Summary JSON
summary_info = {
    'epochs_ran': len(history.history['loss']),
    'final_loss': float(history.history['loss'][-1]),
    'train_time_sec': train_time,
    'num_patches': int(hrdata.shape[0]),
    'scale': scale,
    'hr_patch_size': hrsize,
    'bands_hsi': hsi_bands,
    'bands_msi': msi_bands,
}
with open(os.path.join(CONFIG['RESULTS_DIR'], 'training_summary.json'), 'w') as f:
    json.dump(summary_info, f, indent=2)
print('Saved training summary to training_summary.json')

# Reconstruction & Assessment on Test Scenes

In [None]:
# Prepare test inputs and run model prediction over each scene via tiling reconstruction with progress indicators.

def prepare_test_scene(lr_path, msi_path):
    lr_cube,_ = load_first_cube(lr_path)
    msi_cube,_ = load_first_cube(msi_path)
    if CONFIG['MSI_BANDS_SELECT']:
        msi_cube = msi_cube[:,:,CONFIG['MSI_BANDS_SELECT']]
    up_lr = cv.resize(lr_cube, (msi_cube.shape[1], msi_cube.shape[0]), interpolation=cv.INTER_CUBIC)
    return up_lr, msi_cube

TEST_LR = list_mats(CONFIG['TEST_LR_HSI_DIR'])
TEST_MSI = list_mats(CONFIG['TEST_HR_MSI_DIR'])
if CONFIG['MAX_TEST_SCENES']:
    TEST_LR = TEST_LR[:CONFIG['MAX_TEST_SCENES']]
    TEST_MSI = TEST_MSI[:CONFIG['MAX_TEST_SCENES']]
assert len(TEST_LR)==len(TEST_MSI), 'Mismatch in test file counts'

EDGE = CONFIG['EDGE_OVERLAP']
strider = hrsize - 2*EDGE

reconstructed = {}

for lr_path, msi_path in tqdm(list(zip(TEST_LR, TEST_MSI)), desc='Scenes (test)'):
    scene_name = os.path.splitext(os.path.basename(lr_path))[0]
    up_lr, hr_msi = prepare_test_scene(lr_path, msi_path)
    H, W, _ = hr_msi.shape
    ii = np.arange(0, H, strider)
    jj = np.arange(0, W, strider)
    if ii[-1] + hrsize > H: ii[-1] = H - hrsize
    if len(ii) > 1 and ii[-2] >= ii[-1]: ii = ii[:-1]
    if jj[-1] + hrsize > W: jj[-1] = W - hrsize
    if len(jj) > 1 and jj[-2] >= jj[-1]: jj = jj[:-1]

    mr_patches = []
    lr_patches = []
    for i in ii:
        for j in jj:
            mr_patches.append(hr_msi[i:i+hrsize, j:j+hrsize, :])
            lr_patches.append(up_lr[i:i+hrsize, j:j+hrsize, :])
    mrdatainput = np.stack(mr_patches, axis=0)
    lrdatainput = np.stack(lr_patches, axis=0)

    preds = model.predict((mrdatainput, lrdatainput), verbose=0)

    out_cube = np.zeros((H, W, hsi_bands), dtype=np.float32)
    count = 0
    for i in ii:
        for j in jj:
            out_cube[i:i+hrsize, j:j+hrsize, :] = preds[count]
            count += 1
    count = 0
    for i in ii:
        for j in jj:
            out_cube[i+EDGE:i+hrsize-EDGE, j+EDGE:j+hrsize-EDGE, :] = preds[count, EDGE:-EDGE, EDGE:-EDGE, :]
            count += 1

    reconstructed[scene_name] = out_cube
    savemat(os.path.join(CONFIG['RESULTS_DIR'], f'{scene_name}_reconst.mat'), {f'reconst_{scene_name}': out_cube})
    print(f'Reconstructed {scene_name}: shape {out_cube.shape}')

print('Total test scenes reconstructed:', len(reconstructed))

In [None]:
# Metrics (per-scene if GT available) + consolidated logging
from math import log10

metrics_results = {}

if CONFIG['TEST_GT_HR_HSI_DIR']:
    gt_files = list_mats(CONFIG['TEST_GT_HR_HSI_DIR'])
    gt_map = {os.path.splitext(os.path.basename(f))[0]: f for f in gt_files}

    def compute_metrics(pred, gt, scale):
        assert pred.shape == gt.shape
        z = pred.shape
        n = z[0]*z[1]
        L = z[2]
        temp = np.sum(np.sum((pred-gt)*(pred-gt),axis=0),axis=0)/n
        rmse_per_band = np.sqrt(temp)
        rmse_total = np.sqrt(np.sum(temp)/L)
        psnr = 10*log10(1.0/(rmse_total**2 + 1e-12))
        num = np.sum(pred*gt,axis=2)
        den = np.sqrt(np.sum(pred*pred,axis=2)*np.sum(gt*gt,axis=2))+1e-12
        sam = np.mean(np.arccos(np.clip(num/den, -1, 1))) * 180/np.pi
        mean_gt = np.sum(np.sum(gt,axis=0),axis=0)/n
        ergas = 100/scale*np.sqrt(np.sum((rmse_per_band / (mean_gt+1e-12))**2)/L)
        c1=.0001; c2=.0001
        ssim = []
        mean_p = np.sum(np.sum(pred,axis=0),axis=0)/n
        for i in range(L):
            sigma2_p = np.mean(pred[:,:,i]**2)-mean_p[i]**2
            sigma2_g = np.mean(gt[:,:,i]**2)-mean_gt[i]**2
            sigma_pg = np.mean(pred[:,:,i]*gt[:,:,i]) - mean_p[i]*mean_gt[i]
            ssim_i = ((2*mean_p[i]*mean_gt[i]+c1)*(2*sigma_pg+c2))/((mean_p[i]**2+mean_gt[i]**2+c1)*(sigma2_p+sigma2_g+c2))
            ssim.append(ssim_i)
        mssim = float(np.mean(ssim))
        cc = []
        for i in range(L):
            cc_num = np.sum((pred[:,:,i]-mean_p[i])*(gt[:,:,i]-mean_gt[i]))
            cc_den = np.sqrt(np.sum((pred[:,:,i]-mean_p[i])**2)*np.sum((gt[:,:,i]-mean_gt[i])**2))+1e-12
            cc.append(cc_num/cc_den)
        return {
            'RMSE': float(rmse_total),
            'PSNR': float(psnr),
            'SAM_deg': float(sam),
            'ERGAS': float(ergas),
            'MSSIM': float(np.mean(ssim)),
            'CC': float(np.mean(cc)),
        }

    for scene, pred in reconstructed.items():
        if scene in gt_map:
            gt_cube,_ = load_first_cube(gt_map[scene])
            H = min(gt_cube.shape[0], pred.shape[0])
            W = min(gt_cube.shape[1], pred.shape[1])
            gt_crop = gt_cube[:H,:W,:pred.shape[2]]
            pred_crop = pred[:H,:W,:pred.shape[2]]
            metrics_results[scene] = compute_metrics(pred_crop, gt_crop, scale)
            print('Metrics', scene, metrics_results[scene])
else:
    print('No GT directory configured for test metrics; skipping metric computation.')

# Save metrics JSON if any
os.makedirs(CONFIG['RESULTS_DIR'], exist_ok=True)
if metrics_results:
    with open(os.path.join(CONFIG['RESULTS_DIR'], 'metrics.json'), 'w') as f:
        json.dump(metrics_results, f, indent=2)
    print('Saved metrics.json')
print('Done.')