In [None]:
# !pip install pykrige
from pykrige.ok import OrdinaryKriging
from pykrige.uk import UniversalKriging
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.layers import (
    Input, Dense, Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D, Flatten, Reshape, Dropout, LSTM,
    RepeatVector
)
from keras.models import Model
from keras import backend as K
from keras.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from tensorflow.keras.models import load_model
from tqdm import tqdm
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D, concatenate, Input, ZeroPadding2D, Cropping2D
from tensorflow.keras.models import Model
import itertools
from math import factorial, comb
import random
import os

# ===============================
# 1. Load data, split train/test sets, and extract required channels
# ===============================
out_dir = './dataset'
method = 0
model_dir = './model_standard/'
metric_dir = './metrics_mean/'
pred_dir = './pred_mean/'

## Utilities

In [None]:
def visualize_results_jet_contour(x_data, y_data, pred_data, mask_data, sample_index=0, num_samples=8,
                      arrow_scale=15, arrow_width=0.008, arrow_headwidth=4, arrow_headlength=5,
                      skip=1, hspace=0.1, wspace=0.2):
    """
    Visualize U-Net results with vector fields showing velocity direction and magnitude, including error analysis.

    Parameters:
      x_data: Input data (N, 15, 15, 3)
      y_data: Ground truth (N, 15, 15, 2)
      pred_data: U-Net prediction results (N, 15, 15, 2)
      mask_data: Sensor mask data (N, 15, 15, 3)
      sample_index: Starting sample index
      num_samples: Number of samples to visualize (default: 8)
      arrow_scale: Scale parameter for quiver, smaller value gives larger arrows (default: 15)
      arrow_width: Width of arrow shaft (default: 0.008)
      arrow_headwidth: Width of arrow head as multiple of shaft width (default: 4)
      arrow_headlength: Length of arrow head as multiple of shaft width (default: 5)
      skip: Number of grid points to skip for clearer visualization (default: 1)
      hspace: Height space between subplot rows (default: 0.1)
      wspace: Width space between subplot columns (default: 0.2)
    """
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib.colors import Normalize

    # Ensure we don't exceed dataset size
    max_samples = min(num_samples, len(x_data) - sample_index)

    # Create figure with 4 rows and num_samples columns
    # Add gridspec_kw to control spacing between subplots
    fig, axs = plt.subplots(3, max_samples, figsize=(max_samples*3, 12),
                           gridspec_kw={'hspace': hspace, 'wspace': wspace})

    # If only one sample, make axs 2D
    if max_samples == 1:
        axs = axs.reshape(3, 1)

    # Use jet colormap as in MATLAB code
    cmap = 'jet'
    error_cmap = 'hot'

    # Calculate max error across all samples for consistent color scale
    error_max = 0
    for i in range(max_samples):
        idx = sample_index + i
        if idx >= len(x_data):
            break

        # Calculate velocity vector error magnitude
        u_true = y_data[idx, :, :, 0]
        v_true = y_data[idx, :, :, 1]
        u_pred = pred_data[idx, :, :, 0]
        v_pred = pred_data[idx, :, :, 1]

        error_magnitude = np.sqrt((u_pred - u_true)**2 + (v_pred - v_true)**2)
        error_max = max(error_max, np.max(error_magnitude))

    # Plot each sample
    for i in range(max_samples):
        idx = sample_index + i
        if idx >= len(x_data):
            break

        # Get sensor positions
        sensor_mask = mask_data[idx, :, :, 2]
        sensor_positions = np.where(sensor_mask == 1)

        # Create grid
        y, x = np.mgrid[0:x_data.shape[1], 0:x_data.shape[2]]

        # Row 1: Interpolated inputs (Observation)
        u_interp = x_data[idx, :, :, 0]
        v_interp = x_data[idx, :, :, 1]

        # Calculate velocity magnitude
        magnitude_interp = np.sqrt(u_interp**2 + v_interp**2)

        # Use same max value for all velocity fields to normalize
        max_magnitude = max(np.max(magnitude_interp),
                            np.max(np.sqrt(pred_data[idx,:,:,0]**2 + pred_data[idx,:,:,1]**2)),
                            np.max(np.sqrt(y_data[idx,:,:,0]**2 + y_data[idx,:,:,1]**2)))

        # Plot vector field
        axs[0, i].set_title(f'Sample {idx}')
        norm = Normalize(vmin=0, vmax=max_magnitude)

        # Add contour plot
        contour = axs[0, i].contourf(x, y, magnitude_interp,
                                    levels=15, cmap=cmap, norm=norm, alpha=0.7)

        # quiver = axs[0, i].quiver(x[::skip, ::skip], y[::skip, ::skip],
        #                  u_interp[::skip, ::skip], v_interp[::skip, ::skip],
        #                  magnitude_interp[::skip, ::skip],
        #                  cmap=cmap, norm=norm,
        #                  scale=arrow_scale,
        #                  width=arrow_width,
        #                  headwidth=arrow_headwidth,
        #                  headlength=arrow_headlength)
        quiver = axs[0, i].quiver(x[::skip, ::skip], y[::skip, ::skip],
                  u_interp[::skip, ::skip], v_interp[::skip, ::skip],
                  scale=arrow_scale,
                  width=arrow_width,
                  headwidth=arrow_headwidth,
                  headlength=arrow_headlength,
                  color='black')

        if i == 0:
            axs[0, i].set_ylabel('Interpolated Input', rotation=90, labelpad=15, va='center')

        # Plot sensor positions
        axs[0, i].scatter(sensor_positions[1], sensor_positions[0], c='black', s=10, marker='o')
        axs[0, i].set_aspect('equal')
        axs[0, i].set_xticks([])
        axs[0, i].set_yticks([])

        # Add colorbar
        cbar = fig.colorbar(quiver, ax=axs[0, i], fraction=0.046, pad=0.04)
        #cbar.set_label('Velocity Magnitude')

        # Row 2: Predictions
        u_pred = pred_data[idx, :, :, 0]
        v_pred = pred_data[idx, :, :, 1]
        magnitude_pred = np.sqrt(u_pred**2 + v_pred**2)

        # Add contour plot
        contour = axs[1, i].contourf(x, y, magnitude_pred,
                                    levels=15, cmap=cmap, norm=norm, alpha=0.7)

        # quiver = axs[1, i].quiver(x[::skip, ::skip], y[::skip, ::skip],
        #                  u_pred[::skip, ::skip], v_pred[::skip, ::skip],
        #                  magnitude_pred[::skip, ::skip],
        #                  cmap=cmap, norm=norm,
        #                  scale=arrow_scale,
        #                  width=arrow_width,
        #                  headwidth=arrow_headwidth,
        #                  headlength=arrow_headlength)
        quiver = axs[1, i].quiver(x[::skip, ::skip], y[::skip, ::skip],
                  u_pred[::skip, ::skip], v_pred[::skip, ::skip],
                  scale=arrow_scale,
                  width=arrow_width,
                  headwidth=arrow_headwidth,
                  headlength=arrow_headlength,
                  color='black')

        if i == 0:
            axs[1, i].set_ylabel('Prediction', rotation=90, labelpad=15, va='center')

        axs[1, i].set_aspect('equal')
        axs[1, i].set_xticks([])
        axs[1, i].set_yticks([])
        cbar = fig.colorbar(quiver, ax=axs[1, i], fraction=0.046, pad=0.04)
        #cbar.set_label('Velocity Magnitude')

        # Row 3: Ground truth
        u_true = y_data[idx, :, :, 0]
        v_true = y_data[idx, :, :, 1]
        magnitude_true = np.sqrt(u_true**2 + v_true**2)

        # Add contour plot
        contour = axs[2, i].contourf(x, y, magnitude_true,
                                    levels=15, cmap=cmap, norm=norm, alpha=0.7)

        # quiver = axs[2, i].quiver(x[::skip, ::skip], y[::skip, ::skip],
        #                  u_true[::skip, ::skip], v_true[::skip, ::skip],
        #                  magnitude_true[::skip, ::skip],
        #                  cmap=cmap, norm=norm,
        #                  scale=arrow_scale,
        #                  width=arrow_width,
        #                  headwidth=arrow_headwidth,
        #                  headlength=arrow_headlength)
        quiver = axs[2, i].quiver(x[::skip, ::skip], y[::skip, ::skip],
                  u_true[::skip, ::skip], v_true[::skip, ::skip],
                  scale=arrow_scale,
                  width=arrow_width,
                  headwidth=arrow_headwidth,
                  headlength=arrow_headlength,
                  color='black')

        if i == 0:
            axs[2, i].set_ylabel('Ground Truth', rotation=90, labelpad=15, va='center')

        axs[2, i].set_aspect('equal')
        axs[2, i].set_xticks([])
        axs[2, i].set_yticks([])
        cbar = fig.colorbar(quiver, ax=axs[2, i], fraction=0.046, pad=0.04)
        #cbar.set_label('Velocity Magnitude')

        # Row 4: Error visualization
        # error_magnitude = np.sqrt((u_pred - u_true)**2 + (v_pred - v_true)**2)
        # im4 = axs[3, i].imshow(error_magnitude, cmap=error_cmap, vmin=0, vmax=error_max)

        # if i == 0:
        #     axs[3, i].set_ylabel('Error', rotation=90, labelpad=15, va='center')

        # axs[3, i].set_aspect('equal')
        # axs[3, i].set_xticks([])
        # axs[3, i].set_yticks([])
        # cbar = fig.colorbar(im4, ax=axs[3, i], fraction=0.046, pad=0.04)
        #cbar.set_label('Error Magnitude')

    # Adjust the layout (alternative to tight_layout which sometimes doesn't work well with colorbar)
    plt.tight_layout()
    plt.show()

In [None]:
def visualize_results_jet(x_data, y_data, pred_data, mask_data, sample_index=0, num_samples=8,
                      arrow_scale=15, arrow_width=0.008, arrow_headwidth=4, arrow_headlength=5,
                      skip=1, hspace=0.1, wspace=0.2):
    """
    Visualize U-Net results with vector fields showing velocity direction and magnitude, including error analysis.

    Parameters:
      x_data: Input data (N, 15, 15, 3)
      y_data: Ground truth (N, 15, 15, 2)
      pred_data: U-Net prediction results (N, 15, 15, 2)
      mask_data: Sensor mask data (N, 15, 15, 3)
      sample_index: Starting sample index
      num_samples: Number of samples to visualize (default: 8)
      arrow_scale: Scale parameter for quiver, smaller value gives larger arrows (default: 15)
      arrow_width: Width of arrow shaft (default: 0.008)
      arrow_headwidth: Width of arrow head as multiple of shaft width (default: 4)
      arrow_headlength: Length of arrow head as multiple of shaft width (default: 5)
      skip: Number of grid points to skip for clearer visualization (default: 1)
      hspace: Height space between subplot rows (default: 0.1)
      wspace: Width space between subplot columns (default: 0.2)
    """
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib.colors import Normalize

    # Ensure we don't exceed dataset size
    max_samples = min(num_samples, len(x_data) - sample_index)

    # Create figure with 4 rows and num_samples columns
    # Add gridspec_kw to control spacing between subplots
    fig, axs = plt.subplots(3, max_samples, figsize=(max_samples*3, 12),
                           gridspec_kw={'hspace': hspace, 'wspace': wspace})

    # If only one sample, make axs 2D
    if max_samples == 1:
        axs = axs.reshape(3, 1)

    # Use jet colormap as in MATLAB code
    cmap = 'jet'
    error_cmap = 'hot'

    # Calculate max error across all samples for consistent color scale
    error_max = 0
    for i in range(max_samples):
        idx = sample_index + i
        if idx >= len(x_data):
            break

        # Calculate velocity vector error magnitude
        u_true = y_data[idx, :, :, 0]
        v_true = y_data[idx, :, :, 1]
        u_pred = pred_data[idx, :, :, 0]
        v_pred = pred_data[idx, :, :, 1]

        error_magnitude = np.sqrt((u_pred - u_true)**2 + (v_pred - v_true)**2)
        error_max = max(error_max, np.max(error_magnitude))

    # Plot each sample
    for i in range(max_samples):
        idx = sample_index + i
        if idx >= len(x_data):
            break

        # Get sensor positions
        sensor_mask = mask_data[idx, :, :, 2]
        sensor_positions = np.where(sensor_mask == 1)

        # Create grid
        y, x = np.mgrid[0:x_data.shape[1], 0:x_data.shape[2]]

        # Row 1: Interpolated inputs (Observation)
        u_interp = x_data[idx, :, :, 0]
        v_interp = x_data[idx, :, :, 1]

        # Calculate velocity magnitude
        magnitude_interp = np.sqrt(u_interp**2 + v_interp**2)

        # Use same max value for all velocity fields to normalize
        max_magnitude = max(np.max(magnitude_interp),
                            np.max(np.sqrt(pred_data[idx,:,:,0]**2 + pred_data[idx,:,:,1]**2)),
                            np.max(np.sqrt(y_data[idx,:,:,0]**2 + y_data[idx,:,:,1]**2)))

        # Plot vector field
        axs[0, i].set_title(f'Sample {idx}')
        norm = Normalize(vmin=0, vmax=max_magnitude)

        quiver = axs[0, i].quiver(x[::skip, ::skip], y[::skip, ::skip],
                         u_interp[::skip, ::skip], v_interp[::skip, ::skip],
                         magnitude_interp[::skip, ::skip],
                         cmap=cmap, norm=norm,
                         scale=arrow_scale,
                         width=arrow_width,
                         headwidth=arrow_headwidth,
                         headlength=arrow_headlength)

        if i == 0:
            axs[0, i].set_ylabel('Interpolated Input', rotation=90, labelpad=15, va='center')

        # Plot sensor positions
        axs[0, i].scatter(sensor_positions[1], sensor_positions[0], c='black', s=10, marker='o')
        axs[0, i].set_aspect('equal')
        axs[0, i].set_xticks([])
        axs[0, i].set_yticks([])

        # Add colorbar
        cbar = fig.colorbar(quiver, ax=axs[0, i], fraction=0.046, pad=0.04)
        #cbar.set_label('Velocity Magnitude')

        # Row 2: Predictions
        u_pred = pred_data[idx, :, :, 0]
        v_pred = pred_data[idx, :, :, 1]
        magnitude_pred = np.sqrt(u_pred**2 + v_pred**2)

        quiver = axs[1, i].quiver(x[::skip, ::skip], y[::skip, ::skip],
                         u_pred[::skip, ::skip], v_pred[::skip, ::skip],
                         magnitude_pred[::skip, ::skip],
                         cmap=cmap, norm=norm,
                         scale=arrow_scale,
                         width=arrow_width,
                         headwidth=arrow_headwidth,
                         headlength=arrow_headlength)

        if i == 0:
            axs[1, i].set_ylabel('Prediction', rotation=90, labelpad=15, va='center')

        axs[1, i].set_aspect('equal')
        axs[1, i].set_xticks([])
        axs[1, i].set_yticks([])
        cbar = fig.colorbar(quiver, ax=axs[1, i], fraction=0.046, pad=0.04)
        #cbar.set_label('Velocity Magnitude')

        # Row 3: Ground truth
        u_true = y_data[idx, :, :, 0]
        v_true = y_data[idx, :, :, 1]
        magnitude_true = np.sqrt(u_true**2 + v_true**2)

        quiver = axs[2, i].quiver(x[::skip, ::skip], y[::skip, ::skip],
                         u_true[::skip, ::skip], v_true[::skip, ::skip],
                         magnitude_true[::skip, ::skip],
                         cmap=cmap, norm=norm,
                         scale=arrow_scale,
                         width=arrow_width,
                         headwidth=arrow_headwidth,
                         headlength=arrow_headlength)

        if i == 0:
            axs[2, i].set_ylabel('Ground Truth', rotation=90, labelpad=15, va='center')

        axs[2, i].set_aspect('equal')
        axs[2, i].set_xticks([])
        axs[2, i].set_yticks([])
        cbar = fig.colorbar(quiver, ax=axs[2, i], fraction=0.046, pad=0.04)
        #cbar.set_label('Velocity Magnitude')

        # Row 4: Error visualization
        # error_magnitude = np.sqrt((u_pred - u_true)**2 + (v_pred - v_true)**2)
        # im4 = axs[3, i].imshow(error_magnitude, cmap=error_cmap, vmin=0, vmax=error_max)

        # if i == 0:
        #     axs[3, i].set_ylabel('Error', rotation=90, labelpad=15, va='center')

        # axs[3, i].set_aspect('equal')
        # axs[3, i].set_xticks([])
        # axs[3, i].set_yticks([])
        # cbar = fig.colorbar(im4, ax=axs[3, i], fraction=0.046, pad=0.04)
        #cbar.set_label('Error Magnitude')

    # Adjust the layout (alternative to tight_layout which sometimes doesn't work well with colorbar)
    plt.tight_layout()
    plt.show()

In [None]:
from keras.saving import register_keras_serializable
@register_keras_serializable()
def weighted_vector_loss(y_true, y_pred):
    """
    Robust loss function optimized for wind direction and high wind speed, with enhanced numerical stability
    """
    import tensorflow as tf
    import numpy as np

    # Set larger numerical stability constant
    epsilon = 1e-6

    # Extract U and V components
    true_u = y_true[..., 0]
    true_v = y_true[..., 1]
    pred_u = y_pred[..., 0]
    pred_v = y_pred[..., 1]

    # Calculate wind speed magnitude - added stronger numerical protection
    true_magnitude = tf.sqrt(tf.square(true_u) + tf.square(true_v) + epsilon)
    pred_magnitude = tf.sqrt(tf.square(pred_u) + tf.square(pred_v) + epsilon)

    # 1. MSE loss for wind speed magnitude
    magnitude_loss = tf.square(true_magnitude - pred_magnitude)

    # 2. Wind speed weight - use bounded nonlinear function instead of exponential
    # Calculate relative wind speed magnitude
    max_magnitude = tf.reduce_max(true_magnitude, axis=[1, 2], keepdims=True)
    min_magnitude = tf.reduce_min(true_magnitude, axis=[1, 2], keepdims=True)

    # Ensure denominator is not zero and normalize to [0,1] range
    magnitude_range = tf.maximum(max_magnitude - min_magnitude, epsilon)
    normalized_magnitude = tf.clip_by_value(
        (true_magnitude - min_magnitude) / magnitude_range,
        0.0,
        1.0
    )

    # Use more stable bounded function: 1 + 5*x^2, result range is [1,6]
    # This is safer than exponential function but still gives higher weight to high wind speed regions
    magnitude_weights = 1.0 + 5.0 * tf.square(normalized_magnitude)

    # Check and report any potential NaN
    magnitude_weights = tf.debugging.check_numerics(
        magnitude_weights,
        "NaN/Inf found in magnitude_weights"
    )

    # 3. Wind direction loss - improve vector normalization stability
    # Ensure denominator is large enough
    true_norm = tf.maximum(true_magnitude, epsilon)
    pred_norm = tf.maximum(pred_magnitude, epsilon)

    # Normalize vectors
    true_u_norm = true_u / true_norm
    true_v_norm = true_v / true_norm
    pred_u_norm = pred_u / pred_norm
    pred_v_norm = pred_v / pred_norm

    # Calculate cosine similarity and enforce range limits
    cos_similarity = true_u_norm * pred_u_norm + true_v_norm * pred_v_norm
    cos_similarity_safe = tf.clip_by_value(cos_similarity, -1.0 + epsilon, 1.0 - epsilon)

    # 4. Specifically penalize direction errors
    # Method 1: Use 1-cos(theta) directly as direction loss, range is [0,2]
    # This is more stable than acos and still penalizes opposite directions
    direction_loss_simple = 1.0 - cos_similarity_safe

    # Method 2: Additional weighting for negative directions (>90 degrees), safer implementation
    # When cos_similarity<0 (angle>90 degrees), add additional penalty
    direction_loss_weighted = tf.where(
        cos_similarity_safe < 0,
        2.0 - 2.0 * cos_similarity_safe,  # For [0,-1] range, map to [2,4]
        1.0 - cos_similarity_safe         # For [1,0] range, map to [0,1]
    )

    # Use safer weighting scheme
    direction_loss = direction_loss_weighted

    # 5. Wind direction weight - the higher the wind speed, the more important the direction
    # Use simple linear mapping, range is [0.3,1.0]
    direction_weight = 0.3 + 0.7 * normalized_magnitude

    # 6. Combined loss - use separate weight control instead of multiplication to avoid weight explosion
    weighted_magnitude_loss = magnitude_weights * magnitude_loss
    weighted_direction_loss = direction_weight * direction_loss

    # Set component weights
    lambda_magnitude = 1.0
    lambda_direction = 2.0  # Increase overall weight of direction loss

    # Calculate final loss and add monitoring
    final_magnitude_loss = tf.reduce_mean(weighted_magnitude_loss)
    final_direction_loss = tf.reduce_mean(weighted_direction_loss)

    # Verify results have no NaN
    final_magnitude_loss = tf.debugging.check_numerics(
        final_magnitude_loss,
        "NaN/Inf found in final_magnitude_loss"
    )
    final_direction_loss = tf.debugging.check_numerics(
        final_direction_loss,
        "NaN/Inf found in final_direction_loss"
    )

    total_loss = lambda_magnitude * final_magnitude_loss + lambda_direction * final_direction_loss
    return total_loss

In [None]:
def apply_kriging_reconstruction(x_test, y_test, method='ordinary'):
    """
    Reconstruct flow field using Kriging method, optimize correlation length using experimental variogram fitting
    """
    import numpy as np
    from tqdm import tqdm
    from pykrige.ok import OrdinaryKriging
    from pykrige.uk import UniversalKriging
    from scipy.optimize import curve_fit
    from scipy.spatial.distance import pdist, squareform

    print(f"Reconstructing flow field using {method} Kriging method...")

    # Get grid dimensions
    h, w = x_test.shape[1:3]
    grid_x = np.arange(0, w, 1.0)
    grid_y = np.arange(0, h, 1.0)

    # Define variogram models to test
    variogram_models = ['gaussian', 'spherical', 'exponential']

    # Define theoretical variogram models
    def gaussian_variogram(h, range_param, sill, nugget):
        # Gaussian variogram model
        return nugget + sill * (1 - np.exp(-(h**2) / (range_param**2)))

    def spherical_variogram(h, range_param, sill, nugget):
        # Spherical variogram model
        result = np.zeros_like(h, dtype=float)
        mask = h <= range_param
        result[mask] = nugget + sill * ((3*h[mask])/(2*range_param) - (h[mask]**3)/(2*range_param**3))
        result[~mask] = nugget + sill
        return result

    def exponential_variogram(h, range_param, sill, nugget):
        # Exponential variogram model
        return nugget + sill * (1 - np.exp(-h/range_param))

                    # Calculate experimental variogram
    def calculate_experimental_variogram(x, y, values, n_lags=20, max_dist=None):
        # Combine coordinates
        coords = np.column_stack((x, y))

        # Calculate distances between point pairs
        distances = squareform(pdist(coords, 'euclidean'))

        # Set maximum distance and lag width
        if max_dist is None:
            max_dist = np.max(distances) / 2

        lag_width = max_dist / n_lags
        lags = np.arange(lag_width/2, max_dist, lag_width)

        # Initialize variogram values and point pair counts
        gamma = np.zeros(len(lags))
        counts = np.zeros(len(lags))

        # Calculate semivariance
        n = len(values)
        for i in range(n):
            for j in range(i+1, n):
                dist = distances[i, j]
                if dist <= max_dist:
                    # Determine lag index
                    lag_idx = int(dist / lag_width)
                    if lag_idx < len(lags):
                        # Calculate squared value difference
                        val_diff = (values[i] - values[j])**2
                        gamma[lag_idx] += val_diff
                        counts[lag_idx] += 1

        # Calculate average semivariance
        valid_idx = counts > 0
        gamma[valid_idx] /= (2 * counts[valid_idx])

        return lags[valid_idx], gamma[valid_idx]

    # Fit theoretical variogram
    def fit_variogram_model(lags, gamma, model_type):
        try:
            if model_type == 'gaussian':
                model_func = gaussian_variogram
            elif model_type == 'spherical':
                model_func = spherical_variogram
            else:  # exponential
                model_func = exponential_variogram

            # Initial parameter guess: range, sill, nugget
            initial_guess = [np.mean(lags), np.max(gamma), 0.0]

            # Fit variogram
            params, _ = curve_fit(model_func, lags, gamma, p0=initial_guess,
                                 bounds=([0.01, 0.01, 0], [np.max(lags)*2, np.max(gamma)*2, np.max(gamma)*0.5]),
                                 maxfev=1000)

            # Calculate fitting error
            fitted_values = model_func(lags, *params)
            mse = np.mean((gamma - fitted_values)**2)

            return params, mse
        except Exception as e:
            print(f"Failed to fit {model_type} model: {e}")
            return [3.0, 1.0, 0.0], float('inf')  # Return default parameters and infinite error

    # Sample selection
    sample_size = min(100, len(x_test))
    sample_indices = np.random.choice(len(x_test), sample_size, replace=False)

    print("Finding optimal parameters using experimental variogram fitting method...")

    best_params_u = {'model': 'gaussian', 'range': 3.0, 'sill': 1.0, 'nugget': 0.0}
    best_params_v = {'model': 'gaussian', 'range': 3.0, 'sill': 1.0, 'nugget': 0.0}
    best_mse_u = float('inf')
    best_mse_v = float('inf')

    # Calculate variogram and fit for each sample
    for i in tqdm(sample_indices, desc="Variogram fitting progress"):
        sensor_mask = x_test[i, :, :, 2]
        sensor_positions = np.where(sensor_mask == 1)
        x_coords = sensor_positions[1]
        y_coords = sensor_positions[0]

        if len(x_coords) < 10:  # Ensure sufficient points for variogram calculation
            continue

        for c in range(2):
            values = y_test[i, y_coords, x_coords, c]

            try:
                # Calculate experimental variogram
                lags, gamma = calculate_experimental_variogram(x_coords, y_coords, values)

                if len(lags) < 3:  # Ensure sufficient point pairs for fitting
                    continue

                # Test different variogram models
                for model_type in variogram_models:
                    # Fit variogram model
                    params, mse = fit_variogram_model(lags, gamma, model_type)

                    # Update best parameters
                    if c == 0 and mse < best_mse_u:
                        best_mse_u = mse
                        best_params_u = {'model': model_type, 'range': params[0], 'sill': params[1], 'nugget': params[2]}
                    elif c == 1 and mse < best_mse_v:
                        best_mse_v = mse
                        best_params_v = {'model': model_type, 'range': params[0], 'sill': params[1], 'nugget': params[2]}

            except Exception as e:
                print(f"Variogram calculation failed for sample {i}, component {c}: {e}")
                continue

    print(f"Optimal variogram parameters - U component: {best_params_u}")
    print(f"Optimal variogram parameters - V component: {best_params_v}")

    # Validate optimal parameters found using cross-validation
    print("Validating optimal parameters using cross-validation...")

    validation_size = min(30, len(x_test))
    validation_indices = np.random.choice([i for i in range(len(x_test)) if i not in sample_indices],
                                         validation_size, replace=False)

    rmse_u = 0
    rmse_v = 0
    valid_count = 0

    for i in tqdm(validation_indices, desc="Cross-validation progress"):
        sensor_mask = x_test[i, :, :, 2]
        sensor_positions = np.where(sensor_mask == 1)
        x_coords = sensor_positions[1]
        y_coords = sensor_positions[0]

        if len(x_coords) < 5:
            continue

        # Randomly select 30% of points as test points
        test_points = max(int(len(x_coords) * 0.3), 1)
        test_idx = np.random.choice(len(x_coords), test_points, replace=False)
        train_idx = np.array([j for j in range(len(x_coords)) if j not in test_idx])

        for c in range(2):
            values = y_test[i, y_coords, x_coords, c]
            train_values = values[train_idx]
            test_values = values[test_idx]

            current_params = best_params_u if c == 0 else best_params_v

            try:
                if method.lower() == 'ordinary':
                    krig = OrdinaryKriging(
                        x_coords[train_idx], y_coords[train_idx], train_values,
                        variogram_model=current_params['model'],
                        variogram_parameters={
                            'range': current_params['range'],
                            'sill': current_params['sill'],
                            'nugget': current_params['nugget']
                        },
                        verbose=False,
                        enable_plotting=False
                    )
                else:
                    krig = UniversalKriging(
                        x_coords[train_idx], y_coords[train_idx], train_values,
                        variogram_model=current_params['model'],
                        variogram_parameters={
                            'range': current_params['range'],
                            'sill': current_params['sill'],
                            'nugget': current_params['nugget']
                        },
                        verbose=False,
                        enable_plotting=False
                    )

                # Predict on test points
                test_coords = np.array([x_coords[test_idx], y_coords[test_idx]]).T
                pred_values, _ = krig.execute('points', test_coords[:, 0], test_coords[:, 1])

                # Calculate RMSE
                error = np.sqrt(np.mean((pred_values - test_values)**2))

                if c == 0:
                    rmse_u += error
                else:
                    rmse_v += error

            except Exception as e:
                continue

        valid_count += 1

    if valid_count > 0:
        rmse_u /= valid_count
        rmse_v /= valid_count
        print(f"Cross-validation RMSE - U: {rmse_u:.6f}, V: {rmse_v:.6f}")

    # Predict all samples using optimal parameters
    print("Predicting using optimal parameters...")
    pred_kriging = np.zeros_like(y_test)

    for i in tqdm(range(len(x_test)), desc="Kriging reconstruction progress"):
        sensor_mask = x_test[i, :, :, 2]
        sensor_positions = np.where(sensor_mask == 1)
        x_coords = sensor_positions[1]
        y_coords = sensor_positions[0]

        for c in range(2):
            values = y_test[i, y_coords, x_coords, c]

            if len(x_coords) < 3:
                pred_kriging[i, :, :, c] = x_test[i, :, :, c]
                continue

            current_params = best_params_u if c == 0 else best_params_v

            try:
                if method.lower() == 'ordinary':
                    krig = OrdinaryKriging(
                        x_coords, y_coords, values,
                        variogram_model=current_params['model'],
                        variogram_parameters={
                            'range': current_params['range'],
                            'sill': current_params['sill'],
                            'nugget': current_params['nugget']
                        },
                        verbose=False,
                        enable_plotting=False
                    )
                else:
                    krig = UniversalKriging(
                        x_coords, y_coords, values,
                        variogram_model=current_params['model'],
                        variogram_parameters={
                            'range': current_params['range'],
                            'sill': current_params['sill'],
                            'nugget': current_params['nugget']
                        },
                        verbose=False,
                        enable_plotting=False
                    )

                z, ss = krig.execute('grid', grid_x, grid_y)
                pred_kriging[i, :, :, c] = z

            except Exception as e:
                print(f"Kriging interpolation failed for sample {i}, component {c}: {e}")
                pred_kriging[i, :, :, c] = x_test[i, :, :, c]

    return pred_kriging

In [None]:
def load_my_data_opt(sensor_num, method, data_dir='./optimal_dataset'):
    """
    Load data generated based on optimal sensor positions
    
    Parameters:
        sensor_num: Number of sensors (5, 10, 15, 20, 25, 30)
        method: Method number (0 or 1)
        data_dir: Data file directory
    
    Returns:
        Same output format as original load_my_data function
    """
    
    def load_file(csv_name, data_type):
        """Helper function: load specific file"""
        filename = f"{csv_name}_method{method}_sensor{sensor_num}_{data_type}.npy"
        filepath = os.path.join(data_dir, filename)
        
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"File does not exist: {filepath}")
        
        return np.load(filepath, mmap_mode='r')
    
    print(f"Loading optimal data: Method {method}, Sensor_num {sensor_num}")
    
    try:
        if method == 0:
            # Method 0: Training uses 3 from 0deg, testing uses 2 from 22deg + 3 from 45deg
            
            # === Training data (3 from 0deg) ===
            x1_data_0 = load_file('0deg_1', 'X')
            y1_data_0 = load_file('0deg_1', 'y')
            x2_data_0 = load_file('0deg_2', 'X')
            y2_data_0 = load_file('0deg_2', 'y')
            x3_data_0 = load_file('0deg_3', 'X')
            y3_data_0 = load_file('0deg_3', 'y')
            
            # === Test data (2 from 22deg + 3 from 45deg) ===
            x1_data_22 = load_file('22deg_1', 'X')
            y1_data_22 = load_file('22deg_1', 'y')
            x2_data_22 = load_file('22deg_2', 'X')
            y2_data_22 = load_file('22deg_2', 'y')
            
            x1_data_45 = load_file('45deg_1', 'X')
            y1_data_45 = load_file('45deg_1', 'y')
            x2_data_45 = load_file('45deg_2', 'X')
            y2_data_45 = load_file('45deg_2', 'y')
            x3_data_45 = load_file('45deg_3', 'X')
            y3_data_45 = load_file('45deg_3', 'y')
            
            # === Perturbed test data ===
            x1_perturbed_data_22 = load_file('22deg_1', 'X_perturbed')
            x2_perturbed_data_22 = load_file('22deg_2', 'X_perturbed')
            x1_perturbed_data_45 = load_file('45deg_1', 'X_perturbed')
            x2_perturbed_data_45 = load_file('45deg_2', 'X_perturbed')
            x3_perturbed_data_45 = load_file('45deg_3', 'X_perturbed')
            
            # Combine training data
            x_data = np.concatenate([x1_data_0, x2_data_0, x3_data_0], axis=0)
            y_data = np.concatenate([y1_data_0, y2_data_0, y3_data_0], axis=0)
            
            # Combine test data
            x_test = np.concatenate([x1_data_22, x2_data_22, x1_data_45, x2_data_45, x3_data_45], axis=0)
            x_perturbed_test = np.concatenate([x1_perturbed_data_22, x2_perturbed_data_22, 
                                             x1_perturbed_data_45, x2_perturbed_data_45, x3_perturbed_data_45], axis=0)
            y_test = np.concatenate([y1_data_22, y2_data_22, y1_data_45, y2_data_45, y3_data_45], axis=0)
            
            print(f"Method 0 - Training data: 0deg_1,2,3")
            print(f"Method 0 - Test data: 22deg_1,2 + 45deg_1,2,3")
            
        elif method == 1:
            # Method 1: Training uses 0deg_1 + 22deg_1 + 45deg_1, testing uses the rest
            
            # === Training data (0deg_1, 22deg_1, 45deg_1) ===
            x1_data_0 = load_file('0deg_1', 'X')
            y1_data_0 = load_file('0deg_1', 'y')
            x1_data_22 = load_file('22deg_1', 'X')
            y1_data_22 = load_file('22deg_1', 'y')
            x1_data_45 = load_file('45deg_1', 'X')
            y1_data_45 = load_file('45deg_1', 'y')
            
            # === Test data (0deg_2,3 + 22deg_2 + 45deg_2,3) ===
            x2_data_0 = load_file('0deg_2', 'X')
            y2_data_0 = load_file('0deg_2', 'y')
            x3_data_0 = load_file('0deg_3', 'X')
            y3_data_0 = load_file('0deg_3', 'y')
            x2_data_22 = load_file('22deg_2', 'X')
            y2_data_22 = load_file('22deg_2', 'y')
            x2_data_45 = load_file('45deg_2', 'X')
            y2_data_45 = load_file('45deg_2', 'y')
            x3_data_45 = load_file('45deg_3', 'X')
            y3_data_45 = load_file('45deg_3', 'y')
            
            # === Perturbed test data ===
            x2_perturbed_data_0 = load_file('0deg_2', 'X_perturbed')
            x3_perturbed_data_0 = load_file('0deg_3', 'X_perturbed')
            x2_perturbed_data_22 = load_file('22deg_2', 'X_perturbed')
            x2_perturbed_data_45 = load_file('45deg_2', 'X_perturbed')
            x3_perturbed_data_45 = load_file('45deg_3', 'X_perturbed')
            
            # Combine training data
            x_data = np.concatenate([x1_data_0, x1_data_22, x1_data_45], axis=0)
            y_data = np.concatenate([y1_data_0, y1_data_22, y1_data_45], axis=0)
            
            # Combine test data
            x_test = np.concatenate([x2_data_0, x3_data_0, x2_data_22, x2_data_45, x3_data_45], axis=0)
            x_perturbed_test = np.concatenate([x2_perturbed_data_0, x3_perturbed_data_0, 
                                             x2_perturbed_data_22, x2_perturbed_data_45, x3_perturbed_data_45], axis=0)
            y_test = np.concatenate([y2_data_0, y3_data_0, y2_data_22, y2_data_45, y3_data_45], axis=0)
            
            print(f"Method 1 - Training data: 0deg_1 + 22deg_1 + 45deg_1")
            print(f"Method 1 - Test data: 0deg_2,3 + 22deg_2 + 45deg_2,3")
            
        else:
            raise ValueError(f"Unsupported method value: {method}, only 0 or 1 supported")
        
        # Data shape information
        print(f"Training data shape: X={x_data.shape}, y={y_data.shape}")
        print(f"Test data shape: X={x_test.shape}, y={y_test.shape}")
        print(f"Perturbed test data shape: X_perturbed={x_perturbed_test.shape}")
        
        # Train-validation split
        x_train, x_val, y_train, y_val = train_test_split(x_data, y_data, test_size=0.2, random_state=42)
        
        # Data scaling
        x_train_scaled, _, _ = scale_data(x_train)
        x_val_scaled, _, _ = scale_data(x_val)
        x_test_scaled, _, _ = scale_data(x_test)
        x_perturbed_test_scaled, _, _ = scale_data(x_perturbed_test)
        y_train_scaled, _, _ = scale_data(y_train)
        y_val_scaled, _, _ = scale_data(y_val)
        y_test_scaled, min_vals, max_vals = scale_data(y_test)
        
        # Prepare output
        x_train_input = x_train_scaled
        y_train_output = y_train_scaled
        x_val_input = x_val_scaled
        y_val_output = y_val_scaled
        x_test_input = x_test_scaled
        x_perturbed_test_input = x_perturbed_test_scaled
        y_test_output = y_test_scaled
        
        print(f"✅ Data loading completed!")
        
        return (x_train_input, y_train_output, x_val_input, y_val_output, 
                x_test_input, x_perturbed_test_input, y_test_output, 
                min_vals, max_vals, y_test, x_test)
        
    except FileNotFoundError as e:
        print(f"❌ File loading failed: {e}")
        print(f"Please ensure batch_generate_data.py has been run to generate required data files")
        raise
    except Exception as e:
        print(f"❌ Error occurred during data loading: {e}")
        raise

In [None]:
def load_my_data(sensor_num, method):
  x1_data_0 = np.load(f'{out_dir}/x1_data_{sensor_num}.npy', mmap_mode='r')
  y1_data_0 = np.load(f'{out_dir}/y1_data_{sensor_num}.npy', mmap_mode='r')
  x2_data_0 = np.load(f'{out_dir}/x2_data_{sensor_num}.npy', mmap_mode='r')
  y2_data_0 = np.load(f'{out_dir}/y2_data_{sensor_num}.npy', mmap_mode='r')
  x3_data_0 = np.load(f'{out_dir}/x3_data_{sensor_num}.npy', mmap_mode='r')
  y3_data_0 = np.load(f'{out_dir}/y3_data_{sensor_num}.npy', mmap_mode='r')
  x2_perturbed_data_0 = np.load(f'{out_dir}/x2_perturbed_data_{sensor_num}.npy', mmap_mode='r')
  x3_perturbed_data_0 = np.load(f'{out_dir}/x3_perturbed_data_{sensor_num}.npy', mmap_mode='r')

  x1_data_45 = np.load(f'{out_dir}/45deg_x1_data_{sensor_num}.npy', mmap_mode='r')
  y1_data_45 = np.load(f'{out_dir}/45deg_y1_data_{sensor_num}.npy', mmap_mode='r')
  x2_data_45 = np.load(f'{out_dir}/45deg_x2_data_{sensor_num}.npy', mmap_mode='r')
  y2_data_45 = np.load(f'{out_dir}/45deg_y2_data_{sensor_num}.npy', mmap_mode='r')
  x3_data_45 = np.load(f'{out_dir}/45deg_x3_data_{sensor_num}.npy', mmap_mode='r')
  y3_data_45 = np.load(f'{out_dir}/45deg_y3_data_{sensor_num}.npy', mmap_mode='r')
  x1_perturbed_data_45 = np.load(f'{out_dir}/45deg_x1_perturbed_data_{sensor_num}.npy', mmap_mode='r')
  x2_perturbed_data_45 = np.load(f'{out_dir}/45deg_x2_perturbed_data_{sensor_num}.npy', mmap_mode='r')
  x3_perturbed_data_45 = np.load(f'{out_dir}/45deg_x3_perturbed_data_{sensor_num}.npy', mmap_mode='r')

  x1_data_22 = np.load(f'{out_dir}/22deg_x1_data_{sensor_num}.npy', mmap_mode='r')
  y1_data_22 = np.load(f'{out_dir}/22deg_y1_data_{sensor_num}.npy', mmap_mode='r')
  x2_data_22 = np.load(f'{out_dir}/22deg_x2_data_{sensor_num}.npy', mmap_mode='r')
  y2_data_22 = np.load(f'{out_dir}/22deg_y2_data_{sensor_num}.npy', mmap_mode='r')
  x1_perturbed_data_22 = np.load(f'{out_dir}/22deg_x1_perturbed_data_{sensor_num}.npy', mmap_mode='r')
  x2_perturbed_data_22 = np.load(f'{out_dir}/22deg_x2_perturbed_data_{sensor_num}.npy', mmap_mode='r')

  if method == 0:
    x_data = np.concatenate([x1_data_0, x2_data_0, x3_data_0], axis=0)
    y_data = np.concatenate([y1_data_0, y2_data_0, y3_data_0], axis=0)
    x_test = np.concatenate([x1_data_22, x2_data_22, x1_data_45, x2_data_45, x3_data_45], axis=0)
    x_perturbed_test = np.concatenate([x1_perturbed_data_22, x2_perturbed_data_22, x1_perturbed_data_45, x2_perturbed_data_45, x3_perturbed_data_45], axis=0)
    y_test = np.concatenate([y1_data_22, y2_data_22, y1_data_45, y2_data_45, y3_data_45], axis=0)

    x_train, x_val, y_train, y_val = train_test_split(x_data, y_data, test_size=0.2, random_state=42)
    x_train_scaled,_,_ = scale_data(x_train)
    x_val_scaled,_,_ = scale_data(x_val)
    x_test_scaled,_,_ = scale_data(x_test)
    x_perturbed_test_scaled,_,_ = scale_data(x_perturbed_test)
    y_train_scaled,_,_ = scale_data(y_train)
    y_val_scaled,_,_ = scale_data(y_val)
    y_test_scaled,min_vals,max_vals = scale_data(y_test)
    x_train_input = x_train_scaled#[:-10]  # shape: (N, 15, 15, 3)
    y_train_output = y_train_scaled#[10:]  # shape: (N, 15, 15, 2)
    x_val_input = x_val_scaled
    y_val_output = y_val_scaled
    x_test_input = x_test_scaled#[:-10]
    x_perturbed_test_input = x_perturbed_test_scaled
    y_test_output = y_test_scaled#[10:]
  elif method == 1:
    x_data = np.concatenate([x1_data_0, x1_data_22, x1_data_45], axis=0)
    y_data = np.concatenate([y1_data_0, y1_data_22, y1_data_45], axis=0)
    x_test = np.concatenate([x2_data_0, x3_data_0, x2_data_22, x2_data_45, x3_data_45], axis=0)
    x_perturbed_test = np.concatenate([x2_perturbed_data_0, x3_perturbed_data_0, x2_perturbed_data_22, x2_perturbed_data_45, x3_perturbed_data_45], axis=0)
    y_test = np.concatenate([y2_data_0, y3_data_0, y2_data_22, y2_data_45, y3_data_45], axis=0)

    x_train, x_val, y_train, y_val = train_test_split(x_data, y_data, test_size=0.2, random_state=42)
    x_train_scaled,_,_ = scale_data(x_train)
    x_val_scaled,_,_ = scale_data(x_val)
    x_test_scaled,_,_ = scale_data(x_test)
    x_perturbed_test_scaled,_,_ = scale_data(x_perturbed_test)
    y_train_scaled,_,_ = scale_data(y_train)
    y_val_scaled,_,_ = scale_data(y_val)
    y_test_scaled,min_vals,max_vals = scale_data(y_test)
    x_train_input = x_train_scaled#[:-10]  # shape: (N, 15, 15, 3)
    y_train_output = y_train_scaled#[10:]  # shape: (N, 15, 15, 2)
    x_val_input = x_val_scaled
    y_val_output = y_val_scaled
    x_test_input = x_test_scaled#[:-10]
    x_perturbed_test_input = x_perturbed_test_scaled
    y_test_output = y_test_scaled#[10:]

  # x_data = np.concatenate([x1_data, x2_data], axis=0)
  # y_data = np.concatenate([y1_data, y2_data], axis=0)
  # x_test = x3_data
  # x_perturbed_test = x3_perturbed_data
  # y_test = y3_data
  # # x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.2, random_state=42)
  # x_train, x_val, y_train, y_val = train_test_split(x_data, y_data, test_size=0.2, random_state=42)
  # x_train_scaled,_,_ = scale_data(x_train)
  # x_val_scaled,_,_ = scale_data(x_val)
  # x_test_scaled,_,_ = scale_data(x_test)
  # x_perturbed_test_scaled,_,_ = scale_data(x_perturbed_test)
  # y_train_scaled,_,_ = scale_data(y_train)
  # y_val_scaled,_,_ = scale_data(y_val)
  # y_test_scaled,min_vals,max_vals = scale_data(y_test)
  # x_train_input = x_train_scaled#[:-10]  # shape: (N, 15, 15, 3)
  # y_train_output = y_train_scaled#[10:]  # shape: (N, 15, 15, 2)
  # x_val_input = x_val_scaled
  # y_val_output = y_val_scaled
  # x_test_input = x_test_scaled#[:-10]
  # x_perturbed_test_input = x_perturbed_test_scaled
  # y_test_output = y_test_scaled#[10:]
  return x_train_input, y_train_output, x_val_input, y_val_output, x_test_input, x_perturbed_test_input, y_test_output, min_vals, max_vals, y_test, x_test

In [None]:
def load_perturbed_data(sensor_num):
  x1_data = np.load(f'{out_dir}/x1_data_{sensor_num}.npy', mmap_mode='r')
  y1_data = np.load(f'{out_dir}/y1_data_{sensor_num}.npy', mmap_mode='r')
  x2_data = np.load(f'{out_dir}/x2_data_{sensor_num}.npy', mmap_mode='r')
  y2_data = np.load(f'{out_dir}/y2_data_{sensor_num}.npy', mmap_mode='r')
  x3_data = np.load(f'{out_dir}/x3_data_{sensor_num}.npy', mmap_mode='r')
  y3_data = np.load(f'{out_dir}/y3_data_{sensor_num}.npy', mmap_mode='r')
  x3_perturbed_data = np.load(f'{out_dir}/x3_perturbed_data_{sensor_num}.npy', mmap_mode='r')

  x_data = np.concatenate([x1_data, x2_data], axis=0)
  y_data = np.concatenate([y1_data, y2_data], axis=0)
  x_test = x3_data
  x_perturbed_test = x3_perturbed_data
  y_test = y3_data
  # x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.2, random_state=42)
  x_train, x_val, y_train, y_val = train_test_split(x_data, y_data, test_size=0.2, random_state=42)
  x_train_scaled,_,_ = scale_data(x_train)
  x_val_scaled,_,_ = scale_data(x_val)
  x_test_scaled,_,_ = scale_data(x_test)
  x_perturbed_test_scaled,_,_ = scale_data(x_perturbed_test)
  y_train_scaled,_,_ = scale_data(y_train)
  y_val_scaled,_,_ = scale_data(y_val)
  y_test_scaled,min_vals,max_vals = scale_data(y_test)
  x_train_input = x_train_scaled#[:-10]  # shape: (N, 15, 15, 3)
  y_train_output = y_train_scaled#[10:]  # shape: (N, 15, 15, 2)
  x_val_input = x_val_scaled
  y_val_output = y_val_scaled
  x_test_input = x_test_scaled#[:-10]
  x_perturbed_test_input = x_perturbed_test_scaled
  y_test_output = y_test_scaled#[10:]
  return x_train_input, y_train_output, x_val_input, y_val_output, x_test_input, x_perturbed_test_input, y_test_output, min_vals, max_vals, y_test, x_perturbed_test

In [None]:
def restore_original_scale(scaled_data, min_vals, max_vals):
    # Only denormalize the first two channels (U,V components), keep sensor mask channel unchanged
    restored_data = np.copy(scaled_data)
    
    if scaled_data.shape[-1] >= 3:  # If there are 3 or more channels
        # Only denormalize the first two channels
        restored_data[:, :, :, :2] = scaled_data[:, :, :, :2] * (max_vals - min_vals) + min_vals
        # Sensor mask channel remains unchanged (already copied in copy)
    else:
        # If there are only two channels, process as before
        restored_data = scaled_data * (max_vals - min_vals) + min_vals
    
    return restored_data

def scale_data(data):
    # data shape is (samples, height, width, channels)
    # Only normalize the first two channels (U,V components), keep sensor mask channel unchanged
    scaled_data = np.zeros_like(data, dtype=np.float32)
    min_vals = np.zeros((data.shape[0], 1, 1, 1), dtype=np.float32)
    max_vals = np.zeros((data.shape[0], 1, 1, 1), dtype=np.float32)
    
    for i in range(data.shape[0]):
        if data.shape[-1] >= 3:  # If there are 3 or more channels
            # Only calculate min-max for the first two channels (U,V components)
            physics_data = data[i, :, :, :2]  # Only take the first two channels
            min_val = np.min(physics_data)
            max_val = np.max(physics_data)
            min_vals[i] = min_val
            max_vals[i] = max_val
            
            # Normalize the first two channels
            scaled_data[i, :, :, :2] = (data[i, :, :, :2] - min_val) / (max_val - min_val + 1e-8)
            
            # Sensor mask channel remains unchanged
            if data.shape[-1] > 2:
                scaled_data[i, :, :, 2:] = data[i, :, :, 2:]
        else:
            # If there are only two channels, process as before
            min_val = np.min(data[i])
            max_val = np.max(data[i])
            min_vals[i] = min_val
            max_vals[i] = max_val
            scaled_data[i] = (data[i] - min_val) / (max_val - min_val + 1e-8)
    
    return scaled_data, min_vals, max_vals

In [None]:
def visualize_results(x_data, y_data, pred_data, mask_data, sample_index=0, num_samples=8, model='', max_samples=None):
    """
    Visualize U-Net results for multiple samples in a grid layout.

    Parameters:
      x_data: Input data (N, 15, 15, 3)
      y_data: Ground truth (N, 15, 15, 2)
      pred_data: U-Net prediction results (N, 15, 15, 2)
      mask_data: Sensor mask data (N, 15, 15, 3)
      sample_index: Starting sample index
      num_samples: Number of samples to visualize (default: 8)
    """
    # Ensure we don't exceed dataset size
    # max_samples = min(num_samples, len(x_data) - sample_index)
    # max_samples = np.random.randint(0, len(x_data), size=num_samples)
    # Create figure with 4 rows and num_samples columns
    fig, axs = plt.subplots(4, len(max_samples), figsize=(len(max_samples)*2.5, 10), gridspec_kw={'hspace': -0.5, 'wspace': 0.25})

    # If only one sample, make axs 2D
    if len(max_samples) == 1:
        axs = axs.reshape(4, 1)

    # Common colormap for better comparison
    cmap = 'viridis'

    # For storing min and max values for consistent color scales
    u_min, u_max = float('inf'), float('-inf')
    v_min, v_max = float('inf'), float('-inf')
    error_max = 0

    # Find min/max across all samples for consistent coloring
    for i in max_samples:
        idx = sample_index + i
        if idx >= len(x_data):
            break

        # Update U min/max
        u_min = min(u_min, np.min(x_data[idx, :, :, 0]), np.min(pred_data[idx, :, :, 0]), np.min(y_data[idx, :, :, 0]))
        u_max = max(u_max, np.max(x_data[idx, :, :, 0]), np.max(pred_data[idx, :, :, 0]), np.max(y_data[idx, :, :, 0]))

        # Update V min/max
        v_min = min(v_min, np.min(x_data[idx, :, :, 1]), np.min(pred_data[idx, :, :, 1]), np.min(y_data[idx, :, :, 1]))
        v_max = max(v_max, np.max(x_data[idx, :, :, 1]), np.max(pred_data[idx, :, :, 1]), np.max(y_data[idx, :, :, 1]))

        # Compute errors for max scaling
        u_error = np.abs(pred_data[idx, :, :, 0] - y_data[idx, :, :, 0])
        v_error = np.abs(pred_data[idx, :, :, 1] - y_data[idx, :, :, 1])
        error_max = max(error_max, np.max(u_error), np.max(v_error))

    # Plot each sample
    
    for index, i in enumerate(max_samples):
        idx = sample_index + i
        if idx >= len(x_data):
            break

        # Get sensor positions
        sensor_mask = mask_data[idx, :, :, 2]
        sensor_positions = np.where(sensor_mask == 1)

        # Row 1: Interpolated inputs (Observation)
        im1 = axs[0, index].imshow(x_data[idx, :, :, 0], cmap=cmap, vmin=u_min, vmax=u_max)
        axs[0, index].set_title(f'Sample {idx}')
        if index == 0:
          axs[0, index].set_ylabel(f'Interpolated U', rotation=90, labelpad=15, va='center')
        axs[0, index].scatter(sensor_positions[1], sensor_positions[0], c='red', s=10, marker='o')
        axs[0, index].set_xticks([])
        axs[0, index].set_yticks([])
        fig.colorbar(im1, ax=axs[0, index], fraction=0.046, pad=0.04)

        # Row 2: Predictions
        im2 = axs[1, index].imshow(pred_data[idx, :, :, 0], cmap=cmap, vmin=u_min, vmax=u_max)
        # axs[1, i].set_title('Predicted U')
        if index == 0:
          axs[1, index].set_ylabel(f'Predicted U', rotation=90, labelpad=15, va='center')
        axs[1, index].set_xticks([])
        axs[1, index].set_yticks([])
        fig.colorbar(im2, ax=axs[1, index], fraction=0.046, pad=0.04)

        # Row 3: Ground truth
        im3 = axs[2, index].imshow(y_data[idx, :, :, 0], cmap=cmap, vmin=u_min, vmax=u_max)
        if index == 0:
          axs[2, index].set_ylabel(f'Ground Truth U', rotation=90, labelpad=15, va='center')
        # axs[2, i].set_title('Ground Truth U')
        axs[2, index].set_xticks([])
        axs[2, index].set_yticks([])
        fig.colorbar(im3, ax=axs[2, index], fraction=0.046, pad=0.04)

        # Row 4: Error
        error = np.abs(pred_data[idx, :, :, 0] - y_data[idx, :, :, 0])
        im4 = axs[3, index].imshow(error, cmap='hot', vmin=0, vmax=error_max)
        if index == 0:
          axs[3, index].set_ylabel(f'Error', rotation=90, labelpad=15, va='center')
        #axs[3, i].set_title(f'Error (mean={error.mean():.4f})')
        axs[3, index].set_xticks([])
        axs[3, index].set_yticks([])
        fig.colorbar(im4, ax=axs[3, index], fraction=0.046, pad=0.04)
    plt.suptitle(f'{model}', fontsize=16, fontweight='bold')
    #plt.tight_layout()
    plt.show()

    # # Also display for V component
    # fig, axs = plt.subplots(4, max_samples, figsize=(max_samples*2.5, 10))

    # # If only one sample, make axs 2D
    # if max_samples == 1:
    #     axs = axs.reshape(4, 1)

    # # Plot each sample (V component)
    # for i in range(max_samples):
    #     idx = sample_index + i
    #     if idx >= len(x_data):
    #         break

    #     # Get sensor positions
    #     sensor_mask = mask_data[idx, :, :, 2]
    #     sensor_positions = np.where(sensor_mask == 1)

    #     # Row 1: Interpolated inputs (Observation)
    #     im1 = axs[0, i].imshow(x_data[idx, :, :, 1], cmap=cmap, vmin=v_min, vmax=v_max)
    #     axs[0, i].set_title(f'Sample {idx}\nInterpolated V')
    #     axs[0, i].scatter(sensor_positions[1], sensor_positions[0], c='red', s=10, marker='o')
    #     axs[0, i].set_xticks([])
    #     axs[0, i].set_yticks([])
    #     fig.colorbar(im1, ax=axs[0, i], fraction=0.046, pad=0.04)

    #     # Row 2: Predictions
    #     im2 = axs[1, i].imshow(pred_data[idx, :, :, 1], cmap=cmap, vmin=v_min, vmax=v_max)
    #     axs[1, i].set_title('Predicted V')
    #     axs[1, i].set_xticks([])
    #     axs[1, i].set_yticks([])
    #     fig.colorbar(im2, ax=axs[1, i], fraction=0.046, pad=0.04)

    #     # Row 3: Ground truth
    #     im3 = axs[2, i].imshow(y_data[idx, :, :, 1], cmap=cmap, vmin=v_min, vmax=v_max)
    #     axs[2, i].set_title('Ground Truth V')
    #     axs[2, i].set_xticks([])
    #     axs[2, i].set_yticks([])
    #     fig.colorbar(im3, ax=axs[2, i], fraction=0.046, pad=0.04)

    #     # Row 4: Error
    #     error = np.abs(pred_data[idx, :, :, 1] - y_data[idx, :, :, 1])
    #     im4 = axs[3, i].imshow(error, cmap='hot', vmin=0, vmax=error_max)
    #     axs[3, i].set_title(f'Error (mean={error.mean():.4f})')
    #     axs[3, i].set_xticks([])
    #     axs[3, i].set_yticks([])
    #     fig.colorbar(im4, ax=axs[3, i], fraction=0.046, pad=0.04)

    # plt.tight_layout()
    # plt.show()

# Calculate whole image vector field relative L2 error
def calculate_whole_image_relative_l2_error(pred, true, epsilon=1e-8):
    """
    Calculate relative L2 error for the whole image:
    ∑||(pred_u,pred_v) - (true_u,true_v)||² / (∑||true_u,true_v||² + epsilon)

    Parameters:
        pred: Predicted values with shape [..., 2]
        true: True values with shape [..., 2]
        epsilon: Small constant to prevent division by zero

    Returns:
        Single relative L2 error value for the whole image
    """
    # Ensure we're working with the full vector data (UV)
    assert pred.shape[-1] >= 2 and true.shape[-1] >= 2, "Input must contain at least 2 channels"

    # Calculate vector difference squared for all pixels
    u_diff = np.square(pred[..., 0] - true[..., 0])
    v_diff = np.square(pred[..., 1] - true[..., 1])
    total_diff_squared = np.sum(u_diff + v_diff)

    # Calculate sum of true vector magnitudes squared for all pixels
    true_u_squared = np.square(true[..., 0])
    true_v_squared = np.square(true[..., 1])
    total_true_squared = np.sum(true_u_squared + true_v_squared) + epsilon

    # Calculate whole image relative error
    return total_diff_squared / total_true_squared

# Calculate sample-level whole image vector field relative L2 error
def compute_sample_whole_image_relative_l2_errors(pred_test_restored, y_test):
    """Calculate and display sample-level whole image vector field relative L2 error (calculated per whole image)"""
    whole_image_rel_errors = []
    for i in range(len(pred_test_restored)):
        # Calculate whole image vector field relative L2 error
        rel_error = calculate_whole_image_relative_l2_error(
            pred_test_restored[i],
            y_test[i]
        )

        # Store whole image error for each sample
        whole_image_rel_errors.append(rel_error)

    # Print average whole image relative L2 error
    print(f"Sample-level Whole Image Relative L2 Error: {np.mean(whole_image_rel_errors):.4f}")

    # Plot whole image relative L2 error distribution
    print("\nPlotting Sample-level Whole Image Relative L2 Error Distribution...")
    plot_error_histogram(whole_image_rel_errors, bins=50, error_type="Whole Image Relative L2", use_log=False, clip_percentile=95)

    return whole_image_rel_errors
# Calculate L2 error
# def calculate_l2_error(pred, true, epsilon=1e-8):
#     """Calculate relative L2 error: |pred-true|² / (|true|² + epsilon)"""
#     return np.square(pred - true) / (np.square(true) + epsilon)

def calculate_l2_error(pred, true, epsilon=1e-8):
    """Calculate relative L2 error for vector field"""
    # Squared norm of vector difference
    vector_diff_squared = np.square(pred[:,0] - true[:,0]) + np.square(pred[:,1] - true[:,1])
    # Squared norm of true vector
    true_norm_squared = np.square(true[:,0]) + np.square(true[:,1]) + epsilon
    return vector_diff_squared / true_norm_squared

# Apply logarithmic transformation to error values
def log_transform_errors(errors, epsilon=1e-10):
    """
    Apply logarithmic transformation to error values: log(error + epsilon)

    Parameters:
        errors: Error array
        epsilon: Small constant to prevent taking log of 0

    Returns:
        Log-transformed errors
    """
    return np.log10(np.array(errors) + epsilon)

# Plot error histogram
def plot_error_histogram(errors, bins=50, error_type="L2", use_log=True, clip_percentile=None):
    """
    Plot error histogram to show error distribution

    Parameters:
        errors: Error list (U and V channels already merged)
        bins: Number of histogram bins
        error_type: Name of error type (for title)
        use_log: Whether to use logarithmic transformation
        clip_percentile: Clipping percentile, e.g., 95 means only keep errors less than 95th percentile
    """
    # Process errors
    if clip_percentile is not None:
        error_max = np.percentile(errors, clip_percentile)
        errors_clipped = [e for e in errors if e <= error_max]
    else:
        errors_clipped = errors

    # Logarithmic transformation
    if use_log:
        errors_processed = log_transform_errors(errors_clipped)
        transform_label = "Log"
    else:
        errors_processed = errors_clipped
        transform_label = ""

    plt.figure(figsize=(10, 6))

    # Plot error histogram
    plt.hist(errors_processed, bins=bins, alpha=0.7, color='blue', edgecolor='black')
    plt.axvline(np.mean(errors_processed), color='red', linestyle='dashed', linewidth=1, label=f'Mean: {np.mean(errors_processed):.4f}')
    plt.axvline(np.median(errors_processed), color='green', linestyle='dashed', linewidth=1, label=f'Median: {np.median(errors_processed):.4f}')
    plt.title(f'Sample-level Combined U+V {error_type} {transform_label} Error')
    plt.xlabel(f'Error Value{" (Log Scale)" if use_log else ""}')
    plt.ylabel('Sample Count')
    plt.legend()
    plt.grid(alpha=0.3)
    print(np.max(errors_processed))

    plt.tight_layout()
    plt.show()

    # # Cumulative distribution plot
    # plt.figure(figsize=(10, 6))

    # # Cumulative distribution
    # plt.hist(errors_processed, bins=bins, alpha=0.7, color='blue', edgecolor='black', cumulative=True, density=True)
    # plt.title(f'Sample-level Combined U+V {error_type} {transform_label} Error CDF')
    # plt.xlabel(f'Error Value{" (Log Scale)" if use_log else ""}')
    # plt.ylabel('Cumulative Probability')
    # plt.grid(alpha=0.3)

    # plt.tight_layout()
    # plt.show()

In [None]:
# Add additional evaluation metric functions
def calculate_fractional_bias(exp, pred, epsilon=1e-10):
    """
    Calculate Fractional Bias (FB)
    FB = (mean experimental value - mean predicted value) / (0.5 * (mean experimental value + mean predicted value))

    Parameters:
        exp: Experimental/true values
        pred: Model predicted values
        epsilon: Small constant to prevent division by zero
    Returns:
        FB value, typically in range [-2, 2], 0 indicates no bias
    """
    mean_exp = np.mean(exp)
    mean_pred = np.mean(pred)
    denominator = 0.5 * (mean_exp + mean_pred)
    if abs(denominator) < epsilon:
        return 0.0
    return (mean_exp - mean_pred) / denominator

def calculate_geometric_mean_bias(exp, pred, epsilon=1e-10, max_value=1000.0):
    """
    Calculate Geometric Mean Bias (MG)
    MG = exp(mean(ln(experimental value) - ln(predicted value)))

    Parameters:
        exp: Experimental/true values
        pred: Model predicted values
        epsilon: Small constant to prevent ln(0)
        max_value: Upper limit of return value to prevent extreme values
    Returns:
        MG value, typically in range [0, max_value], 1 indicates no bias
    """
    # Prevent taking logarithm of negative values or zero
    valid_indices = (exp > epsilon) & (pred > epsilon)
    if np.sum(valid_indices) == 0:
        return 1.0  # If there is no valid data, return 1 to indicate no bias

    exp_safe = exp[valid_indices]
    pred_safe = pred[valid_indices]

    ln_diff = np.mean(np.log(exp_safe / pred_safe))
    # Prevent extreme values
    ln_diff = np.clip(ln_diff, -np.log(max_value), np.log(max_value))
    return np.exp(ln_diff)

def calculate_normalized_mean_square_error(exp, pred, epsilon=1e-10, max_value=1000.0):
    """
    Calculate Normalized Mean Square Error (NMSE)
    NMSE = mean((experimental value - predicted value)^2) / (mean experimental value * mean predicted value)

    Parameters:
        exp: Experimental/true values
        pred: Model predicted values
        epsilon: Small constant to prevent division by zero
        max_value: Upper limit of return value to prevent extreme values
    Returns:
        NMSE value, typically in range [0, max_value], 0 indicates perfect prediction
    """
    # Calculate mean square error
    mse = np.mean(np.square(exp - pred))

    # Calculate denominator
    mean_exp = np.abs(np.mean(exp)) + epsilon
    mean_pred = np.abs(np.mean(pred)) + epsilon
    denominator = mean_exp * mean_pred

    # Ensure result is positive
    result = np.abs(mse / denominator)

    # Limit maximum value
    return min(result, max_value)

def calculate_geometric_variance(exp, pred, epsilon=1e-10, max_value=1000.0):
    """
    Calculate Geometric Variance (VG)
    VG = exp(mean((ln(experimental value) - ln(predicted value))^2))

    Parameters:
        exp: Experimental/true values
        pred: Model predicted values
        epsilon: Small constant to prevent ln(0)
        max_value: Upper limit of return value to prevent overflow
    Returns:
        VG value, typically in range [1, max_value], 1 indicates perfect prediction
    """
    # Prevent taking logarithm of negative values or zero
    # Filter out non-positive point pairs
    valid_indices = (exp > epsilon) & (pred > epsilon)
    if np.sum(valid_indices) == 0:
        return 1.0  # If there is no valid data, return 1 to indicate no bias

    exp_safe = exp[valid_indices]
    pred_safe = pred[valid_indices]

    # Calculate squared logarithmic difference
    ln_diff_squared = np.square(np.log(exp_safe) - np.log(pred_safe))

    # Prevent extreme values
    mean_ln_diff_squared = np.mean(ln_diff_squared)
    # Limit maximum exponent value to prevent exp overflow
    mean_ln_diff_squared = min(mean_ln_diff_squared, np.log(max_value))

    return np.exp(mean_ln_diff_squared)

def calculate_fac2(exp, pred, W=0.005, epsilon=1e-10):
    """
    Calculate FAC2 evaluation metric
    Parameters:
        exp: Experimental/true values
        pred: Model predicted values
        W: Allowed error range
    Returns:
        FAC2 value
    """
    # Flatten input to 1D array
    exp_flat = exp.flatten()
    pred_flat = pred.flatten()

    # Calculate number of data points satisfying FAC2 condition
    n = len(exp_flat)
    count = 0

    for i in range(n):
        Pi = pred_flat[i]
        Oi = exp_flat[i]

        # If Oi is zero, replace with epsilon
        if Oi == 0:
            Oi = epsilon

        if 0.5 <= Pi / Oi <= 2:
            count += 1
        elif abs(Oi) <= W and abs(Pi) <= W:
            count += 1

    # Calculate FAC2
    FAC2 = count / n
    return FAC2

def calculate_all_metrics(exp, pred, max_value=1000.0, W=0.005):
    """
    Calculate all evaluation metrics

    Parameters:
        exp: Experimental/true values
        pred: Model predicted values
        max_value: Upper limit of return value to prevent extreme values
    Returns:
        Dictionary containing all metrics
    """
    # Flatten input to 1D array
    exp_flat = exp.flatten()
    pred_flat = pred.flatten()

    # Print some statistics to help debugging
    # print(f"Debug info - True values: min={np.min(exp_flat):.4f}, max={np.max(exp_flat):.4f}, mean={np.mean(exp_flat):.4f}")
    # print(f"Debug info - Predicted values: min={np.min(pred_flat):.4f}, max={np.max(pred_flat):.4f}, mean={np.mean(pred_flat):.4f}")

    # Calculate each metric
    metrics = {
        # 'FB': calculate_fractional_bias(exp_flat, pred_flat),
        'MG': calculate_geometric_mean_bias(exp_flat, pred_flat, max_value=max_value),
        'NMSE': calculate_normalized_mean_square_error(exp_flat, pred_flat, max_value=max_value),
        # 'VG': calculate_geometric_variance(exp_flat, pred_flat, max_value=max_value),
        'FAC2': calculate_fac2(exp, pred, W=W)
    }

    return metrics

In [None]:
def apply_kriging_reconstruction(x_test, y_test, method='ordinary'):
    """
    Reconstruct flow field using Kriging method, use unified optimal correlation length for entire dataset

    Parameters:
        x_test: Input data, containing interpolation results and sensor mask (N, 15, 15, 3)
        y_test: True data for evaluation (N, 15, 15, 2)
        method: 'ordinary' or 'universal', default is 'ordinary'

    Returns:
        pred_kriging: Kriging reconstruction results (N, 15, 15, 2)
    """
    print(f"Reconstructing flow field using {method}Kriging method...")

    # Get grid dimensions
    h, w = x_test.shape[1:3]
    grid_x = np.arange(0, w, 1.0)
    grid_y = np.arange(0, h, 1.0)

    # Define correlation length range to test
    # correlation_lengths = np.linspace(2, 25.0, 20)  # From 0.5 to 10, 20 values total

    # # First find optimal correlation length
    # print("Finding optimal correlation length...")
    # best_length_u = None
    # best_length_v = None
    # best_error_u = float('inf')
    # best_error_v = float('inf')

    # # Use partial samples for correlation length optimization (to save time)
    # sample_indices = np.random.choice(len(x_test), min(5000, len(x_test)), replace=False)

    # for length in tqdm(correlation_lengths, desc="Optimizing correlation length"):
    #     total_error_u = 0
    #     total_error_v = 0
    #     valid_samples = 0

    #     for i in sample_indices:
    #         sensor_mask = x_test[i, :, :, 2]
    #         sensor_positions = np.where(sensor_mask == 1)
    #         x_coords = sensor_positions[1]
    #         y_coords = sensor_positions[0]

    #         if len(x_coords) < 3:
    #             continue

    #         for c in range(2):
    #             values = x_test[i, y_coords, x_coords, c]

    #             try:
    #                 if method.lower() == 'ordinary':
    #                     krig = OrdinaryKriging(
    #                         x_coords, y_coords, values,
    #                         variogram_model='gaussian',
    #                         variogram_parameters={'range': length, 'sill': 1.0, 'nugget': 0.1},
    #                         verbose=False,
    #                         enable_plotting=False
    #                     )
    #                 else:
    #                     krig = UniversalKriging(
    #                         x_coords, y_coords, values,
    #                         variogram_model='gaussian',
    #                         variogram_parameters={'range': length, 'sill': 1.0, 'nugget': 0.1},
    #                         verbose=False,
    #                         enable_plotting=False
    #                     )

    #                 z, ss = krig.execute('grid', grid_x, grid_y)
    #                 pred_values = z[y_coords, x_coords]
    #                 error = np.mean((pred_values - values)**2)

    #                 if c == 0:
    #                     total_error_u += error
    #                 else:
    #                     total_error_v += error

    #             except Exception as e:
    #                 continue

    #         valid_samples += 1

    #     if valid_samples > 0:
    #         avg_error_u = total_error_u / valid_samples
    #         avg_error_v = total_error_v / valid_samples

    #         if avg_error_u < best_error_u:
    #             best_error_u = avg_error_u
    #             best_length_u = length
    #         if avg_error_v < best_error_v:
    #             best_error_v = avg_error_v
    #             best_length_v = length

    # print(f"Optimal correlation length - U component: {best_length_u:.3f}, V component: {best_length_v:.3f}")
    best_length_u = 10
    best_length_v = 10
    # Predict all samples using optimal correlation length
    print("Predicting using optimal correlation length...")
    pred_kriging = np.zeros_like(y_test)

    # for i in tqdm(range(len(x_test)), desc="Kriging reconstruction progress"):
    for i in range(len(x_test)):
        sensor_mask = x_test[i, :, :, 2]
        sensor_positions = np.where(sensor_mask == 1)
        x_coords = sensor_positions[1]
        y_coords = sensor_positions[0]

        for c in range(2):
            values = x_test[i, y_coords, x_coords, c]

            if len(x_coords) < 3:
                pred_kriging[i, :, :, c] = x_test[i, :, :, c]
                continue

            try:
                if method.lower() == 'ordinary':
                    krig = OrdinaryKriging(
                        x_coords, y_coords, values,
                        variogram_model='gaussian',
                        variogram_parameters={'range': best_length_u if c == 0 else best_length_v,
                                            'sill': 1.0, 'nugget': 0},
                        verbose=False,
                        enable_plotting=False
                    )
                else:
                    krig = UniversalKriging(
                        x_coords, y_coords, values,
                        variogram_model='gaussian',
                        variogram_parameters={'range': best_length_u if c == 0 else best_length_v,
                                            'sill': 1.0, 'nugget': 0},
                        verbose=False,
                        enable_plotting=False
                    )

                z, ss = krig.execute('grid', grid_x, grid_y)
                pred_kriging[i, :, :, c] = z

            except Exception as e:
                print(f"Kriging interpolation failed for sample {i}, component {c}: {e}")
                pred_kriging[i, :, :, c] = x_test[i, :, :, c]

    return pred_kriging

## Model Training

### CWGAN

#### Utility Functions

In [None]:
from keras.saving import register_keras_serializable

@register_keras_serializable()
def wasserstein_loss(y_true, y_pred):
    """Wasserstein loss function"""
    return tf.reduce_mean(y_true * y_pred)

@register_keras_serializable()
def gradient_penalty(discriminator, condition_data, real_samples, fake_samples, batch_size):
    """Gradient penalty term"""
    # Random interpolation
    alpha = tf.random.uniform([batch_size, 1, 1, 1], 0., 1.)
    interpolated = alpha * real_samples + (1 - alpha) * fake_samples

    with tf.GradientTape() as tape:
        tape.watch(interpolated)
        pred = discriminator([condition_data, interpolated], training=True)

    gradients = tape.gradient(pred, interpolated)
    gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
    gradient_penalty = tf.reduce_mean((gradients_norm - 1.0) ** 2)

    return gradient_penalty

def build_cwgan_generator(input_shape, noise_dim=100):
    """Build conditional WGAN generator"""
    # Conditional input (sensor data)
    condition_input = Input(shape=input_shape, name='condition_input')

    # Noise input
    noise_input = Input(shape=(noise_dim,), name='noise_input')

    # Reshape noise to same spatial dimensions as conditional input
    noise_reshaped = Dense(input_shape[0] * input_shape[1] * 1)(noise_input)
    noise_reshaped = Reshape((input_shape[0], input_shape[1], 1))(noise_reshaped)

    # Merge conditional input and noise
    merged = concatenate([condition_input, noise_reshaped])

    # === Step 1: ZeroPadding to 16x16 ===
    padded_inputs = ZeroPadding2D(((0, 1), (0, 1)))(merged)  # 15x15 -> 16x16

    # === Encoder part ===
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(padded_inputs)
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = MaxPooling2D((2, 2), padding="valid")(c1)  # 16x16 -> 8x8

    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = MaxPooling2D((2, 2), padding="valid")(c2)  # 8x8 -> 4x4

    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(c3)

    # === Decoder part ===
    u4 = UpSampling2D((2, 2))(c3)  # 4x4 -> 8x8
    u4 = concatenate([u4, c2])  # 8x8 connected with c2
    c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(u4)
    c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)

    u5 = UpSampling2D((2, 2))(c4)  # 8x8 -> 16x16
    u5 = concatenate([u5, c1])  # 16x16 connected with c1
    c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(u5)
    c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(c5)

    # === Step 2: Cropping2D remove padding, restore to 15x15 ===
    cropped_outputs = Cropping2D(((0, 1), (0, 1)))(c5)  # 16x16 -> 15x15

    outputs = Conv2D(2, (1, 1), activation='tanh')(cropped_outputs)

    model = Model([condition_input, noise_input], outputs, name='generator')
    return model

def build_cwgan_discriminator(input_shape):
    """Build conditional WGAN discriminator"""
    # Conditional input (sensor data)
    condition_input = Input(shape=input_shape, name='condition_input')

    # Real/generated data input
    data_input = Input(shape=(input_shape[0], input_shape[1], 2), name='data_input')

    # Merge condition and data
    merged = concatenate([condition_input, data_input])

    # === Step 1: ZeroPadding to 16x16 ===
    padded_inputs = ZeroPadding2D(((0, 1), (0, 1)))(merged)  # 15x15 -> 16x16

    # Convolutional layers
    x = Conv2D(64, (4, 4), strides=2, padding='same')(padded_inputs)
    x = layers.LeakyReLU(0.2)(x)
    x = Dropout(0.3)(x)

    x = Conv2D(128, (4, 4), strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    x = Dropout(0.3)(x)

    x = Conv2D(256, (4, 4), strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    x = Dropout(0.3)(x)

    x = Conv2D(512, (4, 4), strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    x = Dropout(0.3)(x)

    # Global average pooling
    x = layers.GlobalAveragePooling2D()(x)

    # Output layer (no sigmoid, because it's Wasserstein loss)
    outputs = Dense(1)(x)

    model = Model([condition_input, data_input], outputs, name='discriminator')
    return model

class CWGAN(tf.keras.Model):
    """Conditional Wasserstein GAN model"""

    def __init__(self, generator, discriminator, noise_dim=100, gp_weight=10.0):
        super(CWGAN, self).__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.noise_dim = noise_dim
        self.gp_weight = gp_weight

    def compile(self, g_optimizer, d_optimizer):
        super(CWGAN, self).compile()
        self.g_optimizer = g_optimizer
        self.d_optimizer = d_optimizer

    @tf.function
    def train_step(self, data):
        condition_data, real_data = data
        batch_size = tf.shape(condition_data)[0]

        # Train discriminator
        for _ in range(5):  # Discriminator trains 5 times, generator trains once
            noise = tf.random.normal([batch_size, self.noise_dim])

            with tf.GradientTape() as tape:
                # Generate fake data
                fake_data = self.generator([condition_data, noise], training=True)

                # Discriminator prediction
                real_pred = self.discriminator([condition_data, real_data], training=True)
                fake_pred = self.discriminator([condition_data, fake_data], training=True)

                # Wasserstein loss
                d_loss = tf.reduce_mean(fake_pred) - tf.reduce_mean(real_pred)

                # Gradient penalty
                gp = gradient_penalty(self.discriminator, condition_data, real_data, fake_data, batch_size)
                d_loss += self.gp_weight * gp

            # Update discriminator
            d_gradients = tape.gradient(d_loss, self.discriminator.trainable_variables)
            self.d_optimizer.apply_gradients(zip(d_gradients, self.discriminator.trainable_variables))

        # Train generator
        noise = tf.random.normal([batch_size, self.noise_dim])

        with tf.GradientTape() as tape:
            fake_data = self.generator([condition_data, noise], training=True)
            fake_pred = self.discriminator([condition_data, fake_data], training=True)

            # Generator loss (hope discriminator thinks generated data is real)
            g_loss = -tf.reduce_mean(fake_pred)

            # Add L1 loss to improve reconstruction quality
            l1_loss = tf.reduce_mean(tf.abs(real_data - fake_data))
            g_loss += 100.0 * l1_loss  # L1 loss weight

        # Update generator
        g_gradients = tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(zip(g_gradients, self.generator.trainable_variables))

        return {"d_loss": d_loss, "g_loss": g_loss, "l1_loss": l1_loss}

#### For loop train

In [None]:
sensor_list = [5,10,15,20,25,30]
# sensor_list = [30]
from tqdm.notebook import tqdm
for s in range(len(sensor_list)):
    sensor_num = sensor_list[s]
    x_train_input, y_train_output, x_val_input, y_val_output, x_test_input, x_perturbed_test_input, y_test_output, min_vals, max_vals, y_test, x_test = load_my_data(sensor_num, method=method)
    print(f'Dataset for {sensor_num} is ready.')

    input_shape = (15, 15, 3)
    noise_dim = 100

    # Build model
    generator = build_cwgan_generator(input_shape, noise_dim)
    discriminator = build_cwgan_discriminator(input_shape)

    # Create CWGAN
    cwgan = CWGAN(generator, discriminator, noise_dim)

    # Compile model
    g_optimizer = optimizers.Adam(learning_rate=0.0001, beta_1=0.5)
    d_optimizer = optimizers.Adam(learning_rate=0.0001, beta_1=0.5)
    cwgan.compile(g_optimizer, d_optimizer)

    save_dir = model_dir

    # # Train
    # # Training parameters
    # batch_size = 8
    # epochs = 100

    # # Create dataset
    # train_dataset = tf.data.Dataset.from_tensor_slices((x_train_input, y_train_output))
    # train_dataset = train_dataset.shuffle(1000).batch(batch_size)

    # # Train model
    # print(f"Starting CWGAN training, sensor count: {sensor_num}")
    # # print(f"Total {epochs} epochs, approximately {len(x_train_input)//batch_size} batches per epoch")
    # # Early stopping parameters
    # best_val_loss = float('inf')
    # patience = 10  # Stop if no improvement for 15 epochs
    # patience_counter = 0
    # best_epoch = 0

    # # Create validation dataset
    # val_dataset = tf.data.Dataset.from_tensor_slices((x_val_input, y_val_output))
    # val_dataset = val_dataset.batch(batch_size)

    # for epoch in range(epochs):
    #     epoch_d_loss = 0
    #     epoch_g_loss = 0
    #     epoch_l1_loss = 0
    #     num_batches = 0

    #     # Use tqdm to show batch progress
    #     batch_iterator = tqdm(train_dataset, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

    #     for batch in batch_iterator:
    #         losses = cwgan.train_step(batch)
    #         epoch_d_loss += losses["d_loss"]
    #         epoch_g_loss += losses["g_loss"]
    #         epoch_l1_loss += losses["l1_loss"]
    #         num_batches += 1

    #         # Update progress bar to show current loss
    #         batch_iterator.set_postfix({
    #             'D_loss': f'{losses["d_loss"]:.4f}',
    #             'G_loss': f'{losses["g_loss"]:.4f}',
    #             'L1_loss': f'{losses["l1_loss"]:.4f}'
    #         })

    #     # Display average loss after each epoch
    #     avg_d_loss = epoch_d_loss/num_batches
    #     avg_g_loss = epoch_g_loss/num_batches
    #     avg_l1_loss = epoch_l1_loss/num_batches

    #     # Validation set evaluation and early stopping logic
    #     val_g_loss = 0
    #     val_l1_loss = 0
    #     val_batches = 0

    #     for val_batch in val_dataset:
    #         val_condition_data, val_real_data = val_batch
    #         val_noise = tf.random.normal([tf.shape(val_condition_data)[0], noise_dim])
    #         val_fake_data = generator([val_condition_data, val_noise], training=False)

    #         val_fake_pred = discriminator([val_condition_data, val_fake_data], training=False)
    #         val_g_loss_batch = -tf.reduce_mean(val_fake_pred)
    #         val_l1_loss_batch = tf.reduce_mean(tf.abs(val_real_data - val_fake_data))

    #         val_g_loss += val_g_loss_batch
    #         val_l1_loss += val_l1_loss_batch
    #         val_batches += 1

    #     avg_val_g_loss = val_g_loss / val_batches
    #     avg_val_l1_loss = val_l1_loss / val_batches
    #     val_total_loss = avg_val_g_loss + 100.0 * avg_val_l1_loss

    #     print(f"Epoch {epoch+1}/{epochs} - D Loss: {avg_d_loss:.4f}, G Loss: {avg_g_loss:.4f}, L1 Loss: {avg_l1_loss:.4f}")
    #     print(f"Validation - G Loss: {avg_val_g_loss:.4f}, L1 Loss: {avg_val_l1_loss:.4f}, Total: {val_total_loss:.4f}")

    #     # Early stopping check
    #     if np.abs(val_total_loss) < np.abs(best_val_loss):
    #         best_val_loss = val_total_loss
    #         best_epoch = epoch + 1
    #         patience_counter = 0
    #         generator.save(f'{model_dir}{method}_CWGAN-Generator_{sensor_num}_best.keras')
    #         print(f"🎉 New best model! Validation loss: {best_val_loss:.4f}")
    #     else:
    #         patience_counter += 1
    #         print(f"⏳ Validation loss not improving ({patience_counter}/{patience})")

    #         if patience_counter >= patience:
    #             print(f"🛑 Early stopping! Best epoch: {best_epoch}, Best validation loss: {best_val_loss:.4f}")
    #             break

    #     # Save checkpoint every 10 epochs
    #     if (epoch + 1) % 10 == 0:
    #         print(f"Saving checkpoint - Epoch {epoch+1}")
    #         generator.save(f'{model_dir}{method}_CWGAN-Generator_{sensor_num}_epoch_{epoch+1}.keras')

    # # Save generator
    # generator.save(f'{model_dir}{method}_CWGAN-Generator_{sensor_num}.keras')
    # print(f"🏁 Training completed! Best epoch: {best_epoch}, Best validation loss: {best_val_loss:.4f}")

    # Load
    generator_path = f'{model_dir}{method}_CWGAN-Generator_{sensor_num}_best.keras'
    print(f"Loading best model: {generator_path}")
    generator = load_model(generator_path)
    num_generation = 1

    for t in range(2):
        if t == 0:
            # Predict test set
            # num_generation = 10
            pred_tests = []
            for _ in range(num_generation):
                test_noise = tf.random.normal([len(x_test_input), noise_dim])
                start_time = time.time()
                pred_single = generator.predict([x_test_input, test_noise])
                end_time = time.time()
                print(f"inference time: {end_time - start_time:.2f} seconds")
                pred_tests.append(pred_single)
            pred_test = np.mean(pred_tests, axis=0)
            # test_noise = tf.random.normal([len(x_test_input), noise_dim])
            # pred_test = generator.predict([x_test_input, test_noise])
        elif t == 1:
            # Predict perturbed test set
            # num_generation = 30
            pred_tests = []
            for _ in range(num_generation):
                test_noise = tf.random.normal([len(x_perturbed_test_input), noise_dim])
                pred_single = generator.predict([x_perturbed_test_input, test_noise])
                pred_tests.append(pred_single)
            pred_test = np.mean(pred_tests, axis=0)
            # test_noise = tf.random.normal([len(x_perturbed_test_input), noise_dim])
            # pred_test = generator.predict([x_perturbed_test_input, test_noise])

        # # Denormalize (restore original data scale)
        # pred_test_restored = restore_original_scale(pred_test, min_vals, max_vals)

        # # Calculate SSIM and PSNR scores
        # ssim_scores = []
        # # psnr_scores = []
        # for i in range(len(pred_test_restored)):
        #     ssim_u = ssim(pred_test_restored[i, :, :, 0], y_test[i, :, :, 0], data_range=y_test[i, :, :, 0].max() - y_test[i, :, :, 0].min(), multichannel=False)
        #     ssim_v = ssim(pred_test_restored[i, :, :, 1], y_test[i, :, :, 1], data_range=y_test[i, :, :, 1].max() - y_test[i, :, :, 1].min(), multichannel=False)
        #     ssim_score = (ssim_u + ssim_v) / 2
        #     # psnr_score = psnr(pred_test_restored[i, :, :, 0], y_test[i, :, :, 0],
        #     #                 data_range=y_test[i, :, :, 0].max() - y_test[i, :, :, 0].min())
        #     ssim_scores.append(ssim_score)
        #     # psnr_scores.append(psnr_score)

        # # Calculate average SSIM and PSNR
        # average_ssim = np.mean(ssim_scores)
        # # average_psnr = np.mean(psnr_scores)

        # cwgan_metrics = []
        # for i in range(len(pred_test_restored)):
        #     # Calculate for U and V components separately
        #     u_metrics = calculate_all_metrics(y_test[i, :, :, 0], pred_test_restored[i, :, :, 0])
        #     v_metrics = calculate_all_metrics(y_test[i, :, :, 1], pred_test_restored[i, :, :, 1])

        #     # Combine U and V metrics (take average)
        #     combined_metrics = {}
        #     for key in u_metrics:
        #         combined_metrics[key] = (u_metrics[key] + v_metrics[key]) / 2

        #     cwgan_metrics.append(combined_metrics)

        # # Calculate average metrics
        # cwgan_avg_metrics = {}
        # for key in cwgan_metrics[0]:
        #     cwgan_avg_metrics[key] = np.mean([m[key] for m in cwgan_metrics])

        # cwgan_avg_metrics['SSIM'] = average_ssim
        # # cwgan_avg_metrics['PSNR'] = average_psnr

        # print("\n=== CWGAN Evaluation Metrics ===")
        # for key, value in cwgan_avg_metrics.items():
        #     print(f"{key}: {value:.4f}")

        # if t == 0:
        #     np.save(f'{metric_dir}{method}_avg30_cwgan_metrics_{sensor_num}.npy', cwgan_avg_metrics)
        #     np.save(f'{pred_dir}{method}_avg30_cwgan_base_{sensor_num}_y_predict.npy', pred_test_restored)
        #     print(f"Saved normal test set results: sensor_num={sensor_num}")
        # elif t == 1:
        #     np.save(f'{metric_dir}{method}_avg30_cwgan_perturbed_metrics_{sensor_num}.npy', cwgan_avg_metrics)
        #     np.save(f'{pred_dir}{method}_avg30_cwgan_base_{sensor_num}_y_perturbed_predict.npy', pred_test_restored)
        #     print(f"Saved perturbed test set results: sensor_num={sensor_num}")

### UNet

#### Utility Functions

In [None]:
def build_unet(input_shape):
    inputs = Input(input_shape)

    # === Step 1: ZeroPadding to 16x16 ===
    padded_inputs = ZeroPadding2D(((0, 1), (0, 1)))(inputs)  # 15x15 -> 16x16

    # === Encoder part ===
    c1 = Conv2D(32, (3, 3), activation='relu', padding='same')(padded_inputs)
    c1 = Conv2D(32, (3, 3), activation='relu', padding='same')(c1)
    p1 = MaxPooling2D((2, 2), padding="valid")(c1)  # 16x16 -> 8x8

    c2 = Conv2D(64, (3, 3), activation='relu', padding='same')(p1)
    c2 = Conv2D(64, (3, 3), activation='relu', padding='same')(c2)
    p2 = MaxPooling2D((2, 2), padding="valid")(c2)  # 8x8 -> 4x4

    c3 = Conv2D(128, (3, 3), activation='relu', padding='same')(p2)
    c3 = Conv2D(128, (3, 3), activation='relu', padding='same')(c3)

    # === Decoder part ===
    u4 = UpSampling2D((2, 2))(c3)  # 4x4 -> 8x8
    u4 = concatenate([u4, c2])  # 8x8 concatenate with c2
    c4 = Conv2D(64, (3, 3), activation='relu', padding='same')(u4)
    c4 = Conv2D(64, (3, 3), activation='relu', padding='same')(c4)

    u5 = UpSampling2D((2, 2))(c4)  # 8x8 -> 16x16
    u5 = concatenate([u5, c1])  # 16x16 concatenate with c1
    c5 = Conv2D(32, (3, 3), activation='relu', padding='same')(u5)
    c5 = Conv2D(32, (3, 3), activation='relu', padding='same')(c5)

    # === Step 2: Cropping2D remove padding, restore to 15x15 ===
    cropped_outputs = Cropping2D(((0, 1), (0, 1)))(c5)  # 16x16 -> 15x15

    outputs = Conv2D(2, (1, 1), activation='linear')(cropped_outputs)
    model = Model(inputs, outputs)
    # model.compile(optimizer='adam', loss='mean_squared_error')
    model.compile(optimizer='adam', loss=weighted_vector_loss)

    return model

#### For loop train

In [None]:
sensor_list = [5, 10, 15, 20, 25, 30]
# sensor_list = [30]
for s in range(len(sensor_list)):
  sensor_num = sensor_list[s]
  x_train_input, y_train_output, x_val_input, y_val_output, x_test_input, x_perturbed_test_input, y_test_output, min_vals, max_vals, y_test, x_test = load_my_data(sensor_num, method=method)
  print(f'Dataset for {sensor_num} is ready.')
  input_shape = (15, 15, 3)
  unet_model = build_unet(input_shape)
  save_dir = model_dir
  # # Train
  # # Training parameters
  # batch_size = 8
  # epochs = 100

  # # Set callback functions (save best model)
  # model_checkpoint = ModelCheckpoint(
  #     filepath=f'{save_dir}{method}_U-Net-ImageReconstruction_{sensor_num}.keras',
  #     monitor='val_loss', save_best_only=True, verbose=1, mode='min'
  # )
  # reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, min_lr=1e-5)
  # early_stop = EarlyStopping(monitor='val_loss', patience=20, verbose=1, restore_best_weights=True)

  # # Train model
  # history = unet_model.fit(
  #     x_train_input, y_train_output,
  #     epochs=epochs,
  #     batch_size=batch_size,
  #     validation_data=(x_val_input, y_val_output),
  #     callbacks=[reduce_lr, early_stop, model_checkpoint],
  #     shuffle=True
  # )

  # Load
  model_path = f'{save_dir}{method}_U-Net-ImageReconstruction_{sensor_num}.keras'
  unet_model = load_model(model_path)

  for t in range(2):
    if t == 0:
      # Predict test set
      start_time = time.time()
      pred_test = unet_model.predict(x_test_input)
      end_time = time.time()
      print(f"inference time: {end_time - start_time:.2f} seconds")
    elif t == 1:
      pred_test = unet_model.predict(x_perturbed_test_input)


    # Denormalize (restore original data scale)
    # pred_test_restored = restore_original_scale(pred_test, min_vals, max_vals)

    # # Calculate SSIM and PSNR scores
    # ssim_scores = []
    # # psnr_scores = []
    # for i in range(len(pred_test_restored)):
    #     # Calculate SSIM for U and V components separately
    #     ssim_u = ssim(pred_test_restored[i, :, :, 0], y_test[i, :, :, 0], data_range=y_test[i, :, :, 0].max() - y_test[i, :, :, 0].min(), multichannel=False)
    #     ssim_v = ssim(pred_test_restored[i, :, :, 1], y_test[i, :, :, 1], data_range=y_test[i, :, :, 1].max() - y_test[i, :, :, 1].min(), multichannel=False)
    #     ssim_score = (ssim_u + ssim_v) / 2
    #     # psnr_score = psnr(pred_test_restored[i, :, :, 0], y_test[i, :, :, 0], data_range=y_test[i, :, :, 0].max() - y_test[i, :, :, 0].min())
    #     ssim_scores.append(ssim_score)
    #     # psnr_scores.append(psnr_score)

    # # Calculate average SSIM and PSNR
    # average_ssim = np.mean(ssim_scores)
    # # average_psnr = np.mean(psnr_scores)

    # unet_metrics = []
    # for i in range(len(pred_test_restored)):
    #     # Calculate for U and V components separately
    #     u_metrics = calculate_all_metrics(y_test[i, :, :, 0], pred_test_restored[i, :, :, 0])
    #     v_metrics = calculate_all_metrics(y_test[i, :, :, 1], pred_test_restored[i, :, :, 1])

    #     # Combine U and V metrics (take average)
    #     combined_metrics = {}
    #     for key in u_metrics:
    #         combined_metrics[key] = (u_metrics[key] + v_metrics[key]) / 2

    #     unet_metrics.append(combined_metrics)

    # # Calculate average metrics
    # unet_avg_metrics = {}
    # for key in unet_metrics[0]:
    #     unet_avg_metrics[key] = np.mean([m[key] for m in unet_metrics])

    # unet_avg_metrics['SSIM'] = average_ssim
    # # unet_avg_metrics['PSNR'] = average_psnr

    # print("\n=== UNet Evaluation Metrics ===")
    # for key, value in unet_avg_metrics.items():
    #     print(f"{key}: {value:.4f}")

    # if t == 0:
    #   np.save(f'{metric_dir}{method}_unet_metrics_{sensor_num}.npy', unet_avg_metrics)
    #   np.save(f'{pred_dir}{method}_unet_base_{sensor_num}_y_predict.npy', pred_test_restored)
    #   print(f"Saved normal test set results: sensor_num={sensor_num}")
    # elif t == 1:
    #   np.save(f'{metric_dir}{method}_unet_perturbed_metrics_{sensor_num}.npy', unet_avg_metrics)
    #   np.save(f'{pred_dir}{method}_unet_base_{sensor_num}_y_perturbed_predict.npy', pred_test_restored)
    #   print(f"Saved perturbed test set results: sensor_num={sensor_num}")


### ViTAE

#### Utility Functions

In [None]:
from keras.saving import register_keras_serializable
# 2D position encoding function (converted from pos_embed.py)
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int or tuple, representing grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim]
    """
    if isinstance(grid_size, int):
        grid_h_size = grid_w_size = grid_size
    elif isinstance(grid_size, (tuple, list)):
        grid_h_size, grid_w_size = grid_size

    grid_h = np.arange(grid_h_size, dtype=np.float32)
    grid_w = np.arange(grid_w_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # w first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed

def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # Use half dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: Output dimension for each position
    pos: List of positions to encode: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float32)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb

# TensorFlow version of ViTAE components
@register_keras_serializable()
class PatchEmbed(layers.Layer):
    """Split image into patches and apply linear projection"""
    def __init__(self, img_size, patch_size, in_chans, embed_dim):
        super().__init__()
        if isinstance(img_size, int):
            img_size = (img_size, img_size)
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.embed_dim = embed_dim

        self.proj = layers.Conv2D(embed_dim, kernel_size=patch_size, strides=patch_size, padding='valid')

    def call(self, x):
        # Save input batch size
        batch_size = tf.shape(x)[0]
        # Apply projection
        x = self.proj(x)
        # Use tf.shape instead of x.shape to get current size, and explicitly specify flattened dimensions
        hw = self.grid_size[0] * self.grid_size[1]
        # Reshape to [batch_size, num_patches, embed_dim]
        x = tf.reshape(x, [batch_size, hw, self.embed_dim])
        return x

@register_keras_serializable()
class MLPBlock(layers.Layer):
    """MLP module for Transformer blocks"""
    def __init__(self, hidden_dim, mlp_ratio=4.0, dropout_rate=0.0):
        super().__init__()
        self.fc1 = layers.Dense(int(hidden_dim * mlp_ratio))
        self.gelu = lambda x: x * tf.sigmoid(1.702 * x)  # GELU approximation
        self.fc2 = layers.Dense(hidden_dim)
        self.dropout = layers.Dropout(dropout_rate)

    def call(self, x, training=False):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.dropout(x, training=training)
        x = self.fc2(x)
        x = self.dropout(x, training=training)
        return x

@register_keras_serializable()
class TransformerBlock(layers.Layer):
    """Transformer block implementation"""
    def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, dropout_rate=0.0, attn_dropout_rate=0.0):
        super().__init__()
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.attn = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=dim//num_heads,
            dropout=attn_dropout_rate, use_bias=qkv_bias
        )
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.mlp = MLPBlock(dim, mlp_ratio, dropout_rate)

    def call(self, x, training=False):
        x_norm = self.norm1(x)
        # For MultiHeadAttention, we need to explicitly specify query/key/value are all x_norm
        attn_output = self.attn(x_norm, x_norm, x_norm, training=training)
        x = x + attn_output  # Residual connection
        x = x + self.mlp(self.norm2(x), training=training)  # Another residual connection
        return x

@register_keras_serializable()
class CNNDecBlock(layers.Layer):
    """CNN decoder block"""
    def __init__(self, out_chans, norm_layer=layers.BatchNormalization):
        super().__init__()
        self.conv = layers.Conv2D(out_chans, kernel_size=3, padding='same')
        self.norm = norm_layer() if norm_layer is not None else None
        self.act = layers.LeakyReLU(0.02)

    def call(self, x, training=False):
        x = self.conv(x)
        if self.norm is not None:
            x = self.norm(x, training=training)
        x = self.act(x)
        return x

@register_keras_serializable()
class ViTAutoEncoder(tf.keras.Model):
    """Vision Transformer Autoencoder implementation"""
    def __init__(
        self, input_size, in_chans, patch_size,
        enc_chans=1, enc_dim=128, enc_depth=8, enc_num_heads=8,
        enc_mlp_ratio=4., dec_dims=[16, 16, 16, 16, 16], **kwargs
    ):
        super().__init__(**kwargs)
        self.img_size = input_size
        # self.input_size = input_size
        # self.grid_size = (
        #     input_size[0] // patch_size,
        #     input_size[1] // patch_size
        # )
        self.in_chans = in_chans
        self.patch_size = patch_size
        self.enc_chans = enc_chans

        # ViT encoder part
        self.in_chans = in_chans
        self.enc_chans = enc_chans
        self.patch_size = patch_size

        if isinstance(input_size, int):
            self.grid_size = (input_size // patch_size, input_size // patch_size)
        elif isinstance(input_size, (tuple, list)):
            self.grid_size = (input_size[0] // patch_size, input_size[1] // patch_size)

        self.patch_embed = PatchEmbed(input_size, patch_size, in_chans, enc_dim)

        # Positional encoding
        num_patches = self.patch_embed.num_patches
        pos_embed = get_2d_sincos_pos_embed(enc_dim, self.grid_size, cls_token=False)
        self.pos_embed = tf.Variable(initial_value=pos_embed.reshape(1, num_patches, enc_dim),
                                     trainable=False, dtype=tf.float32, name='pos_embed')

        # Transformer blocks
        self.blocks = [
            TransformerBlock(
                dim=enc_dim, num_heads=enc_num_heads,
                mlp_ratio=enc_mlp_ratio, qkv_bias=True,
                dropout_rate=0.0, attn_dropout_rate=0.0
            ) for _ in range(enc_depth)
        ]
        self.norm = layers.LayerNormalization(epsilon=1e-6)
        self.encoder_out = layers.Conv2D(1, kernel_size=1, padding='valid')

        # CNN decoder part
        self.decoder_embed = layers.Dense(patch_size * patch_size * enc_chans)

        # Decoder CNN layers
        dec_dims = [enc_chans] + dec_dims
        self.decoder_cnn_blocks = [
            CNNDecBlock(dec_dims[i+1])
            for i in range(len(dec_dims)-1)
        ]
        self.decoder_out = layers.Conv2D(2, kernel_size=1, padding='valid')

    def get_config(self):
        config = super().get_config()
        config.update({
            "input_size": self.img_size,  # Original size/patch size => grid_size
            "in_chans": self.in_chans,
            "patch_size": self.patch_size,
            "enc_chans": self.enc_chans,
            "enc_dim": self.patch_embed.embed_dim,
            "enc_depth": len(self.blocks),
            "enc_num_heads": self.blocks[0].attn.num_heads,
            "enc_mlp_ratio": self.blocks[0].mlp.fc1.units / self.blocks[0].mlp.fc2.units,
            "dec_dims": [block.conv.filters for block in self.decoder_cnn_blocks]
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

    def unpatchify(self, x):
        """Convert tokens back to image format"""
        ph, pw = self.patch_size, self.patch_size
        h, w = self.grid_size
        batch_size = tf.shape(x)[0]

        # [B, h*w, ph*pw*C] -> [B, h, w, ph, pw, C]
        x = tf.reshape(x, [batch_size, h, w, ph, pw, self.enc_chans])
        # [B, h, w, ph, pw, C] -> [B, C, h, ph, w, pw]
        x = tf.transpose(x, [0, 5, 1, 3, 2, 4])
        # [B, C, h, ph, w, pw] -> [B, C, h*ph, w*pw]
        h_full = h * ph
        w_full = w * pw
        imgs = tf.reshape(x, [batch_size, self.enc_chans, h_full, w_full])
        # [B, C, H, W] -> [B, H, W, C]
        imgs = tf.transpose(imgs, [0, 2, 3, 1])
        return imgs

    def forward_encoder(self, x, training=False):
        # Embed patches
        x = self.patch_embed(x)
        # Add positional encoding - ensure pos_embed is broadcast to correct batch size
        x = x + self.pos_embed  # self.pos_embed is already [1, num_patches, dim], will auto-broadcast
        # Apply Transformer blocks
        for blk in self.blocks:
            x = blk(x, training=training)
        x = self.norm(x)
        return x

    def forward_decoder(self, x, training=False):
        # Embed tokens
        x = self.decoder_embed(x)
        # Reshape to 2D
        x_enc = self.unpatchify(x)
        pred_enc = self.encoder_out(x_enc)

        # 2D convolution layers
        x_dec = x_enc
        for block in self.decoder_cnn_blocks:
            x_dec = block(x_dec, training=training)
        pred_dec = self.decoder_out(x_dec)

        return pred_dec, pred_enc

    def call(self, inputs, training=False):
        latent = self.forward_encoder(inputs, training=training)
        pred_dec, pred_enc = self.forward_decoder(latent, training=training)
        # Only return pred_dec as main output, simplifying post-processing of model predictions
        return pred_dec

def build_vitae(input_shape, model_size='base'):
    """Build ViTAE model, similar to original PyTorch implementation model factory function"""
    # Extract input dimensions
    in_h, in_w, in_c = input_shape
    patch_size = 3  # Set appropriate patch size

    if model_size == 'lite':
        model = ViTAutoEncoder(
            input_size=(in_h, in_w), in_chans=in_c, patch_size=patch_size,
            enc_chans=16, enc_dim=32, enc_depth=8, enc_num_heads=8,
            enc_mlp_ratio=4, dec_dims=[16, 16, 16, 16, 16],
        )
    elif model_size == 'base':
        model = ViTAutoEncoder(
            input_size=(in_h, in_w), in_chans=in_c, patch_size=patch_size,
            enc_chans=32, enc_dim=64, enc_depth=8, enc_num_heads=8,
            enc_mlp_ratio=4, dec_dims=[32, 32, 32, 32, 32],
        )
    elif model_size == 'large':
        model = ViTAutoEncoder(
            input_size=(in_h, in_w), in_chans=in_c, patch_size=patch_size,
            enc_chans=64, enc_dim=128, enc_depth=8, enc_num_heads=8,
            enc_mlp_ratio=4, dec_dims=[64, 64, 64, 64, 64],
        )
    else:
        raise ValueError(f"Unsupported model size: {model_size}")

    return model

#### For Loop Train

In [None]:
sensor_list = [5,10,15,20,25,30]
for s in range(len(sensor_list)):
  sensor_num = sensor_list[s]
  x_train_input, y_train_output, x_val_input, y_val_output, x_test_input, x_perturbed_test_input, y_test_output, min_vals, max_vals, y_test, x_test = load_my_data(sensor_num, method=method)
  print(f'Dataset for {sensor_num} is ready.')
  input_shape = (15, 15, 3)
  vitae_model = build_vitae(input_shape, model_size='base')
  vitae_model.compile(optimizer='adam', loss='mean_squared_error')
  save_dir = model_dir
  
#   # Train
#   # Training parameters
#   batch_size = 8
#   epochs = 100

#   # Set callback functions (save best model)
#   model_checkpoint = ModelCheckpoint(
#       filepath=f'{save_dir}{method}_ViTAE_base-ImageReconstruction_{sensor_num}.keras',
#       monitor='val_loss', save_best_only=True, verbose=1, mode='min'
#   )
#   reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, min_lr=1e-5)
#   early_stop = EarlyStopping(monitor='val_loss', patience=20, verbose=1, restore_best_weights=True)

#   # Train model
#   history = vitae_model.fit(
#       x_train_input, y_train_output,
#       epochs=epochs,
#       batch_size=batch_size,
#       validation_data=(x_val_input, y_val_output),
#       callbacks=[reduce_lr, early_stop, model_checkpoint],
#       shuffle=True
#   )

  # Load
  model_path = f'{save_dir}{method}_ViTAE_base-ImageReconstruction_{sensor_num}.keras'
  vitae_model = load_model(model_path)

  for t in range(2):
    if t == 0:
      # Predict test set
      start_time = time.time()
      pred_test = vitae_model.predict(x_test_input)
      end_time = time.time()
      print(f"inference time: {end_time - start_time:.2f} seconds")
    # elif t == 1:
    #   pred_test = vitae_model.predict(x_perturbed_test_input)

  #   # Denormalize (restore original data scale)
  #   pred_test_restored = restore_original_scale(pred_test, min_vals, max_vals)

  #   # Calculate SSIM and PSNR scores
  #   ssim_scores = []
  #   # psnr_scores = []
  #   for i in range(len(pred_test_restored)):
  #       # Calculate SSIM for U and V components separately
  #       ssim_u = ssim(pred_test_restored[i, :, :, 0], y_test[i, :, :, 0], data_range=y_test[i, :, :, 0].max() - y_test[i, :, :, 0].min(), multichannel=False)
  #       ssim_v = ssim(pred_test_restored[i, :, :, 1], y_test[i, :, :, 1], data_range=y_test[i, :, :, 1].max() - y_test[i, :, :, 1].min(), multichannel=False)
  #       ssim_score = (ssim_u + ssim_v) / 2
  #       # psnr_score = psnr(pred_test_restored[i, :, :, 0], y_test[i, :, :, 0], data_range=y_test[i, :, :, 0].max() - y_test[i, :, :, 0].min())
  #       ssim_scores.append(ssim_score)
  #       # psnr_scores.append(psnr_score)

  #   # Calculate average SSIM and PSNR
  #   average_ssim = np.mean(ssim_scores)
  #   # average_psnr = np.mean(psnr_scores)

  #   vitae_metrics = []
  #   for i in range(len(pred_test_restored)):
  #       # Calculate for U and V components separately
  #       u_metrics = calculate_all_metrics(y_test[i, :, :, 0], pred_test_restored[i, :, :, 0])
  #       v_metrics = calculate_all_metrics(y_test[i, :, :, 1], pred_test_restored[i, :, :, 1])

  #       # Combine U and V metrics (take average)
  #       combined_metrics = {}
  #       for key in u_metrics:
  #           combined_metrics[key] = (u_metrics[key] + v_metrics[key]) / 2

  #       vitae_metrics.append(combined_metrics)

  #   # Calculate average metrics
  #   vitae_avg_metrics = {}
  #   for key in vitae_metrics[0]:
  #       vitae_avg_metrics[key] = np.mean([m[key] for m in vitae_metrics])

  #   vitae_avg_metrics['SSIM'] = average_ssim
  #   # vitae_avg_metrics['PSNR'] = average_psnr

  #   print("\n=== Kriging Additional Evaluation Metrics ===")
  #   for key, value in vitae_avg_metrics.items():
  #       print(f"{key}: {value:.4f}")

  #   if t == 0:
  #     np.save(f'{metric_dir}{method}_vitae_metrics_{sensor_num}.npy', vitae_avg_metrics)
  #     np.save(f'{pred_dir}{method}_vitae_base_{sensor_num}_y_predict.npy', pred_test_restored)
  #     print(f"Saved normal test set results: sensor_num={sensor_num}")
  #   elif t == 1:
  #     np.save(f'{metric_dir}{method}_vitae_perturbed_metrics_{sensor_num}.npy', vitae_avg_metrics)
  #     np.save(f'{pred_dir}{method}_vitae_base_{sensor_num}_y_perturbed_predict.npy', pred_test_restored)
  #     print(f"Saved perturbed test set results: sensor_num={sensor_num}")

In [None]:
sensor_list = [5,10,15,20,25,30]
for s in range(len(sensor_list)):
  sensor_num = sensor_list[s]
  x_train_input, y_train_output, x_val_input, y_val_output, x_test_input, x_perturbed_test_input, y_test_output, min_vals, max_vals, y_test, x_test = load_my_data(sensor_num, method=1)
  print(f'Dataset for {sensor_num} is ready.')
  input_shape = (15, 15, 3)
  vitae_model = build_vitae(input_shape, model_size='base')
  vitae_model.compile(optimizer='adam', loss='mean_squared_error')
  save_dir = model_dir
  
  # Train
  # Training parameters
  batch_size = 8
  epochs = 100

  # Set callback functions (save best model)
  model_checkpoint = ModelCheckpoint(
      filepath=f'{save_dir}1_ViTAE_base-ImageReconstruction_{sensor_num}.keras',
      monitor='val_loss', save_best_only=True, verbose=1, mode='min'
  )
  reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, min_lr=1e-5)
  early_stop = EarlyStopping(monitor='val_loss', patience=20, verbose=1, restore_best_weights=True)

  # Train model
  history = vitae_model.fit(
      x_train_input, y_train_output,
      epochs=epochs,
      batch_size=batch_size,
      validation_data=(x_val_input, y_val_output),
      callbacks=[reduce_lr, early_stop, model_checkpoint],
      shuffle=True
  )

In [None]:
# Model was lost due to incorrect model dir setting during training, but metrics can be calculated using correctly saved pred
sensor_list = [5,10,15,20,25,30]
for s in range(len(sensor_list)):
  sensor_num = sensor_list[s]
  for t in range(2):
    if t == 0:
      # Predict test set
      pred_test_restored= np.load(f'{pred_dir}{method}_vitae_base_{sensor_num}_y_predict.npy')
    elif t == 1:
      pred_test_restored = np.load(f'{pred_dir}{method}_vitae_base_{sensor_num}_y_perturbed_predict.npy')

        # Calculate SSIM and PSNR scores
    ssim_scores = []
    # psnr_scores = []
    for i in range(len(pred_test_restored)):
        # Calculate SSIM for U and V components separately
        ssim_u = ssim(pred_test_restored[i, :, :, 0], y_test[i, :, :, 0], data_range=y_test[i, :, :, 0].max() - y_test[i, :, :, 0].min(), multichannel=False)
        ssim_v = ssim(pred_test_restored[i, :, :, 1], y_test[i, :, :, 1], data_range=y_test[i, :, :, 1].max() - y_test[i, :, :, 1].min(), multichannel=False)
        ssim_score = (ssim_u + ssim_v) / 2
        # psnr_score = psnr(pred_test_restored[i, :, :, 0], y_test[i, :, :, 0], data_range=y_test[i, :, :, 0].max() - y_test[i, :, :, 0].min())
        ssim_scores.append(ssim_score)
        # psnr_scores.append(psnr_score)

    # Calculate average SSIM and PSNR
    average_ssim = np.mean(ssim_scores)
    # average_psnr = np.mean(psnr_scores)

    vitae_metrics = []
    for i in range(len(pred_test_restored)):
        # Calculate for U and V components separately
        u_metrics = calculate_all_metrics(y_test[i, :, :, 0], pred_test_restored[i, :, :, 0])
        v_metrics = calculate_all_metrics(y_test[i, :, :, 1], pred_test_restored[i, :, :, 1])

        # Combine U and V metrics (take average)
        combined_metrics = {}
        for key in u_metrics:
            combined_metrics[key] = (u_metrics[key] + v_metrics[key]) / 2

        vitae_metrics.append(combined_metrics)

    # Calculate average metrics
    vitae_avg_metrics = {}
    for key in vitae_metrics[0]:
        vitae_avg_metrics[key] = np.mean([m[key] for m in vitae_metrics])

    vitae_avg_metrics['SSIM'] = average_ssim
    # vitae_avg_metrics['PSNR'] = average_psnr

    print("\n=== ViTAE Additional Evaluation Metrics ===")
    for key, value in vitae_avg_metrics.items():
        print(f"{key}: {value:.4f}")

    if t == 0:
      np.save(f'{metric_dir}{method}_vitae_metrics_{sensor_num}.npy', vitae_avg_metrics)
      print(f"Saved normal test set results: sensor_num={sensor_num}")
    elif t == 1:
      np.save(f'{metric_dir}{method}_vitae_perturbed_metrics_{sensor_num}.npy', vitae_avg_metrics)
      print(f"Saved perturbed test set results: sensor_num={sensor_num}")

### Kriging

#### For Loop Train

In [None]:
sensor_list = [5,10,15,20,25,30]
#sensor_list = [30]
for s in range(len(sensor_list)):
  sensor_num = sensor_list[s]
  x_train_input, y_train_output, x_val_input, y_val_output, x_test_input, x_perturbed_test_input, y_test_output, min_vals, max_vals, y_test, x_test = load_my_data(sensor_num, method=method)
  print(f'Dataset for {sensor_num} is ready.')

  for t in range(2):
    if t == 0:
      start_time = time.time()
      pred_test = apply_kriging_reconstruction(x_test_input, y_test_output, method='ordinary')
      end_time = time.time()
      print(f"inference time: {end_time - start_time:.2f} seconds")
    # elif t == 1:
    #   pred_test = apply_kriging_reconstruction(x_perturbed_test_input, y_test_output, method='ordinary')

    # pred_kriging = restore_original_scale(pred_test, min_vals, max_vals)


    # # Calculate SSIM and PSNR scores
    # kriging_ssim_scores = []
    # # kriging_psnr_scores = []
    # for i in range(len(pred_kriging)):
    #   # Calculate SSIM for U and V components separately
    #   ssim_u = ssim(pred_kriging[i, :, :, 0], y_test[i, :, :, 0],
    #                 data_range=y_test[i, :, :, 0].max() - y_test[i, :, :, 0].min(),
    #                 multichannel=False)
    #   ssim_v = ssim(pred_kriging[i, :, :, 1], y_test[i, :, :, 1],
    #                 data_range=y_test[i, :, :, 1].max() - y_test[i, :, :, 1].min(),
    #                 multichannel=False)
    #   ssim_score = (ssim_u + ssim_v) / 2
    #   # psnr_score = psnr(pred_kriging[i, :, :, 0], y_test[i, :, :, 0],
    #   #                   data_range=y_test[i, :, :, 0].max() - y_test[i, :, :, 0].min())
    #   kriging_ssim_scores.append(ssim_score)
    #   # kriging_psnr_scores.append(psnr_score)

    # # Calculate average SSIM and PSNR
    # kriging_average_ssim = np.mean(kriging_ssim_scores)
    # # kriging_average_psnr = np.mean(kriging_psnr_scores)

    # # Calculate additional metrics for Kriging
    # kriging_metrics = []
    # for i in range(len(pred_kriging)):
    #     # Calculate for U and V components separately
    #     u_metrics = calculate_all_metrics(y_test[i, :, :, 0], pred_kriging[i, :, :, 0])
    #     v_metrics = calculate_all_metrics(y_test[i, :, :, 1], pred_kriging[i, :, :, 1])

    #     # Combine U and V metrics (take average)
    #     combined_metrics = {}
    #     for key in u_metrics:
    #         combined_metrics[key] = (u_metrics[key] + v_metrics[key]) / 2

    #     kriging_metrics.append(combined_metrics)

    # # Calculate average metrics
    # kriging_avg_metrics = {}
    # for key in kriging_metrics[0]:
    #     kriging_avg_metrics[key] = np.mean([m[key] for m in kriging_metrics])

    # kriging_avg_metrics['SSIM'] = kriging_average_ssim
    # # kriging_avg_metrics['PSNR'] = kriging_average_psnr

    # # print("\n=== Kriging Additional Evaluation Metrics ===")
    # # for key, value in kriging_avg_metrics.items():
    # #     print(f"{key}: {value:.4f}")

    # if t == 0:
    #   np.save(f'{metric_dir}{method}_kriging_metrics_{sensor_num}.npy', kriging_avg_metrics)
    #   np.save(f'{pred_dir}{method}_kriging_base_{sensor_num}_y_predict.npy', pred_kriging)
    #   print(f"Saved normal test set results: sensor_num={sensor_num}")
    # elif t == 1:
    #   np.save(f'{metric_dir}{method}_kriging_perturbed_metrics_{sensor_num}.npy', kriging_avg_metrics)
    #   np.save(f'{pred_dir}{method}_kriging_base_{sensor_num}_y_perturbed_predict.npy', pred_kriging)
    #   print(f"Saved perturbed test set results: sensor_num={sensor_num}")

In [None]:
# Temporarily calculate fac2
# Model was lost due to incorrect model dir setting during training, but metrics can be calculated using correctly saved pred
sensor_list = [5,10,15,20,25,30]
for s in range(len(sensor_list)):
  sensor_num = sensor_list[s]
  for t in range(2):
    if t == 0:
      # Predict test set
      pred_test_restored= np.load(f'{pred_dir}{method}_kriging_base_{sensor_num}_y_predict.npy')
    elif t == 1:
      pred_test_restored = np.load(f'{pred_dir}{method}_kriging_base_{sensor_num}_y_perturbed_predict.npy')

        # Calculate SSIM and PSNR scores
    ssim_scores = []
    # psnr_scores = []
    for i in range(len(pred_test_restored)):
        # Calculate SSIM for U and V components separately
        ssim_u = ssim(pred_test_restored[i, :, :, 0], y_test[i, :, :, 0], data_range=y_test[i, :, :, 0].max() - y_test[i, :, :, 0].min(), multichannel=False)
        ssim_v = ssim(pred_test_restored[i, :, :, 1], y_test[i, :, :, 1], data_range=y_test[i, :, :, 1].max() - y_test[i, :, :, 1].min(), multichannel=False)
        ssim_score = (ssim_u + ssim_v) / 2
        # psnr_score = psnr(pred_test_restored[i, :, :, 0], y_test[i, :, :, 0], data_range=y_test[i, :, :, 0].max() - y_test[i, :, :, 0].min())
        ssim_scores.append(ssim_score)
        # psnr_scores.append(psnr_score)

    # Calculate average SSIM and PSNR
    average_ssim = np.mean(ssim_scores)
    # average_psnr = np.mean(psnr_scores)

    vitae_metrics = []
    for i in range(len(pred_test_restored)):
        # Calculate for U and V components separately
        u_metrics = calculate_all_metrics(y_test[i, :, :, 0], pred_test_restored[i, :, :, 0])
        v_metrics = calculate_all_metrics(y_test[i, :, :, 1], pred_test_restored[i, :, :, 1])

        # Combine U and V metrics (take average)
        combined_metrics = {}
        for key in u_metrics:
            combined_metrics[key] = (u_metrics[key] + v_metrics[key]) / 2

        vitae_metrics.append(combined_metrics)

    # Calculate average metrics
    vitae_avg_metrics = {}
    for key in vitae_metrics[0]:
        vitae_avg_metrics[key] = np.mean([m[key] for m in vitae_metrics])

    vitae_avg_metrics['SSIM'] = average_ssim
    # vitae_avg_metrics['PSNR'] = average_psnr

    print("\n=== Kriging Additional Evaluation Metrics ===")
    for key, value in vitae_avg_metrics.items():
        print(f"{key}: {value:.4f}")

    if t == 0:
      np.save(f'{metric_dir}{method}_kriging_metrics_{sensor_num}.npy', vitae_avg_metrics)
      print(f"Saved normal test set results: sensor_num={sensor_num}")
    elif t == 1:
      np.save(f'{metric_dir}{method}_kriging_perturbed_metrics_{sensor_num}.npy', vitae_avg_metrics)
      print(f"Saved perturbed test set results: sensor_num={sensor_num}")

# QR

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import qr
from sklearn.model_selection import train_test_split

def build_pod_basis(y_data, r=10):
    """
    Build POD basis from flow field data

    Parameters:
        y_data: (N, 15, 15, 2) - Flow field data, u and v components
        r: Number of POD modes to retain

    Returns:
        Psi_r: (450, r) - POD basis matrix
        explained_variance: Energy proportion of first r modes
    """
    N, H, W, C = y_data.shape

    # Reshape data to (450, N) - each row is a time series for one position
    data_matrix = y_data.reshape(N, H*W*C).T  # (450, N)

    # Center the data
    data_mean = np.mean(data_matrix, axis=1, keepdims=True)
    data_centered = data_matrix - data_mean

    # SVD decomposition to get POD modes
    U, S, Vt = np.linalg.svd(data_centered, full_matrices=False)

    # Calculate energy proportion
    total_energy = np.sum(S**2)
    explained_variance = np.cumsum(S**2) / total_energy

    # Select first r modes
    Psi_r = U[:, :r]

    print(f"First {r} modes explain {explained_variance[r-1]:.4f} of total energy")

    return Psi_r, explained_variance[:r]

def qr_sensor_ranking(Psi_r):
    """
    Rank all sensor positions by importance using QR decomposition

    Parameters:
        Psi_r: (450, r) - POD basis matrix

    Returns:
        ranking_results: Dictionary containing complete ranking information
    """
    # QR decomposition with column pivoting
    Q, R, P = qr(Psi_r.T, pivoting=True)

        # P is the importance ranking! P[0] is the most important position index
    print(f"QR decomposition completed, total {len(P)} positions")
    
    # Convert indices to spatial position information
    ranking_results = {
        'global_ranking': P,
        'spatial_info': [],
        'importance_scores': []
    }

    for rank, idx in enumerate(P):
        # Convert index to spatial coordinates
        spatial_idx = idx // 2  # Spatial position index (0-224)
        component = idx % 2     # Component index (0=u, 1=v)

        row = spatial_idx // 15  # Row coordinate
        col = spatial_idx % 15   # Column coordinate

        # Calculate importance score (normalized to 0-1)
        importance_score = (len(P) - rank) / len(P)

        spatial_info = {
            'global_rank': rank + 1,
            'position': (row, col),
            'component': 'u' if component == 0 else 'v',
            'importance_score': importance_score,
            'percentile': (len(P) - rank) / len(P) * 100
        }

        ranking_results['spatial_info'].append(spatial_info)
        ranking_results['importance_scores'].append(importance_score)

    return ranking_results

def analyze_current_sensors(current_sensor_positions, ranking_results):
    """
    Analyze current sensor positions in global ranking

    Parameters:
        current_sensor_positions: [(row, col, component), ...] - Current sensor positions
        ranking_results: QR ranking results

    Returns:
        sensor_analysis: Detailed analysis of current sensors
    """
    sensor_analysis = []

    for sensor_pos in current_sensor_positions:
        row, col, comp = sensor_pos

        # Find this sensor in ranking results
        found = False
        for spatial_info in ranking_results['spatial_info']:
            if (spatial_info['position'] == (row, col) and
                spatial_info['component'] == comp):

                analysis = {
                    'sensor_position': sensor_pos,
                    'global_rank': spatial_info['global_rank'],
                    'importance_score': spatial_info['importance_score'],
                    'percentile': spatial_info['percentile'],
                    'category': categorize_importance(spatial_info['global_rank'])
                }
                sensor_analysis.append(analysis)
                found = True
                break

        if not found:
            print(f"Warning: Sensor position {sensor_pos} not found")
    
    # Sort by importance
    sensor_analysis.sort(key=lambda x: x['global_rank'])

    return sensor_analysis

def categorize_importance(rank, total=450):
    """Categorize sensor importance"""
    percentile = rank / total * 100

    if percentile <= 10:
        return "Critical (Top 10%)"
    elif percentile <= 30:
        return "Important (Top 30%)"
    elif percentile <= 60:
        return "Moderate (Top 60%)"
    else:
        return "Low (Bottom 40%)"

def extract_current_sensor_positions(x_data):
    """
    Extract current sensor positions from third channel of input data

    Parameters:
        x_data: (N, 15, 15, 3) - Input data, third channel is sensor mask

    Returns:
        sensor_positions: [(row, col, 'u'), (row, col, 'v'), ...] - List of sensor positions
    """
    # Use sensor mask from first sample
    sensor_mask = x_data[0, :, :, 2]
    sensor_coords = np.where(sensor_mask == 1)

    sensor_positions = []
    for row, col in zip(sensor_coords[0], sensor_coords[1]):
        # Assume each sensor position measures both u and v
        sensor_positions.append((row, col, 'u'))
        sensor_positions.append((row, col, 'v'))

    return sensor_positions

def visualize_sensor_importance(ranking_results, current_sensors=None, top_k=20):
    """
    Visualize sensor importance distribution
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # 1. Importance heatmap
    importance_map_u = np.zeros((15, 15))
    importance_map_v = np.zeros((15, 15))

    for info in ranking_results['spatial_info']:
        row, col = info['position']
        score = info['importance_score']

        if info['component'] == 'u':
            importance_map_u[row, col] = score
        else:
            importance_map_v[row, col] = score

    im1 = ax1.imshow(importance_map_u, cmap='viridis', vmin=0, vmax=1)
    ax1.set_title('U Component Importance')
    ax1.set_xlabel('Column')
    ax1.set_ylabel('Row')
    plt.colorbar(im1, ax=ax1)

    # Mark current sensor positions
    if current_sensors:
        for sensor in current_sensors:
            if sensor[2] == 'u':
                ax1.scatter(sensor[1], sensor[0], c='red', s=100, marker='x')

    im2 = ax2.imshow(importance_map_v, cmap='viridis', vmin=0, vmax=1)
    ax2.set_title('V Component Importance')
    ax2.set_xlabel('Column')
    ax2.set_ylabel('Row')
    plt.colorbar(im2, ax=ax2)

    # Mark current sensor positions
    if current_sensors:
        for sensor in current_sensors:
            if sensor[2] == 'v':
                ax2.scatter(sensor[1], sensor[0], c='red', s=100, marker='x')

    plt.tight_layout()
    plt.show()

    # 2. Top-k important positions list
    print(f"\n=== Top {top_k} Most Important Sensor Positions ===")
    for i in range(min(top_k, len(ranking_results['spatial_info']))):
        info = ranking_results['spatial_info'][i]
        print(f"Rank {info['global_rank']:3d}: Position {info['position']}, "
              f"Component {info['component']}, Score {info['importance_score']:.4f}")

def main_qr_analysis(sensor_num=5):
    """
    Main analysis function
    """
    print(f"=== QR Sensor Position Analysis (sensor_num={sensor_num}) ===\n")

    # 1. Load data
    out_dir = '/content/drive/MyDrive/TorchDA/dataset'
    x1_data = np.load(f'{out_dir}/x1_data_{sensor_num}.npy', mmap_mode='r')
    y1_data = np.load(f'{out_dir}/y1_data_{sensor_num}.npy', mmap_mode='r')
    x2_data = np.load(f'{out_dir}/x2_data_{sensor_num}.npy', mmap_mode='r')
    y2_data = np.load(f'{out_dir}/y2_data_{sensor_num}.npy', mmap_mode='r')

    # Combine training data
    y_train = np.concatenate([y1_data, y2_data], axis=0)
    x_train = np.concatenate([x1_data, x2_data], axis=0)

    print(f"Training data shape: {y_train.shape}")

    # 2. Build POD basis
    print("\n=== Step 1: Build POD Basis ===")
    Psi_r, explained_var = build_pod_basis(y_train, r=40)

    # 3. QR ranking analysis
    print("\n=== Step 2: QR Sensor Position Ranking ===")
    ranking_results = qr_sensor_ranking(Psi_r)

    # 4. Analyze current sensor positions
    print("\n=== Step 3: Analyze Current Sensor Positions ===")
    current_sensors = extract_current_sensor_positions(x_train)
    sensor_analysis = analyze_current_sensors(current_sensors, ranking_results)

    print(f"Current sensor analysis results:")
    for analysis in sensor_analysis:
        print(f"Sensor {analysis['sensor_position']}: "
              f"Global rank #{analysis['global_rank']}, "
              f"Importance score {analysis['importance_score']:.4f}, "
              f"Category: {analysis['category']}")

    # 5. Visualization
    print("\n=== Step 4: Visualization Analysis Results ===")
    visualize_sensor_importance(ranking_results, current_sensors, top_k=20)

    return ranking_results, sensor_analysis

# # Run analysis
# if __name__ == "__main__":
#     # Analyze case with 5 sensors
#     ranking_results, sensor_analysis = main_qr_analysis(sensor_num=5)

    # Can also batch analyze different sensor numbers
    # for sensor_num in [5, 10, 15, 20, 25, 30]:
    #     print(f"\n{'='*50}")
    #     main_qr_analysis(sensor_num=sensor_num)

In [None]:
ranking_results, sensor_analysis = main_qr_analysis(sensor_num=5)