# Wetland Training Dataset — Smart Spatial Split

**Output:** `wetland_dataset_smart_split.npz`

## Split Strategy

A **mixed split** is used because Class 1 (only 19,225 pixels total) is entirely
confined to the western portion of the study area and cannot be geographically split:

| Class | Split type | Rationale |
|-------|----------|----------|
| 0, 2, 3, 4, 5 | **Geographic** — left/right column tile split | Pixels spread across entire map |
| **1** | **Random 75/25 within its zone** | All 19,225 pixels in cols 1000–6311; no geo split possible |

The geographic test region is tiles with `col_offset < TEST_COL_MAX` (~left 22% of map),
which is where Class 2 is also concentrated, ensuring all classes appear in the test set.

> **Documented limitation:** Because Class 1 pixels are spatially confined, its
> train/test split is not geographically independent. This should be noted in the report.

In [None]:
# CELL 1: Setup
print('Setting up environment...')
import os
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')
else:
    print('Drive already mounted')
!pip install -q rasterio tqdm
import numpy as np
import rasterio
from pathlib import Path
from tqdm import tqdm
print('Setup complete!')

In [None]:
# CELL 2: Configuration
print('=' * 70)
print('CONFIGURATION')
print('=' * 70)

labels_file    = '/content/drive/MyDrive/bow_river_wetlands_10m_final.tif'
embeddings_dir = Path('/content/drive/MyDrive/EarthEngine')
output_file    = '/content/drive/MyDrive/wetland_dataset_smart_split.npz'

# ─── GEOGRAPHIC SPLIT (Classes 0, 2, 3, 4, 5) ────────────────────────
# Hold out tiles with col_offset < TEST_COL_MAX as the test region.
# 8192 comfortably encloses Class 2 (max col 7037) with buffer.
TEST_COL_MAX = 8192

# ─── CLASS 1 RANDOM SPLIT ─────────────────────────────────────────────
# Class 1 is entirely in cols 1000–6311 (western zone).
# We do a random 75/25 split of all its pixels.
CLASS1_TRAIN_FRACTION = 0.75

# Known bounding box of Class 1 pixels (from class distribution analysis)
CLASS1_ROW_MIN, CLASS1_ROW_MAX = 764,  15197
CLASS1_COL_MIN, CLASS1_COL_MAX = 1000,  6311
# ─────────────────────────────────────────────────────────────────────

# Per-class sample budgets
train_samples_per_class = {
    0: 600_000,
    1: 19_225,   # all available — split handled separately
    2: 150_000,
    3: 500_000,
    4: 150_000,
    5: 100_000,
}
test_samples_per_class = {cls: max(1000, int(n * 0.25)) for cls, n in train_samples_per_class.items()}

# Geographic budgets exclude Class 1 (handled separately)
train_samples_geo = {cls: n for cls, n in train_samples_per_class.items() if cls != 1}
test_samples_geo  = {cls: n for cls, n in test_samples_per_class.items()  if cls != 1}

assert os.path.exists(labels_file),  'Labels TIF not found'
assert embeddings_dir.exists(),       'Embeddings dir not found'

print(f'Labels:        {labels_file}')
print(f'Embeddings:    {embeddings_dir}')
print(f'Output:        {output_file}')
print(f'TEST_COL_MAX:  {TEST_COL_MAX}  (geo test = tiles with col < this value)')
print(f'Class 1 split: random {CLASS1_TRAIN_FRACTION*100:.0f}/{(1-CLASS1_TRAIN_FRACTION)*100:.0f} within zone')
print(f'\nTrain target (geo classes): {sum(train_samples_geo.values()):,}')
print(f'Test target  (geo classes): {sum(test_samples_geo.values()):,}')
print('Configuration validated!')

In [None]:
# CELL 3: Discover tiles and do column-based geographic split
print('=' * 70)
print('DISCOVERING TILES — COLUMN SPLIT')
print('=' * 70)

all_tile_files = sorted(embeddings_dir.glob('*.tif'))
print(f'Found {len(all_tile_files)} total tiles')

tile_info = []
for tf in all_tile_files:
    parts = tf.stem.split('-')
    if len(parts) >= 3:
        try:
            tile_info.append((int(parts[-2]), int(parts[-1]), tf))
        except ValueError:
            pass

print(f'Parseable tiles: {len(tile_info)}')

# Geographic split: LEFT = test (contains Classes 1 and 2), RIGHT = train
test_tiles  = [p for r, c, p in tile_info if c < TEST_COL_MAX]
train_tiles = [p for r, c, p in tile_info if c >= TEST_COL_MAX]

print(f'Train tiles (eastern, col >= {TEST_COL_MAX}): {len(train_tiles)}')
print(f'Test  tiles (western, col <  {TEST_COL_MAX}): {len(test_tiles)}')
print(f'Test fraction: {len(test_tiles)/len(tile_info)*100:.1f}% of all tiles')

if not test_tiles:
    raise RuntimeError('No test tiles found — check TEST_COL_MAX or tile naming.')

print('\nSpatial split defined!')

In [None]:
# CELL 4: Sample Class 1 separately (random 75/25 within its bounding zone)
print('=' * 70)
print('SAMPLING CLASS 1 (RANDOM SPLIT WITHIN ZONE)')
print('=' * 70)

np.random.seed(42)

class1_y_all, class1_x_all = [], []

with rasterio.open(labels_file) as src:
    windows = list(src.block_windows(1))
    for block_id, window in tqdm(windows, desc='Scanning for Class 1'):
        row_off = window.row_off
        col_off = window.col_off
        win_h   = window.height
        win_w   = window.width

        # Skip blocks outside Class 1's known bounding box
        if row_off + win_h <= CLASS1_ROW_MIN or row_off > CLASS1_ROW_MAX:
            continue
        if col_off + win_w <= CLASS1_COL_MIN or col_off > CLASS1_COL_MAX:
            continue

        chunk = src.read(1, window=window)
        y_loc, x_loc = np.where(chunk == 1)
        if len(y_loc) == 0:
            continue

        class1_y_all.append(y_loc + row_off)
        class1_x_all.append(x_loc + col_off)

class1_y = np.concatenate(class1_y_all)
class1_x = np.concatenate(class1_x_all)
print(f'\nTotal Class 1 pixels found: {len(class1_y):,}')

# Random shuffle then 75/25 split
shuf = np.random.permutation(len(class1_y))
class1_y = class1_y[shuf]
class1_x = class1_x[shuf]

n_train1 = int(len(class1_y) * CLASS1_TRAIN_FRACTION)
class1_train_y, class1_train_x = class1_y[:n_train1],  class1_x[:n_train1]
class1_test_y,  class1_test_x  = class1_y[n_train1:],  class1_x[n_train1:]

print(f'Class 1 train: {len(class1_train_y):,} pixels (random {CLASS1_TRAIN_FRACTION*100:.0f}%)')
print(f'Class 1 test:  {len(class1_test_y):,} pixels (random {(1-CLASS1_TRAIN_FRACTION)*100:.0f}%)')
print('\n⚠ NOTE: Class 1 split is NOT geographically independent.')
print('  This is documented as a known limitation (Class 1 is spatially confined).')

In [None]:
# CELL 5: Sample geographic classes (0, 2, 3, 4, 5) using tile-based column split
print('=' * 70)
print('SAMPLING GEOGRAPHIC CLASSES (0, 2, 3, 4, 5)')
print('=' * 70)

def sample_coords_from_tiles(tile_paths, samples_per_class, split_name):
    """Sample pixel coordinates only within the bounding box of the given tiles."""
    tile_info_local = []
    for tf in tile_paths:
        parts = tf.stem.split('-')
        try:
            r = int(parts[-2]); c = int(parts[-1])
            with rasterio.open(tf) as s:
                tile_info_local.append((r, c, s.height, s.width))
        except Exception:
            pass

    bbox_row_min = min(r for r,c,h,w in tile_info_local)
    bbox_row_max = max(r+h for r,c,h,w in tile_info_local)
    bbox_col_min = min(c for r,c,h,w in tile_info_local)
    bbox_col_max = max(c+w for r,c,h,w in tile_info_local)
    print(f'\n  {split_name} bbox: rows {bbox_row_min}–{bbox_row_max}, cols {bbox_col_min}–{bbox_col_max}')

    sampled   = {cls: {'y': [], 'x': []} for cls in samples_per_class}
    collected = {cls: 0 for cls in samples_per_class}

    with rasterio.open(labels_file) as src:
        windows = list(src.block_windows(1))
        np.random.shuffle(windows)
        for idx, (block_id, window) in tqdm(enumerate(windows), total=len(windows), desc=f'{split_name} blocks'):
            row_off = window.row_off; col_off = window.col_off
            win_h   = window.height;  win_w   = window.width
            if row_off+win_h <= bbox_row_min or row_off >= bbox_row_max: continue
            if col_off+win_w <= bbox_col_min or col_off >= bbox_col_max: continue

            chunk = src.read(1, window=window)
            for cls in samples_per_class:
                if collected[cls] >= samples_per_class[cls]: continue
                y_l, x_l = np.where(chunk == cls)
                if len(y_l) == 0: continue
                y_g = y_l + row_off; x_g = x_l + col_off
                in_b = ((y_g>=bbox_row_min)&(y_g<bbox_row_max)&
                        (x_g>=bbox_col_min)&(x_g<bbox_col_max))
                y_g = y_g[in_b]; x_g = x_g[in_b]
                if len(y_g) == 0: continue
                needed = samples_per_class[cls] - collected[cls]
                if len(y_g) > needed:
                    idx_s = np.random.choice(len(y_g), needed, replace=False)
                    y_g = y_g[idx_s]; x_g = x_g[idx_s]
                sampled[cls]['y'].append(y_g)
                sampled[cls]['x'].append(x_g)
                collected[cls] += len(y_g)
            if all(collected[c] >= samples_per_class[c] for c in samples_per_class):
                print(f'\n  Got all {split_name} samples after {idx+1} blocks'); break

    print(f'  {split_name} summary:')
    for cls in samples_per_class:
        print(f'    Class {cls}: {collected[cls]:,} / {samples_per_class[cls]:,}')
    return sampled, collected


train_sampled_geo, _ = sample_coords_from_tiles(train_tiles, train_samples_geo, 'TRAIN (geo)')
test_sampled_geo,  _ = sample_coords_from_tiles(test_tiles,  test_samples_geo,  'TEST (geo)')
print('\nGeographic class sampling complete!')

In [None]:
# CELL 6: Consolidate all coordinates into flat arrays

def consolidate_geo(sampled, budgets):
    all_y, all_x, all_labels = [], [], []
    for cls in budgets:
        if not sampled[cls]['y']: continue
        ys = np.concatenate(sampled[cls]['y'])
        xs = np.concatenate(sampled[cls]['x'])
        if len(ys) > budgets[cls]: ys, xs = ys[:budgets[cls]], xs[:budgets[cls]]
        all_y.append(ys); all_x.append(xs)
        all_labels.append(np.full(len(ys), cls))
    return np.concatenate(all_y), np.concatenate(all_x), np.concatenate(all_labels)

# Geo classes
geo_train_y, geo_train_x, geo_train_lbl = consolidate_geo(train_sampled_geo, train_samples_geo)
geo_test_y,  geo_test_x,  geo_test_lbl  = consolidate_geo(test_sampled_geo,  test_samples_geo)

# Merge with Class 1 coords
train_y = np.concatenate([geo_train_y, class1_train_y])
train_x = np.concatenate([geo_train_x, class1_train_x])
train_labels = np.concatenate([geo_train_lbl, np.ones(len(class1_train_y), dtype=np.int64)])

test_y = np.concatenate([geo_test_y,  class1_test_y])
test_x = np.concatenate([geo_test_x,  class1_test_x])
test_labels = np.concatenate([geo_test_lbl, np.ones(len(class1_test_y), dtype=np.int64)])

# Shuffle
for arr_set in [(train_y, train_x, train_labels), (test_y, test_x, test_labels)]:
    pass  # kept separate for shuffling below

shuf_tr = np.random.permutation(len(train_labels))
train_y, train_x, train_labels = train_y[shuf_tr], train_x[shuf_tr], train_labels[shuf_tr]

shuf_te = np.random.permutation(len(test_labels))
test_y,  test_x,  test_labels  = test_y[shuf_te],  test_x[shuf_te],  test_labels[shuf_te]

print(f'Train coordinates: {len(train_labels):,}')
print(f'  Classes in train: {np.unique(train_labels)}')
print(f'Test  coordinates: {len(test_labels):,}')
print(f'  Classes in test:  {np.unique(test_labels)}')

In [None]:
# CELL 7: Extract embeddings
print('\n' + '=' * 70)
print('EXTRACTING EMBEDDINGS')
print('=' * 70)

def extract_embeddings(tile_files, y_indices, x_indices, desc):
    n = len(y_indices)
    X = np.zeros((n, 64), dtype=np.float32)
    found = np.zeros(n, dtype=bool)
    with tqdm(total=len(tile_files), desc=desc, unit=' tiles') as pbar:
        for tf in tile_files:
            try:
                with rasterio.open(tf) as src:
                    if src.count != 64: pbar.update(1); continue
                    parts = tf.stem.split('-')
                    try: r_off = int(parts[-2]); c_off = int(parts[-1])
                    except: pbar.update(1); continue
                    th, tw = src.height, src.width
                    mask = ((y_indices>=r_off)&(y_indices<r_off+th)&
                            (x_indices>=c_off)&(x_indices<c_off+tw))
                    if mask.any():
                        tile_data = src.read()
                        if tile_data.shape[0] != 64: pbar.update(1); continue
                        ly = y_indices[mask] - r_off
                        lx = x_indices[mask] - c_off
                        vals = tile_data[:, ly, lx].T
                        valid = ~np.isnan(vals).any(axis=1)
                        g_idx = np.where(mask)[0]
                        X[g_idx[valid]] = vals[valid]
                        found[g_idx[valid]] = True
            except Exception as e:
                print(f'\nError {tf.name}: {e}')
            pbar.update(1)
            pbar.set_postfix({'found': f'{found.sum():,}/{n:,}'})
    print(f'  Extracted {found.sum():,} / {n:,}')
    return X, found

# TRAIN: geo classes come from eastern (train) tiles;
#        Class 1 train pixels come from western (test) tiles.
# We simply pass all_tile_files so the function finds each pixel in its correct tile.
print('\n-- TRAIN --')
X_train_raw, train_found = extract_embeddings(all_tile_files, train_y, train_x, 'Train (all tiles)')
X_train = X_train_raw[train_found]
y_train = train_labels[train_found]

print('\n-- TEST --')
X_test_raw, test_found = extract_embeddings(all_tile_files, test_y, test_x, 'Test (all tiles)')
X_test = X_test_raw[test_found]
y_test = test_labels[test_found]

print(f'\nFinal train set: {X_train.shape}')
print(f'Final test set:  {X_test.shape}')
print('Extraction complete!')

In [None]:
# CELL 8: Compute class weights and save
print('\n' + '=' * 70)
print('SAVING DATASET')
print('=' * 70)

unique_cls, counts = np.unique(y_train, return_counts=True)
class_weights = np.zeros(6, dtype=np.float32)
for cls, cnt in zip(unique_cls, counts):
    class_weights[cls] = 1.0 / cnt
class_weights = class_weights / class_weights.sum() * 6

print('Class weights (from train):')
for cls in range(6):
    print(f'  Class {cls}: {class_weights[cls]:.4f}')

np.savez_compressed(
    output_file,
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    y_test=y_test,
    class_weights=class_weights,
    test_col_max=np.array(TEST_COL_MAX, dtype=np.int64),
)

print(f'\nSaved: {output_file}')
print(f'  X_train: {X_train.shape}  (eastern tiles + random Class 1)')
print(f'  X_test:  {X_test.shape}   (western tiles + random Class 1)')
print(f'  test_col_max: {TEST_COL_MAX}')
print('\nNext steps:')
print('  1. Download wetland_dataset_smart_split.npz from Google Drive')
print('  2. Place in repo root')
print('  3. Run: python random_forest_spatial/model_rf_spatial.py')

In [None]:
# CELL 9: Verification
print('\n' + '=' * 70)
print('VERIFICATION')
print('=' * 70)

d = np.load(output_file)
print(f'Arrays: {list(d.keys())}')
print(f'X_train: {d["X_train"].shape}  NaN={np.isnan(d["X_train"]).any()}')
print(f'X_test:  {d["X_test"].shape}   NaN={np.isnan(d["X_test"]).any()}')
print(f'y_train classes: {np.unique(d["y_train"])}')
print(f'y_test  classes: {np.unique(d["y_test"])}')
print(f'test_col_max:    {int(d["test_col_max"])}')
d.close()
print('\nVerification passed!')