# LSTM-UNet Model Testing Notebook

This notebook tests the LSTM-UNet hybrid model for urban crop yield prediction. This model combines LSTM for temporal processing with U-Net for spatial processing, making it suitable for spatio-temporal data.

## 1. Setup and Imports

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import sys

# Add src to path for imports
sys.path.append('.')

# Import custom modules
from src.models.lstm_unet import LSTMUNet
from src.data.dataset import TemporalUrbanCropDataset
from src.training import LSTMUNetTrainer
from src.utils import set_seed, get_device, ensure_dir, visualize_prediction, calculate_metrics

# Set random seed for reproducibility
set_seed(42)

# Check for CUDA
device = get_device()
print(f"Using device: {device}")

## 2. Configuration

In [None]:
# Configuration settings
config = {
    'data_dir': 'data/processed',  # Directory with processed data
    'model_dir': 'models',         # Directory to save trained models
    'visualize_dir': 'visualizations',  # Directory for visualizations
    'batch_size': 16,              # Batch size for training
    'epochs': 10,                  # Number of epochs (reduced for testing)
    'learning_rate': 0.001,        # Learning rate
    'test_split': 0.2,             # Fraction of data for testing
    'val_split': 0.1,              # Fraction of training data for validation
    'time_steps': 23               # Number of time steps for LSTM processing
}

# Create necessary directories
ensure_dir(config['model_dir'])
ensure_dir(config['visualize_dir'])

## 3. Load Data

For this notebook, we assume that the data has already been processed and saved as numpy arrays. If not, you'll need to run the data processing scripts first.

In [None]:
# Load processed data
def load_data(config):
    """Load preprocessed data"""
    x_path = os.path.join(config['data_dir'], 'years_array_32_segmented_prevUrb.npy')
    y_path = os.path.join(config['data_dir'], 'crops_array_32_segmented_prevUrb.npy')
    
    if os.path.exists(x_path) and os.path.exists(y_path):
        print("Loading preprocessed data...")
        X_data = np.load(x_path)
        y_data = np.load(y_path)
        print(f"X_data shape: {X_data.shape}")
        print(f"y_data shape: {y_data.shape}")
        return X_data, y_data
    else:
        raise FileNotFoundError(f"Preprocessed data not found at {x_path} and {y_path}. Please run data preprocessing first.")

try:
    X_data, y_data = load_data(config)
    
    # Sample visualization of the data
    plt.figure(figsize=(15, 5))
    
    # Input features (first sample, first 3 channels)
    plt.subplot(1, 2, 1)
    plt.imshow(np.transpose(X_data[0][:3], (1, 2, 0)))
    plt.title('Input Features (First 3 Channels)')
    plt.axis('off')
    
    # Target output
    plt.subplot(1, 2, 2)
    plt.imshow(np.transpose(y_data[0], (1, 2, 0)))
    plt.title('Target Output')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
except FileNotFoundError as e:
    print(f"Error: {e}")
    print("Using dummy data for demonstration purposes...")
    # Create dummy data for demonstration
    # Note: For LSTM-UNet, we'll reshape this later
    X_data = np.random.rand(100, 15, 32, 32)  # [samples, channels, height, width]
    y_data = np.random.rand(100, 3, 32, 32)   # [samples, channels, height, width]
    print(f"Dummy X_data shape: {X_data.shape}")
    print(f"Dummy y_data shape: {y_data.shape}")

## 4. Create Data Loaders with Temporal Reshaping

For LSTM-UNet, we need to reshape the data to include temporal information. We'll use the TemporalUrbanCropDataset class which handles this reshaping.

In [None]:
def create_data_loaders(X_data, y_data, config):
    """Create train, validation, and test data loaders with temporal reshaping"""
    # Create dataset with temporal reshaping
    dataset = TemporalUrbanCropDataset(X_data, y_data, time_steps=config['time_steps'])
    
    # Determine the number of samples for each split
    total_samples = len(dataset)
    test_size = int(total_samples * config['test_split'])
    train_size = total_samples - test_size
    val_size = int(train_size * config['val_split'])
    train_size = train_size - val_size
    
    print(f"Total samples: {total_samples}")
    print(f"Training samples: {train_size}")
    print(f"Validation samples: {val_size}")
    print(f"Test samples: {test_size}")
    
    # Split into train, validation, and test sets
    train_dataset, test_dataset = random_split(
        dataset, [train_size + val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_dataset, val_dataset = random_split(
        train_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2
    )
    
    return train_loader, val_loader, test_loader

# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(X_data, y_data, config)

# Check the shape of the data after temporal reshaping
for x_batch, y_batch in train_loader:
    print(f"Batch shapes after temporal reshaping:")
    print(f"x_batch shape: {x_batch.shape}")  # Should be [batch, time_steps, height, width, channels_per_step]
    print(f"y_batch shape: {y_batch.shape}")  # Should be [batch, channels, height, width]
    break

## 5. Initialize LSTM-UNet Model

In [None]:
# Initialize LSTM-UNet model
def initialize_model(config):
    """Initialize LSTM-UNet model and trainer"""
    # Create LSTM-UNet model
    # input_shape=(time_steps, height, width, channels_per_step)
    channels_per_step = 15 // config['time_steps']
    model = LSTMUNet(input_shape=(config['time_steps'], 32, 32, channels_per_step), 
                    lstm_units=16, unet_filters=16)
    
    print(f"Model created with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters")
    
    # Create trainer
    trainer = LSTMUNetTrainer(model, device=device)
    trainer.compile(
        optimizer='adam',
        learning_rate=config['learning_rate'],
        criterion='mse'
    )
    
    return model, trainer

# Initialize model and trainer
model, trainer = initialize_model(config)

## 6. Train the Model

Train the LSTM-UNet model for a few epochs. For full training, you would increase the number of epochs.

In [None]:
def train_model(trainer, train_loader, val_loader, config):
    """Train the model and save checkpoints"""
    # Create callback for model checkpoint
    from src.training import model_checkpoint, early_stopping
    
    checkpoint_path = os.path.join(config['model_dir'], 'lstm-unet_best.pth')
    callbacks = [
        model_checkpoint(checkpoint_path, monitor='val_loss', save_best_only=True),
        early_stopping(patience=5, monitor='val_loss')
    ]
    
    # Train the model
    print(f"Starting training for {config['epochs']} epochs...")
    try:
        history = trainer.fit(
            train_loader,
            val_loader=val_loader,
            epochs=config['epochs'],
            callbacks=callbacks
        )
    except KeyboardInterrupt:
        print("Training interrupted by user")
    
    # Save the final model
    model_path = os.path.join(config['model_dir'], 'lstm-unet_final.pth')
    trainer.save_model(model_path)
    
    return trainer

# Train the model (uncomment to run training)
# trainer = train_model(trainer, train_loader, val_loader, config)

## 7. Plot Training History

After training, plot the loss and metrics over epochs.

In [None]:
# Plot training history (run after training)
def plot_history(trainer, config):
    """Plot training history"""
    history_path = os.path.join(config['visualize_dir'], 'lstm-unet_history.png')
    trainer.plot_history(save_path=history_path)
    
# If you've trained the model, uncomment to plot the history
# plot_history(trainer, config)

## 8. Load Best Model and Evaluate

Load the best model from checkpoint and evaluate on the test set.

In [None]:
def load_best_model(config):
    """Load the best model from checkpoint"""
    # Initialize model
    channels_per_step = 15 // config['time_steps']
    model = LSTMUNet(input_shape=(config['time_steps'], 32, 32, channels_per_step), 
                    lstm_units=16, unet_filters=16)
    
    trainer = LSTMUNetTrainer(model, device=device)
    trainer.compile(optimizer='adam', learning_rate=config['learning_rate'], criterion='mse')
    
    # Load checkpoint
    checkpoint_path = os.path.join(config['model_dir'], 'lstm-unet_best.pth')
    if os.path.exists(checkpoint_path):
        print(f"Loading model from {checkpoint_path}")
        trainer.load_model(checkpoint_path)
    else:
        print(f"Checkpoint not found at {checkpoint_path}. Using untrained model.")
    
    return model, trainer

# Load the best model
# Uncomment after training or if you have a saved model
# best_model, best_trainer = load_best_model(config)

## 9. Evaluate on Test Set

In [None]:
def evaluate_model(trainer, test_loader, config):
    """Evaluate the model on the test set"""
    print("Evaluating model on test set...")
    test_loss, test_metrics = trainer.evaluate(test_loader)
    print(f"Test Loss: {test_loss:.4f}, Test MAE: {test_metrics['mae']:.4f}")
    
    # Visualize some predictions
    vis_path = os.path.join(config['visualize_dir'], 'lstm-unet_predictions.png')
    
    # Custom visualization for LSTM-UNet (adapting for temporal data)
    model = trainer.model
    model.eval()
    device = get_device()
    model.to(device)
    
    # Get samples
    samples = []
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(test_loader):
            if batch_idx >= 5:  # Get 5 samples
                break
            
            x = x.to(device)
            pred = model(x)
            samples.append((x.cpu().numpy(), y.cpu().numpy(), pred.cpu().numpy()))
    
    # Plot
    fig, axes = plt.subplots(len(samples), 3, figsize=(15, 5*len(samples)))
    
    for i, (x, y, pred) in enumerate(samples):
        # Input (show last time step)
        axes[i, 0].imshow(x[0, -1])
        axes[i, 0].set_title('Input (Last Time Step)')
        axes[i, 0].axis('off')
        
        # Ground truth
        axes[i, 1].imshow(np.transpose(y[0], (1, 2, 0)))
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        # Prediction
        axes[i, 2].imshow(np.transpose(pred[0], (1, 2, 0)))
        axes[i, 2].set_title('Prediction')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(vis_path)
    plt.show()
    
    return test_loss, test_metrics

# Evaluate on test set
# Uncomment after loading a trained model
# test_loss, test_metrics = evaluate_model(best_trainer, test_loader, config)

## 10. Analyze Temporal Predictions

Since LSTM-UNet handles temporal data, we can analyze how predictions change over time.

In [None]:
def analyze_temporal_predictions(model, X_data, y_data, config):
    """Analyze how predictions change with different temporal inputs"""
    model.eval()
    model.to(device)
    
    # Create a dataset with reduced time steps for analysis
    time_steps_to_test = [5, 10, 15, 20, config['time_steps']]
    
    plt.figure(figsize=(15, 10))
    
    # Get a sample
    sample_idx = 0
    
    # Target output
    plt.subplot(2, 3, 1)
    plt.imshow(np.transpose(y_data[sample_idx], (1, 2, 0)))
    plt.title('Target Output')
    plt.axis('off')
    
    # Make predictions with different time steps
    for i, ts in enumerate(time_steps_to_test):
        if i >= 5:  # Only show 5 predictions
            break
            
        # Create a temporal dataset
        temp_dataset = TemporalUrbanCropDataset(X_data[sample_idx:sample_idx+1], 
                                              y_data[sample_idx:sample_idx+1], 
                                              time_steps=ts)
        x_temp, y_temp = temp_dataset[0]
        
        # Make prediction
        with torch.no_grad():
            x_temp = x_temp.unsqueeze(0).to(device)  # Add batch dimension
            pred = model(x_temp)
        
        # Plot prediction
        plt.subplot(2, 3, i+2)
        plt.imshow(np.transpose(pred[0].cpu().numpy(), (1, 2, 0)))
        plt.title(f'Prediction with {ts} Time Steps')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(config['visualize_dir'], 'lstm-unet_temporal_analysis.png'))
    plt.show()

# Analyze temporal predictions
# Uncomment after loading a trained model
# analyze_temporal_predictions(best_model, X_data, y_data, config)

## 11. Create Temporal Animation

Visualize how the model predictions change over time.

In [None]:
def create_temporal_animation(model, X_data, y_data, config, sample_idx=0):
    """Create an animation showing how predictions change over time"""
    from src.utils import create_time_series_animation
    
    model.eval()
    model.to(device)
    
    # Create predictions for different time slices
    predictions = []
    
    for t in range(5, config['time_steps'] + 1):
        # Create dataset with t time steps
        temp_dataset = TemporalUrbanCropDataset(X_data[sample_idx:sample_idx+1], 
                                              y_data[sample_idx:sample_idx+1], 
                                              time_steps=t)
        x_temp, y_temp = temp_dataset[0]
        
        # Make prediction
        with torch.no_grad():
            x_temp = x_temp.unsqueeze(0).to(device)  # Add batch dimension
            pred = model(x_temp)
        
        # Add to predictions
        predictions.append(np.transpose(pred[0].cpu().numpy(), (1, 2, 0)))
    
    # Create animation
    animation_path = os.path.join(config['visualize_dir'], 'lstm-unet_temporal_animation.gif')
    ani = create_time_series_animation(
        data=np.array(predictions),
        title="LSTM-UNet Predictions Over Time",
        save_path=animation_path,
        fps=2
    )
    
    return animation_path

# Create temporal animation
# Uncomment after loading a trained model
# animation_path = create_temporal_animation(best_model, X_data, y_data, config)

## 12. Conclusion

This notebook demonstrated how to load, train, evaluate, and use the LSTM-UNet model for spatio-temporal urban crop yield prediction. The LSTM-UNet model combines LSTM for temporal processing with U-Net for spatial processing, making it suitable for datasets with both spatial and temporal components.

For a full training, you would increase the number of epochs and potentially tune hyperparameters for better performance.