# Bidirectional LSTM-UNet Model Testing Notebook

This notebook tests the Bidirectional LSTM-UNet model for urban crop yield prediction. This model enhances the basic LSTM-UNet by using bidirectional LSTMs for improved temporal information processing.

## 1. Setup and Imports

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import sys

# Add src to path for imports
sys.path.append('.')

# Import custom modules
from src.models.lstm_unet import BidirectionalLSTMUNet
from src.data.dataset import TemporalUrbanCropDataset
from src.training import LSTMUNetTrainer
from src.utils import set_seed, get_device, ensure_dir, visualize_prediction, calculate_metrics

# Set random seed for reproducibility
set_seed(42)

# Check for CUDA
device = get_device()
print(f"Using device: {device}")

## 2. Configuration

In [None]:
# Configuration settings
config = {
    'data_dir': 'data/processed',  # Directory with processed data
    'model_dir': 'models',         # Directory to save trained models
    'visualize_dir': 'visualizations',  # Directory for visualizations
    'batch_size': 16,              # Batch size for training
    'epochs': 10,                  # Number of epochs (reduced for testing)
    'learning_rate': 0.001,        # Learning rate
    'test_split': 0.2,             # Fraction of data for testing
    'val_split': 0.1,              # Fraction of training data for validation
    'time_steps': 23               # Number of time steps for LSTM processing
}

# Create necessary directories
ensure_dir(config['model_dir'])
ensure_dir(config['visualize_dir'])

## 3. Load Data

For this notebook, we assume that the data has already been processed and saved as numpy arrays. If not, you'll need to run the data processing scripts first.

In [None]:
# Load processed data
def load_data(config):
    """Load preprocessed data"""
    x_path = os.path.join(config['data_dir'], 'years_array_32_segmented_prevUrb.npy')
    y_path = os.path.join(config['data_dir'], 'crops_array_32_segmented_prevUrb.npy')
    
    if os.path.exists(x_path) and os.path.exists(y_path):
        print("Loading preprocessed data...")
        X_data = np.load(x_path)
        y_data = np.load(y_path)
        print(f"X_data shape: {X_data.shape}")
        print(f"y_data shape: {y_data.shape}")
        return X_data, y_data
    else:
        raise FileNotFoundError(f"Preprocessed data not found at {x_path} and {y_path}. Please run data preprocessing first.")

try:
    X_data, y_data = load_data(config)
    
    # Sample visualization of the data
    plt.figure(figsize=(15, 5))
    
    # Input features (first sample, first 3 channels)
    plt.subplot(1, 2, 1)
    plt.imshow(np.transpose(X_data[0][:3], (1, 2, 0)))
    plt.title('Input Features (First 3 Channels)')
    plt.axis('off')
    
    # Target output
    plt.subplot(1, 2, 2)
    plt.imshow(np.transpose(y_data[0], (1, 2, 0)))
    plt.title('Target Output')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
except FileNotFoundError as e:
    print(f"Error: {e}")
    print("Using dummy data for demonstration purposes...")
    # Create dummy data for demonstration
    # Note: For Bidirectional LSTM-UNet, we'll reshape this later
    X_data = np.random.rand(100, 15, 32, 32)  # [samples, channels, height, width]
    y_data = np.random.rand(100, 3, 32, 32)   # [samples, channels, height, width]
    print(f"Dummy X_data shape: {X_data.shape}")
    print(f"Dummy y_data shape: {y_data.shape}")

## 4. Create Data Loaders with Temporal Reshaping

For Bidirectional LSTM-UNet, we need to reshape the data to include temporal information. We'll use the TemporalUrbanCropDataset class which handles this reshaping.

In [None]:
def create_data_loaders(X_data, y_data, config):
    """Create train, validation, and test data loaders with temporal reshaping"""
    # Create dataset with temporal reshaping
    dataset = TemporalUrbanCropDataset(X_data, y_data, time_steps=config['time_steps'])
    
    # Determine the number of samples for each split
    total_samples = len(dataset)
    test_size = int(total_samples * config['test_split'])
    train_size = total_samples - test_size
    val_size = int(train_size * config['val_split'])
    train_size = train_size - val_size
    
    print(f"Total samples: {total_samples}")
    print(f"Training samples: {train_size}")
    print(f"Validation samples: {val_size}")
    print(f"Test samples: {test_size}")
    
    # Split into train, validation, and test sets
    train_dataset, test_dataset = random_split(
        dataset, [train_size + val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_dataset, val_dataset = random_split(
        train_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2
    )
    
    return train_loader, val_loader, test_loader

# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(X_data, y_data, config)

# Check the shape of the data after temporal reshaping
for x_batch, y_batch in train_loader:
    print(f"Batch shapes after temporal reshaping:")
    print(f"x_batch shape: {x_batch.shape}")  # Should be [batch, time_steps, height, width, channels_per_step]
    print(f"y_batch shape: {y_batch.shape}")  # Should be [batch, channels, height, width]
    break

## 5. Initialize Bidirectional LSTM-UNet Model

In [None]:
# Initialize Bidirectional LSTM-UNet model
def initialize_model(config):
    """Initialize Bidirectional LSTM-UNet model and trainer"""
    # Create Bidirectional LSTM-UNet model
    # input_shape=(time_steps, height, width, channels_per_step)
    channels_per_step = 15 // config['time_steps']
    model = BidirectionalLSTMUNet(input_shape=(config['time_steps'], 32, 32, channels_per_step), 
                                lstm_units=16, unet_filters=16)
    
    print(f"Model created with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters")
    
    # Create trainer
    trainer = LSTMUNetTrainer(model, device=device)
    trainer.compile(
        optimizer='adam',
        learning_rate=config['learning_rate'],
        criterion='mse'
    )
    
    return model, trainer

# Initialize model and trainer
model, trainer = initialize_model(config)

## 6. Train the Model

Train the Bidirectional LSTM-UNet model for a few epochs. For full training, you would increase the number of epochs.

In [None]:
def train_model(trainer, train_loader, val_loader, config):
    """Train the model and save checkpoints"""
    # Create callback for model checkpoint
    from src.training import model_checkpoint, early_stopping
    
    checkpoint_path = os.path.join(config['model_dir'], 'bilstm-unet_best.pth')
    callbacks = [
        model_checkpoint(checkpoint_path, monitor='val_loss', save_best_only=True),
        early_stopping(patience=5, monitor='val_loss')
    ]
    
    # Train the model
    print(f"Starting training for {config['epochs']} epochs...")
    try:
        history = trainer.fit(
            train_loader,
            val_loader=val_loader,
            epochs=config['epochs'],
            callbacks=callbacks
        )
    except KeyboardInterrupt:
        print("Training interrupted by user")
    
    # Save the final model
    model_path = os.path.join(config['model_dir'], 'bilstm-unet_final.pth')
    trainer.save_model(model_path)
    
    return trainer

# Train the model (uncomment to run training)
# trainer = train_model(trainer, train_loader, val_loader, config)

## 7. Plot Training History

After training, plot the loss and metrics over epochs.

In [None]:
# Plot training history (run after training)
def plot_history(trainer, config):
    """Plot training history"""
    history_path = os.path.join(config['visualize_dir'], 'bilstm-unet_history.png')
    trainer.plot_history(save_path=history_path)
    
# If you've trained the model, uncomment to plot the history
# plot_history(trainer, config)

## 8. Load Best Model and Evaluate

Load the best model from checkpoint and evaluate on the test set.

In [None]:
def load_best_model(config):
    """Load the best model from checkpoint"""
    # Initialize model
    channels_per_step = 15 // config['time_steps']
    model = BidirectionalLSTMUNet(input_shape=(config['time_steps'], 32, 32, channels_per_step), 
                                lstm_units=16, unet_filters=16)
    
    trainer = LSTMUNetTrainer(model, device=device)
    trainer.compile(optimizer='adam', learning_rate=config['learning_rate'], criterion='mse')
    
    # Load checkpoint
    checkpoint_path = os.path.join(config['model_dir'], 'bilstm-unet_best.pth')
    if os.path.exists(checkpoint_path):
        print(f"Loading model from {checkpoint_path}")
        trainer.load_model(checkpoint_path)
    else:
        print(f"Checkpoint not found at {checkpoint_path}. Using untrained model.")
    
    return model, trainer

# Load the best model
# Uncomment after training or if you have a saved model
# best_model, best_trainer = load_best_model(config)

## 9. Evaluate on Test Set

In [None]:
def evaluate_model(trainer, test_loader, config):
    """Evaluate the model on the test set"""
    print("Evaluating model on test set...")
    test_loss, test_metrics = trainer.evaluate(test_loader)
    print(f"Test Loss: {test_loss:.4f}, Test MAE: {test_metrics['mae']:.4f}")
    
    # Visualize some predictions
    vis_path = os.path.join(config['visualize_dir'], 'bilstm-unet_predictions.png')
    
    # Custom visualization for Bidirectional LSTM-UNet (adapting for temporal data)
    model = trainer.model
    model.eval()
    device = get_device()
    model.to(device)
    
    # Get samples
    samples = []
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(test_loader):
            if batch_idx >= 5:  # Get 5 samples
                break
            
            x = x.to(device)
            pred = model(x)
            samples.append((x.cpu().numpy(), y.cpu().numpy(), pred.cpu().numpy()))
    
    # Plot
    fig, axes = plt.subplots(len(samples), 3, figsize=(15, 5*len(samples)))
    
    for i, (x, y, pred) in enumerate(samples):
        # Input (show last time step)
        axes[i, 0].imshow(x[0, -1])
        axes[i, 0].set_title('Input (Last Time Step)')
        axes[i, 0].axis('off')
        
        # Ground truth
        axes[i, 1].imshow(np.transpose(y[0], (1, 2, 0)))
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        # Prediction
        axes[i, 2].imshow(np.transpose(pred[0], (1, 2, 0)))
        axes[i, 2].set_title('Prediction')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(vis_path)
    plt.show()
    
    return test_loss, test_metrics

# Evaluate on test set
# Uncomment after loading a trained model
# test_loss, test_metrics = evaluate_model(best_trainer, test_loader, config)

## 10. Compare with Standard LSTM-UNet

Compare the performance of the Bidirectional LSTM-UNet with the standard LSTM-UNet model to see if bidirectional processing improves performance.

In [None]:
def compare_with_lstm_unet(config):
    """Compare Bidirectional LSTM-UNet with standard LSTM-UNet"""
    from src.models.lstm_unet import LSTMUNet
    
    # Load the best models if available
    channels_per_step = 15 // config['time_steps']
    
    # Initialize bidirectional model
    bilstm_model = BidirectionalLSTMUNet(input_shape=(config['time_steps'], 32, 32, channels_per_step), 
                                        lstm_units=16, unet_filters=16)
    bilstm_trainer = LSTMUNetTrainer(bilstm_model, device=device)
    bilstm_trainer.compile(optimizer='adam', learning_rate=config['learning_rate'], criterion='mse')
    
    # Initialize standard LSTM model
    lstm_model = LSTMUNet(input_shape=(config['time_steps'], 32, 32, channels_per_step), 
                        lstm_units=16, unet_filters=16)
    lstm_trainer = LSTMUNetTrainer(lstm_model, device=device)
    lstm_trainer.compile(optimizer='adam', learning_rate=config['learning_rate'], criterion='mse')
    
    # Load checkpoints if available
    bilstm_path = os.path.join(config['model_dir'], 'bilstm-unet_best.pth')
    lstm_path = os.path.join(config['model_dir'], 'lstm-unet_best.pth')
    
    models_loaded = True
    
    if os.path.exists(bilstm_path):
        print(f"Loading Bidirectional LSTM-UNet from {bilstm_path}")
        bilstm_trainer.load_model(bilstm_path)
    else:
        print(f"Bidirectional LSTM-UNet checkpoint not found. Using untrained model.")
        models_loaded = False
        
    if os.path.exists(lstm_path):
        print(f"Loading LSTM-UNet from {lstm_path}")
        lstm_trainer.load_model(lstm_path)
    else:
        print(f"LSTM-UNet checkpoint not found. Using untrained model.")
        models_loaded = False
    
    if not models_loaded:
        print("Cannot compare models without trained checkpoints.")
        return
    
    # Create test data loaders
    _, _, test_loader = create_data_loaders(X_data, y_data, config)
    
    # Evaluate both models
    print("\nEvaluating Bidirectional LSTM-UNet:")
    bilstm_loss, bilstm_metrics = bilstm_trainer.evaluate(test_loader)
    
    print("\nEvaluating standard LSTM-UNet:")
    lstm_loss, lstm_metrics = lstm_trainer.evaluate(test_loader)
    
    # Compare results
    print("\nComparison:")
    print(f"Bidirectional LSTM-UNet - Loss: {bilstm_loss:.4f}, MAE: {bilstm_metrics['mae']:.4f}")
    print(f"Standard LSTM-UNet - Loss: {lstm_loss:.4f}, MAE: {lstm_metrics['mae']:.4f}")
    
    # Plot comparison
    plt.figure(figsize=(10, 6))
    metrics = ['loss', 'mae']
    values = [
        [bilstm_loss, bilstm_metrics['mae']],
        [lstm_loss, lstm_metrics['mae']]
    ]
    
    x = np.arange(len(metrics))
    width = 0.35
    
    plt.bar(x - width/2, values[0], width, label='Bidirectional LSTM-UNet')
    plt.bar(x + width/2, values[1], width, label='Standard LSTM-UNet')
    
    plt.xlabel('Metrics')
    plt.ylabel('Values')
    plt.title('Model Comparison')
    plt.xticks(x, metrics)
    plt.legend()
    
    plt.savefig(os.path.join(config['visualize_dir'], 'bilstm_vs_lstm_comparison.png'))
    plt.show()
    
    return bilstm_metrics, lstm_metrics

# Compare with standard LSTM-UNet
# Uncomment to run comparison if you have trained models
# bilstm_metrics, lstm_metrics = compare_with_lstm_unet(config)

## 11. Analyze Bidirectional Processing

Analyze how the bidirectional processing affects predictions compared to unidirectional.

In [None]:
def analyze_bidirectional_effect(config):
    """Analyze how bidirectional processing affects predictions"""
    from src.models.lstm_unet import LSTMUNet
    
    # Load data
    X_data, y_data = load_data(config)
    
    # Create temporal dataset
    channels_per_step = 15 // config['time_steps']
    
    # Create datasets with original and reversed temporal order
    original_dataset = TemporalUrbanCropDataset(X_data[:5], y_data[:5], time_steps=config['time_steps'])
    
    # Initialize models
    bilstm_model = BidirectionalLSTMUNet(input_shape=(config['time_steps'], 32, 32, channels_per_step), 
                                        lstm_units=16, unet_filters=16)
    lstm_model = LSTMUNet(input_shape=(config['time_steps'], 32, 32, channels_per_step), 
                        lstm_units=16, unet_filters=16)
    
    # Load pre-trained models if available
    bilstm_path = os.path.join(config['model_dir'], 'bilstm-unet_best.pth')
    lstm_path = os.path.join(config['model_dir'], 'lstm-unet_best.pth')
    
    device = get_device()
    bilstm_model.to(device)
    lstm_model.to(device)
    
    if os.path.exists(bilstm_path) and os.path.exists(lstm_path):
        bilstm_checkpoint = torch.load(bilstm_path, map_location=device)
        lstm_checkpoint = torch.load(lstm_path, map_location=device)
        
        bilstm_model.load_state_dict(bilstm_checkpoint['model_state_dict'])
        lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
        print("Pre-trained models loaded successfully.")
    else:
        print("Pre-trained models not found. Using untrained models for demonstration.")
    
    # Make predictions
    bilstm_model.eval()
    lstm_model.eval()
    
    results = []
    
    with torch.no_grad():
        for i in range(len(original_dataset)):
            x, y = original_dataset[i]
            x = x.unsqueeze(0).to(device)  # Add batch dimension
            
            # Forward predictions
            bilstm_pred = bilstm_model(x)
            lstm_pred = lstm_model(x)
            
            # Reverse temporal order for standard LSTM (to simulate backward processing)
            x_reversed = x.clone()
            x_reversed = torch.flip(x_reversed, [1])  # Reverse time dimension
            lstm_reversed_pred = lstm_model(x_reversed)
            
            results.append({
                'x': x.cpu().numpy(),
                'y': y.cpu().numpy(),
                'bilstm_pred': bilstm_pred.cpu().numpy(),
                'lstm_pred': lstm_pred.cpu().numpy(),
                'lstm_reversed_pred': lstm_reversed_pred.cpu().numpy()
            })
    
    # Visualize one sample
    for idx, result in enumerate(results):
        plt.figure(figsize=(15, 10))
        
        # Ground truth
        plt.subplot(2, 2, 1)
        plt.imshow(np.transpose(result['y'], (1, 2, 0)))
        plt.title('Ground Truth')
        plt.axis('off')
        
        # Bidirectional LSTM-UNet prediction
        plt.subplot(2, 2, 2)
        plt.imshow(np.transpose(result['bilstm_pred'][0], (1, 2, 0)))
        plt.title('Bidirectional LSTM-UNet')
        plt.axis('off')
        
        # Standard LSTM-UNet prediction (forward)
        plt.subplot(2, 2, 3)
        plt.imshow(np.transpose(result['lstm_pred'][0], (1, 2, 0)))
        plt.title('Standard LSTM-UNet (Forward)')
        plt.axis('off')
        
        # Standard LSTM-UNet prediction (reversed)
        plt.subplot(2, 2, 4)
        plt.imshow(np.transpose(result['lstm_reversed_pred'][0], (1, 2, 0)))
        plt.title('Standard LSTM-UNet (Reversed)')
        plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(config['visualize_dir'], f'bidirectional_analysis_sample_{idx}.png'))
        plt.show()
        
        # Calculate metrics for comparison
        bilstm_mae = np.mean(np.abs(result['bilstm_pred'][0] - result['y']))
        lstm_mae = np.mean(np.abs(result['lstm_pred'][0] - result['y']))
        lstm_reversed_mae = np.mean(np.abs(result['lstm_reversed_pred'][0] - result['y']))
        
        print(f"Sample {idx}:")
        print(f"Bidirectional LSTM-UNet MAE: {bilstm_mae:.4f}")
        print(f"Standard LSTM-UNet (Forward) MAE: {lstm_mae:.4f}")
        print(f"Standard LSTM-UNet (Reversed) MAE: {lstm_reversed_mae:.4f}")
        print()
        
        if idx >= 2:  # Show only first 3 samples to save space
            break

# Analyze bidirectional processing
# Uncomment to run analysis if you have trained models
# analyze_bidirectional_effect(config)

## 12. Conclusion

This notebook demonstrated how to load, train, evaluate, and use the Bidirectional LSTM-UNet model for spatio-temporal urban crop yield prediction. The Bidirectional LSTM-UNet model enhances the standard LSTM-UNet by processing temporal information in both forward and backward directions, potentially capturing more complex temporal dependencies.

For a full training, you would increase the number of epochs and potentially tune hyperparameters for better performance. Comparing with the standard LSTM-UNet model can provide insights into the benefits of bidirectional processing for your specific dataset.