# Wetland Training Dataset — Smart Spatial Split

**Output:** `wetland_dataset_smart_split.npz`

This notebook fixes the **spatial data leakage** issue using a **data-driven test region**
instead of an arbitrary geographic fraction.

## Why a column-based (east/west) split?

Analysis of the labels raster (`bow_river_wetlands_10m_final.tif`) revealed that the
rare wetland classes are geographically concentrated in the **western (left) portion** of the study area:

| Class | Pixel count | Col range |
|-------|------------|----------|
| Class 1 (rare) | 19,225 | cols 1,000 – 6,311 |
| Class 2 | 901,620 | cols 747 – 7,037 |

A row-based (north/south) split completely missed these classes in the test region.
This notebook holds out all tiles with column offset **< `TEST_COL_MAX`** (≈ left 22% of the map),
which is guaranteed to include every pixel of Classes 1 and 2.

In [None]:
# CELL 1: Setup
print('Setting up environment...')

import os
from google.colab import drive

if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')
else:
    print('Drive already mounted')

!pip install -q rasterio tqdm

import numpy as np
import rasterio
from pathlib import Path
from tqdm import tqdm

print('Setup complete!')

In [None]:
# CELL 2: Configuration
print('=' * 70)
print('CONFIGURATION')
print('=' * 70)

labels_file    = '/content/drive/MyDrive/bow_river_wetlands_10m_final.tif'
embeddings_dir = Path('/content/drive/MyDrive/EarthEngine')
output_file    = '/content/drive/MyDrive/wetland_dataset_smart_split.npz'

# ─────────────────────────────────────────────────────────────────────
# SMART SPLIT: Hold out tiles whose column offset is LESS than this value.
#
# This value is chosen to fully enclose the rare wetland classes:
#   Class 1: max col = 6,311
#   Class 2: max col = 7,037
#
# TEST_COL_MAX = 8,192 adds ~1,155 col pixels of buffer past Class 2's edge.
# This captures all Class 1 and Class 2 pixels in the test set.
# (Approx 22% of the 31,427-wide map.)
TEST_COL_MAX = 8192
# ─────────────────────────────────────────────────────────────────────

# Per-class sample budgets
train_samples_per_class = {
    0: 600_000,
    1: 19_225,
    2: 150_000,
    3: 500_000,
    4: 150_000,
    5: 100_000,
}
# Test budget: 25% of training budget per class
test_samples_per_class = {cls: max(1000, int(n * 0.25)) for cls, n in train_samples_per_class.items()}

print(f'Labels:         {labels_file}')
print(f'Embeddings dir: {embeddings_dir}')
print(f'Output:         {output_file}')
print(f'\nTest region:  tiles with col_offset < {TEST_COL_MAX} (left ~22% of map)')
print(f'Train region: tiles with col_offset >= {TEST_COL_MAX} (right ~78%)')
print(f'\nTrain target: {sum(train_samples_per_class.values()):,} samples')
print(f'Test target:  {sum(test_samples_per_class.values()):,} samples')

assert os.path.exists(labels_file), 'Labels TIF not found'
assert embeddings_dir.exists(), 'Embeddings dir not found'

print('\nConfiguration validated!')

In [None]:
# CELL 3: Discover tiles and split into train / test sets by column
print('=' * 70)
print('DISCOVERING TILES AND SPLITTING (COLUMN-BASED)')
print('=' * 70)

all_tile_files = sorted(embeddings_dir.glob('*.tif'))
print(f'Found {len(all_tile_files)} total tiles')

# Parse row/col offsets from filenames: *-RRRRRRRRRR-CCCCCCCCCC.tif
tile_info = []
for tf in all_tile_files:
    parts = tf.stem.split('-')
    if len(parts) >= 3:
        try:
            row_off = int(parts[-2])
            col_off = int(parts[-1])
            tile_info.append((row_off, col_off, tf))
        except ValueError:
            pass

print(f'Parseable tiles: {len(tile_info)}')

# Split: test = left columns (col_off < TEST_COL_MAX), train = rest
test_tiles  = [p for r, c, p in tile_info if c < TEST_COL_MAX]
train_tiles = [p for r, c, p in tile_info if c >= TEST_COL_MAX]

all_col_offsets = sorted(set(c for r, c, p in tile_info))
col_min = all_col_offsets[0]
col_max = all_col_offsets[-1]

print(f'\nColumn range of all tiles: {col_min} — {col_max}')
print(f'TEST_COL_MAX threshold:    {TEST_COL_MAX}')
print(f'Train tiles: {len(train_tiles)}')
print(f'Test tiles:  {len(test_tiles)}')
print(f'Test fraction: {len(test_tiles) / len(tile_info) * 100:.1f}% of all tiles')

if not test_tiles:
    raise RuntimeError('No test tiles found — increase TEST_COL_MAX or check tile naming.')

print('\nTest tiles (left column band):')
for p in test_tiles:
    print(f'  {p.name}')

print('\nSpatial split defined!')

In [None]:
# CELL 4: Sample pixel coordinates restricted to each split's bounding box
print('\n' + '=' * 70)
print('SAMPLING PIXEL COORDINATES')
print('=' * 70)

np.random.seed(42)

def sample_coords_from_tiles(tile_paths, samples_per_class, split_name):
    """Sample pixel coordinates only within the bounding box of the given tiles."""
    if not tile_paths:
        return {cls: {'y': [], 'x': []} for cls in samples_per_class}, {cls: 0 for cls in samples_per_class}

    # Compute bounding box of these tiles in raster pixel coords
    tile_info_local = []
    for tf in tile_paths:
        parts = tf.stem.split('-')
        try:
            r = int(parts[-2]); c = int(parts[-1])
            with rasterio.open(tf) as src:
                tile_info_local.append((r, c, src.height, src.width))
        except Exception:
            pass

    bbox_row_min = min(r for r, c, h, w in tile_info_local)
    bbox_row_max = max(r + h for r, c, h, w in tile_info_local)
    bbox_col_min = min(c for r, c, h, w in tile_info_local)
    bbox_col_max = max(c + w for r, c, h, w in tile_info_local)

    print(f'\n  {split_name} tile bounding box: rows {bbox_row_min}–{bbox_row_max}, cols {bbox_col_min}–{bbox_col_max}')

    sampled   = {cls: {'y': [], 'x': []} for cls in samples_per_class}
    collected = {cls: 0 for cls in samples_per_class}

    with rasterio.open(labels_file) as src:
        windows = list(src.block_windows(1))
        np.random.shuffle(windows)

        for idx, (block_id, window) in tqdm(enumerate(windows), total=len(windows), desc=f'{split_name} blocks'):
            row_off = window.row_off
            col_off = window.col_off
            win_h   = window.height
            win_w   = window.width

            # Skip blocks outside this split's tile bounding box
            if row_off + win_h <= bbox_row_min or row_off >= bbox_row_max:
                continue
            if col_off + win_w <= bbox_col_min or col_off >= bbox_col_max:
                continue

            labels_chunk = src.read(1, window=window)

            for cls in samples_per_class:
                if collected[cls] >= samples_per_class[cls]:
                    continue
                y_local, x_local = np.where(labels_chunk == cls)
                if len(y_local) == 0:
                    continue

                y_global = y_local + row_off
                x_global = x_local + col_off

                # Keep only pixels inside this split's bbox
                in_bbox = (
                    (y_global >= bbox_row_min) & (y_global < bbox_row_max) &
                    (x_global >= bbox_col_min) & (x_global < bbox_col_max)
                )
                y_global = y_global[in_bbox]
                x_global = x_global[in_bbox]
                if len(y_global) == 0:
                    continue

                needed = samples_per_class[cls] - collected[cls]
                if len(y_global) > needed:
                    idx_s = np.random.choice(len(y_global), needed, replace=False)
                    y_global = y_global[idx_s]
                    x_global = x_global[idx_s]

                sampled[cls]['y'].append(y_global)
                sampled[cls]['x'].append(x_global)
                collected[cls] += len(y_global)

            if all(collected[cls] >= samples_per_class[cls] for cls in samples_per_class):
                print(f'\n  Got all {split_name} samples after {idx+1} blocks')
                break

    print(f'  {split_name} collection summary:')
    for cls in samples_per_class:
        print(f'    Class {cls}: {collected[cls]:,} / {samples_per_class[cls]:,}')

    return sampled, collected


train_sampled, train_collected = sample_coords_from_tiles(train_tiles, train_samples_per_class, 'TRAIN')
test_sampled,  test_collected  = sample_coords_from_tiles(test_tiles,  test_samples_per_class,  'TEST')

print('\nCoordinate sampling complete!')

In [None]:
# CELL 5: Consolidate sampled coords into flat arrays

def consolidate(sampled, samples_per_class):
    all_y, all_x, all_labels = [], [], []
    for cls in samples_per_class:
        if not sampled[cls]['y']:
            continue
        ys = np.concatenate(sampled[cls]['y'])
        xs = np.concatenate(sampled[cls]['x'])
        if len(ys) > samples_per_class[cls]:
            ys = ys[:samples_per_class[cls]]
            xs = xs[:samples_per_class[cls]]
        all_y.append(ys)
        all_x.append(xs)
        all_labels.append(np.full(len(ys), cls))
    y_idx  = np.concatenate(all_y)
    x_idx  = np.concatenate(all_x)
    labels = np.concatenate(all_labels)
    shuf = np.random.permutation(len(labels))
    return y_idx[shuf], x_idx[shuf], labels[shuf]


train_y_idx, train_x_idx, train_labels = consolidate(train_sampled, train_samples_per_class)
test_y_idx,  test_x_idx,  test_labels  = consolidate(test_sampled,  test_samples_per_class)

print(f'Train coordinates: {len(train_labels):,}')
print(f'Test  coordinates: {len(test_labels):,}')

In [None]:
# CELL 6: Extract embeddings for each split
print('\n' + '=' * 70)
print('EXTRACTING EMBEDDINGS')
print('=' * 70)

def extract_embeddings(tile_files, y_indices, x_indices, desc):
    n = len(y_indices)
    X = np.zeros((n, 64), dtype=np.float32)
    found = np.zeros(n, dtype=bool)

    with tqdm(total=len(tile_files), desc=desc, unit=' tiles') as pbar:
        for tile_file in tile_files:
            try:
                with rasterio.open(tile_file) as src:
                    if src.count != 64:
                        pbar.update(1); continue

                    parts = tile_file.stem.split('-')
                    try:
                        r_off = int(parts[-2])
                        c_off = int(parts[-1])
                    except (ValueError, IndexError):
                        pbar.update(1); continue

                    th, tw = src.height, src.width
                    in_y = (y_indices >= r_off) & (y_indices < r_off + th)
                    in_x = (x_indices >= c_off) & (x_indices < c_off + tw)
                    mask = in_y & in_x

                    if mask.any():
                        tile_data = src.read()  # (64, H, W)
                        if tile_data.shape[0] != 64:
                            pbar.update(1); continue

                        local_y = y_indices[mask] - r_off
                        local_x = x_indices[mask] - c_off
                        vals = tile_data[:, local_y, local_x].T  # (n_px, 64)

                        valid = ~np.isnan(vals).any(axis=1)
                        g_idx = np.where(mask)[0]
                        X[g_idx[valid]] = vals[valid]
                        found[g_idx[valid]] = True

            except Exception as e:
                print(f'\nError {tile_file.name}: {e}')
            pbar.update(1)
            pbar.set_postfix({'found': f'{found.sum():,}/{n:,}'})

    print(f'  Extracted {found.sum():,} / {n:,}')
    return X, found


print('\n-- TRAIN --')
X_train_raw, train_found = extract_embeddings(train_tiles, train_y_idx, train_x_idx, 'Train tiles')
X_train = X_train_raw[train_found]
y_train = train_labels[train_found]

print('\n-- TEST --')
X_test_raw, test_found = extract_embeddings(test_tiles, test_y_idx, test_x_idx, 'Test tiles')
X_test = X_test_raw[test_found]
y_test = test_labels[test_found]

print(f'\nFinal train set: {X_train.shape}')
print(f'Final test set:  {X_test.shape}')
print('Extraction complete!')

In [None]:
# CELL 7: Compute class weights and save
print('\n' + '=' * 70)
print('SAVING DATASET')
print('=' * 70)

unique_cls, counts = np.unique(y_train, return_counts=True)
class_weights = np.zeros(6, dtype=np.float32)
for cls, cnt in zip(unique_cls, counts):
    class_weights[cls] = 1.0 / cnt
class_weights = class_weights / class_weights.sum() * 6

print('Class weights (from train):')
for cls in range(6):
    print(f'  Class {cls}: {class_weights[cls]:.4f}')

print(f'\nSaving to: {output_file}')
np.savez_compressed(
    output_file,
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    y_test=y_test,
    class_weights=class_weights,
    test_col_max=np.array(TEST_COL_MAX, dtype=np.int64),
)

print('\n' + '=' * 70)
print('DATASET SAVED: wetland_dataset_smart_split.npz')
print('=' * 70)
print(f'  X_train: {X_train.shape}  (train pixels — right ~78% of map)')
print(f'  y_train: {y_train.shape}')
print(f'  X_test:  {X_test.shape}   (test pixels — left ~22%, guaranteed all-class)')
print(f'  y_test:  {y_test.shape}')
print(f'  test_col_max: {TEST_COL_MAX}')
print(f'\nNext steps:')
print(f'  1. Download wetland_dataset_smart_split.npz from Google Drive')
print(f'  2. Place it in the repo root (same level as random_forest/)')
print(f'  3. Run: python random_forest_spatial/model_rf_spatial.py')

In [None]:
# CELL 8: Quick verification
print('\n' + '=' * 70)
print('VERIFICATION')
print('=' * 70)

d = np.load(output_file)
print(f'Arrays: {list(d.keys())}')
print(f'X_train: {d["X_train"].shape}  dtype={d["X_train"].dtype}  NaN={np.isnan(d["X_train"]).any()}')
print(f'X_test:  {d["X_test"].shape}   dtype={d["X_test"].dtype}   NaN={np.isnan(d["X_test"]).any()}')
print(f'y_train classes: {np.unique(d["y_train"])}')
print(f'y_test  classes: {np.unique(d["y_test"])}')
print(f'test_col_max:    {int(d["test_col_max"])}')
d.close()

print('\nVerification passed!')