# Wetland Training Dataset — Smart Spatial Split

**Output:** `wetland_dataset_smart_split.npz`

## Split Strategy

Classes 1 and 2 are both confined to the western portion of the map and cannot be
geographically split. A **mixed split** is used:

| Class | Split type | Rationale |
|-------|----------|----------|
| 0, 3, 4, 5 | **Geographic** — left/right column tile split | Pixels spread across entire map |
| **1** | **Random 75/25 within zone** | All 19,225 pixels in cols 1,000–6,311 |
| **2** | **Random 75/25 within zone** | All 901,620 pixels in cols 747–7,037 |

> **Documented limitation:** Classes 1 and 2 train/test splits are not geographically
> independent. This should be noted in the report.

In [None]:
# CELL 1: Setup
import os, gc, shutil
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')
else:
    print('Drive already mounted')

!pip install -q rasterio tqdm
import numpy as np
import rasterio
from pathlib import Path
from tqdm import tqdm

DRIVE_TILES = '/content/drive/MyDrive/EarthEngine'
LOCAL_TILES  = '/content/EarthEngine'
free_gb = shutil.disk_usage('/content').free / (1024**3)
print(f'Free disk space: {free_gb:.1f} GB')

if free_gb > 110 and not os.path.exists(LOCAL_TILES):
    print('Copying tiles to local disk (~10 min)...')
    !cp -r {DRIVE_TILES} {LOCAL_TILES}
    TILES_PATH = LOCAL_TILES
elif os.path.exists(LOCAL_TILES) and len(list(Path(LOCAL_TILES).glob('*.tif'))) > 0:
    TILES_PATH = LOCAL_TILES
    print('Local tiles already present.')
else:
    TILES_PATH = DRIVE_TILES
    print(f'Insufficient disk ({free_gb:.1f} GB) — reading from Drive. Extraction will take ~60-90 min.')

print(f'Tile source: {TILES_PATH}')
print('Setup complete!')

In [None]:
# CELL 2: Configuration
print('=' * 70)
print('CONFIGURATION')
print('=' * 70)

labels_file    = '/content/drive/MyDrive/bow_river_wetlands_10m_final.tif'
embeddings_dir = Path(TILES_PATH)
output_file    = '/content/drive/MyDrive/wetland_dataset_smart_split.npz'

# Geographic split threshold (classes 0, 3, 4, 5 only)
TEST_COL_MAX = 8192

# Classes to handle with random within-zone split (both are confined to western region)
WESTERN_CLASSES = {
    1: {'row_min': 764,  'row_max': 15197, 'col_min': 1000, 'col_max': 6311,  'train_budget': 14_418, 'test_budget': 4_806},
    2: {'row_min': 45,   'row_max': 15175, 'col_min': 747,  'col_max': 7037,  'train_budget': 150_000, 'test_budget': 37_500},
}

# Geographic classes (split by tile column)
train_samples_geo = {0: 600_000, 3: 500_000, 4: 150_000, 5: 100_000}
test_samples_geo  = {cls: max(1000, int(n * 0.25)) for cls, n in train_samples_geo.items()}

assert os.path.exists(labels_file), f'Labels TIF not found: {labels_file}'
assert embeddings_dir.exists(),     f'Embeddings dir not found: {embeddings_dir}'
tile_count = len(list(embeddings_dir.glob('*.tif')))
assert tile_count > 0, f'No .tif files in {embeddings_dir}'

print(f'Labels:     {labels_file}')
print(f'Embeddings: {embeddings_dir}  ({tile_count} tiles)')
print(f'Output:     {output_file}')
print(f'TEST_COL_MAX: {TEST_COL_MAX}')
print(f'Western classes (random split): {list(WESTERN_CLASSES.keys())}')
print(f'Geographic classes: {list(train_samples_geo.keys())}')
print('Configuration validated!')

In [None]:
# CELL 3: Discover tiles and do column-based geographic split
print('=' * 70)
print('DISCOVERING TILES — COLUMN SPLIT')
print('=' * 70)

all_tile_files = sorted(embeddings_dir.glob('*.tif'))
tile_info = []
for tf in all_tile_files:
    parts = tf.stem.split('-')
    if len(parts) >= 3:
        try: tile_info.append((int(parts[-2]), int(parts[-1]), tf))
        except ValueError: pass

if not tile_info:
    raise RuntimeError('No parseable tiles found — check tile naming (*-ROW-COL.tif)')

test_tiles  = [p for r, c, p in tile_info if c < TEST_COL_MAX]
train_tiles = [p for r, c, p in tile_info if c >= TEST_COL_MAX]

print(f'Total tiles: {len(tile_info)}')
print(f'Train tiles (col >= {TEST_COL_MAX}): {len(train_tiles)}')
print(f'Test  tiles (col <  {TEST_COL_MAX}): {len(test_tiles)}')
print(f'Test fraction: {len(test_tiles)/len(tile_info)*100:.1f}% of all tiles')

if not test_tiles:
    raise RuntimeError('No test tiles found — check TEST_COL_MAX')
print('Spatial split defined!')

In [None]:
# CELL 4: Sample Classes 1 and 2 (random split within their western zones)
print('=' * 70)
print('SAMPLING WESTERN CLASSES (1, 2) — RANDOM WITHIN-ZONE SPLIT')
print('=' * 70)
np.random.seed(42)

western_train_coords = {}  # cls -> (y, x)
western_test_coords  = {}

for cls, cfg in WESTERN_CLASSES.items():
    print(f'\n--- Class {cls} (zone: rows {cfg["row_min"]}–{cfg["row_max"]}, cols {cfg["col_min"]}–{cfg["col_max"]}) ---')
    y_all, x_all = [], []

    with rasterio.open(labels_file) as src:
        for block_id, window in tqdm(list(src.block_windows(1)), desc=f'Class {cls} scan'):
            r0 = window.row_off; c0 = window.col_off
            if r0 + window.height <= cfg['row_min'] or r0 > cfg['row_max']: continue
            if c0 + window.width  <= cfg['col_min'] or c0 > cfg['col_max']: continue
            chunk = src.read(1, window=window)
            y_l, x_l = np.where(chunk == cls)
            if len(y_l) == 0: continue
            y_all.append(y_l + r0)
            x_all.append(x_l + c0)

    y_cls = np.concatenate(y_all); x_cls = np.concatenate(x_all)
    del y_all, x_all; gc.collect()
    print(f'  Found {len(y_cls):,} pixels')

    # Subsample to budget total then split
    total_budget = cfg['train_budget'] + cfg['test_budget']
    if len(y_cls) > total_budget:
        idx = np.random.choice(len(y_cls), total_budget, replace=False)
        y_cls = y_cls[idx]; x_cls = x_cls[idx]

    shuf = np.random.permutation(len(y_cls))
    y_cls = y_cls[shuf]; x_cls = x_cls[shuf]; del shuf

    n_train = min(cfg['train_budget'], len(y_cls))
    western_train_coords[cls] = (y_cls[:n_train].copy(), x_cls[:n_train].copy())
    western_test_coords[cls]  = (y_cls[n_train:n_train + cfg['test_budget']].copy(),
                                  x_cls[n_train:n_train + cfg['test_budget']].copy())
    del y_cls, x_cls; gc.collect()

    print(f'  Train: {len(western_train_coords[cls][0]):,}  |  Test: {len(western_test_coords[cls][0]):,}')

print('\n⚠ NOTE: Classes 1 and 2 splits are NOT geographically independent (documented limitation).')

In [None]:
# CELL 5: Sample geographic classes (0, 3, 4, 5) using tile-based column split
print('=' * 70)
print('SAMPLING GEOGRAPHIC CLASSES (0, 3, 4, 5)')
print('=' * 70)

def sample_coords_from_tiles(tile_paths, samples_per_class, split_name):
    if not tile_paths:
        raise RuntimeError(f'No tiles provided for {split_name}.')
    tile_info_local = []
    for tf in tile_paths:
        parts = tf.stem.split('-')
        try:
            r = int(parts[-2]); c = int(parts[-1])
            with rasterio.open(tf) as s:
                tile_info_local.append((r, c, s.height, s.width))
        except Exception: pass
    if not tile_info_local:
        raise RuntimeError(f'{split_name}: could not open any tiles.')

    bbox_row_min = min(r   for r,c,h,w in tile_info_local)
    bbox_row_max = max(r+h for r,c,h,w in tile_info_local)
    bbox_col_min = min(c   for r,c,h,w in tile_info_local)
    bbox_col_max = max(c+w for r,c,h,w in tile_info_local)
    print(f'  {split_name} bbox: rows {bbox_row_min}–{bbox_row_max}, cols {bbox_col_min}–{bbox_col_max}')

    sampled   = {cls: {'y': [], 'x': []} for cls in samples_per_class}
    collected = {cls: 0 for cls in samples_per_class}

    with rasterio.open(labels_file) as src:
        windows = list(src.block_windows(1))
        np.random.shuffle(windows)
        for idx, (block_id, window) in tqdm(enumerate(windows), total=len(windows), desc=split_name):
            r0 = window.row_off; c0 = window.col_off
            rh = window.height;  cw = window.width
            if r0+rh <= bbox_row_min or r0 >= bbox_row_max: continue
            if c0+cw <= bbox_col_min or c0 >= bbox_col_max: continue
            chunk = src.read(1, window=window)
            for cls in samples_per_class:
                if collected[cls] >= samples_per_class[cls]: continue
                y_l, x_l = np.where(chunk == cls)
                if len(y_l) == 0: continue
                y_g = y_l + r0; x_g = x_l + c0
                in_b = (y_g>=bbox_row_min)&(y_g<bbox_row_max)&(x_g>=bbox_col_min)&(x_g<bbox_col_max)
                y_g = y_g[in_b]; x_g = x_g[in_b]
                if len(y_g) == 0: continue
                needed = samples_per_class[cls] - collected[cls]
                if len(y_g) > needed:
                    s = np.random.choice(len(y_g), needed, replace=False)
                    y_g = y_g[s]; x_g = x_g[s]
                sampled[cls]['y'].append(y_g)
                sampled[cls]['x'].append(x_g)
                collected[cls] += len(y_g)
            if all(collected[c] >= samples_per_class[c] for c in samples_per_class):
                print(f'  All collected after {idx+1} blocks'); break
    print(f'  {split_name} summary:')
    for cls in samples_per_class:
        print(f'    Class {cls}: {collected[cls]:,} / {samples_per_class[cls]:,}')
    return sampled

train_sampled_geo = sample_coords_from_tiles(train_tiles, train_samples_geo, 'TRAIN (geo)')
test_sampled_geo  = sample_coords_from_tiles(test_tiles,  test_samples_geo,  'TEST (geo)')
print('\nGeographic class sampling complete!')

In [None]:
# CELL 6: Consolidate all coordinates; aggressively free intermediate memory

def consolidate_geo(sampled, budgets):
    all_y, all_x, all_lbl = [], [], []
    for cls in budgets:
        if not sampled[cls]['y']: continue
        ys = np.concatenate(sampled[cls]['y'])
        xs = np.concatenate(sampled[cls]['x'])
        if len(ys) > budgets[cls]: ys, xs = ys[:budgets[cls]], xs[:budgets[cls]]
        all_y.append(ys); all_x.append(xs)
        all_lbl.append(np.full(len(ys), cls, dtype=np.int64))
    return np.concatenate(all_y), np.concatenate(all_x), np.concatenate(all_lbl)

geo_tr_y, geo_tr_x, geo_tr_lbl = consolidate_geo(train_sampled_geo, train_samples_geo)
del train_sampled_geo; gc.collect()
geo_te_y, geo_te_x, geo_te_lbl = consolidate_geo(test_sampled_geo,  test_samples_geo)
del test_sampled_geo; gc.collect()

# Build western class arrays
w_tr_y   = np.concatenate([western_train_coords[c][0] for c in WESTERN_CLASSES])
w_tr_x   = np.concatenate([western_train_coords[c][1] for c in WESTERN_CLASSES])
w_tr_lbl = np.concatenate([np.full(len(western_train_coords[c][0]), c, dtype=np.int64) for c in WESTERN_CLASSES])
w_te_y   = np.concatenate([western_test_coords[c][0]  for c in WESTERN_CLASSES])
w_te_x   = np.concatenate([western_test_coords[c][1]  for c in WESTERN_CLASSES])
w_te_lbl = np.concatenate([np.full(len(western_test_coords[c][0]),  c, dtype=np.int64) for c in WESTERN_CLASSES])
del western_train_coords, western_test_coords; gc.collect()

train_y = np.concatenate([geo_tr_y, w_tr_y])
train_x = np.concatenate([geo_tr_x, w_tr_x])
train_labels = np.concatenate([geo_tr_lbl, w_tr_lbl])
del geo_tr_y, geo_tr_x, geo_tr_lbl, w_tr_y, w_tr_x, w_tr_lbl; gc.collect()

test_y = np.concatenate([geo_te_y, w_te_y])
test_x = np.concatenate([geo_te_x, w_te_x])
test_labels = np.concatenate([geo_te_lbl, w_te_lbl])
del geo_te_y, geo_te_x, geo_te_lbl, w_te_y, w_te_x, w_te_lbl; gc.collect()

shuf = np.random.permutation(len(train_labels))
train_y, train_x, train_labels = train_y[shuf], train_x[shuf], train_labels[shuf]; del shuf
shuf = np.random.permutation(len(test_labels))
test_y, test_x, test_labels = test_y[shuf], test_x[shuf], test_labels[shuf]
del shuf; gc.collect()

print(f'Train: {len(train_labels):,}  classes: {sorted(np.unique(train_labels).tolist())}')
print(f'Test:  {len(test_labels):,}   classes: {sorted(np.unique(test_labels).tolist())}')

In [None]:
# CELL 7: Extract embeddings — del raw buffers between train/test to prevent OOM
print('\n' + '=' * 70)
print('EXTRACTING EMBEDDINGS')
print('=' * 70)

def extract_embeddings(tile_files, y_indices, x_indices, desc):
    n = len(y_indices)
    X = np.zeros((n, 64), dtype=np.float32)
    found = np.zeros(n, dtype=bool)
    with tqdm(total=len(tile_files), desc=desc, unit=' tiles') as pbar:
        for tf in tile_files:
            try:
                with rasterio.open(tf) as src:
                    if src.count != 64: pbar.update(1); continue
                    parts = tf.stem.split('-')
                    try: r_off = int(parts[-2]); c_off = int(parts[-1])
                    except: pbar.update(1); continue
                    th, tw = src.height, src.width
                    mask = ((y_indices>=r_off)&(y_indices<r_off+th)&
                            (x_indices>=c_off)&(x_indices<c_off+tw))
                    if mask.any():
                        tile_data = src.read()
                        if tile_data.shape[0] != 64: pbar.update(1); continue
                        ly = y_indices[mask] - r_off
                        lx = x_indices[mask] - c_off
                        vals = tile_data[:, ly, lx].T
                        del tile_data
                        valid = ~np.isnan(vals).any(axis=1)
                        g_idx = np.where(mask)[0]
                        X[g_idx[valid]] = vals[valid]
                        found[g_idx[valid]] = True
                        del vals
            except Exception as e:
                print(f'\nError {tf.name}: {e}')
            pbar.update(1)
            pbar.set_postfix({'found': f'{found.sum():,}/{n:,}'})
    print(f'  Extracted {found.sum():,} / {n:,}')
    return X, found

print('\n-- TRAIN --')
X_train_raw, train_found = extract_embeddings(all_tile_files, train_y, train_x, 'Train tiles')
X_train = X_train_raw[train_found]
y_train = train_labels[train_found]
del X_train_raw, train_found, train_y, train_x, train_labels; gc.collect()
print(f'X_train: {X_train.shape}  classes: {sorted(np.unique(y_train).tolist())}')

print('\n-- TEST --')
X_test_raw, test_found = extract_embeddings(all_tile_files, test_y, test_x, 'Test tiles')
X_test = X_test_raw[test_found]
y_test = test_labels[test_found]
del X_test_raw, test_found, test_y, test_x, test_labels; gc.collect()
print(f'X_test: {X_test.shape}  classes: {sorted(np.unique(y_test).tolist())}')

print('\nExtraction complete!')

In [None]:
# CELL 8: Compute class weights and save
print('\n' + '=' * 70)
print('SAVING DATASET')
print('=' * 70)

unique_cls, counts = np.unique(y_train, return_counts=True)
class_weights = np.zeros(6, dtype=np.float32)
for cls, cnt in zip(unique_cls, counts):
    class_weights[cls] = 1.0 / cnt
class_weights = class_weights / class_weights.sum() * 6

print('Class weights (from train):')
for cls in range(6):
    print(f'  Class {cls}: {class_weights[cls]:.4f}')

np.savez_compressed(
    output_file,
    X_train=X_train, y_train=y_train,
    X_test=X_test,   y_test=y_test,
    class_weights=class_weights,
    test_col_max=np.array(TEST_COL_MAX, dtype=np.int64),
)
print(f'\nSaved: {output_file}')
print(f'  X_train: {X_train.shape}  |  X_test: {X_test.shape}')
print('\nNext steps:')
print('  1. Download wetland_dataset_smart_split.npz from Drive')
print('  2. Place in repo root')
print('  3. Run: python random_forest_spatial/model_rf_spatial.py')

In [None]:
# CELL 9: Verification
d = np.load(output_file)
print(f'Arrays: {list(d.keys())}')
print(f'X_train: {d["X_train"].shape}  NaN={np.isnan(d["X_train"]).any()}')
print(f'X_test:  {d["X_test"].shape}   NaN={np.isnan(d["X_test"]).any()}')
print(f'y_train classes: {sorted(np.unique(d["y_train"]).tolist())}')
print(f'y_test  classes: {sorted(np.unique(d["y_test"]).tolist())}')
d.close()
print('Verification passed!')