# Wetland Training Dataset — Smart Spatial Split

**Output:** `wetland_dataset_smart_split.npz`

## Split Strategy

A **mixed split** is used because Class 1 (only 19,225 pixels total) is entirely
confined to the western portion of the study area and cannot be geographically split:

| Class | Split type | Rationale |
|-------|----------|----------|
| 0, 2, 3, 4, 5 | **Geographic** — left/right column tile split | Pixels spread across entire map |
| **1** | **Random 75/25 within its zone** | All 19,225 pixels in cols 1000–6311; no geo split possible |

> **Documented limitation:** Class 1 train/test split is not geographically independent.
> This should be noted in the report.

In [None]:
# CELL 1: Setup
import os, gc
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')
else:
    print('Drive already mounted')

!pip install -q rasterio tqdm
import numpy as np
import rasterio
from pathlib import Path
from tqdm import tqdm

# Check disk space — tiles are ~1.2 GB each × 88 = ~106 GB total.
# Only attempt local copy if there is enough space.
import shutil
free_gb = shutil.disk_usage('/content').free / (1024**3)
print(f'Free disk space: {free_gb:.1f} GB')
DRIVE_TILES  = '/content/drive/MyDrive/EarthEngine'
LOCAL_TILES  = '/content/EarthEngine'

if free_gb > 110 and not os.path.exists(LOCAL_TILES):
    print('Enough space — copying tiles to local disk for faster I/O (~10 min)...')
    !cp -r {DRIVE_TILES} {LOCAL_TILES}
    TILES_PATH = LOCAL_TILES
    print('Copy done! Using local tiles.')
elif os.path.exists(LOCAL_TILES) and len(list(Path(LOCAL_TILES).glob('*.tif'))) > 0:
    TILES_PATH = LOCAL_TILES
    print(f'Local tiles already present — using {LOCAL_TILES}')
else:
    TILES_PATH = DRIVE_TILES
    print(f'Not enough disk space ({free_gb:.1f} GB) — reading tiles from Drive.')
    print('Extraction will take ~60-90 min (Drive I/O bound). This is expected.')

print(f'\nTile source: {TILES_PATH}')
print('Setup complete!')

In [None]:
# CELL 2: Configuration
print('=' * 70)
print('CONFIGURATION')
print('=' * 70)

labels_file    = '/content/drive/MyDrive/bow_river_wetlands_10m_final.tif'
embeddings_dir = Path(TILES_PATH)
output_file    = '/content/drive/MyDrive/wetland_dataset_smart_split.npz'

TEST_COL_MAX = 8192

CLASS1_TRAIN_FRACTION = 0.75
CLASS1_ROW_MIN, CLASS1_ROW_MAX = 764,  15197
CLASS1_COL_MIN, CLASS1_COL_MAX = 1000,  6311

train_samples_per_class = {
    0: 600_000,
    1: 19_225,
    2: 150_000,
    3: 500_000,
    4: 150_000,
    5: 100_000,
}
test_samples_per_class = {cls: max(1000, int(n * 0.25)) for cls, n in train_samples_per_class.items()}
train_samples_geo = {cls: n for cls, n in train_samples_per_class.items() if cls != 1}
test_samples_geo  = {cls: n for cls, n in test_samples_per_class.items()  if cls != 1}

assert os.path.exists(labels_file), f'Labels TIF not found: {labels_file}'
assert embeddings_dir.exists(),     f'Embeddings dir not found: {embeddings_dir}'
tile_count = len(list(embeddings_dir.glob('*.tif')))
assert tile_count > 0, f'No .tif tiles found in {embeddings_dir}'

print(f'Labels:        {labels_file}')
print(f'Embeddings:    {embeddings_dir}  ({tile_count} tiles found)')
print(f'Output:        {output_file}')
print(f'TEST_COL_MAX:  {TEST_COL_MAX}')
print(f'Train target (geo): {sum(train_samples_geo.values()):,}')
print(f'Test target  (geo): {sum(test_samples_geo.values()):,}')
print('Configuration validated!')

In [None]:
# CELL 3: Discover tiles and do column-based geographic split
print('=' * 70)
print('DISCOVERING TILES — COLUMN SPLIT')
print('=' * 70)

all_tile_files = sorted(embeddings_dir.glob('*.tif'))
print(f'Found {len(all_tile_files)} total tiles')

tile_info = []
for tf in all_tile_files:
    parts = tf.stem.split('-')
    if len(parts) >= 3:
        try:
            tile_info.append((int(parts[-2]), int(parts[-1]), tf))
        except ValueError:
            pass

print(f'Parseable tiles: {len(tile_info)}')
if len(tile_info) == 0:
    raise RuntimeError('No tiles could be parsed — check that tile filenames contain row/col offsets like *-RRRR-CCCC.tif')

test_tiles  = [p for r, c, p in tile_info if c < TEST_COL_MAX]
train_tiles = [p for r, c, p in tile_info if c >= TEST_COL_MAX]

print(f'Train tiles (eastern, col >= {TEST_COL_MAX}): {len(train_tiles)}')
print(f'Test  tiles (western, col <  {TEST_COL_MAX}): {len(test_tiles)}')
print(f'Test fraction: {len(test_tiles)/len(tile_info)*100:.1f}% of all tiles')

if not test_tiles:
    raise RuntimeError('No test tiles found — check TEST_COL_MAX')
print('\nSpatial split defined!')

In [None]:
# CELL 4: Sample Class 1 separately (random 75/25 within its bounding zone)
print('=' * 70)
print('SAMPLING CLASS 1 (RANDOM SPLIT WITHIN ZONE)')
print('=' * 70)

np.random.seed(42)
class1_y_all, class1_x_all = [], []
with rasterio.open(labels_file) as src:
    for block_id, window in tqdm(list(src.block_windows(1)), desc='Scanning for Class 1'):
        r0 = window.row_off; c0 = window.col_off
        if r0 + window.height <= CLASS1_ROW_MIN or r0 > CLASS1_ROW_MAX: continue
        if c0 + window.width  <= CLASS1_COL_MIN or c0 > CLASS1_COL_MAX: continue
        chunk = src.read(1, window=window)
        y_loc, x_loc = np.where(chunk == 1)
        if len(y_loc) == 0: continue
        class1_y_all.append(y_loc + r0)
        class1_x_all.append(x_loc + c0)

class1_y = np.concatenate(class1_y_all)
class1_x = np.concatenate(class1_x_all)
del class1_y_all, class1_x_all; gc.collect()

shuf = np.random.permutation(len(class1_y))
class1_y = class1_y[shuf]; class1_x = class1_x[shuf]; del shuf

n_train1 = int(len(class1_y) * CLASS1_TRAIN_FRACTION)
class1_train_y = class1_y[:n_train1].copy()
class1_train_x = class1_x[:n_train1].copy()
class1_test_y  = class1_y[n_train1:].copy()
class1_test_x  = class1_x[n_train1:].copy()
del class1_y, class1_x; gc.collect()

print(f'Total Class 1 pixels: {n_train1 + len(class1_test_y):,}')
print(f'  Train: {len(class1_train_y):,}  |  Test: {len(class1_test_y):,}')
print('\n⚠ NOTE: Class 1 split is NOT geographically independent (documented limitation).')

In [None]:
# CELL 5: Sample geographic classes (0, 2, 3, 4, 5) using tile-based column split
print('=' * 70)
print('SAMPLING GEOGRAPHIC CLASSES (0, 2, 3, 4, 5)')
print('=' * 70)

def sample_coords_from_tiles(tile_paths, samples_per_class, split_name):
    if not tile_paths:
        raise RuntimeError(f'No tiles provided for {split_name}. Check tile discovery in Cell 3.')
    tile_info_local = []
    for tf in tile_paths:
        parts = tf.stem.split('-')
        try:
            r = int(parts[-2]); c = int(parts[-1])
            with rasterio.open(tf) as s:
                tile_info_local.append((r, c, s.height, s.width))
        except Exception:
            pass
    if not tile_info_local:
        raise RuntimeError(f'{split_name}: could not open any tiles. Check file paths and naming.')

    bbox_row_min = min(r     for r,c,h,w in tile_info_local)
    bbox_row_max = max(r+h   for r,c,h,w in tile_info_local)
    bbox_col_min = min(c     for r,c,h,w in tile_info_local)
    bbox_col_max = max(c+w   for r,c,h,w in tile_info_local)
    print(f'  {split_name} bbox: rows {bbox_row_min}–{bbox_row_max}, cols {bbox_col_min}–{bbox_col_max}')

    sampled   = {cls: {'y': [], 'x': []} for cls in samples_per_class}
    collected = {cls: 0 for cls in samples_per_class}

    with rasterio.open(labels_file) as src:
        windows = list(src.block_windows(1))
        np.random.shuffle(windows)
        for idx, (block_id, window) in tqdm(enumerate(windows), total=len(windows), desc=f'{split_name}'):
            r0 = window.row_off; c0 = window.col_off
            rh = window.height;  cw = window.width
            if r0+rh <= bbox_row_min or r0 >= bbox_row_max: continue
            if c0+cw <= bbox_col_min or c0 >= bbox_col_max: continue
            chunk = src.read(1, window=window)
            for cls in samples_per_class:
                if collected[cls] >= samples_per_class[cls]: continue
                y_l, x_l = np.where(chunk == cls)
                if len(y_l) == 0: continue
                y_g = y_l + r0; x_g = x_l + c0
                in_b = (y_g>=bbox_row_min)&(y_g<bbox_row_max)&(x_g>=bbox_col_min)&(x_g<bbox_col_max)
                y_g = y_g[in_b]; x_g = x_g[in_b]
                if len(y_g) == 0: continue
                needed = samples_per_class[cls] - collected[cls]
                if len(y_g) > needed:
                    s = np.random.choice(len(y_g), needed, replace=False)
                    y_g = y_g[s]; x_g = x_g[s]
                sampled[cls]['y'].append(y_g)
                sampled[cls]['x'].append(x_g)
                collected[cls] += len(y_g)
            if all(collected[c] >= samples_per_class[c] for c in samples_per_class):
                print(f'  All samples collected after {idx+1} blocks'); break
    print(f'  {split_name} summary:')
    for cls in samples_per_class:
        print(f'    Class {cls}: {collected[cls]:,} / {samples_per_class[cls]:,}')
    return sampled

train_sampled_geo = sample_coords_from_tiles(train_tiles, train_samples_geo, 'TRAIN (geo)')
test_sampled_geo  = sample_coords_from_tiles(test_tiles,  test_samples_geo,  'TEST (geo)')
print('\nGeographic class sampling complete!')

In [None]:
# CELL 6: Consolidate all coordinates; aggressively free intermediate memory

def consolidate_geo(sampled, budgets):
    all_y, all_x, all_lbl = [], [], []
    for cls in budgets:
        if not sampled[cls]['y']: continue
        ys = np.concatenate(sampled[cls]['y'])
        xs = np.concatenate(sampled[cls]['x'])
        if len(ys) > budgets[cls]: ys, xs = ys[:budgets[cls]], xs[:budgets[cls]]
        all_y.append(ys); all_x.append(xs)
        all_lbl.append(np.full(len(ys), cls, dtype=np.int64))
    return np.concatenate(all_y), np.concatenate(all_x), np.concatenate(all_lbl)

geo_train_y, geo_train_x, geo_train_lbl = consolidate_geo(train_sampled_geo, train_samples_geo)
del train_sampled_geo; gc.collect()

geo_test_y, geo_test_x, geo_test_lbl = consolidate_geo(test_sampled_geo, test_samples_geo)
del test_sampled_geo; gc.collect()

train_y = np.concatenate([geo_train_y, class1_train_y])
train_x = np.concatenate([geo_train_x, class1_train_x])
train_labels = np.concatenate([geo_train_lbl, np.ones(len(class1_train_y), dtype=np.int64)])
del geo_train_y, geo_train_x, geo_train_lbl, class1_train_y, class1_train_x; gc.collect()

test_y = np.concatenate([geo_test_y,  class1_test_y])
test_x = np.concatenate([geo_test_x,  class1_test_x])
test_labels = np.concatenate([geo_test_lbl, np.ones(len(class1_test_y), dtype=np.int64)])
del geo_test_y, geo_test_x, geo_test_lbl, class1_test_y, class1_test_x; gc.collect()

shuf = np.random.permutation(len(train_labels))
train_y, train_x, train_labels = train_y[shuf], train_x[shuf], train_labels[shuf]; del shuf
shuf = np.random.permutation(len(test_labels))
test_y, test_x, test_labels = test_y[shuf], test_x[shuf], test_labels[shuf]
del shuf; gc.collect()

print(f'Train coordinates: {len(train_labels):,}  classes: {np.unique(train_labels)}')
print(f'Test  coordinates: {len(test_labels):,}   classes: {np.unique(test_labels)}')

In [None]:
# CELL 7: Extract embeddings
# Tile reads peak at ~1.2 GB each. Persistent arrays peak at ~450 MB.
# del statements prevent OOM between train and test extraction.
print('\n' + '=' * 70)
print('EXTRACTING EMBEDDINGS')
print('=' * 70)

def extract_embeddings(tile_files, y_indices, x_indices, desc):
    n = len(y_indices)
    X = np.zeros((n, 64), dtype=np.float32)
    found = np.zeros(n, dtype=bool)
    with tqdm(total=len(tile_files), desc=desc, unit=' tiles') as pbar:
        for tf in tile_files:
            try:
                with rasterio.open(tf) as src:
                    if src.count != 64: pbar.update(1); continue
                    parts = tf.stem.split('-')
                    try: r_off = int(parts[-2]); c_off = int(parts[-1])
                    except: pbar.update(1); continue
                    th, tw = src.height, src.width
                    mask = ((y_indices>=r_off)&(y_indices<r_off+th)&
                            (x_indices>=c_off)&(x_indices<c_off+tw))
                    if mask.any():
                        tile_data = src.read()
                        if tile_data.shape[0] != 64: pbar.update(1); continue
                        ly = y_indices[mask] - r_off
                        lx = x_indices[mask] - c_off
                        vals = tile_data[:, ly, lx].T
                        del tile_data  # free 1.2 GB tile immediately
                        valid = ~np.isnan(vals).any(axis=1)
                        g_idx = np.where(mask)[0]
                        X[g_idx[valid]] = vals[valid]
                        found[g_idx[valid]] = True
                        del vals
            except Exception as e:
                print(f'\nError {tf.name}: {e}')
            pbar.update(1)
            pbar.set_postfix({'found': f'{found.sum():,}/{n:,}'})
    print(f'  Extracted {found.sum():,} / {n:,}')
    return X, found


print('\n-- TRAIN --')
X_train_raw, train_found = extract_embeddings(all_tile_files, train_y, train_x, 'Train tiles')
X_train = X_train_raw[train_found]
y_train = train_labels[train_found]
del X_train_raw, train_found, train_y, train_x, train_labels; gc.collect()  # free ~360 MB before test
print(f'X_train: {X_train.shape}')

print('\n-- TEST --')
X_test_raw, test_found = extract_embeddings(all_tile_files, test_y, test_x, 'Test tiles')
X_test = X_test_raw[test_found]
y_test = test_labels[test_found]
del X_test_raw, test_found, test_y, test_x, test_labels; gc.collect()  # free before save
print(f'X_test: {X_test.shape}')

print('\nExtraction complete!')

In [None]:
# CELL 8: Compute class weights and save to Drive
print('\n' + '=' * 70)
print('SAVING DATASET')
print('=' * 70)

unique_cls, counts = np.unique(y_train, return_counts=True)
class_weights = np.zeros(6, dtype=np.float32)
for cls, cnt in zip(unique_cls, counts):
    class_weights[cls] = 1.0 / cnt
class_weights = class_weights / class_weights.sum() * 6

print('Class weights (from train):')
for cls in range(6):
    print(f'  Class {cls}: {class_weights[cls]:.4f}')

np.savez_compressed(
    output_file,
    X_train=X_train, y_train=y_train,
    X_test=X_test,   y_test=y_test,
    class_weights=class_weights,
    test_col_max=np.array(TEST_COL_MAX, dtype=np.int64),
)
print(f'\nSaved: {output_file}')
print(f'  X_train: {X_train.shape}  |  X_test: {X_test.shape}')

In [None]:
# CELL 9: Verification
d = np.load(output_file)
print(f'Arrays: {list(d.keys())}')
print(f'X_train: {d["X_train"].shape}  NaN={np.isnan(d["X_train"]).any()}')
print(f'X_test:  {d["X_test"].shape}   NaN={np.isnan(d["X_test"]).any()}')
print(f'y_train classes: {np.unique(d["y_train"])}')
print(f'y_test  classes: {np.unique(d["y_test"])}')
d.close()
print('Verification passed!')