# Wetland Training Dataset — Middle Row Band Split

**Output:** `wetland_dataset_middle_split.npz`

## Split Strategy

Holds out a **horizontal strip from the middle of the map** (~rows 40%–60%) as the test region.
Training uses tiles from both the **northern** and **southern** portions of the map.

This eliminates the east/west domain shift problem from the column-based split because:
- The model sees landscape features from **both sides** of the held-out region
- All 6 wetland classes appear in the test band (Classes 1 & 2 row ranges overlap the middle strip)
- No random within-zone fallback needed — **purely geographic split for all classes**

| Class | Total pixels | Row range | In test band? |
|-------|------------|----------|---------------|
| 0 | 628M | 0–20,606 | ✅ |
| 1 | 19,225 | 764–15,197 | ✅ |
| 2 | 901,620 | 45–15,175 | ✅ |
| 3 | 14.6M | 1–20,606 | ✅ |
| 4 | 2.3M | 344–20,221 | ✅ |
| 5 | 1.5M | 3–19,112 | ✅ |

In [None]:
# CELL 1: Setup
import os, gc, shutil
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')
else:
    print('Drive already mounted')

!pip install -q rasterio tqdm
import numpy as np
import rasterio
from pathlib import Path
from tqdm import tqdm

DRIVE_TILES = '/content/drive/MyDrive/EarthEngine'
LOCAL_TILES  = '/content/EarthEngine'
free_gb = shutil.disk_usage('/content').free / (1024**3)
print(f'Free disk space: {free_gb:.1f} GB')

if free_gb > 110 and not os.path.exists(LOCAL_TILES):
    print('Copying tiles to local disk...')
    !cp -r {DRIVE_TILES} {LOCAL_TILES}
    TILES_PATH = LOCAL_TILES
elif os.path.exists(LOCAL_TILES) and len(list(Path(LOCAL_TILES).glob('*.tif'))) > 0:
    TILES_PATH = LOCAL_TILES
    print('Local tiles already present.')
else:
    TILES_PATH = DRIVE_TILES
    print(f'Reading from Drive ({free_gb:.1f} GB free). Extraction ~60-90 min.')

print(f'Tile source: {TILES_PATH}')
print('Setup complete!')

In [None]:
# CELL 2: Configuration
print('=' * 70)
print('CONFIGURATION')
print('=' * 70)

labels_file    = '/content/drive/MyDrive/bow_river_wetlands_10m_final.tif'
embeddings_dir = Path(TILES_PATH)
output_file    = '/content/drive/MyDrive/wetland_dataset_middle_split.npz'

# ─── MIDDLE ROW BAND SPLIT ──────────────────────────────────────────
# Test = tiles whose row_offset falls in the middle [40%, 60%] of the
# total tile row range. Gives a ~20% test band that overlaps Class 1
# (rows 764-15197) and Class 2 (rows 45-15175), so both appear in
# training (north + south) AND testing (middle).
TEST_BAND_FRAC_LOW  = 0.40   # lower bound of test band (fraction of total row range)
TEST_BAND_FRAC_HIGH = 0.60   # upper bound of test band
# ────────────────────────────────────────────────────────────────────

# Per-class sample budgets
train_samples_per_class = {
    0: 600_000,
    1: 14_418,    # ~75% of all 19,225 Class 1 pixels (from north + south)
    2: 150_000,
    3: 500_000,
    4: 150_000,
    5: 100_000,
}
test_samples_per_class = {
    0: 150_000,
    1: 4_806,     # remaining ~25% of Class 1
    2: 37_500,
    3: 125_000,
    4: 37_500,
    5: 25_000,
}

assert os.path.exists(labels_file), f'Labels TIF not found: {labels_file}'
assert embeddings_dir.exists(),     f'Embeddings dir not found: {embeddings_dir}'
tile_count = len(list(embeddings_dir.glob('*.tif')))
assert tile_count > 0, f'No .tif files in {embeddings_dir}'

print(f'Labels:     {labels_file}')
print(f'Embeddings: {embeddings_dir}  ({tile_count} tiles)')
print(f'Output:     {output_file}')
print(f'Test band:  rows {TEST_BAND_FRAC_LOW*100:.0f}%–{TEST_BAND_FRAC_HIGH*100:.0f}% of tile row range')
print(f'Train:      north + south of that band')
print(f'Train target: {sum(train_samples_per_class.values()):,}')
print(f'Test target:  {sum(test_samples_per_class.values()):,}')
print('Configuration validated!')

In [None]:
# CELL 3: Discover tiles and split into middle band (test) vs north+south (train)
print('=' * 70)
print('DISCOVERING TILES — MIDDLE ROW BAND SPLIT')
print('=' * 70)

all_tile_files = sorted(embeddings_dir.glob('*.tif'))
tile_info = []
for tf in all_tile_files:
    parts = tf.stem.split('-')
    if len(parts) >= 3:
        try: tile_info.append((int(parts[-2]), int(parts[-1]), tf))
        except ValueError: pass

if not tile_info:
    raise RuntimeError('No parseable tiles — check tile naming (*-ROW-COL.tif)')

# Determine test band row bounds from actual tile offsets
all_row_offsets = sorted(set(r for r, c, p in tile_info))
row_min = all_row_offsets[0]
row_max = all_row_offsets[-1]
total_row_range = row_max - row_min

TEST_ROW_MIN = row_min + int(total_row_range * TEST_BAND_FRAC_LOW)
TEST_ROW_MAX = row_min + int(total_row_range * TEST_BAND_FRAC_HIGH)

# Snap to nearest actual tile row offsets
TEST_ROW_MIN = min(all_row_offsets, key=lambda r: abs(r - TEST_ROW_MIN))
TEST_ROW_MAX = min(all_row_offsets, key=lambda r: abs(r - TEST_ROW_MAX))

test_tiles  = [p for r, c, p in tile_info if TEST_ROW_MIN <= r <= TEST_ROW_MAX]
train_tiles = [p for r, c, p in tile_info if r < TEST_ROW_MIN or r > TEST_ROW_MAX]

print(f'Total tiles:      {len(tile_info)}')
print(f'Row offset range: {row_min} — {row_max}')
print(f'Test band:        rows {TEST_ROW_MIN} — {TEST_ROW_MAX} '
      f'({(TEST_ROW_MAX - TEST_ROW_MIN) / total_row_range * 100:.1f}% of range)')
print(f'Train tiles (north + south): {len(train_tiles)}')
print(f'Test  tiles (middle):        {len(test_tiles)}')

if not test_tiles:
    raise RuntimeError('No test tiles found. Adjust TEST_BAND_FRAC_LOW/HIGH.')
if not train_tiles:
    raise RuntimeError('No train tiles found.')

print('\nMiddle band split defined!')

In [None]:
# CELL 4: Sample pixel coordinates from each split
# All 6 classes are sampled geographically — no special cases needed.
print('=' * 70)
print('SAMPLING PIXEL COORDINATES')
print('=' * 70)

np.random.seed(42)

def sample_coords_from_tiles(tile_paths, samples_per_class, split_name):
    if not tile_paths:
        raise RuntimeError(f'No tiles provided for {split_name}.')
    tile_info_local = []
    for tf in tile_paths:
        parts = tf.stem.split('-')
        try:
            r = int(parts[-2]); c = int(parts[-1])
            with rasterio.open(tf) as s:
                tile_info_local.append((r, c, s.height, s.width))
        except Exception: pass
    if not tile_info_local:
        raise RuntimeError(f'{split_name}: could not open any tiles.')

    bbox_row_min = min(r   for r,c,h,w in tile_info_local)
    bbox_row_max = max(r+h for r,c,h,w in tile_info_local)
    bbox_col_min = min(c   for r,c,h,w in tile_info_local)
    bbox_col_max = max(c+w for r,c,h,w in tile_info_local)
    print(f'  {split_name} bbox: rows {bbox_row_min}–{bbox_row_max}, cols {bbox_col_min}–{bbox_col_max}')

    sampled   = {cls: {'y': [], 'x': []} for cls in samples_per_class}
    collected = {cls: 0 for cls in samples_per_class}

    with rasterio.open(labels_file) as src:
        windows = list(src.block_windows(1))
        np.random.shuffle(windows)
        for idx, (block_id, window) in tqdm(enumerate(windows), total=len(windows), desc=split_name):
            r0 = window.row_off; c0 = window.col_off
            rh = window.height;  cw = window.width
            if r0+rh <= bbox_row_min or r0 >= bbox_row_max: continue
            if c0+cw <= bbox_col_min or c0 >= bbox_col_max: continue
            chunk = src.read(1, window=window)
            for cls in samples_per_class:
                if collected[cls] >= samples_per_class[cls]: continue
                y_l, x_l = np.where(chunk == cls)
                if len(y_l) == 0: continue
                y_g = y_l + r0; x_g = x_l + c0
                in_b = (y_g>=bbox_row_min)&(y_g<bbox_row_max)&(x_g>=bbox_col_min)&(x_g<bbox_col_max)
                y_g = y_g[in_b]; x_g = x_g[in_b]
                if len(y_g) == 0: continue
                needed = samples_per_class[cls] - collected[cls]
                if len(y_g) > needed:
                    s = np.random.choice(len(y_g), needed, replace=False)
                    y_g = y_g[s]; x_g = x_g[s]
                sampled[cls]['y'].append(y_g)
                sampled[cls]['x'].append(x_g)
                collected[cls] += len(y_g)
            if all(collected[c] >= samples_per_class[c] for c in samples_per_class):
                print(f'  All collected after {idx+1} blocks'); break
    print(f'  {split_name} summary:')
    for cls in samples_per_class:
        print(f'    Class {cls}: {collected[cls]:,} / {samples_per_class[cls]:,}')
    return sampled

train_sampled = sample_coords_from_tiles(train_tiles, train_samples_per_class, 'TRAIN (north+south)')
test_sampled  = sample_coords_from_tiles(test_tiles,  test_samples_per_class,  'TEST  (middle band)')
print('\nCoordinate sampling complete!')

In [None]:
# CELL 5: Consolidate and free memory

def consolidate(sampled, budgets):
    all_y, all_x, all_lbl = [], [], []
    for cls in budgets:
        if not sampled[cls]['y']: continue
        ys = np.concatenate(sampled[cls]['y'])
        xs = np.concatenate(sampled[cls]['x'])
        if len(ys) > budgets[cls]: ys, xs = ys[:budgets[cls]], xs[:budgets[cls]]
        all_y.append(ys); all_x.append(xs)
        all_lbl.append(np.full(len(ys), cls, dtype=np.int64))
    return np.concatenate(all_y), np.concatenate(all_x), np.concatenate(all_lbl)

train_y, train_x, train_labels = consolidate(train_sampled, train_samples_per_class)
del train_sampled; gc.collect()
test_y, test_x, test_labels = consolidate(test_sampled, test_samples_per_class)
del test_sampled; gc.collect()

shuf = np.random.permutation(len(train_labels))
train_y, train_x, train_labels = train_y[shuf], train_x[shuf], train_labels[shuf]; del shuf
shuf = np.random.permutation(len(test_labels))
test_y, test_x, test_labels = test_y[shuf], test_x[shuf], test_labels[shuf]
del shuf; gc.collect()

print(f'Train: {len(train_labels):,}  classes: {sorted(np.unique(train_labels).tolist())}')
print(f'Test:  {len(test_labels):,}   classes: {sorted(np.unique(test_labels).tolist())}')

In [None]:
# CELL 6: Extract embeddings
print('\n' + '=' * 70)
print('EXTRACTING EMBEDDINGS')
print('=' * 70)

def extract_embeddings(tile_files, y_indices, x_indices, desc):
    n = len(y_indices)
    X = np.zeros((n, 64), dtype=np.float32)
    found = np.zeros(n, dtype=bool)
    with tqdm(total=len(tile_files), desc=desc, unit=' tiles') as pbar:
        for tf in tile_files:
            try:
                with rasterio.open(tf) as src:
                    if src.count != 64: pbar.update(1); continue
                    parts = tf.stem.split('-')
                    try: r_off = int(parts[-2]); c_off = int(parts[-1])
                    except: pbar.update(1); continue
                    th, tw = src.height, src.width
                    mask = ((y_indices>=r_off)&(y_indices<r_off+th)&
                            (x_indices>=c_off)&(x_indices<c_off+tw))
                    if mask.any():
                        tile_data = src.read()
                        if tile_data.shape[0] != 64: pbar.update(1); continue
                        ly = y_indices[mask] - r_off
                        lx = x_indices[mask] - c_off
                        vals = tile_data[:, ly, lx].T
                        del tile_data
                        valid = ~np.isnan(vals).any(axis=1)
                        g_idx = np.where(mask)[0]
                        X[g_idx[valid]] = vals[valid]
                        found[g_idx[valid]] = True
                        del vals
            except Exception as e:
                print(f'\nError {tf.name}: {e}')
            pbar.update(1)
            pbar.set_postfix({'found': f'{found.sum():,}/{n:,}'})
    print(f'  Extracted {found.sum():,} / {n:,}')
    return X, found

print('\n-- TRAIN --')
X_train_raw, train_found = extract_embeddings(all_tile_files, train_y, train_x, 'Train tiles')
X_train = X_train_raw[train_found]
y_train = train_labels[train_found]
del X_train_raw, train_found, train_y, train_x, train_labels; gc.collect()
print(f'X_train: {X_train.shape}  classes: {sorted(np.unique(y_train).tolist())}')

print('\n-- TEST --')
X_test_raw, test_found = extract_embeddings(all_tile_files, test_y, test_x, 'Test tiles')
X_test = X_test_raw[test_found]
y_test = test_labels[test_found]
del X_test_raw, test_found, test_y, test_x, test_labels; gc.collect()
print(f'X_test: {X_test.shape}  classes: {sorted(np.unique(y_test).tolist())}')

print('\nExtraction complete!')

In [None]:
# CELL 7: Compute class weights and save
print('\n' + '=' * 70)
print('SAVING DATASET')
print('=' * 70)

unique_cls, counts = np.unique(y_train, return_counts=True)
class_weights = np.zeros(6, dtype=np.float32)
for cls, cnt in zip(unique_cls, counts):
    class_weights[cls] = 1.0 / cnt
class_weights = class_weights / class_weights.sum() * 6

print('Class weights (from train):')
for cls in range(6):
    print(f'  Class {cls}: {class_weights[cls]:.4f}')

np.savez_compressed(
    output_file,
    X_train=X_train, y_train=y_train,
    X_test=X_test,   y_test=y_test,
    class_weights=class_weights,
    test_row_min=np.array(TEST_ROW_MIN, dtype=np.int64),
    test_row_max=np.array(TEST_ROW_MAX, dtype=np.int64),
)
print(f'\nSaved: {output_file}')
print(f'  X_train: {X_train.shape}  |  X_test: {X_test.shape}')
print(f'  test_row_min: {TEST_ROW_MIN}  test_row_max: {TEST_ROW_MAX}')

In [None]:
# CELL 8: Verification
d = np.load(output_file)
print(f'Arrays: {list(d.keys())}')
print(f'X_train: {d["X_train"].shape}  NaN={np.isnan(d["X_train"]).any()}')
print(f'X_test:  {d["X_test"].shape}   NaN={np.isnan(d["X_test"]).any()}')
print(f'y_train classes: {sorted(np.unique(d["y_train"]).tolist())}')
print(f'y_test  classes: {sorted(np.unique(d["y_test"]).tolist())}')
print(f'test_row_min: {int(d["test_row_min"])}  test_row_max: {int(d["test_row_max"])}')
d.close()
print('Verification passed!')