In [1]:
import xarray as xr
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

In [2]:
path = 'C:/Users/ABHISHEK/OneDrive/Documents/SSH Data/ostst-single-layer-fd-lat-40-urms-5-kf-13-kr-4-beta.nc'

ds = xr.open_dataset(path)

In [3]:
ssh = ds['ssh']  


x_norm = (ds.x - ds.x.min()) / (ds.x.max() - ds.x.min())
y_norm = (ds.y - ds.y.min()) / (ds.y.max() - ds.y.min())

In [4]:
ssh_mean = ssh.mean()
ssh_std = ssh.std()
ssh_norm = (ssh - ssh_mean) / ssh_std

In [5]:
print(f"SSH Original range: [{ssh.min().values:.4f}, {ssh.max().values:.4f}]")
print(f"SSH Normalized range: [{ssh_norm.min().values:.4f}, {ssh_norm.max().values:.4f}]")
print(f"SSH Mean: {ssh_mean.values:.4f}, SSH Std: {ssh_std.values:.4f}")

SSH Original range: [-0.0719, 0.0775]
SSH Normalized range: [-3.3165, 3.5742]
SSH Mean: 0.0000, SSH Std: 0.0217


In [41]:
def create_enhanced_ssh_dataset(ssh_norm, max_slices=100, train_ratio=0.8):
    """
    Create enhanced training dataset utilizing maximum available data
    """
    total_available = len(ssh_norm.t)
    print(f"Total available time slices: {total_available}")
    
    # Calculating max_slices
    if total_available > max_slices:
        step_size = max(1, total_available // max_slices)
        time_indices = np.arange(0, total_available, step_size)[:max_slices]
    else:
        time_indices = np.arange(0, total_available)
    
    print(f"Using {len(time_indices)} time slices (step size: {step_size if total_available > max_slices else 1})")
    
    # Extract SSH slices
    ssh_slices = ssh_norm.isel(t=time_indices).values
    
    # Split into training and validation
    split_idx = int(train_ratio * len(ssh_slices))
    
    train_slices = ssh_slices[:split_idx]
    val_slices = ssh_slices[split_idx:]
    
    print(f"Training slices: {len(train_slices)}")
    print(f"Validation slices: {len(val_slices)}")
    
    return train_slices, val_slices, len(time_indices)

In [42]:
ssh_train_enhanced, ssh_val_enhanced, total_slices_used = create_enhanced_ssh_dataset(ssh_norm, max_slices=100, train_ratio=0.8)

Total available time slices: 1826
Using 100 time slices (step size: 18)
Training slices: 80
Validation slices: 20


In [43]:
def create_training_data(ssh_slices, missing_rate=0.15):
    """Create training data with masks for multiple time slices"""
    n_times, height, width = ssh_slices.shape
    
    inputs = []
    targets = []
    
    for i in range(n_times):
        ssh_slice = ssh_slices[i]
        
        
        mask = np.ones_like(ssh_slice)
        missing_pixels = np.random.rand(*ssh_slice.shape) < missing_rate
        mask[missing_pixels] = 0
        
       
        masked_ssh = ssh_slice * mask
        
        
        input_tensor = np.stack([masked_ssh, mask], axis=-1)
        target_tensor = np.stack([ssh_slice, mask], axis=-1)
        
        inputs.append(input_tensor)
        targets.append(target_tensor)
    
    return np.array(inputs), np.array(targets)

In [44]:
X_train_enhanced, y_train_enhanced = create_training_data(ssh_train_enhanced, missing_rate=0.15)
X_val_enhanced, y_val_enhanced = create_training_data(ssh_val_enhanced, missing_rate=0.15)

In [45]:

print(f"Enhanced training data shape: {X_train_enhanced.shape}")
print(f"Enhanced validation data shape: {X_val_enhanced.shape}")
print(f"Total data utilization: {total_slices_used/len(ssh_norm.t)*100:.1f}%")

Enhanced training data shape: (80, 256, 256, 2)
Enhanced validation data shape: (20, 256, 256, 2)
Total data utilization: 5.5%


In [49]:
def build_enhanced_unet_2d(input_shape, dropout_rate=0.2):
    inputs = layers.Input(shape=input_shape)
    
    # Encoder
    c1 = layers.Conv2D(32, (3,3), activation='relu', padding='same')(inputs)
    c1 = layers.Dropout(dropout_rate)(c1)
    c1 = layers.Conv2D(32, (3,3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2,2))(c1)
    
    c2 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(p1)
    c2 = layers.Dropout(dropout_rate)(c2)
    c2 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2,2))(c2)
    
    c3 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(p2)
    c3 = layers.Dropout(dropout_rate)(c3)
    c3 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2,2))(c3)
    
    # Bottleneck
    c4 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(p3)
    c4 = layers.Dropout(dropout_rate)(c4)
    c4 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c4)
    
    # Decoder
    u1 = layers.UpSampling2D((2,2))(c4)
    u1 = layers.Concatenate()([u1, c3])
    c5 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(u1)
    c5 = layers.Dropout(dropout_rate)(c5)
    c5 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c5)
    
    u2 = layers.UpSampling2D((2,2))(c5)
    u2 = layers.Concatenate()([u2, c2])
    c6 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(u2)
    c6 = layers.Dropout(dropout_rate)(c6)
    c6 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c6)
    
    u3 = layers.UpSampling2D((2,2))(c6)
    u3 = layers.Concatenate()([u3, c1])
    c7 = layers.Conv2D(32, (3,3), activation='relu', padding='same')(u3)
    c7 = layers.Dropout(dropout_rate)(c7)
    c7 = layers.Conv2D(32, (3,3), activation='relu', padding='same')(c7)
    
    outputs = layers.Conv2D(1, (1,1), activation='linear')(c7)
    
    return models.Model(inputs, outputs)

In [50]:
import tensorflow.keras.backend as K
def masked_mse(y_true, y_pred):
    """Masked Mean Squared Error loss"""
    mask = y_true[..., 1:2]
    target = y_true[..., :1]
    diff = (target - y_pred) * mask
    return K.sum(K.square(diff)) / (K.sum(mask) + 1e-8)

def masked_mae(y_true, y_pred):
    """Masked Mean Absolute Error"""
    mask = y_true[..., 1:2]
    target = y_true[..., :1]
    diff = K.abs(target - y_pred) * mask
    return K.sum(diff) / (K.sum(mask) + 1e-8)

def r2_metric(y_true, y_pred):
    """R² score for masked data"""
    mask = y_true[..., 1:2]
    target = y_true[..., :1]
    
    y_pred_masked = y_pred * mask
    y_true_masked = target * mask
    
    ss_res = K.sum(K.square(y_true_masked - y_pred_masked))
    ss_tot = K.sum(K.square(y_true_masked - K.mean(y_true_masked)))
    return 1 - (ss_res / (ss_tot + K.epsilon()))

In [51]:
input_shape = X_train_enhanced.shape[1:]
model_enhanced = build_enhanced_unet_2d(input_shape)
model_enhanced.compile(
    optimizer='adam',
    loss=masked_mse,
    metrics=[masked_mae, r2_metric]
)

In [52]:
from tensorflow.keras import callbacks
early_stop_enhanced = callbacks.EarlyStopping(
    monitor='val_loss',
    patience=30,  # Increased patience for larger dataset
    restore_best_weights=True,
    verbose=1
)

reduce_lr_enhanced = callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=15,  # Increased patience
    min_lr=1e-8,
    verbose=1
)


In [53]:
print("Training with enhanced dataset...")
history_enhanced = model_enhanced.fit(
    X_train_enhanced, y_train_enhanced,
    validation_data=(X_val_enhanced, y_val_enhanced),
    epochs=150,  # Increased epochs for larger dataset
    batch_size=16,  # Increased batch size for efficiency
    callbacks=[early_stop_enhanced, reduce_lr_enhanced],
    verbose=1
)

Training with enhanced dataset...
Epoch 1/150


ResourceExhaustedError: Graph execution error:

Detected at node 'model_3/conv2d_58/Relu' defined at (most recent call last):
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\ipykernel_launcher.py", line 18, in <module>
      app.launch_new_instance()
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
      app.start()
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\ipykernel\kernelapp.py", line 739, in start
      self.io_loop.start()
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\tornado\platform\asyncio.py", line 211, in start
      self.asyncio_loop.run_forever()
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\asyncio\base_events.py", line 601, in run_forever
      self._run_once()
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\asyncio\base_events.py", line 1905, in _run_once
      handle._run()
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue
      await self.process_one()
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\ipykernel\kernelbase.py", line 534, in process_one
      await dispatch(*args)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell
      await result
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\ipykernel\ipkernel.py", line 362, in execute_request
      await super().execute_request(stream, ident, parent)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\ipykernel\kernelbase.py", line 778, in execute_request
      reply_content = await reply_content
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\ipykernel\ipkernel.py", line 449, in do_execute
      res = shell.run_cell(
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\IPython\core\interactiveshell.py", line 3048, in run_cell
      result = self._run_cell(
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\IPython\core\interactiveshell.py", line 3103, in _run_cell
      result = runner(coro)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\IPython\core\interactiveshell.py", line 3308, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\IPython\core\interactiveshell.py", line 3490, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\IPython\core\interactiveshell.py", line 3550, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\ABHISHEK\AppData\Local\Temp\ipykernel_15824\1064350715.py", line 2, in <module>
      history_enhanced = model_enhanced.fit(
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\engine\training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\engine\training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\engine\training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\engine\training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\engine\training.py", line 993, in train_step
      y_pred = self(x, training=True)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\engine\functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\engine\functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\layers\convolutional\base_conv.py", line 314, in call
      return self.activation(outputs)
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\activations.py", line 317, in relu
      return backend.relu(
    File "C:\Users\ABHISHEK\anaconda3\envs\ssh_prediction\lib\site-packages\keras\backend.py", line 5366, in relu
      x = tf.nn.relu(x)
Node: 'model_3/conv2d_58/Relu'
OOM when allocating tensor with shape[16,192,128,128] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node model_3/conv2d_58/Relu}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_4731]

In [None]:
def simple_slice_comparison(model, ssh_slices, ssh_mean, ssh_std, n_comparisons=4):
    """
    Simple visualization of training slices vs true slices
    """
    
    fig, axes = plt.subplots(3, n_comparisons, figsize=(16, 12))
    
    for i in range(n_comparisons):
        # Create test data for this slice
        X_test, y_test = create_training_data(ssh_slices[i:i+1], missing_rate=0.15)
        
        # Make prediction
        pred = model.predict(X_test, verbose=0)[0, :, :, 0]
        
        # Extract components
        original_slice = ssh_slices[i]
        masked_input = X_test[0, :, :, 0]
        mask = X_test[0, :, :, 1]
        
        # Denormalize for visualization
        original_denorm = original_slice * ssh_std.values + ssh_mean.values
        masked_denorm = masked_input * ssh_std.values + ssh_mean.values
        pred_denorm = pred * ssh_std.values + ssh_mean.values
        
        # Set consistent color scale
        vmin, vmax = original_denorm.min(), original_denorm.max()
        
        # Row 1: TRUE SSH
        im1 = axes[0, i].imshow(original_denorm, cmap='RdYlBu_r', vmin=vmin, vmax=vmax)
        axes[0, i].set_title(f'TRUE SSH - Slice {i+1}', fontweight='bold')
        plt.colorbar(im1, ax=axes[0, i], fraction=0.046)
        
        # Row 2: MASKED INPUT (what model receives for training)
        masked_display = np.where(mask == 1, masked_denorm, np.nan)
        im2 = axes[1, i].imshow(masked_display, cmap='RdYlBu_r', vmin=vmin, vmax=vmax)
        axes[1, i].set_title(f'TRAINING INPUT - Slice {i+1}\n(15% missing)', fontweight='bold')
        plt.colorbar(im2, ax=axes[1, i], fraction=0.046)
        
        # Row 3: MODEL PREDICTION
        im3 = axes[2, i].imshow(pred_denorm, cmap='RdYlBu_r', vmin=vmin, vmax=vmax)
        axes[2, i].set_title(f'PREDICTION - Slice {i+1}', fontweight='bold')
        plt.colorbar(im3, ax=axes[2, i], fraction=0.046)
        
        # Calculate and print metrics
        observed_mask = mask == 1
        if np.sum(observed_mask) > 0:
            r2 = r2_score(original_denorm[observed_mask], pred_denorm[observed_mask])
            rmse = np.sqrt(mean_squared_error(original_denorm[observed_mask], pred_denorm[observed_mask]))
            print(f"Slice {i+1} - R²: {r2:.4f}, RMSE: {rmse:.4f}m")
    
    # Add row labels
    axes[0, 0].set_ylabel('TRUE SSH', fontsize=12, fontweight='bold')
    axes[1, 0].set_ylabel('TRAINING INPUT', fontsize=12, fontweight='bold')
    axes[2, 0].set_ylabel('PREDICTION', fontsize=12, fontweight='bold')
    
    plt.suptitle('SSH Inpainting: True vs Training Slices Comparison', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

In [None]:
# Enhanced analysis function (your existing function modified)
def plot_essential_analysis(history, model, X_val, y_val, ssh_mean, ssh_std):
    """Create focused plots for essential model analysis"""
    
    # Make predictions
    predictions = model.predict(X_val)
    
    # Select first validation sample for detailed analysis
    sample_idx = 0
    true_ssh = y_val[sample_idx, :, :, 0]
    pred_ssh = predictions[sample_idx, :, :, 0]
    mask = y_val[sample_idx, :, :, 1]
    
    # Denormalize data
    true_ssh_denorm = true_ssh * ssh_std.values + ssh_mean.values
    pred_ssh_denorm = pred_ssh * ssh_std.values + ssh_mean.values
    
    # Only use observed points (where mask = 1)
    observed_mask = mask == 1
    true_obs = true_ssh_denorm[observed_mask]
    pred_obs = pred_ssh_denorm[observed_mask]
    
    # Calculate comprehensive metrics
    r2 = r2_score(true_obs, pred_obs)
    rmse = np.sqrt(mean_squared_error(true_obs, pred_obs))
    mae = mean_absolute_error(true_obs, pred_obs)
    
    # Calculate different accuracy thresholds
    residuals = pred_obs - true_obs
    accuracy_1pct = np.mean(np.abs(residuals) < 0.01 * np.abs(true_obs)) * 100
    accuracy_5pct = np.mean(np.abs(residuals) < 0.05 * np.abs(true_obs)) * 100
    accuracy_10pct = np.mean(np.abs(residuals) < 0.10 * np.abs(true_obs)) * 100
    
    # Correlation coefficient
    correlation = np.corrcoef(true_obs, pred_obs)[0, 1]
    
    # Create subplot layout (2x2)
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. Loss Function Analysis (Overfitting Check)
    ax1 = axes[0, 0]
    ax1.plot(history.history['loss'], label='Training Loss', linewidth=2.5, color='blue')
    ax1.plot(history.history['val_loss'], label='Validation Loss', linewidth=2.5, color='red')
    ax1.set_title('Loss Function - Overfitting Check', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Masked MSE Loss', fontsize=12)
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Add overfitting assessment
    final_train_loss = history.history['loss'][-1]
    final_val_loss = history.history['val_loss'][-1]
    overfitting_gap = final_val_loss - final_train_loss
    
    if overfitting_gap > 0.1:
        overfitting_status = "Potential Overfitting"
        color = 'red'
    elif overfitting_gap > 0.05:
        overfitting_status = "Slight Overfitting"
        color = 'orange'
    else:
        overfitting_status = "Good Generalization"
        color = 'green'
    
    ax1.text(0.02, 0.98, f'Status: {overfitting_status}\nGap: {overfitting_gap:.4f}', 
             transform=ax1.transAxes, fontsize=10, verticalalignment='top',
             bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.3))
    
    # 2. Learning Rate Schedule
    ax2 = axes[0, 1]
    if 'lr' in history.history:
        ax2.plot(history.history['lr'], linewidth=2.5, color='green')
        ax2.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
        ax2.set_xlabel('Epoch', fontsize=12)
        ax2.set_ylabel('Learning Rate', fontsize=12)
        ax2.set_yscale('log')
        ax2.grid(True, alpha=0.3)
    else:
        ax2.text(0.5, 0.5, 'Learning Rate\nHistory Not Available', 
                transform=ax2.transAxes, ha='center', va='center', 
                fontsize=12, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))
        ax2.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    
    # 3. Prediction vs True SSH Scatter Plot
    ax3 = axes[1, 0]
    ax3.scatter(true_obs, pred_obs, alpha=0.6, s=15, c='blue', edgecolors='none')
    
    # Perfect prediction line
    min_val = min(true_obs.min(), pred_obs.min())
    max_val = max(true_obs.max(), pred_obs.max())
    ax3.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2.5, label='Perfect Prediction')
    
    ax3.set_xlabel('True SSH (m)', fontsize=12)
    ax3.set_ylabel('Predicted SSH (m)', fontsize=12)
    ax3.set_title('Prediction vs True SSH\n(Observed Points Only)', fontsize=14, fontweight='bold')
    ax3.legend(fontsize=11)
    ax3.grid(True, alpha=0.3)
    ax3.set_aspect('equal', adjustable='box')
    
    # 4. Performance Metrics Display
    ax4 = axes[1, 1]
    ax4.axis('off')
    
    # Final training metrics
    final_train_r2 = history.history['r2_metric'][-1] if 'r2_metric' in history.history else 'N/A'
    final_val_r2 = history.history['val_r2_metric'][-1] if 'val_r2_metric' in history.history else 'N/A'
    
    metrics_text = f"""
    **PERFORMANCE METRICS**
    
    ═══════════════════════════
    
    **R² Score:** {r2:.4f}
    **RMSE:** {rmse:.4f} m
    **MAE:** {mae:.4f} m
    **Correlation:** {correlation:.4f}
    
    **Accuracy Metrics:**
    • Within ±1%: {accuracy_1pct:.1f}%
    • Within ±5%: {accuracy_5pct:.1f}%
    • Within ±10%: {accuracy_10pct:.1f}%
    
    **Training Results:**
    • Final Train Loss: {final_train_loss:.4f}
    • Final Val Loss: {final_val_loss:.4f}
    • Final Train R²: {final_train_r2:.4f}
    • Final Val R²: {final_val_r2:.4f}
    
    **Data Info:**
    • Observed Points: {len(true_obs):,}
    • SSH Mean: {ssh_mean.values:.4f} m
    • SSH Std: {ssh_std.values:.4f} m
    """
    
    # Color-code the metrics based on performance
    if r2 > 0.9:
        bg_color = "lightgreen"
    elif r2 > 0.8:
        bg_color = "lightyellow"
    else:
        bg_color = "lightcoral"
    
    ax4.text(0.05, 0.95, metrics_text, transform=ax4.transAxes, fontsize=11,
             verticalalignment='top', fontfamily='monospace',
             bbox=dict(boxstyle="round,pad=0.4", facecolor=bg_color, alpha=0.8))
    
    ax4.set_title('Model Performance Summary', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    return r2, rmse, mae, accuracy_5pct, correlation

In [None]:
# Apply analysis to enhanced model
r2_enh, rmse_enh, mae_enh, accuracy_enh, correlation_enh = plot_essential_analysis(
    history_enhanced, model_enhanced, X_val_enhanced, y_val_enhanced, ssh_mean, ssh_std
)

# Use simple slice comparison with enhanced training data
simple_slice_comparison(model_enhanced, ssh_train_enhanced[:4], ssh_mean, ssh_std)

# Print final summary
print("="*60)
print("                    ENHANCED MODEL SUMMARY")
print("="*60)
print(f"Training Data Utilization: {total_slices_used}/{len(ssh_norm.t)} ({total_slices_used/len(ssh_norm.t)*100:.1f}%)")
print(f"Training slices: {len(ssh_train_enhanced)}")
print(f"Validation slices: {len(ssh_val_enhanced)}")
print(f"R² Score:           {r2_enh:.4f}")
print(f"RMSE:              {rmse_enh:.4f} m")
print(f"MAE:               {mae_enh:.4f} m")
print(f"Accuracy (±5%):    {accuracy_enh:.2f}%")
print(f"Correlation:       {correlation_enh:.4f}")
print("="*60)