# Notebook 2: Applying Models to Satellite Imagery

This notebook demonstrates how to load pre-trained Sargassum detection models and apply them to Landsat-8 and Sentinel-2 scenes. It mirrors `2_classify.py` (TFLite) and `2_classify_ML.py` (XGBoost GPU).

**Workflow:**
1. Environment Setup
2. Configuration — choose which models to run
3. Define Core Processing Functions
4. Load Models & Artefacts
5. Classify Scenes (Landsat-8 + Sentinel-2)
6. Visualize Fractional Cover Maps

## 1. Environment Setup & Dependencies

```bash
pip install -r requirements.txt
```

**Key Library Versions:**
- `numpy==1.26.4`
- `scikit-learn==1.5.2`
- `xgboost==2.1.4`
- `tensorflow==2.17.0`
- `rasterio==1.4.3`
- `opencv-python==4.10.0.84`

In [None]:
# --- Standard Library Imports ---
import os
import glob
import math
import time
import joblib
import traceback

# --- Third-party Library Imports ---
import numpy as np
import rasterio
import rasterio.windows
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

try:
    from tqdm.notebook import tqdm
except ImportError:
    from tqdm import tqdm

try:
    import cv2
    CV2_AVAILABLE = True
except ImportError:
    CV2_AVAILABLE = False
    print('Warning: OpenCV not found. Multi-resolution resampling will be unavailable.')

# --- TFLite (via TensorFlow) ---
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
TF_AVAILABLE = False
try:
    import tensorflow as tf
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    TF_AVAILABLE = True
    print('TensorFlow loaded successfully.')
except ImportError:
    print('Warning: TensorFlow not found. TFLite models will be skipped.')

# --- XGBoost ---
XGB_AVAILABLE = False
try:
    import xgboost as xgb
    XGB_AVAILABLE = True
    print('XGBoost loaded successfully.')
except ImportError:
    print('Warning: XGBoost not found. XGBoost model will be skipped.')

print('All libraries loaded.')

## 2. Configuration

Specify which models to run. TFLite models require TensorFlow; XGBoost models require XGBoost.
Edit `TFLITE_MODELS` and `XGBOOST_MODELS` to select which files to load.

In [None]:
# --- Select models to run ---
# TFLite options: 'mlp_classifier_int8.tflite', 'mlp_classifier_f16.tflite',
#                 'cnn_classifier_int8.tflite', 'cnn_classifier_f16.tflite'
TFLITE_MODELS = [
    'mlp_classifier_int8.tflite',
    'mlp_classifier_f16.tflite',
    'cnn_classifier_int8.tflite',
    'cnn_classifier_f16.tflite',
]

# XGBoost options: 'xgboost_gpu_classifier.joblib'
XGBOOST_MODELS = [
    'xgboost_gpu_classifier.joblib',
]

# --- Static paths (no need to change) ---
MODEL_ROOT_DIR       = 'output/models_classification/'
DATA_ROOT_DIR        = 'satellite_data/'
OUTPUT_DIR           = 'output/fractional_cover_maps/'
POSITIVE_CLASS_LABEL = 'sargassum'
DISABLE_L8_HARMONIZATION = False

os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f'TFLite models : {TFLITE_MODELS}')
print(f'XGBoost models: {XGBOOST_MODELS}')
print(f'Output dir    : {OUTPUT_DIR}')

## 3. Define Core Processing Functions

All helper functions for loading artefacts, locating band files, and running windowed inference.

In [None]:
# ── Sensor Configuration ──────────────────────────────────────────────────────
TRAINING_BAND_ORDER = ['Blue', 'Green', 'Red', 'NIR', 'SWIR1']

L8_TO_S2_HARMONIZATION_COEFFS = {
    'Blue':  {'slope': 0.9778, 'intercept': 0.0048},
    'Green': {'slope': 1.0379, 'intercept': -0.0009},
    'Red':   {'slope': 1.0431, 'intercept': -0.0011},
    'NIR':   {'slope': 0.9043, 'intercept': 0.0040},
    'SWIR1': {'slope': 0.9872, 'intercept': -0.0001},
}

SENSOR_CONFIG = {
    'Landsat-8': {
        'name': 'Landsat-8',
        'bands_needed': {'B2': 30, 'B3': 30, 'B4': 30, 'B5': 30, 'B6': 30},
        'target_resolution': 30,
        'band_map_to_train': {'B2': 'Blue', 'B3': 'Green', 'B4': 'Red', 'B5': 'NIR', 'B6': 'SWIR1'},
        'scale': 0.0000275, 'offset': -0.2,
        'file_patterns': {
            'B2': ['*_B2.TIF'], 'B3': ['*_B3.TIF'], 'B4': ['*_B4.TIF'],
            'B5': ['*_B5.TIF'], 'B6': ['*_B6.TIF'],
        },
        'harmonization_coeffs': L8_TO_S2_HARMONIZATION_COEFFS,
    },
    'Sentinel-2': {
        'name': 'Sentinel-2',
        'bands_needed': {'B02': 10, 'B03': 10, 'B04': 10, 'B08': 10, 'B8A': 20, 'B11': 20},
        'target_resolution': 10,
        'band_map_to_train': {
            'B02': 'Blue', 'B03': 'Green', 'B04': 'Red',
            'B08': 'NIR',  'B8A': 'NIR',  'B11': 'SWIR1',
        },
        'scale': 0.0001, 'offset': 0.0,
        'file_patterns': {
            'B02': ['**/R10m/*_B02_10m.jp2'], 'B03': ['**/R10m/*_B03_10m.jp2'],
            'B04': ['**/R10m/*_B04_10m.jp2'], 'B08': ['**/R10m/*_B08_10m.jp2'],
            'B8A': ['**/R20m/*_B8A_20m.jp2'], 'B11': ['**/R20m/*_B11_20m.jp2'],
        },
    },
}
NODATA_VALUE = -9999.0
OUTPUT_DTYPE = np.float32


# ── Helper Functions ──────────────────────────────────────────────────────────
def load_scaler_and_encoder(scaler_path, label_encoder_path, positive_class_label):
    scaler, le, pos_idx = None, None, -1
    try:
        scaler = joblib.load(scaler_path)
        print(f'Scaler loaded: {scaler_path}')
    except Exception as e:
        print(f'Error loading scaler: {e}'); return None, None, -1
    try:
        le = joblib.load(label_encoder_path)
        print(f'LabelEncoder loaded: {label_encoder_path}')
        pos_idx = le.transform([str(positive_class_label)])[0]
        print(f'Positive class "{positive_class_label}" -> index {pos_idx}')
    except Exception as e:
        print(f'Error loading LabelEncoder: {e}'); return scaler, None, -1
    return scaler, le, pos_idx


def find_scene_band_files(scene_path, scene_id, sensor_config):
    band_files = {}
    for band_id, pattern_list in sensor_config['file_patterns'].items():
        for pattern in pattern_list:
            found = sorted(glob.glob(os.path.join(scene_path, pattern), recursive='**' in pattern))
            if found:
                band_files[band_id] = found[0]; break
    found_concepts = {sensor_config['band_map_to_train'][b] for b in band_files if b in sensor_config['band_map_to_train']}
    required = {'Blue', 'Green', 'Red', 'SWIR1'}
    if not required.issubset(found_concepts) or 'NIR' not in found_concepts:
        missing = (required - found_concepts) | ({'NIR'} if 'NIR' not in found_concepts else set())
        print(f'Error: Missing bands for {scene_id}: {missing}'); return None
    return band_files


def classify_scene(scene_id, band_files_dict, sensor_cfg, model_info,
                   scaler, positive_class_index, output_dir):
    """Classify a single scene with a single model. Loads all bands into RAM first."""
    sensor_name = sensor_cfg['name']
    target_res  = sensor_cfg['target_resolution']
    model_obj   = model_info['model_obj']
    model_type  = model_info['type']
    name_prefix = model_info['name_prefix']

    out_file = f'fractional_cover_{sensor_name}_{scene_id}_{target_res}m_{name_prefix}.tif'
    out_path = os.path.join(output_dir, out_file)

    if os.path.exists(out_path):
        print(f'    Output exists, skipping: {out_path}'); return out_path

    # Build band map (use only first NIR band found)
    scene_band_map, has_nir = {}, False
    for s_band, concept in sensor_cfg['band_map_to_train'].items():
        if s_band not in band_files_dict: continue
        if concept == 'NIR':
            if not has_nir: scene_band_map[s_band] = concept; has_nir = True
        else:
            scene_band_map[s_band] = concept

    # Reference band for grid and windows
    ref_bid = next(b for b, r in sensor_cfg['bands_needed'].items()
                   if r == target_res and b in band_files_dict)

    # Read profile
    with rasterio.open(band_files_dict[ref_bid]) as src:
        profile = src.profile.copy()
        profile.pop('blockxsize', None); profile.pop('blockysize', None); profile.pop('tiled', None)
        profile.update(dtype=OUTPUT_DTYPE, count=1, nodata=NODATA_VALUE, compress='lzw', driver='GTiff')
        w, h = src.width, src.height
    print(f'\n  Scene: {scene_id} | Model: {name_prefix} | Grid: {w}x{h}px')

    # Pre-load all bands into RAM
    full_bands, full_nodata = {}, {}
    for s_bid in scene_band_map:
        with rasterio.open(band_files_dict[s_bid]) as src:
            full_bands[s_bid]  = src.read(1)
            full_nodata[s_bid] = int(src.nodata or 0)
    print('    All bands loaded into RAM.')

    t_pred, n_pix = 0.0, 0

    with rasterio.open(out_path, 'w', **profile) as dst:
        with rasterio.open(band_files_dict[ref_bid]) as src:
            windows = list(src.block_windows(1))

        for _, window in tqdm(windows, desc=f'  Classifying ({name_prefix})', unit='block', leave=False):
            wh, ww = window.height, window.width
            if wh == 0 or ww == 0: continue
            out_blk = np.full((wh, ww), NODATA_VALUE, dtype=OUTPUT_DTYPE)

            bands_raw, valid_mask, ok = {}, None, True
            for s_bid in scene_band_map:
                try:
                    band_res = sensor_cfg['bands_needed'][s_bid]
                    ratio = band_res / target_res
                    if ratio == 1.0:
                        slc = (slice(window.row_off, window.row_off + wh),
                               slice(window.col_off, window.col_off + ww))
                    else:
                        sr0 = int(window.row_off // ratio)
                        sc0 = int(window.col_off // ratio)
                        sr1 = int(math.ceil((window.row_off + wh) / ratio))
                        sc1 = int(math.ceil((window.col_off + ww) / ratio))
                        slc = (slice(sr0, sr1), slice(sc0, sc1))
                    native = full_bands[s_bid][slc]
                    if ratio != 1.0:
                        if not CV2_AVAILABLE or native.size == 0: ok = False; break
                        native = cv2.resize(native.astype(np.float32), (ww, wh), interpolation=cv2.INTER_LANCZOS4)
                    bands_raw[s_bid] = native
                    mask = (native.astype(np.int64) != full_nodata[s_bid])
                    valid_mask = mask if valid_mask is None else (valid_mask & mask)
                except Exception:
                    ok = False; break

            if not ok or valid_mask is None or not valid_mask.any():
                dst.write(out_blk, 1, window=window); continue

            s_scale, s_off = sensor_cfg['scale'], sensor_cfg['offset']
            n_valid = int(valid_mask.sum())
            feats = np.zeros((n_valid, len(TRAINING_BAND_ORDER)), dtype=np.float64)

            spectra = {}
            for s_bid, raw in bands_raw.items():
                spectra[s_bid] = (raw[valid_mask].astype(np.float64) * s_scale) + s_off

            if sensor_name == 'Landsat-8' and not DISABLE_L8_HARMONIZATION:
                hc = sensor_cfg.get('harmonization_coeffs', {})
                for s_bid, concept in scene_band_map.items():
                    if concept in hc and s_bid in spectra:
                        spectra[s_bid] = hc[concept]['slope'] * spectra[s_bid] + hc[concept]['intercept']

            for i, concept in enumerate(TRAINING_BAND_ORDER):
                src_bid = next((b for b, c in scene_band_map.items() if c == concept), None)
                if src_bid and src_bid in spectra:
                    feats[:, i] = spectra[src_bid]

            nan_mask = ~np.isnan(feats).any(axis=1)
            if not nan_mask.any():
                dst.write(out_blk, 1, window=window); continue

            feats_scaled = scaler.transform(feats[nan_mask].astype(np.float32))
            preds = np.full(nan_mask.sum(), np.nan, dtype=OUTPUT_DTYPE)

            try:
                t0 = time.perf_counter()
                if model_type == 'tflite':
                    interp = model_obj
                    inp_d  = interp.get_input_details()[0]
                    out_d  = interp.get_output_details()[0]
                    is_cnn = inp_d['shape'].ndim == 3 or len(inp_d['shape']) > 2
                    inp_data = feats_scaled.reshape(feats_scaled.shape[0], feats_scaled.shape[1], 1) \
                               if (len(inp_d['shape']) == 3 and inp_d['shape'][-1] == 1) else feats_scaled
                    bs = inp_data.shape[0]
                    if inp_d['shape'][0] != bs:
                        interp.resize_tensor_input(inp_d['index'], [bs] + list(inp_d['shape'][1:]))
                        interp.allocate_tensors()
                    interp.set_tensor(inp_d['index'], inp_data.astype(inp_d['dtype']))
                    interp.invoke()
                    preds = interp.get_tensor(out_d['index']).flatten()
                elif model_type == 'xgboost':
                    preds = model_obj.predict_proba(feats_scaled.astype(np.float32))[:, positive_class_index]
                else:
                    raise ValueError(f'Unknown model type: {model_type}')
                t_pred += time.perf_counter() - t0
                n_pix  += feats_scaled.shape[0]
            except Exception as e:
                print(f'\n    Prediction error: {e}'); traceback.print_exc()
                dst.write(out_blk, 1, window=window); continue

            # Map predictions back to full output block
            tmp = np.full(n_valid, NODATA_VALUE, dtype=OUTPUT_DTYPE)
            tmp[nan_mask] = preds
            out_blk[valid_mask] = tmp
            out_blk[np.isnan(out_blk) | np.isinf(out_blk)] = NODATA_VALUE
            dst.write(out_blk, 1, window=window)

    if n_pix > 0 and t_pred > 0:
        print(f'    Inference: {n_pix:,} px | {n_pix/t_pred:,.0f} px/s | {t_pred/n_pix*1e6:.2f} µs/px')
    print(f'  Saved: {out_path}')
    return out_path

## 4. Load Models & Artefacts

Load the feature scaler, label encoder, and all selected models.

In [None]:
SCALER_PATH       = os.path.join(MODEL_ROOT_DIR, 'scaler_classification.joblib')
LABEL_ENC_PATH    = os.path.join(MODEL_ROOT_DIR, 'label_encoder_classification.joblib')

scaler, le, positive_class_index = load_scaler_and_encoder(
    SCALER_PATH, LABEL_ENC_PATH, POSITIVE_CLASS_LABEL
)

all_loaded_models = []

# Load TFLite models
if TF_AVAILABLE:
    for fname in TFLITE_MODELS:
        path = os.path.join(MODEL_ROOT_DIR, fname)
        if not os.path.exists(path):
            print(f'Skipping (not found): {path}'); continue
        try:
            interp = tf.lite.Interpreter(model_path=path)
            interp.allocate_tensors()
            all_loaded_models.append({
                'name_prefix': os.path.splitext(fname)[0],
                'model_obj': interp,
                'type': 'tflite',
            })
            print(f'Loaded TFLite : {fname}')
        except Exception as e:
            print(f'Error loading {fname}: {e}')
else:
    print('Skipping TFLite models (TensorFlow not available).')

# Load XGBoost models
if XGB_AVAILABLE:
    for fname in XGBOOST_MODELS:
        path = os.path.join(MODEL_ROOT_DIR, fname)
        if not os.path.exists(path):
            print(f'Skipping (not found): {path}'); continue
        try:
            model = joblib.load(path)
            if isinstance(model, xgb.XGBClassifier):
                model.set_params(device='cuda')  # ensure GPU inference
                all_loaded_models.append({
                    'name_prefix': os.path.splitext(fname)[0],
                    'model_obj': model,
                    'type': 'xgboost',
                })
                print(f'Loaded XGBoost: {fname}')
        except Exception as e:
            print(f'Error loading {fname}: {e}')
else:
    print('Skipping XGBoost models (XGBoost not available).')

print(f'\nTotal models loaded: {len(all_loaded_models)}')

## 5. Classify Satellite Scenes

Run every loaded model on both the example Landsat-8 and Sentinel-2 scenes.
Each combination produces one GeoTIFF fractional cover map in `output/fractional_cover_maps/`.

In [None]:
LANDSAT_SCENE_ID  = 'LC08_L1GT_016046_20150723_20200908_02_T2'
SENTINEL_SCENE_ID = 'S2B_MSIL2A_20180827T160059_N0500_R097_T16QGH_20230624T195041.SAFE'

SCENES = [
    ('Landsat-8',  os.path.join(DATA_ROOT_DIR, 'landsat',  LANDSAT_SCENE_ID),  LANDSAT_SCENE_ID),
    ('Sentinel-2', os.path.join(DATA_ROOT_DIR, 'sentinel', SENTINEL_SCENE_ID), SENTINEL_SCENE_ID),
]

all_output_paths = []

for model_info in all_loaded_models:
    print(f'\n===== Model: {model_info["name_prefix"]} ({model_info["type"]}) =====')
    for sensor_key, scene_path, scene_id in SCENES:
        sensor_cfg = SENSOR_CONFIG[sensor_key]
        band_files = find_scene_band_files(scene_path, scene_id, sensor_cfg)
        if not band_files:
            print(f'  Skipping {scene_id} (band files not found).'); continue
        out = classify_scene(
            scene_id, band_files, sensor_cfg, model_info,
            scaler, positive_class_index, OUTPUT_DIR,
        )
        if out:
            all_output_paths.append(out)

print(f'\nAll done. {len(all_output_paths)} maps generated:')
for p in sorted(all_output_paths):
    print(f'  {p}')

## 6. Visualize Fractional Cover Maps

Display all generated maps. Pixel values represent the predicted **probability of Sargassum presence** (0 = no sargassum, 1 = sargassum). Nodata pixels are shown in white.

In [None]:
def show_cover_map(tif_path, title, ax, cmap='YlOrBr'):
    """Display a fractional cover GeoTIFF on a matplotlib axis."""
    with rasterio.open(tif_path) as src:
        data   = src.read(1).astype(np.float32)
        nodata = src.nodata
    if nodata is not None:
        data = np.where(data == nodata, np.nan, data)
    im = ax.imshow(data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest')
    ax.set_title(title, fontsize=8)
    ax.axis('off')
    return im


tif_files = sorted(glob.glob(os.path.join(OUTPUT_DIR, 'fractional_cover_*.tif')))
print(f'Found {len(tif_files)} fractional cover maps.')

if tif_files:
    n_cols = 2
    n_rows = math.ceil(len(tif_files) / n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 7, n_rows * 5))
    axes_flat = np.array(axes).flatten()

    for i, tif_path in enumerate(tif_files):
        label = os.path.basename(tif_path).replace('fractional_cover_', '').replace('.tif', '')
        try:
            im = show_cover_map(tif_path, label, axes_flat[i])
            plt.colorbar(im, ax=axes_flat[i], label='Sargassum probability', shrink=0.85, pad=0.02)
        except Exception as e:
            axes_flat[i].text(0.5, 0.5, f'Error:\n{e}', ha='center', va='center',
                              transform=axes_flat[i].transAxes, fontsize=8)
            axes_flat[i].axis('off')

    for j in range(len(tif_files), len(axes_flat)):
        axes_flat[j].axis('off')

    plt.suptitle('Sargassum Fractional Cover Maps', fontsize=14, fontweight='bold', y=1.01)
    plt.tight_layout()
    viz_path = os.path.join(OUTPUT_DIR, 'comparison_visualization.png')
    plt.savefig(viz_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f'Visualization saved to: {viz_path}')
else:
    print('No output maps found. Run Section 5 first.')

## Conclusion

This notebook demonstrated the full inference workflow:
- **TFLite** models (INT8 / FLOAT16) via the TensorFlow Lite interpreter
- **XGBoost GPU** model via `predict_proba`

To extend the workflow:
- Add your own satellite scenes to `satellite_data/` and add entries to the `SCENES` list.
- Add or remove model filenames from `TFLITE_MODELS` / `XGBOOST_MODELS`.
- The output GeoTIFFs can be opened in QGIS, ArcGIS, or any GIS tool for further analysis.