# Wetland Mapping - Dataset Preparation
This notebook prepares the balanced dataset for wetland classification using Google Earth Engine embeddings.

## Setup Instructions:
1. Upload all your Google Earth TIF files to Kaggle dataset (55 files from `Google_Dataset/`)
2. Upload `bow_river_wetlands_10m_final.tif` (wetland labels)
3. Run this notebook to create the balanced 1.5M sample dataset
4. Download the output `wetland_dataset_1.5M.npz` for training

In [None]:
# Install dependencies
!pip install -q rasterio tqdm

In [None]:
# Clone your repository (gee_embed_CNN_dev branch)
!git clone -b gee_embed_CNN_dev https://github.com/Jcub05/Wetland-Mapping-ELEC498-Group-46.git
%cd Wetland-Mapping-ELEC498-Group-46

## Step 1: Organize Your Uploaded Files
Move your uploaded TIF files to the expected locations

In [None]:
import os
import shutil

# Kaggle datasets are typically mounted at /kaggle/input/
# Adjust these paths based on how you name your Kaggle dataset
KAGGLE_INPUT = '/kaggle/input/wetland-embeddings'  # Change to match your dataset name

# Create Google_Dataset directory if it doesn't exist
os.makedirs('Google_Dataset', exist_ok=True)

# Copy all embedding TIF files
print("Copying embedding tiles...")
for file in os.listdir(KAGGLE_INPUT):
    if file.endswith('.tif') and 'embeddings' in file:
        src = os.path.join(KAGGLE_INPUT, file)
        dst = os.path.join('Google_Dataset', file)
        shutil.copy(src, dst)
        print(f"  Copied {file}")

# Copy labels file
labels_src = os.path.join(KAGGLE_INPUT, 'bow_river_wetlands_10m_final.tif')
if os.path.exists(labels_src):
    shutil.copy(labels_src, 'bow_river_wetlands_10m_final.tif')
    print("âœ“ Copied labels file")
else:
    print("âš  Warning: Labels file not found. Make sure it's uploaded!")

## Step 2: Build VRT (Virtual Raster)
Combine all 55 tiled TIF files into a single virtual raster

In [None]:
# Run the VRT builder from your repo
!python build_vrt_and_verify.py

## Step 3: Load Balanced Dataset
Run the optimized dataloader to create ~1.5M balanced samples

In [None]:
import rasterio
import numpy as np
import torch
from tqdm import tqdm
from collections import defaultdict

# Load files
embeddings_file = "bow_river_embeddings_2020_matched.vrt"
labels_file = "bow_river_wetlands_10m_final.tif"

print("Loading labels...")
with rasterio.open(labels_file) as labels_src:
    labels_full = labels_src.read(1)
    print(f"Labels (original): {labels_full.shape}")

print(f"\nOpening embeddings VRT: {embeddings_file}")
embeddings_src = rasterio.open(embeddings_file)
print(f"Embeddings: {embeddings_src.count} bands x {embeddings_src.height} x {embeddings_src.width}")

# Crop labels to match embeddings
labels = labels_full[:embeddings_src.height, :embeddings_src.width]
print(f"Labels (cropped): {labels.shape}")

# Verify alignment
assert (embeddings_src.height, embeddings_src.width) == labels.shape, "Dimension mismatch!"
print("âœ“ Dimensions match!")

In [None]:
# Analyze class distribution
valid_mask = (labels >= 0) & (labels <= 5)
valid_count = valid_mask.sum()
print(f"\nTotal labeled pixels: {valid_count:,} out of {labels.size:,} ({100*valid_count/labels.size:.2f}%)")

unique_classes, class_counts = np.unique(labels[valid_mask], return_counts=True)
print("\nClass distribution:")
for cls, count in zip(unique_classes, class_counts):
    print(f"  Class {cls}: {count:,} pixels ({100*count/valid_count:.2f}%)")

In [None]:
# Balanced sampling strategy (~1.5M samples)
samples_per_class = {
    0: 600_000,   # Background
    1: 19_225,    # Use ALL (smallest class)
    2: 150_000,   # Moderate wetland type
    3: 500_000,   # Largest wetland class
    4: 150_000,   # Moderate wetland type
    5: 100_000,   # Moderate wetland type
}
total_target = sum(samples_per_class.values())
print(f"\nBalanced sampling strategy (target: {total_target:,} samples)")

sampled_indices_y = []
sampled_indices_x = []
sampled_labels = []

for cls in unique_classes:
    class_mask = (labels == cls)
    y_idx, x_idx = np.where(class_mask)
    
    n_available = len(y_idx)
    n_target = samples_per_class[cls]
    n_sample = min(n_target, n_available)
    
    # Sample from this class
    if n_available > n_target:
        sample_idx = np.random.choice(n_available, n_target, replace=False)
    else:
        sample_idx = np.arange(n_available)
        print(f"  âš  Class {cls}: only {n_available:,} available (target: {n_target:,})")
    
    sampled_indices_y.append(y_idx[sample_idx])
    sampled_indices_x.append(x_idx[sample_idx])
    sampled_labels.append(np.full(n_sample, cls))
    
    print(f"  Class {cls}: sampled {n_sample:,} / {n_available:,} pixels")

# Combine and shuffle
y_indices = np.concatenate(sampled_indices_y)
x_indices = np.concatenate(sampled_indices_x)
y = np.concatenate(sampled_labels)

np.random.seed(42)
shuffle_idx = np.random.permutation(len(y_indices))
y_indices = y_indices[shuffle_idx]
x_indices = x_indices[shuffle_idx]
y = y[shuffle_idx]

print(f"\nTotal balanced samples: {len(y):,}")

In [None]:
# Calculate class weights for loss function
unique_sampled, sampled_counts = np.unique(y, return_counts=True)
class_weights = torch.zeros(6)
for cls, count in zip(unique_sampled, sampled_counts):
    class_weights[cls] = 1.0 / count
class_weights = class_weights / class_weights.sum() * 6

print("\nClass weights for loss function:")
for cls in range(6):
    print(f"  Class {cls}: {class_weights[cls]:.4f}")
print("\nðŸ’¡ Use: nn.CrossEntropyLoss(weight=class_weights)")

## Step 4: Extract Embeddings (Optimized Batch Reading)
This uses row-based batching for ~100x speedup over pixel-by-pixel

In [None]:
# Extract embeddings using optimized batching
print("\nReading embeddings for sampled pixels (optimized batching)...")
n_samples = len(y_indices)
X = np.zeros((n_samples, embeddings_src.count), dtype=np.float32)

# Group samples by row for efficient batch reading
row_to_samples = defaultdict(list)
for idx, (y_coord, x_coord) in enumerate(zip(y_indices, x_indices)):
    row_to_samples[y_coord].append((idx, x_coord))

print(f"Grouped {n_samples:,} samples into {len(row_to_samples):,} unique rows")

# Read row by row
sample_count = 0
with tqdm(total=len(row_to_samples), desc="Reading rows", unit=" rows") as pbar:
    for row_idx in sorted(row_to_samples.keys()):
        # Read entire row at once
        row_data = embeddings_src.read(window=((row_idx, row_idx+1), (0, embeddings_src.width)))
        row_data = row_data[:, 0, :]  # (64, width)
        
        # Extract samples from this row
        for sample_idx, col_idx in row_to_samples[row_idx]:
            X[sample_idx, :] = row_data[:, col_idx]
            sample_count += 1
        
        pbar.update(1)

embeddings_src.close()
print(f"\nâœ“ Loaded {sample_count:,} samples")
print(f"  X shape: {X.shape} ({X.nbytes / (1024**3):.2f} GB)")
print(f"  y shape: {y.shape}")

## Step 5: Save Dataset
Save the prepared dataset for training

In [None]:
# Save dataset
output_file = 'wetland_dataset_1.5M.npz'
np.savez_compressed(
    output_file,
    X=X,
    y=y,
    class_weights=class_weights.numpy(),
    samples_per_class=np.array(list(samples_per_class.values()))
)

print(f"\nâœ“ Dataset saved to {output_file}")
print(f"  File size: {os.path.getsize(output_file) / (1024**3):.2f} GB")
print(f"\nTo load for training:")
print("  data = np.load('wetland_dataset_1.5M.npz')")
print("  X, y = data['X'], data['y']")
print("  class_weights = torch.from_numpy(data['class_weights'])")

## Optional: Quick Dataset Verification

In [None]:
# Verify the saved dataset
data = np.load(output_file)
print("Dataset contents:")
print(f"  X: {data['X'].shape} - embeddings")
print(f"  y: {data['y'].shape} - labels")
print(f"  class_weights: {data['class_weights'].shape}")
print(f"\nClass distribution in dataset:")
unique, counts = np.unique(data['y'], return_counts=True)
for cls, count in zip(unique, counts):
    print(f"  Class {cls}: {count:,} samples ({100*count/len(data['y']):.2f}%)")