In [1]:
from pathlib import Path
import os

workdir = Path("/Users/Anthony/Data and Analysis Local/NYS_Wetlands_GHG/")
os.chdir(workdir)
print(f"Current working directory: {Path.cwd()}")

Current working directory: /Users/Anthony/Data and Analysis Local/NYS_Wetlands_GHG


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from pathlib import Path
import sys
import time

# Add script directory to Python path
script_dir = Path("/Users/Anthony/Data and Analysis Local/NYS_Wetlands_GHG/Python_Code_Analysis/DL_Learning")
sys.path.insert(0, str(script_dir))

# Import modules (these now use metadata)
from _04_dataset import get_dataloaders, load_metadata
from _05_unet_model import UNet

# Load metadata
data_dir = "Data/Patches_v2"
metadata = load_metadata(data_dir)

print(f"Loaded metadata from {data_dir}")
print(f"  in_channels: {metadata['in_channels']}")
print(f"  num_classes: {metadata['num_classes']}")
print(f"  patch_size: {metadata['patch_size']}")
print(f"  band_names: {metadata['band_names']}")
print(f"  class_names: {metadata['class_names']}")

Loaded metadata from Data/Patches_v2
  in_channels: 11
  num_classes: 5
  patch_size: 256
  band_names: ['r', 'g', 'b', 'nir', 'ndvi', 'ndwi', 'dem', 'chm', 'slope_5m', 'TPI_5m', 'Geomorph_5m']
  class_names: ['Background', 'EMW', 'FSW', 'SSW', 'OWW']


In [3]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch and return average loss."""
    model.train()
    running_loss = 0.0

    for batch_idx, (X, y) in enumerate(train_loader):
        X, y = X.to(device), y.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(X)
        loss = criterion(outputs, y)

        # Backward pass
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Progress update every 10 batches
        if (batch_idx + 1) % 10 == 0:
            print(f"    Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    return running_loss / len(train_loader)


def validate(model, val_loader, criterion, device, metadata):
    """Validate and return loss plus per-class accuracy."""
    model.eval()
    running_loss = 0.0
    num_classes = metadata["num_classes"]
    class_names = metadata["class_names"]

    # Track correct predictions per class
    correct_per_class = torch.zeros(num_classes)
    total_per_class = torch.zeros(num_classes)

    with torch.no_grad():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)

            outputs = model(X)
            loss = criterion(outputs, y)
            running_loss += loss.item()

            # Get predictions
            preds = torch.argmax(outputs, dim=1)

            # Per-class accuracy
            for c in range(num_classes):
                mask = (y == c)
                total_per_class[c] += mask.sum().item()
                correct_per_class[c] += ((preds == c) & mask).sum().item()

    avg_loss = running_loss / len(val_loader)

    # Calculate per-class accuracy
    class_acc = {}
    for c in range(num_classes):
        if total_per_class[c] > 0:
            class_acc[class_names[c]] = correct_per_class[c] / total_per_class[c]
        else:
            class_acc[class_names[c]] = 0.0

    # Overall accuracy
    overall_acc = correct_per_class.sum() / total_per_class.sum()

    return avg_loss, overall_acc.item(), class_acc

In [4]:
def main():
    # === CONFIGURATION ===
    output_dir = Path("Models")
    output_dir.mkdir(exist_ok=True)

    num_epochs = 25
    batch_size = 10
    learning_rate = 0.001

    # Device selection
    device = torch.device("cuda" if torch.cuda.is_available() else 
                          "mps" if torch.backends.mps.is_available() else "cpu")
    print(f"Using device: {device}")

    # === LOAD DATA (using metadata-aware dataloaders) ===
    print("\nLoading data...")
    train_loader, val_loader, meta = get_dataloaders(data_dir, batch_size=batch_size)
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")

    # === COMPUTE CLASS WEIGHTS FROM TRAINING DATA ===
    print("\nComputing class weights from training data...")
    y_train = np.load(Path(data_dir) / "y_train.npy")
    classes, counts = np.unique(y_train, return_counts=True)
    total = counts.sum()
    
    frequencies = counts / total
    weights = 1.0 / frequencies
    weights = weights / weights.min()  # Normalize so smallest weight is 1.0
    
    class_weights = torch.tensor(weights, dtype=torch.float32)
    print("Class weights:")
    for c, w in zip(meta["class_names"], weights):
        print(f"  {c}: {w:.2f}")

    # === CREATE MODEL (using metadata for in_channels and num_classes) ===
    print("\nInitializing model...")
    model = UNet(
        in_channels=meta["in_channels"],
        num_classes=meta["num_classes"],
        base_filters=32
    )
    model = model.to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")

    # === LOSS AND OPTIMIZER ===
    class_weights = class_weights.to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # === TRAINING LOOP ===
    print("\nStarting training...")
    print("=" * 60)

    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(num_epochs):
        epoch_start = time.time()
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 40)

        # Train
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)

        # Validate (pass metadata for class names)
        val_loss, val_acc, class_acc = validate(model, val_loader, criterion, device, meta)

        epoch_time = time.time() - epoch_start

        # Log results
        print(f"\n  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss:   {val_loss:.4f}")
        print(f"  Val Acc:    {val_acc:.4f}")
        print(f"  Time:       {epoch_time:.1f}s")
        print("  Per-class accuracy:")
        for name, acc in class_acc.items():
            print(f"    {name}: {acc:.4f}")

        # Save history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_acc': val_acc,
                'metadata': meta  # Save metadata with model
            }, output_dir / "best_model.pth")
            print("  [Saved new best model]")

    # === SAVE FINAL MODEL AND HISTORY ===
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'val_acc': val_acc,
        'metadata': meta
    }, output_dir / "final_model.pth")

    np.save(output_dir / "training_history.npy", history)

    print("\n" + "=" * 60)
    print("Training complete!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Models saved to: {output_dir}")


if __name__ == "__main__":
    main()

Using device: mps

Loading data...
Training batches: 31
Validation batches: 8

Computing class weights from training data...
Class weights:
  Background: 1.00
  EMW: 23.42
  FSW: 14.10
  SSW: 14.43
  OWW: 64.02

Initializing model...
Total parameters: 7,768,421

Starting training...

Epoch 1/25
----------------------------------------
    Batch 10/31, Loss: 1.4466
    Batch 20/31, Loss: 1.5466
    Batch 30/31, Loss: 1.5079

  Train Loss: 1.4576
  Val Loss:   1.5551
  Val Acc:    0.2893
  Time:       9.9s
  Per-class accuracy:
    Background: 0.2830
    EMW: 0.3450
    FSW: 0.4080
    SSW: 0.0541
    OWW: 0.5943
  [Saved new best model]

Epoch 2/25
----------------------------------------
    Batch 10/31, Loss: 1.4278
    Batch 20/31, Loss: 1.2934
    Batch 30/31, Loss: 1.5818

  Train Loss: 1.3135
  Val Loss:   4.5355
  Val Acc:    0.3660
  Time:       7.4s
  Per-class accuracy:
    Background: 0.3316
    EMW: 0.0386
    FSW: 0.9300
    SSW: 0.0272
    OWW: 0.6491

Epoch 3/25
---------