# UNet Model Testing Notebook

This notebook tests the UNet model for urban crop yield prediction. The UNet architecture is used for image-to-image prediction, converting environmental data into crop yield predictions.

## 1. Setup and Imports

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import sys

# Add src to path for imports
sys.path.append('.')

# Import custom modules
from src.models.unet import UNet
from src.data.dataset import UrbanCropDataset
from src.training import ModelTrainer
from src.utils import set_seed, get_device, ensure_dir, visualize_prediction, calculate_metrics

# Set random seed for reproducibility
set_seed(42)

# Check for CUDA
device = get_device()
print(f"Using device: {device}")

## 2. Configuration

In [None]:
# Configuration settings
config = {
    'data_dir': 'data/processed',  # Directory with processed data
    'model_dir': 'models',         # Directory to save trained models
    'visualize_dir': 'visualizations',  # Directory for visualizations
    'batch_size': 16,              # Batch size for training
    'epochs': 10,                  # Number of epochs (reduced for testing)
    'learning_rate': 0.001,        # Learning rate
    'test_split': 0.2,             # Fraction of data for testing
    'val_split': 0.1               # Fraction of training data for validation
}

# Create necessary directories
ensure_dir(config['model_dir'])
ensure_dir(config['visualize_dir'])

## 3. Load Data

For this notebook, we assume that the data has already been processed and saved as numpy arrays. If not, you'll need to run the data processing scripts first.

In [None]:
# Load processed data
def load_data(config):
    """Load preprocessed data"""
    x_path = os.path.join(config['data_dir'], 'years_array_32_segmented_prevUrb.npy')
    y_path = os.path.join(config['data_dir'], 'crops_array_32_segmented_prevUrb.npy')
    
    if os.path.exists(x_path) and os.path.exists(y_path):
        print("Loading preprocessed data...")
        X_data = np.load(x_path)
        y_data = np.load(y_path)
        print(f"X_data shape: {X_data.shape}")
        print(f"y_data shape: {y_data.shape}")
        return X_data, y_data
    else:
        raise FileNotFoundError(f"Preprocessed data not found at {x_path} and {y_path}. Please run data preprocessing first.")

try:
    X_data, y_data = load_data(config)
    
    # Sample visualization of the data
    plt.figure(figsize=(15, 5))
    
    # Input features (first sample, first 3 channels)
    plt.subplot(1, 2, 1)
    plt.imshow(np.transpose(X_data[0][:3], (1, 2, 0)))
    plt.title('Input Features (First 3 Channels)')
    plt.axis('off')
    
    # Target output
    plt.subplot(1, 2, 2)
    plt.imshow(np.transpose(y_data[0], (1, 2, 0)))
    plt.title('Target Output')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
except FileNotFoundError as e:
    print(f"Error: {e}")
    print("Using dummy data for demonstration purposes...")
    # Create dummy data for demonstration
    X_data = np.random.rand(100, 15, 32, 32)  # [samples, channels, height, width]
    y_data = np.random.rand(100, 3, 32, 32)   # [samples, channels, height, width]
    print(f"Dummy X_data shape: {X_data.shape}")
    print(f"Dummy y_data shape: {y_data.shape}")

## 4. Create Data Loaders

Split the data into training, validation, and test sets, and create data loaders.

In [None]:
def create_data_loaders(X_data, y_data, config):
    """Create train, validation, and test data loaders"""
    # Create dataset
    dataset = UrbanCropDataset(X_data, y_data)
    
    # Determine the number of samples for each split
    total_samples = len(dataset)
    test_size = int(total_samples * config['test_split'])
    train_size = total_samples - test_size
    val_size = int(train_size * config['val_split'])
    train_size = train_size - val_size
    
    print(f"Total samples: {total_samples}")
    print(f"Training samples: {train_size}")
    print(f"Validation samples: {val_size}")
    print(f"Test samples: {test_size}")
    
    # Split into train, validation, and test sets
    train_dataset, test_dataset = random_split(
        dataset, [train_size + val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_dataset, val_dataset = random_split(
        train_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2
    )
    
    return train_loader, val_loader, test_loader

# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(X_data, y_data, config)

## 5. Initialize UNet Model

In [None]:
# Initialize UNet model
def initialize_model(config):
    """Initialize UNet model and trainer"""
    # Create UNet model
    model = UNet(in_channels=15, out_channels=3)
    print(f"Model created with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters")
    
    # Create trainer
    trainer = ModelTrainer(model, device=device)
    trainer.compile(
        optimizer='adam',
        learning_rate=config['learning_rate'],
        criterion='mse'
    )
    
    return model, trainer

# Initialize model and trainer
model, trainer = initialize_model(config)

## 6. Train the Model

Train the UNet model for a few epochs. For a full training, you would increase the number of epochs.

In [None]:
def train_model(trainer, train_loader, val_loader, config):
    """Train the model and save checkpoints"""
    # Create callback for model checkpoint
    from src.training import model_checkpoint, early_stopping
    
    checkpoint_path = os.path.join(config['model_dir'], 'unet_best.pth')
    callbacks = [
        model_checkpoint(checkpoint_path, monitor='val_loss', save_best_only=True),
        early_stopping(patience=5, monitor='val_loss')
    ]
    
    # Train the model
    print(f"Starting training for {config['epochs']} epochs...")
    try:
        history = trainer.fit(
            train_loader,
            val_loader=val_loader,
            epochs=config['epochs'],
            callbacks=callbacks
        )
    except KeyboardInterrupt:
        print("Training interrupted by user")
    
    # Save the final model
    model_path = os.path.join(config['model_dir'], 'unet_final.pth')
    trainer.save_model(model_path)
    
    return trainer

# Train the model (uncomment to run training)
# trainer = train_model(trainer, train_loader, val_loader, config)

## 7. Plot Training History

After training, plot the loss and metrics over epochs.

In [None]:
# Plot training history (run after training)
def plot_history(trainer, config):
    """Plot training history"""
    history_path = os.path.join(config['visualize_dir'], 'unet_history.png')
    trainer.plot_history(save_path=history_path)
    
# If you've trained the model, uncomment to plot the history
# plot_history(trainer, config)

## 8. Load Best Model and Evaluate

Load the best model from checkpoint and evaluate on the test set.

In [None]:
def load_best_model(config):
    """Load the best model from checkpoint"""
    # Initialize model
    model = UNet(in_channels=15, out_channels=3)
    trainer = ModelTrainer(model, device=device)
    trainer.compile(optimizer='adam', learning_rate=config['learning_rate'], criterion='mse')
    
    # Load checkpoint
    checkpoint_path = os.path.join(config['model_dir'], 'unet_best.pth')
    if os.path.exists(checkpoint_path):
        print(f"Loading model from {checkpoint_path}")
        trainer.load_model(checkpoint_path)
    else:
        print(f"Checkpoint not found at {checkpoint_path}. Using untrained model.")
    
    return model, trainer

# Load the best model
# Uncomment after training or if you have a saved model
# best_model, best_trainer = load_best_model(config)

## 9. Evaluate on Test Set

In [None]:
def evaluate_model(trainer, test_loader, config):
    """Evaluate the model on the test set"""
    print("Evaluating model on test set...")
    test_loss, test_metrics = trainer.evaluate(test_loader)
    print(f"Test Loss: {test_loss:.4f}, Test MAE: {test_metrics['mae']:.4f}")
    
    # Visualize some predictions
    vis_path = os.path.join(config['visualize_dir'], 'unet_predictions.png')
    visualize_prediction(trainer.model, test_loader, num_samples=5, save_path=vis_path)
    
    return test_loss, test_metrics

# Evaluate on test set
# Uncomment after loading a trained model
# test_loss, test_metrics = evaluate_model(best_trainer, test_loader, config)

## 10. Make Predictions on New Data

Use the trained model to make predictions on new data samples.

In [None]:
def predict_on_sample(model, X_data, idx=0):
    """Make a prediction on a single sample"""
    model.eval()
    model.to(device)
    
    # Get a sample
    x = torch.tensor(X_data[idx:idx+1], dtype=torch.float32).to(device)
    
    # Make prediction
    with torch.no_grad():
        pred = model(x)
    
    # Convert to numpy
    pred = pred.cpu().numpy()
    
    # Visualize
    plt.figure(figsize=(10, 5))
    
    # Input (first 3 channels)
    plt.subplot(1, 2, 1)
    plt.imshow(np.transpose(X_data[idx][:3], (1, 2, 0)))
    plt.title('Input (First 3 Channels)')
    plt.axis('off')
    
    # Prediction
    plt.subplot(1, 2, 2)
    plt.imshow(np.transpose(pred[0], (1, 2, 0)))
    plt.title('Model Prediction')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return pred

# Make a prediction on a sample
# Uncomment after loading a trained model
# sample_idx = 10  # Change to try different samples
# prediction = predict_on_sample(best_model, X_data, idx=sample_idx)

## 11. Conclusion

This notebook demonstrated how to load, train, evaluate, and use the UNet model for urban crop yield prediction. For a full training, you would increase the number of epochs and potentially tune hyperparameters for better performance.