In [1]:
from pathlib import Path
import os

workdir = Path("/Users/Anthony/Data and Analysis Local/NYS_Wetlands_GHG/")
os.chdir(workdir)
print(f"Current working directory: {Path.cwd()}")

# === CONFIGURATION ===
# Must match the output from NYS_03_create_patches_v2.ipynb
data_dir = Path("Data/Patches_v2")
cluster_id = 208  # Cluster to load, or None for legacy files
huc_id = None     # Specific HUC to load, or None to combine all HUCs in cluster

# Training hyperparameters
num_epochs = 2
batch_size = 10
learning_rate = 0.001
base_filters = 32  # U-Net base filter count

# Output directory for models
output_dir = Path("Models")
output_dir.mkdir(exist_ok=True)

print(f"\nData Configuration:")
print(f"  data_dir: {data_dir}")
print(f"  cluster_id: {cluster_id}")
print(f"  huc_id: {huc_id or 'All HUCs in cluster'}")
print(f"\nTraining Configuration:")
print(f"  num_epochs: {num_epochs}")
print(f"  batch_size: {batch_size}")
print(f"  learning_rate: {learning_rate}")
print(f"  base_filters: {base_filters}")
print(f"  output_dir: {output_dir}")

Current working directory: /Users/Anthony/Data and Analysis Local/NYS_Wetlands_GHG

Data Configuration:
  data_dir: Data/Patches_v2
  cluster_id: 208
  huc_id: All HUCs in cluster

Training Configuration:
  num_epochs: 2
  batch_size: 10
  learning_rate: 0.001
  base_filters: 32
  output_dir: Models


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import json
import time
import sys

# Add script directory to Python path
script_dir = Path("Python_Code_Analysis/DL_Implement/")
sys.path.insert(0, str(script_dir))

# Import modules
from NYS_04_dataset import get_dataloaders, find_patch_files, load_and_merge_metadata
from NYS_05_unet_model import UNet

# === LOAD DATA ===
print("Loading data...")
train_loader, val_loader, metadata = get_dataloaders(
    data_dir, 
    cluster_id=cluster_id, 
    huc_id=huc_id, 
    batch_size=batch_size
)

print(f"\nDataset Summary:")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  in_channels: {metadata['in_channels']}")
print(f"  num_classes: {metadata['num_classes']}")
print(f"  band_names: {metadata['band_names']}")
if "hucs_included" in metadata:
    print(f"  HUCs included: {metadata['hucs_included']}")

Current working directory: /Users/Anthony/Data and Analysis Local/NYS_Wetlands_GHG

Configuration:
  data_dir: Data/Patches_v2
  cluster_id: 208
  huc_id: All HUCs in cluster
Found 8 training file(s)
Found 8 validation file(s)

Dataset Summary:
  Total training patches: 2940
  Total validation patches: 742
  Training batches: 147
  Validation batches: 37
  HUCs included: ['metadata', '041402011002', 'metadata', '041402011004', 'metadata', 'metadata', 'metadata', 'metadata', 'metadata', 'metadata']

Batch shapes:
  X: torch.Size([16, 11, 256, 256]) (dtype: torch.float32)
  y: torch.Size([16, 256, 256]) (dtype: torch.int64)

Normalized band ranges (first batch):
  r: min=0.000, max=0.992
  g: min=0.086, max=0.996
  b: min=0.235, max=1.000
  nir: min=0.086, max=0.984
  ndvi: min=0.113, max=1.000
  ndwi: min=0.186, max=0.897
  dem: min=0.016, max=0.897
  chm: min=0.000, max=0.834
  slope_5m: min=0.000, max=0.669
  TPI_5m: min=0.241, max=0.693
  Geomorph_5m: min=0.100, max=1.000

Label clas

In [3]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch and return average loss."""
    model.train()
    running_loss = 0.0

    for batch_idx, (X, y) in enumerate(train_loader):
        X, y = X.to(device), y.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(X)
        loss = criterion(outputs, y)

        # Backward pass
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Progress update every 10 batches
        if (batch_idx + 1) % 10 == 0:
            print(f"    Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    return running_loss / len(train_loader)


def validate(model, val_loader, criterion, device, metadata):
    """Validate and return loss plus per-class accuracy."""
    model.eval()
    running_loss = 0.0
    num_classes = metadata["num_classes"]
    class_names = metadata["class_names"]

    # Track correct predictions per class
    correct_per_class = torch.zeros(num_classes)
    total_per_class = torch.zeros(num_classes)

    with torch.no_grad():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)

            outputs = model(X)
            loss = criterion(outputs, y)
            running_loss += loss.item()

            # Get predictions
            preds = torch.argmax(outputs, dim=1)

            # Per-class accuracy
            for c in range(num_classes):
                mask = (y == c)
                total_per_class[c] += mask.sum().item()
                correct_per_class[c] += ((preds == c) & mask).sum().item()

    avg_loss = running_loss / len(val_loader)

    # Calculate per-class accuracy
    class_acc = {}
    for c in range(num_classes):
        if total_per_class[c] > 0:
            class_acc[class_names[c]] = correct_per_class[c] / total_per_class[c]
        else:
            class_acc[class_names[c]] = 0.0

    # Overall accuracy
    overall_acc = correct_per_class.sum() / total_per_class.sum()

    return avg_loss, overall_acc.item(), class_acc

In [5]:
def compute_class_weights(data_dir, cluster_id, huc_id, class_names):
    """Compute class weights from training data using inverse frequency."""
    files = find_patch_files(data_dir, cluster_id, huc_id)
    
    print("Loading y_train files for class weight computation...")
    y_train_list = [np.load(f) for f in files['y_train']]
    y_train = np.concatenate(y_train_list, axis=0)
    print(f"  Combined y_train shape: {y_train.shape}")
    
    # Count pixels per class
    classes, counts = np.unique(y_train, return_counts=True)
    total = counts.sum()
    
    # Compute inverse frequency weights
    frequencies = counts / total
    weights = 1.0 / frequencies
    weights = weights / weights.min()  # Normalize so smallest weight is 1.0
    
    print("\nClass distribution and weights:")
    for c, count, w in zip(classes, counts, weights):
        pct = count / total * 100
        print(f"  {class_names[c]}: {count:,} pixels ({pct:.2f}%) -> weight: {w:.2f}")
    
    return torch.tensor(weights, dtype=torch.float32)


def main():
    # Device selection
    device = torch.device("cuda" if torch.cuda.is_available() else 
                          "mps" if torch.backends.mps.is_available() else "cpu")
    print(f"Using device: {device}")

    # === COMPUTE CLASS WEIGHTS ===
    class_weights = compute_class_weights(
        data_dir, cluster_id, huc_id, metadata["class_names"]
    )

    # === CREATE MODEL ===
    print("\nInitializing model...")
    model = UNet(
        in_channels=metadata["in_channels"],
        num_classes=metadata["num_classes"],
        base_filters=base_filters
    )
    model = model.to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")

    # === LOSS AND OPTIMIZER ===
    class_weights = class_weights.to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # === TRAINING LOOP ===
    print("\nStarting training...")
    print("=" * 60)

    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(num_epochs):
        epoch_start = time.time()
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 40)

        # Train
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)

        # Validate
        val_loss, val_acc, class_acc = validate(model, val_loader, criterion, device, metadata)

        epoch_time = time.time() - epoch_start

        # Log results
        print(f"\n  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss:   {val_loss:.4f}")
        print(f"  Val Acc:    {val_acc:.4f}")
        print(f"  Time:       {epoch_time:.1f}s")
        print("  Per-class accuracy:")
        for name, acc in class_acc.items():
            print(f"    {name}: {acc:.4f}")

        # Save history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_acc': val_acc,
                'metadata': metadata,
                'config': {
                    'cluster_id': cluster_id,
                    'huc_id': huc_id,
                    'base_filters': base_filters,
                    'learning_rate': learning_rate,
                }
            }, output_dir / "best_model.pth")
            print("  [Saved new best model]")

    # === SAVE FINAL MODEL AND HISTORY ===
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'val_acc': val_acc,
        'metadata': metadata,
        'config': {
            'cluster_id': cluster_id,
            'huc_id': huc_id,
            'base_filters': base_filters,
            'learning_rate': learning_rate,
        }
    }, output_dir / "final_model.pth")

    np.save(output_dir / "training_history.npy", history)

    print("\n" + "=" * 60)
    print("Training complete!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Models saved to: {output_dir}")

    return history


# Run training (comment out to prevent execution)
history = main()

Using device: mps
Loading y_train files for class weight computation...
  Combined y_train shape: (2337, 256, 256)

Class distribution and weights:
  Background: 134,791,435 pixels (88.01%) -> weight: 1.00
  EMW: 4,141,171 pixels (2.70%) -> weight: 32.55
  FSW: 7,176,036 pixels (4.69%) -> weight: 18.78
  SSW: 4,557,607 pixels (2.98%) -> weight: 29.58
  OWW: 2,491,383 pixels (1.63%) -> weight: 54.10

Initializing model...
Total parameters: 7,768,421

Starting training...

Epoch 1/2
----------------------------------------
    Batch 10/234, Loss: 1.4849
    Batch 20/234, Loss: 1.3901
    Batch 30/234, Loss: 1.4328
    Batch 40/234, Loss: 1.2212
    Batch 50/234, Loss: 1.5189
    Batch 60/234, Loss: 1.3520
    Batch 70/234, Loss: 1.3949
    Batch 80/234, Loss: 1.3633
    Batch 90/234, Loss: 1.1958
    Batch 100/234, Loss: 1.0425
    Batch 110/234, Loss: 0.8465
    Batch 120/234, Loss: 1.1509
    Batch 130/234, Loss: 1.3416
    Batch 140/234, Loss: 1.3075
    Batch 150/234, Loss: 1.3492
  

In [None]:
!jupyter nbconvert --to script Python_Code_Analysis/DL_Implement/NYS_06_train.ipynb --TagRemovePreprocessor.remove_cell_tags='{"remove"}'