# Magnetic Field Forecasting Neural Network (model2024)

This notebook implements a dual-branch neural network for 15-minute forecasting of Earth's surface magnetic field variations. The model combines two data streams: (1) large-scale solar and geophysical parameters sampled at 15-minute intervals, and (2) high-resolution SECS-based current patterns sampled at 1-minute intervals. The architecture uses LSTM networks for temporal processing, convolutional layers for spatial feature extraction, and cross-attention mechanisms to fuse information across different spatial and temporal scales.

Key Features:
- 15-minute ahead forecasting of magnetic field components (Be, Bn, Bu)
- Dual-branch architecture combining global and local information
- Memory-efficient batch generation with time-shifted targets
- Multi-scale temporal processing (24h and 3h lookback windows)
- Cross-attention mechanism for feature fusion

## Table of Contents

1. Setup and Imports
    * Import required libraries
    * Configure GPU settings
    * Set random seeds
    * Define data paths and constants

2. Data Loading and Exploration
- 2.1 Target Data
    * Load magnetic field components (Be, Bn, Bu)
    * Process timestamps
    * Analyze data distribution
- 2.2 Branch 1: Large-scale Data
    * Load solar/geophysical parameters (15-min resolution)
    * Organize feature groups (auroral, solar, seasonal, etc.)
    * Validate data quality and timestamps
- 2.3 Branch 2: SECS Data
    * Load SECS grid data (1-min resolution)
    * Verify spatial dimensions (21x21x3)
    * Check temporal alignment

3. Data Pipeline
    * Initialize BatchGenerator with 15-min forecast horizon
    * Create time-shifted training windows
    * Implement efficient batch generation
    * Split data (70/15/15)
    * Create TensorFlow datasets

4. Model Architecture
- 4.1 Branch 1: LSTM Network
    * Process 24h of 15-min resolution data
    * Extract temporal features from global parameters
- 4.2 Branch 2: CNN-LSTM Network
    * Process 3h of 1-min resolution SECS data
    * Extract spatiotemporal features
- 4.3 Cross-Attention Fusion
    * Multi-head attention for feature interaction
    * Combine multi-scale information
- 4.4 Decoder
    * Dense layers for feature processing
    * Output layer for 15-min ahead predictions

5. Training
    * Define custom loss function
    * Configure Adam optimizer
    * Set up callbacks (LR scheduling, early stopping)
    * Monitor training progress
    * Save best model

6. Evaluation
    * Calculate prediction metrics (MSE, RMSE)
    * Analyze forecast accuracy
    * Assess temporal performance
    * Compare with baseline models

7. Visualization
    * Plot predicted vs actual values
    * Visualize attention patterns
    * Display component-wise performance
    * Show error distributions
    * Create forecast examples

The model aims to provide accurate 15-minute forecasts of magnetic field variations by leveraging both long-term global patterns and short-term local dynamics. The dual-branch architecture and cross-attention mechanism enable the model to learn relevant features at different spatial and temporal scales while maintaining computational efficiency through careful batch generation and data handling.

In [1]:
# 1. Setup and Imports

# Import required libraries
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import os
from sklearn.preprocessing import StandardScaler
from tensorflow.keras import layers, Model
import warnings
import json
from pathlib import Path
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Check GPU availability and configure
print("TensorFlow version:", tf.__version__)
print("GPU Available:", tf.config.list_physical_devices('GPU'))

# Configure GPU memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("GPU memory growth enabled")
    except RuntimeError as e:
        print(e)

# Define paths
DATA_DIR = "/Users/akv020/Tensorflow/fennomag-net/source/model2024/data"
TARGET_PATH = os.path.join(DATA_DIR, "target.csv")
GEODATA_PATH = os.path.join(DATA_DIR, "geodata.csv")
SECS_DATA_PATH = os.path.join(DATA_DIR, "secs_data.npy")
SECS_TIMESTAMPS_PATH = os.path.join(DATA_DIR, "secs_timestamps.npy")

# Create output directories for model artifacts
OUTPUT_DIR = Path("model_outputs")
OUTPUT_DIR.mkdir(exist_ok=True)
CHECKPOINT_DIR = OUTPUT_DIR / "checkpoints"
CHECKPOINT_DIR.mkdir(exist_ok=True)
LOG_DIR = OUTPUT_DIR / "logs"
LOG_DIR.mkdir(exist_ok=True)

# Model parameters
# Temporal parameters
BRANCH1_LOOKBACK = 96    # 24 hours at 15-minute intervals
BRANCH2_LOOKBACK = 180   # 3 hours at 1-minute intervals
FORECAST_HORIZON = 15    # 15-minute forecast horizon
BATCH_SIZE = 32

# Model architecture parameters
EMBEDDING_DIM = 64       # Dimension of feature embeddings
NUM_HEADS = 4           # Number of attention heads
DROPOUT_RATE = 0.2      # Dropout rate for regularization

# Training parameters
EPOCHS = 100
LEARNING_RATE = 1e-3
MIN_LR = 1e-6
PATIENCE = 10           # Early stopping patience

# Print configuration
print("\nModel Configuration:")
print(f"Branch 1 Lookback: {BRANCH1_LOOKBACK} timesteps ({BRANCH1_LOOKBACK//4} hours)")
print(f"Branch 2 Lookback: {BRANCH2_LOOKBACK} timesteps ({BRANCH2_LOOKBACK//60} hours)")
print(f"Forecast Horizon: {FORECAST_HORIZON} minutes")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Embedding Dimension: {EMBEDDING_DIM}")
print(f"Number of Attention Heads: {NUM_HEADS}")
print(f"Dropout Rate: {DROPOUT_RATE}")

# Verify data files exist
print("\nVerifying data files:")
for path in [TARGET_PATH, GEODATA_PATH, SECS_DATA_PATH, SECS_TIMESTAMPS_PATH]:
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required data file not found: {path}")
    else:
        print(f"✓ Found: {path}")

# Set up plotting style
plt.style.use('default')
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12
plt.rcParams['axes.grid'] = True
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
plt.rcParams['legend.fontsize'] = 12
plt.rcParams['figure.titlesize'] = 16

# Set up color scheme for consistent visualization
COLORS = {
    'Be': '#1f77b4',  # Blue
    'Bn': '#2ca02c',  # Green
    'Bu': '#ff7f0e',  # Orange
    'train': '#1f77b4',
    'val': '#ff7f0e',
    'test': '#2ca02c'
}

# Function to save model configuration
def save_config():
    """Save model configuration to JSON file."""
    config = {
        'temporal_params': {
            'branch1_lookback': BRANCH1_LOOKBACK,
            'branch2_lookback': BRANCH2_LOOKBACK,
            'forecast_horizon': FORECAST_HORIZON,
            'batch_size': BATCH_SIZE
        },
        'model_params': {
            'embedding_dim': EMBEDDING_DIM,
            'num_heads': NUM_HEADS,
            'dropout_rate': DROPOUT_RATE
        },
        'training_params': {
            'epochs': EPOCHS,
            'learning_rate': LEARNING_RATE,
            'min_lr': MIN_LR,
            'patience': PATIENCE
        }
    }
    
    with open(OUTPUT_DIR / 'model_config.json', 'w') as f:
        json.dump(config, f, indent=4)
    print("\nModel configuration saved to:", OUTPUT_DIR / 'model_config.json')

# Save initial configuration
save_config()

TensorFlow version: 2.9.0
GPU Available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
GPU memory growth enabled

Model Configuration:
Branch 1 Lookback: 96 timesteps (24 hours)
Branch 2 Lookback: 180 timesteps (3 hours)
Forecast Horizon: 15 minutes
Batch Size: 32
Embedding Dimension: 64
Number of Attention Heads: 4
Dropout Rate: 0.2

Verifying data files:
✓ Found: /Users/akv020/Tensorflow/fennomag-net/source/model2024/data/target.csv
✓ Found: /Users/akv020/Tensorflow/fennomag-net/source/model2024/data/geodata.csv
✓ Found: /Users/akv020/Tensorflow/fennomag-net/source/model2024/data/secs_data.npy
✓ Found: /Users/akv020/Tensorflow/fennomag-net/source/model2024/data/secs_timestamps.npy

Model configuration saved to: model_outputs/model_config.json


In [None]:
# 2. Data Loading and Exploration

# 2.1 Target Data
print("Loading target data...")
target_df = pd.read_csv(TARGET_PATH)
target_df['DateTime'] = pd.to_datetime(target_df['DateTime'])
print(f"Target data shape: {target_df.shape}")
print("\nFirst few rows of target data:")
print(target_df.head())
print("\nTarget data info:")
print(target_df.info())

# Calculate basic statistics for target components
print("\nTarget data statistics:")
print(target_df[['Be_0_0', 'Bn_0_0', 'Bu_0_0']].describe())

# 2.2 Branch 1: Large-scale Data
print("\nLoading Branch 1 (large-scale) data...")
geo_df = pd.read_csv(GEODATA_PATH)
geo_df['DateTime'] = pd.to_datetime(geo_df['DateTime'])

# Create feature groups for better organization
feature_groups = {
    'auroral': ['SME', 'SML', 'SMU'],
    'ring_current': ['SYM_D', 'SYM_H'],
    'disturbance': ['ASY_D', 'ASY_H'],
    'solar': ['Sunspot', 'f107', 'ap_index', 'Lyman'],
    'seasonal': ['DOY_cos', 'DOY_sin', 'TOD_cos', 'TOD_sin'],
    'geometry': ['SolarZenithAngle'],
    'previous_magnetic': ['BE', 'BN', 'BU'],
    'magnetic_std': ['stdE', 'stdN', 'stdU']
}

print(f"Branch 1 data shape: {geo_df.shape}")
print("\nFeature groups:")
for group, features in feature_groups.items():
    print(f"{group}: {features}")

# Calculate sampling intervals
branch1_interval = (geo_df['DateTime'].max() - geo_df['DateTime'].min()) / len(geo_df)
print(f"\nBranch 1 sampling interval: {branch1_interval}")

# 2.3 Branch 2: SECS Data
print("\nLoading Branch 2 (SECS) data...")
secs_data = np.load(SECS_DATA_PATH)
secs_timestamps = np.load(SECS_TIMESTAMPS_PATH)
secs_timestamps = pd.to_datetime(secs_timestamps)

print(f"SECS data shape: {secs_data.shape}")
print(f"Number of SECS timestamps: {len(secs_timestamps)}")

# Calculate sampling intervals
branch2_interval = (secs_timestamps.max() - secs_timestamps.min()) / len(secs_timestamps)
print(f"Branch 2 sampling interval: {branch2_interval}")

# Basic data validation
print("\nData validation:")
print(f"Target data time range: {target_df['DateTime'].min()} to {target_df['DateTime'].max()}")
print(f"Branch 1 time range: {geo_df['DateTime'].min()} to {geo_df['DateTime'].max()}")
print(f"Branch 2 time range: {secs_timestamps.min()} to {secs_timestamps.max()}")

# Check for missing values
print("\nMissing values check:")
print("Target data missing values:")
print(target_df.isnull().sum())
print("\nBranch 1 missing values:")
print(geo_df.isnull().sum())

# Visualize data distributions
plt.figure(figsize=(15, 10))

# Target data distribution
plt.subplot(221)
for col, color in zip(['Be_0_0', 'Bn_0_0', 'Bu_0_0'], [COLORS['Be'], COLORS['Bn'], COLORS['Bu']]):
    plt.hist(target_df[col], bins=50, alpha=0.5, label=col, color=color)
plt.title('Target Data Distribution')
plt.xlabel('Magnetic Field Component Value')
plt.ylabel('Count')
plt.legend()

# Branch 1 feature distribution (auroral indices)
plt.subplot(222)
for col in feature_groups['auroral']:
    plt.hist(geo_df[col], bins=50, alpha=0.5, label=col)
plt.title('Auroral Indices Distribution')
plt.xlabel('Index Value')
plt.ylabel('Count')
plt.legend()

# Branch 1 feature distribution (solar parameters)
plt.subplot(223)
for col in feature_groups['solar']:
    plt.hist(geo_df[col], bins=50, alpha=0.5, label=col)
plt.title('Solar Parameters Distribution')
plt.xlabel('Parameter Value')
plt.ylabel('Count')
plt.legend()

# Branch 2 data distribution
plt.subplot(224)
plt.hist(secs_data.ravel(), bins=50, alpha=0.5, color='blue')
plt.title('SECS Data Distribution')
plt.xlabel('Current Density Value')
plt.ylabel('Count')

plt.tight_layout()
plt.show()

# Additional temporal analysis
print("\nTemporal Analysis:")
print(f"Total time span: {target_df['DateTime'].max() - target_df['DateTime'].min()}")
print(f"Number of days: {(target_df['DateTime'].max() - target_df['DateTime'].min()).days}")
print(f"Average samples per day: {len(target_df) / ((target_df['DateTime'].max() - target_df['DateTime'].min()).days):.2f}")

# Save data statistics
data_stats = {
    'target_stats': target_df[['Be_0_0', 'Bn_0_0', 'Bu_0_0']].describe().to_dict(),
    'temporal_info': {
        'total_days': (target_df['DateTime'].max() - target_df['DateTime'].min()).days,
        'samples_per_day': len(target_df) / ((target_df['DateTime'].max() - target_df['DateTime'].min()).days),
        'branch1_interval': str(branch1_interval),
        'branch2_interval': str(branch2_interval)
    }
}

with open(OUTPUT_DIR / 'data_statistics.json', 'w') as f:
    json.dump(data_stats, f, indent=4)
print("\nData statistics saved to:", OUTPUT_DIR / 'data_statistics.json')

In [None]:
# 3. Data Pipeline

print("\nInitializing Data Pipeline...")
print(f"Forecast Horizon: {FORECAST_HORIZON} minutes")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Train/Val/Test Split: 70/15/15")

# Import the BatchGenerator
from batch_generator import BatchGenerator

# Calculate valid observation windows
valid_start = pd.Timestamp('2024-01-02 00:00:00')  # First valid observation
valid_end = pd.Timestamp('2024-12-31 23:30:00')    # Last valid observation

print("\nValid Observation Windows:")
print(f"Start of valid observations: {valid_start}")
print(f"End of valid observations: {valid_end}")
print(f"Total valid observation period: {(valid_end - valid_start).total_seconds() / 3600:.1f} hours")

# Initialize the batch generator with time-shifted targets
print("\nInitializing BatchGenerator...")
batch_generator = BatchGenerator(
    target_path=TARGET_PATH,
    geodata_path=GEODATA_PATH,
    secs_data_path=SECS_DATA_PATH,
    secs_timestamps_path=SECS_TIMESTAMPS_PATH,
    batch_size=BATCH_SIZE,
    train_ratio=0.7,
    val_ratio=0.15,
    forecast_horizon=FORECAST_HORIZON,
    valid_start=valid_start,
    valid_end=valid_end
)

# Create TensorFlow datasets for each split
print("\nCreating TensorFlow datasets...")
train_dataset = batch_generator.create_tf_dataset(split='train')
val_dataset = batch_generator.create_tf_dataset(split='val')
test_dataset = batch_generator.create_tf_dataset(split='test')

# Verify dataset shapes and temporal alignment
print("\nVerifying dataset shapes and temporal alignment:")
for name, dataset in [('Training', train_dataset), ('Validation', val_dataset), ('Test', test_dataset)]:
    print(f"\n{name} Dataset:")
    for batch in dataset.take(1):
        # Print shapes
        print(f"Branch 1 input shape: {batch['branch1_input'].shape}")  # (batch_size, 96, n_features)
        print(f"Branch 2 input shape: {batch['branch2_input'].shape}")  # (batch_size, 180, 21, 21, 3)
        print(f"Target shape: {batch['target'].shape}")  # (batch_size, 3)
        
        # Verify temporal windows
        print("\nTemporal window verification:")
        print("Branch 1: 96 timesteps = 24 hours ending at current time t")
        print("Branch 2: 180 timesteps = 3 hours ending at current time t")
        print("Target: 15 minutes ahead of current time t")
        
        # Get a sample from the batch
        sample_idx = 0  # First sample in batch
        
        # Get the corresponding indices for the split
        if name == 'Training':
            indices = batch_generator.train_indices
        elif name == 'Validation':
            indices = batch_generator.val_indices
        else:
            indices = batch_generator.test_indices
            
        if len(indices) > 0:
            # Get a valid index from the split
            idx = indices[0]  # Use first index instead of random to ensure consistency
            
            # Get timestamps
            current_time = batch_generator.target_df['DateTime'].iloc[idx]
            target_time = current_time + pd.Timedelta(minutes=15)  # 15 minutes ahead
            
            print("\nSample timestamps:")
            print(f"Current time (t): {current_time}")
            print(f"Branch 1 window: [{current_time - pd.Timedelta(hours=24)} ... {current_time}]")
            print(f"Branch 2 window: [{current_time - pd.Timedelta(hours=3)} ... {current_time}]")
            print(f"Target time: {target_time} (t + 15min)")
            
            # Print sample values
            print("\nSample values:")
            print("Branch 1 input (first and last timestep):")
            print("First timestep (t-24h):", batch['branch1_input'][sample_idx, 0, :5].numpy(), "...")
            print("Last timestep (t):", batch['branch1_input'][sample_idx, -1, :5].numpy(), "...")
            
            print("\nBranch 2 input (first and last timestep, center pixel):")
            print("First timestep (t-3h):", batch['branch2_input'][sample_idx, 0, 10, 10, :].numpy())
            print("Last timestep (t):", batch['branch2_input'][sample_idx, -1, 10, 10, :].numpy())
            
            print("\nTarget values (t + 15min):")
            print(f"Be: {batch['target'][sample_idx, 0].numpy():.2f}")
            print(f"Bn: {batch['target'][sample_idx, 1].numpy():.2f}")
            print(f"Bu: {batch['target'][sample_idx, 2].numpy():.2f}")
        break

# Visualize temporal alignment
plt.figure(figsize=(15, 5))

# Plot sample sequences with timestamps
plt.subplot(131)
sample_branch1 = batch['branch1_input'][0, :, 0].numpy()  # First feature over time
plt.plot(sample_branch1, color=COLORS['train'], label='Branch 1')
plt.axvline(x=95, color='red', linestyle='--', label='Current time (t)')
plt.title('Branch 1 Sequence\n(24 hours ending at t)')
plt.xlabel('Timestep (15-min intervals)')
plt.ylabel('Feature Value')
plt.legend()

plt.subplot(132)
sample_branch2 = batch['branch2_input'][0, :, 10, 10, 0].numpy()  # Center pixel, first channel
plt.plot(sample_branch2, color=COLORS['train'], label='Branch 2')
plt.axvline(x=179, color='red', linestyle='--', label='Current time (t)')
plt.title('Branch 2 Sequence\n(3 hours ending at t)')
plt.xlabel('Timestep (1-min intervals)')
plt.ylabel('Current Density')
plt.legend()

plt.subplot(133)
target_components = ['Be', 'Bn', 'Bu']
for i, (comp, color) in enumerate(zip(target_components, [COLORS['Be'], COLORS['Bn'], COLORS['Bu']])):
    plt.axhline(y=batch['target'][0, i].numpy(), color=color, label=comp, linestyle='--')
plt.title('Target Values\n(t + 15min)')
plt.legend()
plt.xlabel('Component')
plt.ylabel('Magnetic Field Value')

plt.tight_layout()
plt.show()

print("\nData pipeline setup complete. Ready for model training.")

## 4.1 Branch 1: LSTM Network
The first branch processes the large-scale solar and geophysical data using LSTM layers to capture temporal dependencies.

In [None]:
## 4.1 Branch 1: LSTM Network
# The first branch processes the large-scale solar and geophysical data using LSTM layers to capture temporal dependencies.

# Define model parameters
EMBEDDING_DIM = 64  # Dimension of the embedding space
branch1_shape = (96, 22)  # (timesteps, features) for Branch 1

def create_branch1_model(input_shape, embedding_dim=64, dropout_rate=0.2):
    """Create the LSTM-based model for Branch 1 (large-scale data).
    
    Args:
        input_shape: Shape of the input data (timesteps, features)
        embedding_dim: Dimension of the embedding space
        dropout_rate: Dropout rate for regularization
        
    Returns:
        Keras model for Branch 1
    """
    inputs = layers.Input(shape=input_shape, name='branch1_input')
    
    # First LSTM layer with return sequences
    x = layers.LSTM(embedding_dim, return_sequences=True, 
                    name='branch1_lstm1')(inputs)
    x = layers.BatchNormalization(name='branch1_bn1')(x)
    x = layers.Dropout(dropout_rate, name='branch1_dropout1')(x)
    
    # Second LSTM layer
    x = layers.LSTM(embedding_dim, return_sequences=False, 
                    name='branch1_lstm2')(x)
    x = layers.BatchNormalization(name='branch1_bn2')(x)
    x = layers.Dropout(dropout_rate, name='branch1_dropout2')(x)
    
    # Dense layer to create embedding
    outputs = layers.Dense(embedding_dim, activation='relu', 
                          name='branch1_embedding')(x)
    
    return Model(inputs=inputs, outputs=outputs, name='branch1_model')

# Test Branch 1
print("\nTesting Branch 1:")
branch1_model = create_branch1_model(branch1_shape, EMBEDDING_DIM)
branch1_model.summary()
# Test with sample input
sample_branch1 = tf.random.normal((1, *branch1_shape))
branch1_output = branch1_model(sample_branch1)
print(f"Branch 1 output shape: {branch1_output.shape}")  # Should be (1, 64)

## 4.2 Branch 2: CNN-LSTM Network
The second branch processes the SECS grid data using a combination of convolutional and LSTM layers to capture both spatial and temporal patterns.

In [None]:
## 4.2 Branch 2: CNN-LSTM Network
# The second branch processes the SECS grid data using a combination of convolutional and LSTM layers to capture both spatial and temporal patterns.

# Define model parameters
branch2_shape = (180, 21, 21, 3)  # (timesteps, height, width, channels) for Branch 2

def create_branch2_model(input_shape, embedding_dim=64, dropout_rate=0.2):
    """Create the CNN-LSTM model for Branch 2 (SECS data).
    
    This model:
    1. Processes each time step with Conv2D layers to extract spatial features
    2. Flattens the spatial features
    3. Feeds the flattened features through LSTM layers to capture temporal patterns
    
    Args:
        input_shape: Shape of the input data (timesteps, height, width, channels)
        embedding_dim: Dimension of the embedding space
        dropout_rate: Dropout rate for regularization
        
    Returns:
        Keras model for Branch 2
    """
    inputs = layers.Input(shape=input_shape, name='branch2_input')
    
    # Process each time step with Conv2D layers
    # We'll use a TimeDistributed wrapper to apply the same Conv2D operations to each time step
    
    # First Conv2D layer: 21x21 -> 19x19
    x = layers.TimeDistributed(
        layers.Conv2D(64, kernel_size=(3, 3), activation='relu', padding='valid'),
        name='branch2_conv1'
    )(inputs)
    x = layers.TimeDistributed(
        layers.BatchNormalization(),
        name='branch2_bn1'
    )(x)
    x = layers.TimeDistributed(
        layers.MaxPooling2D(pool_size=(2, 2)),
        name='branch2_pool1'
    )(x)  # 19x19 -> 9x9
    
    # Second Conv2D layer: 9x9 -> 7x7
    x = layers.TimeDistributed(
        layers.Conv2D(32, kernel_size=(3, 3), activation='relu', padding='valid'),
        name='branch2_conv2'
    )(x)
    x = layers.TimeDistributed(
        layers.BatchNormalization(),
        name='branch2_bn2'
    )(x)
    x = layers.TimeDistributed(
        layers.MaxPooling2D(pool_size=(2, 2)),
        name='branch2_pool2'
    )(x)  # 7x7 -> 3x3
    
    # Flatten the spatial dimensions for each time step
    # Shape: (batch, timesteps, 3*3*32)
    x = layers.TimeDistributed(
        layers.Flatten(),
        name='branch2_flatten'
    )(x)
    
    # Apply dropout to the flattened features
    x = layers.Dropout(dropout_rate, name='branch2_dropout1')(x)
    
    # First LSTM layer with return sequences
    x = layers.LSTM(embedding_dim, return_sequences=True, 
                    name='branch2_lstm1')(x)
    x = layers.BatchNormalization(name='branch2_bn3')(x)
    x = layers.Dropout(dropout_rate, name='branch2_dropout2')(x)
    
    # Second LSTM layer
    x = layers.LSTM(embedding_dim, return_sequences=False, 
                    name='branch2_lstm2')(x)
    x = layers.BatchNormalization(name='branch2_bn4')(x)
    x = layers.Dropout(dropout_rate, name='branch2_dropout3')(x)
    
    # Dense layer to create embedding
    outputs = layers.Dense(embedding_dim, activation='relu', 
                          name='branch2_embedding')(x)
    
    return Model(inputs=inputs, outputs=outputs, name='branch2_model')

print("\nTesting Branch 2:")
branch2_model = create_branch2_model(branch2_shape, EMBEDDING_DIM)
branch2_model.summary()
# Test with sample input
sample_branch2 = tf.random.normal((1, *branch2_shape))
branch2_output = branch2_model(sample_branch2)
print(f"Branch 2 output shape: {branch2_output.shape}")  # Should be (1, 64)

## 4.3 Cross-Attention Fusion
The cross-attention mechanism allows the model to focus on the most relevant features from both branches when making predictions.

In [None]:
class CrossAttention(layers.Layer):
    """Custom layer implementing cross-attention between two feature vectors.
    
    This layer computes attention scores between two feature vectors and
    produces a weighted combination of the second vector based on the first.
    """
    
    def __init__(self, num_heads=4, key_dim=16, **kwargs):
        super(CrossAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.key_dim = key_dim
        
    def build(self, input_shape):
        # Create projection matrices for query, key, and value
        self.query_dense = layers.Dense(self.num_heads * self.key_dim)
        self.key_dense = layers.Dense(self.num_heads * self.key_dim)
        self.value_dense = layers.Dense(self.num_heads * self.key_dim)
        
        # Output projection
        self.output_dense = layers.Dense(input_shape[0][-1])
        
        super(CrossAttention, self).build(input_shape)
        
    def call(self, inputs):
        # Unpack inputs
        query, key_value = inputs
        
        # Project inputs to query, key, and value
        query = self.query_dense(query)
        key = self.key_dense(key_value)
        value = self.value_dense(key_value)
        
        # Reshape for multi-head attention
        batch_size = tf.shape(query)[0]
        
        query = tf.reshape(query, [batch_size, 1, self.num_heads, self.key_dim])
        key = tf.reshape(key, [batch_size, 1, self.num_heads, self.key_dim])
        value = tf.reshape(value, [batch_size, 1, self.num_heads, self.key_dim])
        
        # Compute attention scores
        attention_scores = tf.matmul(query, key, transpose_b=True)
        attention_scores = attention_scores / tf.math.sqrt(tf.cast(self.key_dim, tf.float32))
        attention_weights = tf.nn.softmax(attention_scores, axis=-1)
        
        # Apply attention weights to values
        attention_output = tf.matmul(attention_weights, value)
        
        # Reshape and project output
        attention_output = tf.reshape(attention_output, [batch_size, self.num_heads * self.key_dim])
        output = self.output_dense(attention_output)
        
        return output
    
    def get_config(self):
        config = super(CrossAttention, self).get_config()
        config.update({
            'num_heads': self.num_heads,
            'key_dim': self.key_dim
        })
        return config
    
# Test Cross-Attention
print("\nTesting Cross-Attention:")
cross_attention = CrossAttention(num_heads=4, key_dim=16)
# Test with sample inputs
attended_features = cross_attention([branch1_output, branch2_output])
print(f"Cross-attention output shape: {attended_features.shape}")  # Should be (1, 64)

# Test reverse direction
attended_features_reverse = cross_attention([branch2_output, branch1_output])
print(f"Reverse cross-attention output shape: {attended_features_reverse.shape}")  # Should be (1, 64)


## 4.4 Decoder
The decoder takes the fused features and produces the final predictions for the three magnetic field components.

In [None]:
# 4.4 Decoder
def create_decoder(embedding_dim, dropout_rate=0.2):
    """
    Creates the decoder part of the model that produces final predictions.
    
    Args:
        embedding_dim (int): Dimension of the input features (will be doubled due to concatenation)
        dropout_rate (float): Dropout rate for regularization
        
    Returns:
        tf.keras.Model: Decoder model
    """
    # Note: input shape is 2*embedding_dim because we concatenate two vectors
    inputs = layers.Input(shape=(2*embedding_dim,))
    
    # Dense layers with batch normalization and dropout
    x = layers.Dense(128, activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate)(x)
    
    x = layers.Dense(64, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate)(x)
    
    # Output layer for 3 magnetic field components
    outputs = layers.Dense(3, activation='linear', name='decoder_output')(x)
    
    return tf.keras.Model(inputs=inputs, outputs=outputs, name='decoder')

# Test Decoder
print("\nTesting Decoder:")
decoder = create_decoder(EMBEDDING_DIM)
decoder.summary()
# Test with sample input (note: this should be 128-dimensional after concatenation)
sample_fused = tf.random.normal((1, 128))  # Changed from 64 to 128
decoder_output = decoder(sample_fused)
print(f"Decoder output shape: {decoder_output.shape}")  # Should be (1, 3)


## 4.5 Complete Model
Now we'll combine all components to create the complete dual-branch neural network model.

In [None]:
## 4.5 Complete Model
# Now we'll combine all components to create the complete dual-branch neural network model.

# Define additional model parameters
NUM_HEADS = 4  # Number of attention heads
DROPOUT_RATE = 0.2  # Dropout rate for regularization

def create_dual_branch_model(branch1_shape, branch2_shape, embedding_dim=64, num_heads=4, dropout_rate=0.2):
    """Create the complete dual-branch neural network model.
    
    Args:
        branch1_shape: Shape of Branch 1 input (timesteps, features)
        branch2_shape: Shape of Branch 2 input (timesteps, height, width, channels)
        embedding_dim: Dimension of the embedding space
        num_heads: Number of attention heads
        dropout_rate: Dropout rate for regularization
        
    Returns:
        Complete Keras model
    """
    # Create input layers
    branch1_input = layers.Input(shape=branch1_shape, name='branch1_input')
    branch2_input = layers.Input(shape=branch2_shape, name='branch2_input')
    
    # Create branch models
    branch1_model = create_branch1_model(branch1_shape, embedding_dim, dropout_rate)
    branch2_model = create_branch2_model(branch2_shape, embedding_dim, dropout_rate)
    
    # Process inputs through branch models
    branch1_features = branch1_model(branch1_input)
    branch2_features = branch2_model(branch2_input)
    
    # Apply cross-attention in both directions
    branch1_attended = CrossAttention(num_heads=num_heads, key_dim=embedding_dim//num_heads, 
                                      name='branch1_attention')([branch1_features, branch2_features])
    branch2_attended = CrossAttention(num_heads=num_heads, key_dim=embedding_dim//num_heads, 
                                      name='branch2_attention')([branch2_features, branch1_features])
    
    # Concatenate attended features
    fused_features = layers.Concatenate(name='feature_fusion')([branch1_attended, branch2_attended])
    
    # Create and apply decoder
    decoder = create_decoder(embedding_dim, dropout_rate)
    outputs = decoder(fused_features)
    
    # Create and return the complete model
    model = Model(inputs=[branch1_input, branch2_input], outputs=outputs, name='dual_branch_model')
    
    return model

# Test Complete Model
print("\nTesting Complete Model:")
model = create_dual_branch_model(
    branch1_shape=branch1_shape,
    branch2_shape=branch2_shape,
    embedding_dim=EMBEDDING_DIM,
    num_heads=NUM_HEADS,
    dropout_rate=DROPOUT_RATE
)
model.summary()

# Test with sample inputs
sample_output = model([sample_branch1, sample_branch2])
print(f"Complete model output shape: {sample_output.shape}")  # Should be (1, 3)

## 4.6 Model Summary and Visualization
Let's create the model and visualize its architecture.

In [None]:
## 4.6 Model Summary and Visualization
# Let's create the model and visualize its architecture.

# Get input shapes from the datasets
sample_batch = next(iter(train_dataset))
branch1_shape = sample_batch['branch1_input'].shape[1:]
branch2_shape = sample_batch['branch2_input'].shape[1:]

print(f"Branch 1 input shape: {branch1_shape}")
print(f"Branch 2 input shape: {branch2_shape}")

# Create the model
model = create_dual_branch_model(
    branch1_shape=branch1_shape,
    branch2_shape=branch2_shape,
    embedding_dim=EMBEDDING_DIM,
    num_heads=NUM_HEADS,
    dropout_rate=DROPOUT_RATE
)

# Display model summary
model.summary()

# Visualize model architecture
try:
    from tensorflow.keras.utils import plot_model
    plot_model(model, to_file='model_architecture.png', show_shapes=True)
    print("\nModel architecture saved to 'model_architecture.png'")
except ImportError:
    print("\nCould not visualize model architecture. Install pydot and graphviz for visualization.")

## 5. Training

In [None]:
## 5. Training
# 5.1 Define Loss Functions
def custom_loss(y_true, y_pred):
    """
    Custom loss function combining MSE for each magnetic field component
    with additional weighting for vertical component (Bu) which is typically
    more important for space weather applications.
    """
    # Split predictions into components
    be_pred, bn_pred, bu_pred = tf.unstack(y_pred, axis=1)
    be_true, bn_true, bu_true = tf.unstack(y_true, axis=1)
    
    # Calculate MSE for each component
    be_loss = tf.keras.losses.mean_squared_error(be_true, be_pred)
    bn_loss = tf.keras.losses.mean_squared_error(bn_true, bn_pred)
    bu_loss = tf.keras.losses.mean_squared_error(bu_true, bu_pred)
    
    # Weight vertical component equally
    return be_loss + bn_loss + bu_loss

# 5.2 Configure Optimizer with Gradient Clipping
optimizer = tf.keras.optimizers.Adam(
    learning_rate=5e-4,  # More conservative initial learning rate
    beta_1=0.9,
    beta_2=0.999,
    epsilon=1e-07,
    clipnorm=1.0  # Add gradient clipping
)

# 5.3 Set up Callbacks
# Learning rate scheduler with more patience
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=7,  # Increased patience
    min_lr=1e-6,
    verbose=1
)

# Early stopping with increased patience
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=15,  # Increased patience
    restore_best_weights=True,
    verbose=1
)

# Model checkpointing
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    'best_model.h5',
    monitor='val_loss',
    save_best_only=True,
    verbose=1
)

# TensorBoard logging
tensorboard = tf.keras.callbacks.TensorBoard(
    log_dir='./logs',
    write_graph=True,
    update_freq='epoch'
)

# 5.4 Compile Model
model.compile(
    optimizer=optimizer,
    loss=custom_loss,
    metrics=[
        tf.keras.metrics.MeanSquaredError(name='mse'),
        tf.keras.metrics.RootMeanSquaredError(name='rmse')
    ]
)

# 5.5 Train Model
EPOCHS = 100
BATCH_SIZE = 64  # Increased batch size for better stability

# Calculate steps per epoch from the actual number of samples
train_samples = 24527  # From earlier data split calculation
val_samples = 5255    # From earlier data split calculation

train_steps = train_samples // BATCH_SIZE
val_steps = val_samples // BATCH_SIZE

print(f"\nDataset sizes:")
print(f"Training steps per epoch: {train_steps}")
print(f"Validation steps per epoch: {val_steps}")
print(f"Total training samples: {train_samples}")
print(f"Total validation samples: {val_samples}\n")

# Create a function to extract inputs and targets from the dataset
def prepare_dataset(dataset):
    return dataset.map(lambda x: (
        {'branch1_input': x['branch1_input'], 'branch2_input': x['branch2_input']},
        x['target']
    ))

# Prepare the datasets
train_dataset_prepared = prepare_dataset(train_dataset)
val_dataset_prepared = prepare_dataset(val_dataset)

# Train the model
history = model.fit(
    train_dataset_prepared,
    validation_data=val_dataset_prepared,
    epochs=EPOCHS,
    steps_per_epoch=train_steps,
    validation_steps=val_steps,
    callbacks=[
        lr_scheduler,
        early_stopping,
        checkpoint,
        tensorboard
    ],
    verbose=1
)

# 5.6 Plot Training History
plt.figure(figsize=(15, 5))

# Plot loss
plt.subplot(1, 3, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot MSE
plt.subplot(1, 3, 2)
plt.plot(history.history['mse'], label='Training MSE')
plt.plot(history.history['val_mse'], label='Validation MSE')
plt.title('Model MSE')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.legend()

# Plot RMSE
plt.subplot(1, 3, 3)
plt.plot(history.history['rmse'], label='Training RMSE')
plt.plot(history.history['val_rmse'], label='Validation RMSE')
plt.title('Model RMSE')
plt.xlabel('Epoch')
plt.ylabel('RMSE')
plt.legend()

plt.tight_layout()
plt.show()

# 5.7 Save Training History
# Convert numpy types to Python native types
history_dict = {}
for key, values in history.history.items():
    history_dict[key] = [float(v) for v in values]

with open('training_history.json', 'w') as f:
    json.dump(history_dict, f)

In [None]:
## 7. Visualization of Predictions
import matplotlib.dates as mdates

# 7.1 Generate Predictions for a Single Batch
print("\nGenerating predictions for a single batch...")

# Load the best model
best_model = tf.keras.models.load_model('best_model.h5', 
                                      custom_objects={
                                          'CrossAttention': CrossAttention,
                                          'custom_loss': custom_loss
                                      })

# Prepare test dataset
test_dataset_prepared = prepare_dataset(test_dataset)

# Get a single batch
dataset_iterator = iter(test_dataset_prepared)
batch = next(dataset_iterator)
inputs, targets = batch
predictions = best_model.predict(inputs, verbose=0)

# Get timestamps for this batch
batch_indices = batch_generator.test_indices[:len(targets)]  # Get indices for this batch
batch_timestamps = pd.to_datetime(batch_generator.target_df['DateTime'].iloc[batch_indices])

# Create a DataFrame for this batch
results_df = pd.DataFrame({
    'DateTime': batch_timestamps,
    'Be_true': targets[:, 0],
    'Bn_true': targets[:, 1],
    'Bu_true': targets[:, 2],
    'Be_pred': predictions[:, 0],
    'Bn_pred': predictions[:, 1],
    'Bu_pred': predictions[:, 2]
})

# Get unique days in this batch
unique_days = results_df['DateTime'].dt.date.unique()
print(f"\nFound {len(unique_days)} unique days in this batch")

# Select 3 random days (or all days if less than 3)
num_days = min(3, len(unique_days))
random_days = np.random.choice(unique_days, size=num_days, replace=False)
random_days.sort()  # Sort for chronological order

# 7.2 Create Visualization
plt.figure(figsize=(20, 15))

# Components and their labels
components = ['Be', 'Bn', 'Bu']
component_labels = ['East', 'North', 'Up']
# Create subplots (3 components × number of days)
for day_idx, day in enumerate(random_days):
    # Get data for this day
    day_data = results_df[results_df['DateTime'].dt.date == day]
    day_times = day_data['DateTime']
    
    # Print time range for debugging
    print(f"\nDay {day}:")
    print(f"Time range: {day_times.min()} to {day_times.max()}")
    print(f"Number of samples: {len(day_times)}")
    
    # Plot each component
    for comp_idx, (comp, label) in enumerate(zip(components, component_labels)):
        subplot_idx = day_idx * 3 + comp_idx + 1
        plt.subplot(num_days, 3, subplot_idx)
        
        # Plot true and predicted values
        plt.plot(day_times, day_data[f'{comp}_true'], 'b-', label='True', alpha=0.7)
        plt.plot(day_times, day_data[f'{comp}_pred'], 'r--', label='Predicted', alpha=0.7)
        
        # Customize plot
        plt.title(f'{label} Component - {day}\n{day_times.min().strftime("%H:%M")} to {day_times.max().strftime("%H:%M")}')
        plt.xlabel('Time')
        plt.ylabel('Magnetic Field (nT)')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Format x-axis to show hours
        plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
        plt.xticks(rotation=45)
        
        # Add RMSE for this day and component
        rmse = np.sqrt(np.mean((day_data[f'{comp}_true'] - day_data[f'{comp}_pred'])**2))
        plt.text(0.02, 0.98, f'RMSE: {rmse:.2f}', 
                transform=plt.gca().transAxes, 
                verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

# 7.3 Save Sample Predictions
sample_results = {
    'dates': [str(day) for day in random_days],
    'predictions': {
        str(day): {
            'times': [str(t) for t in day_data['DateTime']],
            'true_values': day_data[['Be_true', 'Bn_true', 'Bu_true']].values.tolist(),
            'predicted_values': day_data[['Be_pred', 'Bn_pred', 'Bu_pred']].values.tolist()
        } for day, day_data in zip(
            random_days,
            [results_df[results_df['DateTime'].dt.date == day] for day in random_days]
        )
    }
}

with open('sample_predictions.json', 'w') as f:
    json.dump(sample_results, f, indent=4)