# 04_dataset

In [5]:
from pathlib import Path
import os
import json

workdir = Path("/Users/Anthony/Data and Analysis Local/NYS_Wetlands_GHG/")
os.chdir(workdir)
print(f"Current working directory: {Path.cwd()}")

def load_metadata(data_dir="Data/Patches_v2"):
    """Load metadata from patches directory."""
    metadata_path = Path(data_dir) / "metadata.json"
    if metadata_path.exists():
        with open(metadata_path) as f:
            return json.load(f)
    else:
        raise FileNotFoundError(f"Metadata not found at {metadata_path}.")

# Load metadata
metadata = load_metadata()
print(f"\nMetadata loaded:")
print(f"  in_channels: {metadata['in_channels']}")
print(f"  band_names: {metadata['band_names']}")
print(f"\nNormalization parameters:")
for name, params in metadata['normalization'].items():
    print(f"  {name}: {params}")

Current working directory: /Users/Anthony/Data and Analysis Local/NYS_Wetlands_GHG

Metadata loaded:
  in_channels: 11
  band_names: ['r', 'g', 'b', 'nir', 'ndvi', 'ndwi', 'dem', 'chm', 'slope_5m', 'TPI_5m', 'Geomorph_5m']

Normalization parameters:
  r: {'type': 'divide', 'value': 255.0}
  g: {'type': 'divide', 'value': 255.0}
  b: {'type': 'divide', 'value': 255.0}
  nir: {'type': 'divide', 'value': 255.0}
  ndvi: {'type': 'shift_scale', 'shift': 1.0, 'scale': 2.0}
  ndwi: {'type': 'shift_scale', 'shift': 1.0, 'scale': 2.0}
  dem: {'type': 'minmax', 'min': 311.703369140625, 'max': 497.0373840332031}
  chm: {'type': 'minmax', 'min': 0.0, 'max': 38.871971130371094}
  slope_5m: {'type': 'minmax', 'min': 0.0, 'max': 40.392330169677734}
  TPI_5m: {'type': 'minmax', 'min': -0.0667724609375, 'max': 0.079681396484375}
  Geomorph_5m: {'type': 'minmax', 'min': 1.0, 'max': 10.0}


In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [7]:
class WetlandDataset(Dataset):
    """PyTorch Dataset for wetland segmentation patches."""

    def __init__(self, X_path, y_path, metadata, normalize=True):
        """
        Args:
            X_path: Path to input patches numpy file
            y_path: Path to label patches numpy file
            metadata: Metadata dict containing band_names and normalization params
            normalize: Whether to normalize inputs
        """
        self.X = np.load(X_path)
        self.y = np.load(y_path)
        self.normalize = normalize
        self.metadata = metadata
        self.band_names = metadata["band_names"]
        self.normalization = metadata["normalization"]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X = self.X[idx].astype(np.float32).copy()
        y = self.y[idx].astype(np.int64).copy()

        if self.normalize:
            for i, band_name in enumerate(self.band_names):
                norm_params = self.normalization[band_name]

                if norm_params["type"] == "divide":
                    X[i] = X[i] / norm_params["value"]
                elif norm_params["type"] == "shift_scale":
                    X[i] = (X[i] + norm_params["shift"]) / norm_params["scale"]
                elif norm_params["type"] == "minmax":
                    min_val = norm_params["min"]
                    max_val = norm_params["max"]
                    X[i] = (X[i] - min_val) / (max_val - min_val)

        return torch.from_numpy(X), torch.from_numpy(y)


def get_dataloaders(data_dir="Data/Patches_v2", batch_size=16):
    """Create training and validation DataLoaders using metadata."""
    data_dir = Path(data_dir)
    metadata = load_metadata(data_dir)

    train_dataset = WetlandDataset(
        data_dir / "X_train.npy",
        data_dir / "y_train.npy",
        metadata,
        normalize=True
    )
    val_dataset = WetlandDataset(
        data_dir / "X_val.npy",
        data_dir / "y_val.npy",
        metadata,
        normalize=True
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )

    return train_loader, val_loader, metadata

In [8]:
# === TEST THE DATASET ===
data_dir = "Data/Patches_v2"
train_loader, val_loader, metadata = get_dataloaders(data_dir, batch_size=16)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Get one batch and check shapes/ranges
X_batch, y_batch = next(iter(train_loader))

print(f"\nBatch X shape: {X_batch.shape}")
print(f"Batch y shape: {y_batch.shape}")
print(f"X dtype: {X_batch.dtype}")
print(f"y dtype: {y_batch.dtype}")

print("\nNormalized band ranges:")
for i, name in enumerate(metadata["band_names"]):
    band = X_batch[:, i, :, :]
    print(f"  {name}: min={band.min():.3f}, max={band.max():.3f}")

print(f"\nLabel classes in batch: {torch.unique(y_batch).tolist()}")

Training batches: 19
Validation batches: 5

Batch X shape: torch.Size([16, 11, 256, 256])
Batch y shape: torch.Size([16, 256, 256])
X dtype: torch.float32
y dtype: torch.int64

Normalized band ranges:
  r: min=0.067, max=0.992
  g: min=0.133, max=0.996
  b: min=0.267, max=1.000
  nir: min=0.071, max=0.984
  ndvi: min=0.243, max=0.892
  ndwi: min=0.232, max=0.819
  dem: min=0.015, max=0.925
  chm: min=0.000, max=0.870
  slope_5m: min=0.000, max=0.851
  TPI_5m: min=0.046, max=0.814
  Geomorph_5m: min=0.000, max=1.000

Label classes in batch: [0, 1, 2, 3, 4]
