# Wetland Training Dataset — Spatial Split

**Output:** `wetland_dataset_spatial_split.npz`

This notebook fixes the **spatial data leakage** problem in the original dataset creator.
Instead of a random pixel-level 80/20 split, the dataset is split **by tile region**:
- Training pixels are sampled only from tiles in the top ~80% of the study area (by row offset)
- Test pixels are sampled only from tiles in the bottom ~20% (a contiguous geographic block the model never sees during training)

This produces a geographically honest evaluation — the model is tested on a region it has genuinely never encountered.

In [None]:
# CELL 1: Setup
print('Setting up environment...')

import os
from google.colab import drive

if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')
else:
    print('Drive already mounted')

!pip install -q rasterio tqdm

import numpy as np
import rasterio
from pathlib import Path
from tqdm import tqdm

print('Setup complete!')

In [None]:
# CELL 2: Configuration
print('=' * 70)
print('CONFIGURATION')
print('=' * 70)

labels_file    = '/content/drive/MyDrive/bow_river_wetlands_10m_final.tif'
embeddings_dir = Path('/content/drive/MyDrive/EarthEngine')
output_file    = '/content/drive/MyDrive/wetland_dataset_spatial_split.npz'

# ──────────────────────────────────────────────────────────────
# SPATIAL SPLIT: what fraction of tiles (by row offset) to hold
# out as the test region. 0.20 = bottom 20% of tile rows.
TEST_TILE_FRACTION = 0.20
# ──────────────────────────────────────────────────────────────

# Per-class sample budgets (training pixels only)
train_samples_per_class = {
    0: 600_000,
    1: 19_225,
    2: 150_000,
    3: 500_000,
    4: 150_000,
    5: 100_000,
}
# Test budget: ~20% of training budget
test_samples_per_class = {cls: max(1000, int(n * 0.25)) for cls, n in train_samples_per_class.items()}

print(f'Labels:     {labels_file}')
print(f'Embeddings: {embeddings_dir}')
print(f'Output:     {output_file}')
print(f'Test tile fraction: {TEST_TILE_FRACTION*100:.0f}% of row range held out')
print(f'\nTrain target: {sum(train_samples_per_class.values()):,} samples')
print(f'Test target:  {sum(test_samples_per_class.values()):,} samples')

assert os.path.exists(labels_file), 'Labels not found'
assert embeddings_dir.exists(), 'Embeddings dir not found'

print('Configuration validated!')

In [None]:
# CELL 3: Discover tiles and split into train / test sets
print('=' * 70)
print('DISCOVERING TILES AND SPLITTING SPATIALLY')
print('=' * 70)

all_tile_files = sorted(embeddings_dir.glob('*.tif'))
print(f'Found {len(all_tile_files)} total tiles')

# Parse row offsets from filenames: *-RRRRRRRRRR-CCCCCCCCCC.tif
tile_info = []  # list of (row_offset, col_offset, path)
for tf in all_tile_files:
    parts = tf.stem.split('-')
    if len(parts) >= 3:
        try:
            row_off = int(parts[-2])
            col_off = int(parts[-1])
            tile_info.append((row_off, col_off, tf))
        except ValueError:
            pass

print(f'Parseable tiles: {len(tile_info)}')

# Determine the cutoff row: top TEST_TILE_FRACTION of the row range becomes test
all_row_offsets = sorted(set(r for r, c, p in tile_info))
row_min = all_row_offsets[0]
row_max = all_row_offsets[-1]
total_row_range = row_max - row_min + 1
cutoff_rows_from_bottom = int(total_row_range * TEST_TILE_FRACTION)
test_row_min = row_max - cutoff_rows_from_bottom + 1

train_tiles = [p for r, c, p in tile_info if r < test_row_min]
test_tiles  = [p for r, c, p in tile_info if r >= test_row_min]

print(f'\nRow range: {row_min} – {row_max} (range = {total_row_range})')
print(f'Test region: rows >= {test_row_min}  (bottom {TEST_TILE_FRACTION*100:.0f}% of range)')
print(f'Train tiles: {len(train_tiles)}')
print(f'Test tiles:  {len(test_tiles)}')

if not test_tiles:
    raise RuntimeError('No test tiles found — increase TEST_TILE_FRACTION or check tile naming.')

print('\nTest tiles:')
for p in test_tiles:
    print(f'  {p.name}')

print('\nSpatial split defined!')

In [None]:
# CELL 4: Sample coordinates from labels raster, restricted per split
print('\n' + '=' * 70)
print('SAMPLING PIXEL COORDINATES')
print('=' * 70)

np.random.seed(42)

def sample_coords_from_tiles(tile_paths, samples_per_class, split_name):
    """Sample pixel coordinates only within the bounding box of the given tiles."""
    if not tile_paths:
        return {cls: {'y': [], 'x': []} for cls in samples_per_class}, {cls: 0 for cls in samples_per_class}

    # Compute bounding box of these tiles (in raster pixel coords)
    tile_info_local = []
    for tf in tile_paths:
        parts = tf.stem.split('-')
        try:
            r = int(parts[-2]); c = int(parts[-1])
            with rasterio.open(tf) as src:
                tile_info_local.append((r, c, src.height, src.width))
        except Exception:
            pass

    row_starts = [r for r, c, h, w in tile_info_local]
    row_ends   = [r + h for r, c, h, w in tile_info_local]
    col_starts = [c for r, c, h, w in tile_info_local]
    col_ends   = [c + w for r, c, h, w in tile_info_local]

    bbox_row_min = min(row_starts)
    bbox_row_max = max(row_ends)
    bbox_col_min = min(col_starts)
    bbox_col_max = max(col_ends)

    print(f'\n  {split_name} tile bounding box: rows {bbox_row_min}–{bbox_row_max}, cols {bbox_col_min}–{bbox_col_max}')

    sampled = {cls: {'y': [], 'x': []} for cls in samples_per_class}
    collected = {cls: 0 for cls in samples_per_class}

    with rasterio.open(labels_file) as src:
        windows = list(src.block_windows(1))
        np.random.shuffle(windows)

        for idx, (block_id, window) in tqdm(enumerate(windows), total=len(windows), desc=f'{split_name} blocks'):
            row_off = window.row_off
            col_off = window.col_off
            win_h   = window.height
            win_w   = window.width

            # Skip blocks outside this split's tile bounding box
            if row_off + win_h <= bbox_row_min or row_off >= bbox_row_max:
                continue
            if col_off + win_w <= bbox_col_min or col_off >= bbox_col_max:
                continue

            labels_chunk = src.read(1, window=window)

            for cls in samples_per_class:
                if collected[cls] >= samples_per_class[cls]:
                    continue
                y_local, x_local = np.where(labels_chunk == cls)
                if len(y_local) == 0:
                    continue

                y_global = y_local + row_off
                x_global = x_local + col_off

                # Keep only pixels inside tile bounding box
                in_bbox = (
                    (y_global >= bbox_row_min) & (y_global < bbox_row_max) &
                    (x_global >= bbox_col_min) & (x_global < bbox_col_max)
                )
                y_global = y_global[in_bbox]
                x_global = x_global[in_bbox]
                if len(y_global) == 0:
                    continue

                needed = samples_per_class[cls] - collected[cls]
                if len(y_global) > needed:
                    idx_s = np.random.choice(len(y_global), needed, replace=False)
                    y_global = y_global[idx_s]
                    x_global = x_global[idx_s]

                sampled[cls]['y'].append(y_global)
                sampled[cls]['x'].append(x_global)
                collected[cls] += len(y_global)

            if all(collected[cls] >= samples_per_class[cls] for cls in samples_per_class):
                print(f'\n  Got all {split_name} samples after {idx+1} blocks')
                break

    print(f'  {split_name} collection summary:')
    for cls in samples_per_class:
        print(f'    Class {cls}: {collected[cls]:,} / {samples_per_class[cls]:,}')

    return sampled, collected


train_sampled, train_collected = sample_coords_from_tiles(train_tiles, train_samples_per_class, 'TRAIN')
test_sampled,  test_collected  = sample_coords_from_tiles(test_tiles,  test_samples_per_class,  'TEST')

print('\nCoordinate sampling complete!')

In [None]:
# CELL 5: Helper — consolidate sampled coords into flat arrays

def consolidate(sampled, samples_per_class):
    all_y, all_x, all_labels = [], [], []
    for cls in samples_per_class:
        if not sampled[cls]['y']:
            continue
        ys = np.concatenate(sampled[cls]['y'])
        xs = np.concatenate(sampled[cls]['x'])
        if len(ys) > samples_per_class[cls]:
            ys = ys[:samples_per_class[cls]]
            xs = xs[:samples_per_class[cls]]
        all_y.append(ys)
        all_x.append(xs)
        all_labels.append(np.full(len(ys), cls))

    y_idx  = np.concatenate(all_y)
    x_idx  = np.concatenate(all_x)
    labels = np.concatenate(all_labels)

    shuf = np.random.permutation(len(labels))
    return y_idx[shuf], x_idx[shuf], labels[shuf]

train_y_idx, train_x_idx, train_labels = consolidate(train_sampled, train_samples_per_class)
test_y_idx,  test_x_idx,  test_labels  = consolidate(test_sampled,  test_samples_per_class)

print(f'Train coordinates: {len(train_labels):,}')
print(f'Test  coordinates: {len(test_labels):,}')

In [None]:
# CELL 6: Extract embeddings for each split
print('\n' + '=' * 70)
print('EXTRACTING EMBEDDINGS')
print('=' * 70)

def extract_embeddings(tile_files, y_indices, x_indices, desc):
    n = len(y_indices)
    X = np.zeros((n, 64), dtype=np.float32)
    found = np.zeros(n, dtype=bool)

    with tqdm(total=len(tile_files), desc=desc, unit=' tiles') as pbar:
        for tile_file in tile_files:
            try:
                with rasterio.open(tile_file) as src:
                    if src.count != 64:
                        pbar.update(1); continue

                    parts = tile_file.stem.split('-')
                    try:
                        r_off = int(parts[-2])
                        c_off = int(parts[-1])
                    except (ValueError, IndexError):
                        pbar.update(1); continue

                    th, tw = src.height, src.width
                    in_y = (y_indices >= r_off) & (y_indices < r_off + th)
                    in_x = (x_indices >= c_off) & (x_indices < c_off + tw)
                    mask = in_y & in_x

                    if mask.any():
                        tile_data = src.read()  # (64, H, W)
                        if tile_data.shape[0] != 64:
                            pbar.update(1); continue

                        local_y = y_indices[mask] - r_off
                        local_x = x_indices[mask] - c_off
                        vals = tile_data[:, local_y, local_x].T  # (n_px, 64)

                        valid = ~np.isnan(vals).any(axis=1)
                        g_idx = np.where(mask)[0]
                        X[g_idx[valid]] = vals[valid]
                        found[g_idx[valid]] = True

            except Exception as e:
                print(f'\nError {tile_file.name}: {e}')
            pbar.update(1)
            pbar.set_postfix({'found': f'{found.sum():,}/{n:,}'})

    print(f'  Extracted {found.sum():,} / {n:,}')
    return X, found


print('\n-- TRAIN --')
X_train_raw, train_found = extract_embeddings(train_tiles, train_y_idx, train_x_idx, 'Train tiles')
X_train = X_train_raw[train_found]
y_train = train_labels[train_found]

print('\n-- TEST --')
X_test_raw, test_found = extract_embeddings(test_tiles,  test_y_idx,  test_x_idx,  'Test tiles')
X_test = X_test_raw[test_found]
y_test = test_labels[test_found]

print(f'\nFinal train set: {X_train.shape}')
print(f'Final test set:  {X_test.shape}')
print('Extraction complete!')

In [None]:
# CELL 7: Compute class weights and save
print('\n' + '=' * 70)
print('SAVING DATASET')
print('=' * 70)

# Class weights from training distribution only
unique_cls, counts = np.unique(y_train, return_counts=True)
class_weights = np.zeros(6, dtype=np.float32)
for cls, cnt in zip(unique_cls, counts):
    class_weights[cls] = 1.0 / cnt
class_weights = class_weights / class_weights.sum() * 6  # normalize

print('Class weights (from train):')
for cls in range(6):
    print(f'  Class {cls}: {class_weights[cls]:.4f}')

print(f'\nSaving to: {output_file}')
np.savez_compressed(
    output_file,
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    y_test=y_test,
    class_weights=class_weights,
    test_row_min=np.array(test_row_min, dtype=np.int64),
)

print('\n' + '=' * 70)
print('DATASET SAVED: wetland_dataset_spatial_split.npz')
print('=' * 70)
print(f'  X_train: {X_train.shape}  (train pixels from geographic top ~80%)')
print(f'  y_train: {y_train.shape}')
print(f'  X_test:  {X_test.shape}   (test pixels from geographic bottom ~20%)')
print(f'  y_test:  {y_test.shape}')
print(f'  test_row_min: {test_row_min}  (use this in visualize_test_region.py)')
print(f'\nNext steps:')
print(f'  1. Download wetland_dataset_spatial_split.npz from Google Drive')
print(f'  2. Place it in the repo root (same level as random_forest/)')
print(f'  3. Run: python random_forest/model_rf_spatial.py')
print(f'  4. Run: python random_forest/visualize_test_region.py <embeddings_dir>')

In [None]:
# CELL 8: Quick verification
print('\n' + '=' * 70)
print('VERIFICATION')
print('=' * 70)

d = np.load(output_file)
print(f'Arrays: {list(d.keys())}')
print(f'X_train: {d["X_train"].shape}  dtype={d["X_train"].dtype}  NaN={np.isnan(d["X_train"]).any()}')
print(f'X_test:  {d["X_test"].shape}   dtype={d["X_test"].dtype}   NaN={np.isnan(d["X_test"]).any()}')
print(f'y_train classes: {np.unique(d["y_train"])}')
print(f'y_test  classes: {np.unique(d["y_test"])}')
print(f'test_row_min: {int(d["test_row_min"])}')
d.close()

print('\nVerification passed!')