In [None]:
import sys
import numpy as np
import pandas as pd
from scipy.interpolate import griddata
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

dir = 'dataset/'

# Utilities

In [None]:
def generate_perturbed_test_data(X, y, grid_shape, max_perturbation=1):
    """
    Generate a new test set with perturbed sensor positions based on the original test set.

    Parameters:
      X: Original test set data, shape (N, 15, 15, 3)
      y: Original test set ground truth, shape (N, 15, 15, 2)
      grid_shape: Grid shape, e.g., (15, 15)
      max_perturbation: Maximum perturbation range (grid cells), default is 2

    Returns:
      X_perturbed: Test set data after sensor position perturbation
    """
    n_rows, n_cols = grid_shape
    num_samples = X.shape[0]

    # Create array to store perturbed data
    X_perturbed = np.zeros_like(X)

    # Get grid coordinates (consistent with generate_data_csv function)
    x_coords = np.linspace(0, 1, n_cols)  # Assume coordinates are in [0,1] range
    y_coords = np.linspace(0, 1, n_rows)
    xv, yv = np.meshgrid(x_coords, y_coords)

    # Extract sensor positions from the first sample as template
    # We assume all samples use the same sensor position layout
    template_mask = X[0, :, :, 2]
    template_sensor_rows, template_sensor_cols = np.where(template_mask == 1)
    template_sensor_positions = np.column_stack((template_sensor_rows, template_sensor_cols))

    # Add random perturbation to each sensor position in the template to generate a unified perturbed template
    perturbed_positions = np.zeros_like(template_sensor_positions)
    for j, (row, col) in enumerate(template_sensor_positions):
        # Generate random perturbation (between -max_perturbation and max_perturbation)
        perturbation_row = np.random.randint(-max_perturbation, max_perturbation + 1)
        perturbation_col = np.random.randint(-max_perturbation, max_perturbation + 1)

        # Apply perturbation, ensuring position remains within grid bounds
        new_row = np.clip(row + perturbation_row, 0, n_rows - 1)
        new_col = np.clip(col + perturbation_col, 0, n_cols - 1)

        perturbed_positions[j] = [new_row, new_col]

    # Construct unified perturbed sensor mask
    new_mask = np.zeros(grid_shape, dtype=np.float32)
    new_mask[perturbed_positions[:, 0].astype(int), perturbed_positions[:, 1].astype(int)] = 1

    print("Generated unified perturbed sensor mask, all test samples will share this mask")

    # Process each sample
    for i in tqdm(range(num_samples), desc="Processing perturbed samples"):
        # Get U, V values at perturbed sensor positions
        gt_U = y[i, :, :, 0]
        gt_V = y[i, :, :, 1]

        sensor_data_U = gt_U[perturbed_positions[:, 0].astype(int), perturbed_positions[:, 1].astype(int)]
        sensor_data_V = gt_V[perturbed_positions[:, 0].astype(int), perturbed_positions[:, 1].astype(int)]

        # Get physical coordinates corresponding to sensors
        sensor_coords = np.column_stack((
            xv[perturbed_positions[:, 0].astype(int), perturbed_positions[:, 1].astype(int)],
            yv[perturbed_positions[:, 0].astype(int), perturbed_positions[:, 1].astype(int)]
        ))

        # Interpolate U and V components separately using griddata
        grid_U = griddata(sensor_coords, sensor_data_U, (xv, yv), method='nearest')
        grid_V = griddata(sensor_coords, sensor_data_V, (xv, yv), method='nearest')

        # Save perturbed results, all samples use the same perturbed mask
        X_perturbed[i, :, :, 0] = grid_U
        X_perturbed[i, :, :, 1] = grid_V
        X_perturbed[i, :, :, 2] = new_mask

    return X_perturbed

def visualize_perturbation(X, X_perturbed, y, sample_index=0):
    """
    Visualize comparison of original and perturbed sensor positions and interpolation results.

    Parameters:
      X: Original data, shape (N, 15, 15, 3)
      X_perturbed: Perturbed data, shape (N, 15, 15, 3)
      y: Ground truth data, shape (N, 15, 15, 2)
      sample_index: Index of sample to visualize, default is 0
    """
    # Extract original and perturbed masks
    original_mask = X[sample_index, :, :, 2]
    perturbed_mask = X_perturbed[sample_index, :, :, 2]

    # Extract original and perturbed U, V interpolation fields
    original_U = X[sample_index, :, :, 0]
    original_V = X[sample_index, :, :, 1]
    perturbed_U = X_perturbed[sample_index, :, :, 0]
    perturbed_V = X_perturbed[sample_index, :, :, 1]

    # Extract ground truth
    gt_U = y[sample_index, :, :, 0]
    gt_V = y[sample_index, :, :, 1]

    # Create grid coordinates (assuming 15x15 grid)
    n_rows, n_cols = original_mask.shape
    x_coords = np.linspace(0, 1, n_cols)
    y_coords = np.linspace(0, 1, n_rows)
    extent = (x_coords.min(), x_coords.max(), y_coords.min(), y_coords.max())

    # Create plot
    fig, axs = plt.subplots(2, 3, figsize=(18, 10))

    # Ground Truth U
    im0 = axs[0, 0].imshow(gt_U, origin='lower', extent=extent)
    axs[0, 0].set_title("Ground Truth U")
    axs[0, 0].set_xlabel("X")
    axs[0, 0].set_ylabel("Y")
    fig.colorbar(im0, ax=axs[0, 0])

    # Original sensor positions and interpolated U
    im1 = axs[0, 1].imshow(original_U, origin='lower', extent=extent)
    axs[0, 1].set_title("Original Interpolated U")
    axs[0, 1].set_xlabel("X")
    axs[0, 1].set_ylabel("Y")
    fig.colorbar(im1, ax=axs[0, 1])

    # Perturbed sensor positions and interpolated U
    im2 = axs[0, 2].imshow(perturbed_U, origin='lower', extent=extent)
    axs[0, 2].set_title("Perturbed Interpolated U")
    axs[0, 2].set_xlabel("X")
    axs[0, 2].set_ylabel("Y")
    fig.colorbar(im2, ax=axs[0, 2])

    # Ground Truth V
    im3 = axs[1, 0].imshow(gt_V, origin='lower', extent=extent)
    axs[1, 0].set_title("Ground Truth V")
    axs[1, 0].set_xlabel("X")
    axs[1, 0].set_ylabel("Y")
    fig.colorbar(im3, ax=axs[1, 0])

    # Compare original and perturbed sensor positions
    axs[1, 1].imshow(original_mask, origin='lower', extent=extent, alpha=0.5)
    axs[1, 1].set_title("Original Sensor Positions")
    axs[1, 1].set_xlabel("X")
    axs[1, 1].set_ylabel("Y")

    axs[1, 2].imshow(perturbed_mask, origin='lower', extent=extent, alpha=0.5)
    axs[1, 2].set_title("Perturbed Sensor Positions")
    axs[1, 2].set_xlabel("X")
    axs[1, 2].set_ylabel("Y")

    plt.tight_layout()
    plt.show()

In [None]:
def get_uniform_sensor_positions_center(grid_shape, sensor_num):
    """
    Given a 2D grid shape grid_shape (e.g., (15,15)) and the required number of sensors sensor_num,
    first divide the grid into several regions and take the center point of each region as candidate sensor positions,
    then uniformly sample sensor_num sensor positions from the candidates.

    Parameters:
      grid_shape: Shape of the 2D grid (n_rows, n_cols)
      sensor_num: Required number of sensors

    Returns:
      sensor_positions: Array of shape (sensor_num, 2), each row represents [row_index, col_index]
    """
    n_rows, n_cols = grid_shape

    # First estimate how many segments in row and column directions
    # Here we take m = floor(sqrt(sensor_num)), n = ceil(sensor_num/m)
    m = int(np.floor(np.sqrt(sensor_num)))
    n = int(np.ceil(sensor_num / m))

    # Calculate center point indices for each region:
    # For row direction, divide [0, n_rows) into m regions, each region width is region_h = n_rows/m
    # Center point is approximately int((i + 0.5) * region_h)
    region_h = n_rows / m
    region_w = n_cols / n
    row_centers = [int((i + 0.5) * region_h) for i in range(m)]
    col_centers = [int((j + 0.5) * region_w) for j in range(n)]

    # Generate candidate sensor positions (Cartesian product)
    candidate_positions = np.array([[r, c] for r in row_centers for c in col_centers])

    # If number of candidates exceeds sensor_num, uniformly select sensor_num points from them
    total_candidates = candidate_positions.shape[0]
    if total_candidates > sensor_num:
        indices = np.linspace(0, total_candidates - 1, sensor_num, dtype=int)
        sensor_positions = candidate_positions[indices]
    else:
        sensor_positions = candidate_positions  # If exactly equal to or less than sensor_num

    return sensor_positions

In [None]:
def load_optimal_sensors(method, sensor_num, sensors_dir='./optimal_sensors/'):
    """
    Load optimal sensor positions for specified method and sensor_num

    Parameters:
        method: Method number (0 or 1)
        sensor_num: Number of sensors
        sensors_dir: Directory where files are saved

    Returns:
        positions: Sensor position array, shape: (sensor_num, 2)
    """
    filename = f'optimal_sensors_method{method}_num{sensor_num}.npy'
    filepath = os.path.join(sensors_dir, filename)

    if not os.path.exists(filepath):
        raise FileNotFoundError(f"Optimal sensor file not found: {filepath}")

    positions = np.load(filepath)  # shape: (sensor_num, 2)
    print(f"Loaded optimal sensor positions: {filename}, shape: {positions.shape}")

    return positions

In [None]:
def generate_data_csv(csv_file, sensor_num=10, method=0, use_optimal=True, sensors_dir='./optimal_sensors/'):
    """
    Read CSV data and generate interpolated input and ground truth data.

    Parameters:
      csv_file: CSV file path (e.g., '0deg_1.csv')
      sensor_num: Number of sensors per sample (default is 10)
      method: Data method number, used to select corresponding optimal sensor positions
      use_optimal: Whether to use optimal sensor positions, False uses uniform distribution
      sensors_dir: Directory for optimal sensor position files

    Returns:
      X: Interpolation results and sensor mask, shape (N, 15, 15, 3)
         Channel 0: Interpolated field of U component
         Channel 1: Interpolated field of V component
         Channel 2: Sensor position mask (1 at sensor points, 0 elsewhere)
      y: Ground truth data, containing complete U, V fields, shape (N, 15, 15, 2)
    """
    # Read CSV data
    df = pd.read_csv(csv_file)

    # Get all unique times corresponding to samples
    times = np.sort(df['Time'].unique())
    num_samples = len(times)

    # Get grid x and y coordinates (assuming all samples have consistent X and Y coordinates)
    x_coords = np.sort(df['X'].unique())
    y_coords = np.sort(df['Y'].unique())

    # Build 15√ó15 grid
    xv, yv = np.meshgrid(x_coords, y_coords)
    grid_shape = xv.shape  # Should be (15, 15)

    # Pre-allocate output arrays
    X = np.zeros((num_samples, grid_shape[0], grid_shape[1], 3), dtype=np.float32)
    y = np.zeros((num_samples, grid_shape[0], grid_shape[1], 2), dtype=np.float32)

    # Select sensor position strategy
    if use_optimal:
        try:
            sensor_positions = load_optimal_sensors(method, sensor_num, sensors_dir)
            print(f"Using optimal sensor positions (Method {method})")
        except FileNotFoundError as e:
            print(f"Warning: {e}")
            print("Falling back to uniform distribution sensor positions")
            sensor_positions = get_uniform_sensor_positions_center(grid_shape, sensor_num)
    else:
        sensor_positions = get_uniform_sensor_positions_center(grid_shape, sensor_num)
        print("Using uniform distribution sensor positions")

    print(f"Sensor positions: {sensor_positions[:5]}...")  # Show first 5 positions

    # Iterate through each sample (each unique Time value)
    for i, t in enumerate(tqdm(times, desc="Processing samples")):
        # Select data for current time
        df_t = df[df['Time'] == t]

        # Use pivot to convert scattered data to 2D grid data
        U_grid = df_t.pivot(index='Y', columns='X', values='U').values  # shape (15,15)
        V_grid = df_t.pivot(index='Y', columns='X', values='V').values  # shape (15,15)

        # Save ground truth data
        y[i, :, :, 0] = U_grid
        y[i, :, :, 1] = V_grid

        # Extract sensor data: get U and V values at sensor positions separately
        sensor_data_U = U_grid[sensor_positions[:, 0], sensor_positions[:, 1]]
        sensor_data_V = V_grid[sensor_positions[:, 0], sensor_positions[:, 1]]

        # Get real physical coordinates corresponding to sensors
        sensor_coords = np.column_stack((
            xv[sensor_positions[:, 0], sensor_positions[:, 1]],
            yv[sensor_positions[:, 0], sensor_positions[:, 1]]
        ))

        # Interpolate U and V components separately using griddata
        grid_U = griddata(sensor_coords, sensor_data_U, (xv, yv), method='nearest')
        grid_V = griddata(sensor_coords, sensor_data_V, (xv, yv), method='nearest')

        # Construct sensor mask: assign 1 at sensor_positions, 0 elsewhere
        mask = np.zeros(grid_shape, dtype=np.float32)
        mask[sensor_positions[:, 0], sensor_positions[:, 1]] = 1

        # Save interpolation results and mask to X
        X[i, :, :, 0] = grid_U
        X[i, :, :, 1] = grid_V
        X[i, :, :, 2] = mask

    return X, y

In [None]:
def visualize_interpolation(X, y, csv_file, sample_index=0):
    """
    Visualize U, V fields and sensor positions before and after interpolation for a sample.

    Parameters:
      X: Interpolated data and sensor mask (shape [num_samples, 15,15,3])
      y: Ground truth data (shape [num_samples, 15,15,2])
      csv_file: Original CSV file path, used to extract grid coordinates
      sample_index: Index of sample to visualize, default is 0
    """
    # Get grid coordinates from CSV
    df = pd.read_csv(csv_file)
    x_coords = np.sort(df['X'].unique())
    y_coords = np.sort(df['Y'].unique())
    extent = (x_coords.min(), x_coords.max(), y_coords.min(), y_coords.max())

    # Extract interpolated data and ground truth for specified sample
    interp_U = X[sample_index, :, :, 0]
    interp_V = X[sample_index, :, :, 1]
    mask     = X[sample_index, :, :, 2]
    gt_U     = y[sample_index, :, :, 0]
    gt_V     = y[sample_index, :, :, 1]

    # Plot: left side shows ground truth, right side shows interpolation results; another plot shows sensor positions
    fig, axs = plt.subplots(2, 3, figsize=(18, 10))

    # Ground Truth U
    im0 = axs[0, 0].imshow(gt_U, origin='lower', extent=extent)
    axs[0, 0].set_title("Ground Truth U")
    axs[0, 0].set_xlabel("X")
    axs[0, 0].set_ylabel("Y")
    fig.colorbar(im0, ax=axs[0, 0])

    # Interpolated U
    im1 = axs[0, 1].imshow(interp_U, origin='lower', extent=extent)
    axs[0, 1].set_title("Interpolated U")
    axs[0, 1].set_xlabel("X")
    axs[0, 1].set_ylabel("Y")
    fig.colorbar(im1, ax=axs[0, 1])

    # Sensor mask
    im2 = axs[0, 2].imshow(mask, origin='lower', extent=extent)
    axs[0, 2].set_title("Sensor Mask")
    axs[0, 2].set_xlabel("X")
    axs[0, 2].set_ylabel("Y")
    fig.colorbar(im2, ax=axs[0, 2])

    # Ground Truth V
    im3 = axs[1, 0].imshow(gt_V, origin='lower', extent=extent)
    axs[1, 0].set_title("Ground Truth V")
    axs[1, 0].set_xlabel("X")
    axs[1, 0].set_ylabel("Y")
    fig.colorbar(im3, ax=axs[1, 0])

    # Interpolated V
    im4 = axs[1, 1].imshow(interp_V, origin='lower', extent=extent)
    axs[1, 1].set_title("Interpolated V")
    axs[1, 1].set_xlabel("X")
    axs[1, 1].set_ylabel("Y")
    fig.colorbar(im4, ax=axs[1, 1])

    # Overlay sensor positions on ground truth U plot
    axs[1, 2].imshow(gt_U, origin='lower', extent=extent)
    # Get grid indices where sensors are located based on mask
    sensor_rows, sensor_cols = np.where(mask == 1)
    # Map grid indices to physical coordinates (assuming x_coords and y_coords correspond to columns and rows respectively)
    sensor_x = x_coords[sensor_cols]
    sensor_y = y_coords[sensor_rows]
    axs[1, 2].scatter(sensor_x, sensor_y, color='red', label="Sensors")
    axs[1, 2].set_title("Sensors on Ground Truth U")
    axs[1, 2].set_xlabel("X")
    axs[1, 2].set_ylabel("Y")
    axs[1, 2].legend()

    plt.tight_layout()
    plt.show()

# Generate Data

In [None]:
method=0

In [None]:
# Generate data, sensor number is fixed at 10 here
sensor_num = [5, 10, 15, 20, 25, 30]

for i in range(len(sensor_num)):
  X1, y1 = generate_data_csv(f'{dir}45deg_1.csv', sensor_num=sensor_num[i], method=0, use_optimal=True, sensors_dir=sensors_dir)
  X2, y2 = generate_data_csv(f'{dir}45deg_2.csv', sensor_num=sensor_num[i], method=0, use_optimal=True, sensors_dir=sensors_dir)
  X3, y3 = generate_data_csv(f'{dir}45deg_3.csv', sensor_num=sensor_num[i], method=0, use_optimal=True, sensors_dir=sensors_dir)
  # visualize_interpolation(X1, y1, f'{dir}0deg_1.csv', sample_index=1)
  # Generate perturbed test set data
  # grid_shape = X3.shape[1:3]  # Assume (15, 15)
  # X3_perturbed = generate_perturbed_test_data(X3, y3, grid_shape, max_perturbation=1)
  grid1_shape = X1.shape[1:3]  # Assume (15, 15)
  grid2_shape = X2.shape[1:3]
  grid3_shape = X3.shape[1:3]
  X1_perturbed = generate_perturbed_test_data(X1, y1, grid1_shape, max_perturbation=1)
  X2_perturbed = generate_perturbed_test_data(X2, y2, grid2_shape, max_perturbation=1)
  X3_perturbed = generate_perturbed_test_data(X3, y3, grid3_shape, max_perturbation=1)

  # Visualize comparison of original and perturbed (optional)
  # visualize_perturbation(X3, X3_perturbed, y3, sample_index=1)

  # np.save(f'{dir}45deg_x1_data_{sensor_num[i]}.npy', X1)
  # np.save(f'{dir}45deg_y1_data_{sensor_num[i]}.npy', y1)
  # np.save(f'{dir}45deg_x2_data_{sensor_num[i]}.npy', X2)
  # np.save(f'{dir}45deg_y2_data_{sensor_num[i]}.npy', y2)
  # np.save(f'{dir}45deg_x3_data_{sensor_num[i]}.npy', X3)
  # np.save(f'{dir}45deg_y3_data_{sensor_num[i]}.npy', y3)
  # print(f'num sensor = {sensor_num[i]}: Data saved')
  # np.save(f'{dir}45deg_x1_perturbed_data_{sensor_num[i]}.npy', X1_perturbed)
  # np.save(f'{dir}45deg_x2_perturbed_data_{sensor_num[i]}.npy', X2_perturbed)
  # np.save(f'{dir}45deg_x3_perturbed_data_{sensor_num[i]}.npy', X3_perturbed)
  print(f'num sensor = {sensor_num[i]}: Perturbed data saved')

method=1

In [None]:
# Generate data, sensor number is fixed at 10 here
sensor_num = [5, 10, 15, 20, 25, 30]
sensor_num = [5]

for i in range(len(sensor_num)):
  X1, y1 = generate_data_csv(f'{dir}22deg_1.csv', sensor_num=sensor_num[i])
  X2, y2 = generate_data_csv(f'{dir}22deg_2.csv', sensor_num=sensor_num[i])
  # X3, y3 = generate_data_csv(f'{dir}45deg_3.csv', sensor_num=sensor_num[i])
  visualize_interpolation(X1, y1, f'{dir}22deg_1.csv', sample_index=1)
  # # Generate perturbed test set data
  # grid1_shape = X1.shape[1:3]  # Assume (15, 15)
  # grid2_shape = X2.shape[1:3]
  # X1_perturbed = generate_perturbed_test_data(X1, y1, grid1_shape, max_perturbation=1)
  # X2_perturbed = generate_perturbed_test_data(X2, y2, grid2_shape, max_perturbation=1)

  # Visualize comparison of original and perturbed (optional)
  # visualize_perturbation(X3, X3_perturbed, y3, sample_index=1)

  # np.save(f'{dir}22deg_x1_data_{sensor_num[i]}.npy', X1)
  # np.save(f'{dir}22deg_y1_data_{sensor_num[i]}.npy', y1)
  # np.save(f'{dir}22deg_x2_data_{sensor_num[i]}.npy', X2)
  # np.save(f'{dir}22deg_y2_data_{sensor_num[i]}.npy', y2)
  # np.save(f'{dir}45deg_x3_data_{sensor_num[i]}.npy', X3)
  # np.save(f'{dir}45deg_y3_data_{sensor_num[i]}.npy', y3)
  # print(f'num sensor = {sensor_num[i]}: Data saved')
  # np.save(f'{dir}22deg_x1_perturbed_data_{sensor_num[i]}.npy', X1_perturbed)
  # np.save(f'{dir}22deg_x2_perturbed_data_{sensor_num[i]}.npy', X2_perturbed)
  # print(f'num sensor = {sensor_num[i]}: Perturbed data saved')

In [None]:
def visualize_interpolation(X, y, csv_file, num_samples=1):
    """
    Visualize ground truth U, V fields for multiple samples.

    Parameters:
      X: Interpolated data and sensor mask (shape [num_samples, 15,15,3])
      y: Ground truth data (shape [num_samples, 15,15,2])
      csv_file: Original CSV file path, used to extract grid coordinates
      num_samples: Number of samples to display, default is 1
    """
    # Get grid coordinates from CSV
    df = pd.read_csv(csv_file)
    x_coords = np.sort(df['X'].unique())
    y_coords = np.sort(df['Y'].unique())
    extent = (x_coords.min(), x_coords.max(), y_coords.min(), y_coords.max())

    # Create subplot layout: each sample occupies 2 rows (U and V), number of columns equals number of samples
    fig, axs = plt.subplots(2, num_samples, figsize=(6*num_samples, 10))

    # If only one sample, axs needs to be reshaped
    if num_samples == 1:
        axs = axs.reshape(2, 1)

    for i in range(num_samples):
        # Extract ground truth for the i-th sample
        gt_U = y[i, :, :, 0]
        gt_V = y[i, :, :, 1]

        # Ground Truth U
        im0 = axs[0, i].imshow(gt_U, origin='lower', extent=extent)
        axs[0, i].set_title(f"Sample {i+1} - Ground Truth U")
        axs[0, i].set_xlabel("X")
        axs[0, i].set_ylabel("Y")
        fig.colorbar(im0, ax=axs[0, i])

        # Ground Truth V
        im1 = axs[1, i].imshow(gt_V, origin='lower', extent=extent)
        axs[1, i].set_title(f"Sample {i+1} - Ground Truth V")
        axs[1, i].set_xlabel("X")
        axs[1, i].set_ylabel("Y")
        fig.colorbar(im1, ax=axs[1, i])

    plt.tight_layout()
    plt.show()

In [None]:
# Generate data, sensor number is fixed at 10 here
sensor_num = [5]

for i in range(len(sensor_num)):
  X1, y1 = generate_data_csv(f'{dir}22deg_1.csv', sensor_num=sensor_num[i])
  # X2, y2 = generate_data_csv(f'{dir}22deg_2.csv', sensor_num=sensor_num[i])
  # X3, y3 = generate_data_csv(f'{dir}45deg_3.csv', sensor_num=sensor_num[i])
  visualize_interpolation(X1, y1, f'{dir}22deg_1.csv', num_samples=20)

In [None]:
# Generate data, sensor number is fixed at 10 here
sensor_num = [5]

for i in range(len(sensor_num)):
  X1, y1 = generate_data_csv(f'{dir}0deg_1.csv', sensor_num=sensor_num[i])
  # X2, y2 = generate_data_csv(f'{dir}22deg_2.csv', sensor_num=sensor_num[i])
  # X3, y3 = generate_data_csv(f'{dir}45deg_3.csv', sensor_num=sensor_num[i])
  visualize_interpolation(X1, y1, f'{dir}0deg_1.csv', num_samples=20)

# Batch Generate Optimal

In [None]:
import numpy as np
import os
from tqdm import tqdm


def batch_generate_all_data(base_dir='./dataset/',
                           sensors_dir='./optimal_sensors/',
                           output_dir='./generated_data/'):
    """
    Batch generate data for all method, sensor_num, CSV combinations

    Parameters:
        base_dir: Directory where CSV files are located
        sensors_dir: Directory for optimal sensor position files
        output_dir: Directory for output npy files
    """

    # Define all parameters
    csv_files = [
        '45deg_1.csv', '45deg_2.csv', '45deg_3.csv',
        '0deg_1.csv', '0deg_2.csv', '0deg_3.csv',
        '22deg_1.csv', '22deg_2.csv'
    ]

    sensor_nums = [5, 10, 15, 20, 25, 30]
    methods = [0, 1]

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Count total number of tasks
    total_tasks = len(methods) * len(sensor_nums) * len(csv_files)
    print(f"Total tasks to process: {total_tasks}")

    # Start batch generation
    task_count = 0

    for method in methods:
        print(f"\n{'='*50}")
        print(f"Starting to process Method {method}")
        print(f"{'='*50}")

        for sensor_num in sensor_nums:
            print(f"\n--- Method {method}, Sensor_num {sensor_num} ---")

            # Check if corresponding optimal sensor position file exists
            try:
                positions = load_optimal_sensors(method, sensor_num, sensors_dir)
                print(f"Successfully loaded optimal sensor positions: Method {method}, Sensor_num {sensor_num}")
            except FileNotFoundError as e:
                print(f"Warning: {e}")
                print(f"Skipping Method {method}, Sensor_num {sensor_num}")
                continue

            for csv_file in csv_files:
                task_count += 1
                csv_path = os.path.join(base_dir, csv_file)

                # Extract identifier from CSV filename (remove .csv suffix)
                csv_name = csv_file.replace('.csv', '')

                print(f"[{task_count}/{total_tasks}] Processing: {csv_file}")

                try:
                    # Generate data
                    X, y = generate_data_csv(
                        csv_path,
                        sensor_num=sensor_num,
                        method=method,
                        use_optimal=True,
                        sensors_dir=sensors_dir
                    )

                    # Generate perturbed data
                    grid_shape = X.shape[1:3]  # (15, 15)
                    X_perturbed = generate_perturbed_test_data(X, y, grid_shape, max_perturbation=1)

                    # Build filename
                    # Format: {csv_name}_method{method}_sensor{sensor_num}_{type}.npy
                    base_filename = f"{csv_name}_method{method}_sensor{sensor_num}"

                    x_filename = f"{base_filename}_X.npy"
                    y_filename = f"{base_filename}_y.npy"
                    x_perturbed_filename = f"{base_filename}_X_perturbed.npy"

                    # Save files
                    np.save(os.path.join(output_dir, x_filename), X)
                    np.save(os.path.join(output_dir, y_filename), y)
                    np.save(os.path.join(output_dir, x_perturbed_filename), X_perturbed)

                    print(f"  ‚úÖ Successfully saved:")
                    print(f"     - {x_filename}")
                    print(f"     - {y_filename}")
                    print(f"     - {x_perturbed_filename}")
                    print(f"     Data shape: X={X.shape}, y={y.shape}")

                except Exception as e:
                    print(f"  ‚ùå Processing failed: {csv_file} - {str(e)}")
                    continue

    print(f"\n{'='*50}")
    print(f"Batch generation completed! Total tasks processed: {task_count}")
    print(f"Files saved in: {output_dir}")
    print(f"{'='*50}")

def list_generated_files(output_dir='./generated_data/', show_details=True):
    """
    List all generated files, organized by category

    Parameters:
        output_dir: File directory
        show_details: Whether to show detailed file information
    """
    if not os.path.exists(output_dir):
        print(f"Directory does not exist: {output_dir}")
        return

    files = [f for f in os.listdir(output_dir) if f.endswith('.npy')]
    files.sort()

    if not files:
        print(f"No .npy files found in directory: {output_dir}")
        return

    print(f"\nGenerated file list (total {len(files)} files):")
    print(f"{'='*60}")

    # Group by method and sensor_num
    from collections import defaultdict
    grouped_files = defaultdict(lambda: defaultdict(list))

    for file in files:
        # Parse filename
        parts = file.replace('.npy', '').split('_')

        if 'method' in file and 'sensor' in file:
            # Find positions of method and sensor
            method_idx = next(i for i, part in enumerate(parts) if part.startswith('method'))
            sensor_idx = next(i for i, part in enumerate(parts) if part.startswith('sensor'))

            method = parts[method_idx].replace('method', '')
            sensor = parts[sensor_idx].replace('sensor', '')

            # CSV name (parts before method)
            csv_parts = parts[:method_idx]
            csv_name = '_'.join(csv_parts)

            # File type (parts after sensor)
            type_parts = parts[sensor_idx+1:]
            file_type = '_'.join(type_parts) if type_parts else 'X'

            grouped_files[f"Method{method}"][f"Sensor{sensor}"].append({
                'csv': csv_name,
                'type': file_type,
                'filename': file
            })

    # Display grouped results
    for method_key in sorted(grouped_files.keys()):
        print(f"\nüìÅ {method_key}:")

        for sensor_key in sorted(grouped_files[method_key].keys()):
            print(f"  üìÅ {sensor_key}:")

            # Group by CSV
            csv_groups = defaultdict(list)
            for file_info in grouped_files[method_key][sensor_key]:
                csv_groups[file_info['csv']].append(file_info)

            for csv_name in sorted(csv_groups.keys()):
                files_for_csv = csv_groups[csv_name]
                types = [f['type'] for f in files_for_csv]

                print(f"    üìÑ {csv_name}: {', '.join(sorted(types))}")

                if show_details:
                    for file_info in files_for_csv:
                        filepath = os.path.join(output_dir, file_info['filename'])
                        if os.path.exists(filepath):
                            size_mb = os.path.getsize(filepath) / (1024*1024)
                            print(f"       - {file_info['filename']} ({size_mb:.2f} MB)")

def generate_file_mapping(output_dir='./generated_data/'):
    """
    Generate file mapping table for convenient use later
    """
    files = [f for f in os.listdir(output_dir) if f.endswith('.npy')]

    mapping = {}

    for file in files:
        # Parse filename
        parts = file.replace('.npy', '').split('_')

        if 'method' in file and 'sensor' in file:
            method_idx = next(i for i, part in enumerate(parts) if part.startswith('method'))
            sensor_idx = next(i for i, part in enumerate(parts) if part.startswith('sensor'))

            method = int(parts[method_idx].replace('method', ''))
            sensor = int(parts[sensor_idx].replace('sensor', ''))

            csv_parts = parts[:method_idx]
            csv_name = '_'.join(csv_parts)

            type_parts = parts[sensor_idx+1:]
            file_type = '_'.join(type_parts) if type_parts else 'X'

            key = (method, sensor, csv_name, file_type)
            mapping[key] = file

    return mapping

def load_data_by_key(method, sensor_num, csv_name, data_type='X', output_dir='./generated_data/'):
    """
    Load corresponding data based on key values

    Parameters:
        method: Method number (0 or 1)
        sensor_num: Number of sensors
        csv_name: CSV name (without .csv suffix)
        data_type: Data type ('X', 'y', 'X_perturbed')
        output_dir: Data directory

    Returns:
        data: Loaded numpy array
    """
    filename = f"{csv_name}_method{method}_sensor{sensor_num}_{data_type}.npy"
    filepath = os.path.join(output_dir, filename)

    if not os.path.exists(filepath):
        raise FileNotFoundError(f"File does not exist: {filepath}")

    return np.load(filepath)

In [None]:
"""
Main function: demonstration of how to use
"""
print("Starting batch data generation...")

# 1. Batch generate all data
batch_generate_all_data(
    base_dir='/content/drive/MyDrive/TorchDA/dataset',  # Adjust to your CSV file directory
    sensors_dir='/content/drive/MyDrive/TorchDA/position',
    output_dir='/content/drive/MyDrive/TorchDA/optimal_dataset'
)

# 2. List generated files
print("\nViewing generated files:")
list_generated_files('./generated_data/')