# 04_dataset

In [1]:
from pathlib import Path
import os
workdir = Path("/Users/Anthony/Data and Analysis Local/NYS_Wetlands_GHG/")
print(workdir)
os.chdir(workdir)
current_working_dir = Path.cwd()
print(f"Current working directory is now: {current_working_dir}")

/Users/Anthony/Data and Analysis Local/NYS_Wetlands_GHG
Current working directory is now: /Users/Anthony/Data and Analysis Local/NYS_Wetlands_GHG


In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [9]:
class WetlandDataset(Dataset):
    """PyTorch Dataset for wetland segmentation patches."""
    
    def __init__(self, X_path, y_path, normalize=True):
        """
        Args:
            X_path: Path to input patches numpy file (N, 7, 128, 128)
            y_path: Path to label patches numpy file (N, 128, 128)
            normalize: Whether to normalize inputs
        """
        self.X = np.load(X_path)
        self.y = np.load(y_path)
        self.normalize = normalize
        
        # Normalization parameters
        # Bands: R, G, B, NIR, NDWI, NDVI, DEM
        self.rgb_nir_max = 255.0
        self.ndwi_ndvi_range = (-1.0, 1.0)
        self.dem_range = (311.0, 410.0)  # From observed data
        self.chm_range = (0.0, 36.0)  # From observed data
        self.slp_range = (0.0, 67)
        self.tpi_range = (0.0, 12)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        X = self.X[idx].astype(np.float32).copy()
        y = self.y[idx].astype(np.int64).copy()
        
        if self.normalize:
            # R, G, B, NIR (bands 0-3): divide by 255
                #0 to 4 not including 4
            X[0:4] = X[0:4] / self.rgb_nir_max
            
            # NDWI, NDVI (bands 4-5): shift from [-1,1] to [0,1]
                #4 to 6 not including 6
            X[4:6] = (X[4:6] + 1.0) / 2.0
            
            # DEM (band 6): min-max normalize
            X[6] = (X[6] - self.dem_range[0]) / (self.dem_range[1] - self.dem_range[0])

            # CHM (band 7): min-max normalize
            X[7] = (X[7] - self.chm_range[0]) / (self.chm_range[1] - self.chm_range[0])

            # Slp (band 8): min-max normalize
            X[8] = (X[8] - self.slp_range[0]) / (self.slp_range[1] - self.slp_range[0])
        
        return torch.from_numpy(X), torch.from_numpy(y)


def get_dataloaders(train_X_path, train_y_path, val_X_path, val_y_path, batch_size=16):
    """Create training and validation DataLoaders."""
    
    train_dataset = WetlandDataset(train_X_path, train_y_path, normalize=True)
    val_dataset = WetlandDataset(val_X_path, val_y_path, normalize=True)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0  # Set to 0 for Windows compatibility; increase on Linux
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )
    
    return train_loader, val_loader

In [10]:
# === TEST THE DATASET ===
if __name__ == "__main__":
    train_loader, val_loader = get_dataloaders(
        "Data/Patches_v2/X_train.npy",
        "Data/Patches_v2/y_train.npy",
        "Data/Patches_v2/X_val.npy",
        "Data/Patches_v2/y_val.npy",
        batch_size=16
    )
    
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    
    # Get one batch and check shapes/ranges
    X_batch, y_batch = next(iter(train_loader))
    
    print(f"\nBatch X shape: {X_batch.shape}")  # Should be (16, 7, 128, 128)
    print(f"Batch y shape: {y_batch.shape}")    # Should be (16, 128, 128)
    print(f"X dtype: {X_batch.dtype}")
    print(f"y dtype: {y_batch.dtype}")
    
    print("\nNormalized band ranges:")
    band_names = ['R', 'G', 'B', 'NIR', 'NDWI', 'NDVI', 'DEM', "CHM"]
    for i, name in enumerate(band_names):
        band = X_batch[:, i, :, :]
        print(f"  {name}: min={band.min():.3f}, max={band.max():.3f}")
    
    print(f"\nLabel classes in batch: {torch.unique(y_batch).tolist()}")

Training batches: 19
Validation batches: 5

Batch X shape: torch.Size([16, 12, 256, 256])
Batch y shape: torch.Size([16, 256, 256])
X dtype: torch.float32
y dtype: torch.int64

Normalized band ranges:
  R: min=0.035, max=1.000
  G: min=0.122, max=0.984
  B: min=0.267, max=1.000
  NIR: min=0.086, max=0.988
  NDWI: min=0.330, max=0.930
  NDVI: min=0.218, max=0.740
  DEM: min=0.023, max=1.879
  CHM: min=0.000, max=0.997

Label classes in batch: [0, 1, 2, 3, 4]
