In [None]:
# original
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os

file_path = "data_stream-enda_stepType-accum.nc"  
ds = xr.open_dataset(file_path)

ssrd = ds["ssrd"].copy()
ssrd = ssrd.where(ssrd != 0, np.nan)

output_dir = "ssrd_0" 
os.makedirs(output_dir, exist_ok=True) 

custom_cmap = mcolors.LinearSegmentedColormap.from_list(
    "custom_cmap", ["#3882a4", "#a8d7da", "#acc872", "#fff1d6", "#f3bc42"]
)

time_steps = [100, 1000, 2000]

for time_index in time_steps:
    if time_index < len(ds["valid_time"]):
        ssrd_slice = ssrd.isel(valid_time=time_index)

        fig, ax = plt.subplots(figsize=(8, 5))
        img = ax.imshow(ssrd_slice, cmap=custom_cmap)

        cbar = fig.colorbar(img, ax=ax, orientation="horizontal", pad=0.15)
        cbar.set_label("SSRD (W/m²)")

        ax.set_title(f"SSRD Map (Time Step {time_index})")
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")

        fig.tight_layout()
        plt.subplots_adjust(bottom=0.15)

        output_path = os.path.join(output_dir, f"ssrd_timestep_{time_index}.png")
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        print(f"Figure saved: {output_path}")

        plt.show()
    else:
        print(f"Time step {time_index} is out of range. Skipping.")

In [None]:
# original
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

file_path = "data_stream-enda_stepType-accum.nc"
ds = xr.open_dataset(file_path)

ssrd = ds["ssrd"].copy()
ssrd = ssrd.where(ssrd != 0, np.nan)

mean_value = np.nanmean(ssrd.values) 
std_value = np.nanstd(ssrd.values)

missing_values = np.isnan(ssrd.values).sum()
total_values = ssrd.size
missing_ratio = (missing_values / total_values) * 100

print(f"Mean after setting missing values: {mean_value:.2f}")
print(f"Standard deviation after setting missing values: {std_value:.2f}")
print(f"Number of missing values: {missing_values:,}")
print(f"Percentage of missing values: {missing_ratio:.2f}%")

In [None]:
# Nearest Neighbor
import numpy as np
import xarray as xr
from scipy.ndimage import generic_filter

file_path = "data_stream-enda_stepType-accum.nc"
ds = xr.open_dataset(file_path)

ssrd_nearest = ds["ssrd"].copy()

ssrd_nearest = ssrd_nearest.where(ssrd_nearest != 0, np.nan)

def nearest_neighbor(values):
    valid_values = values[~np.isnan(values)]
    return valid_values[0] if len(valid_values) > 0 else np.nan

time_dim = ds["valid_time"].values

for t in range(len(time_dim)):
    data_slice = ssrd_nearest.isel(valid_time=t).values
    
    if np.isnan(data_slice).sum() > 0:
        data_slice = generic_filter(data_slice, nearest_neighbor, size=3, mode='nearest')

    ssrd_nearest[t, :, :] = data_slice

remaining_missing_values_nearest = np.isnan(ssrd_nearest).sum().item()
print(f"remaining missing values: {remaining_missing_values_nearest}")

output_path = "ssrd_nearest_filled.nc"
ssrd_nearest.to_netcdf(output_path)
print(f"file saved: {output_path}")

In [None]:
# Nearest Neighbor
import xarray as xr

filled_file_path = "ssrd_nearest_filled.nc"
filled_ds = xr.open_dataset(filled_file_path)

print("mean and standard deviation:", 
      filled_ds["ssrd"].mean(skipna=True).item(), 
      filled_ds["ssrd"].std(skipna=True).item())

In [None]:
# Nearest Neighbor
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os

filled_file_path = "ssrd_nearest_filled.nc"
ds = xr.open_dataset(file_path)

output_dir = "ssrd_near" 
os.makedirs(output_dir, exist_ok=True) 

custom_cmap = mcolors.LinearSegmentedColormap.from_list(
    "custom_cmap", ["#3882a4", "#a8d7da", "#acc872", "#fff1d6", "#f3bc42"]
)

time_steps = [100, 1000, 2000]

for time_index in time_steps:
    if time_index < len(ds["valid_time"]): 
        ssrd_slice = ssrd.isel(valid_time=time_index)

        fig, ax = plt.subplots(figsize=(8, 5))
        img = ax.imshow(ssrd_slice, cmap=custom_cmap)

        cbar = fig.colorbar(img, ax=ax, orientation="horizontal", pad=0.15)
        cbar.set_label("SSRD (W/m²)")

        ax.set_title(f"SSRD Map (Time Step {time_index})")
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")

        fig.tight_layout()
        plt.subplots_adjust(bottom=0.15)

        output_path = os.path.join(output_dir, f"ssrd_timestep_{time_index}.png")
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        print(f"Figure saved: {output_path}")

        plt.show()
    else:
        print(f"Time step {time_index} is out of range. Skipping.")

In [None]:
# Linear Interpolation
import numpy as np
import xarray as xr
from scipy.interpolate import griddata

file_path = "data_stream-enda_stepType-accum.nc"
ds = xr.open_dataset(file_path)

ssrd_linear = ds["ssrd"].copy()

ssrd_linear = ssrd_linear.where(ssrd_linear != 0, np.nan)

time_dim = ds["valid_time"].values

for t in range(len(time_dim)):
    data_slice = ssrd_linear.isel(valid_time=t).values

    valid_mask = ~np.isnan(data_slice)
    valid_points = np.column_stack(np.where(valid_mask))
    valid_values = data_slice[valid_mask]
    
    nan_mask = np.isnan(data_slice)
    nan_points = np.column_stack(np.where(nan_mask))

    if len(valid_points) < 4:  # At least 4 points are required for Delaunay triangulation
    print(f"Skipping time step {t}, valid points: {len(valid_points)}, linear interpolation not possible.")
    continue

if np.ptp(valid_points[:, 1]) == 0: 
    print(f"Skipping time step {t}, valid points are coplanar on the y-axis, linear interpolation not possible.")
    continue

try:
    data_slice[nan_mask] = griddata(valid_points, valid_values, nan_points, method='linear')
except Exception as e:
    print(f"Linear interpolation failed at time step {t}, error: {e}")
    continue

ssrd_linear[t, :, :] = data_slice

remaining_missing_values_linear = np.isnan(ssrd_linear).sum().item()
print(f"Remaining missing values: {remaining_missing_values_linear}")

output_path = "ssrd_linear_filled.nc"
ssrd_linear.to_netcdf(output_path)
print(f"Interpolated data saved to: {output_path}")

In [None]:
# Linear
import xarray as xr

filled_file_path = "ssrd_linear_filled.nc"
filled_ds = xr.open_dataset(filled_file_path)

print("mean and standard deviation:", 
      filled_ds["ssrd"].mean(skipna=True).item(), 
      filled_ds["ssrd"].std(skipna=True).item())

In [None]:
# Linear
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os

filled_file_path = "ssrd_linear_filled.nc"
ds = xr.open_dataset(file_path)

output_dir = "ssrd_near" 
os.makedirs(output_dir, exist_ok=True) 

custom_cmap = mcolors.LinearSegmentedColormap.from_list(
    "custom_cmap", ["#3882a4", "#a8d7da", "#acc872", "#fff1d6", "#f3bc42"]
)

time_steps = [100, 1000, 2000]

for time_index in time_steps:
    if time_index < len(ds["valid_time"]): 
        ssrd_slice = ssrd.isel(valid_time=time_index)

        fig, ax = plt.subplots(figsize=(8, 5))
        img = ax.imshow(ssrd_slice, cmap=custom_cmap)

        cbar = fig.colorbar(img, ax=ax, orientation="horizontal", pad=0.15)
        cbar.set_label("SSRD (W/m²)")

        ax.set_title(f"SSRD Map (Time Step {time_index})")
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")

        fig.tight_layout()
        plt.subplots_adjust(bottom=0.15)

        output_path = os.path.join(output_dir, f"ssrd_timestep_{time_index}.png")
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        print(f"Figure saved: {output_path}")

        plt.show()
    else:
        print(f"Time step {time_index} is out of range. Skipping.")

In [None]:
# Linear Interpolation + Nearest Neighbor
import numpy as np
import xarray as xr
from scipy.interpolate import griddata

file_path = "data_stream-enda_stepType-accum.nc"
ds = xr.open_dataset(file_path)

ssrd_linear = ds["ssrd"].copy()

ssrd_linear = ssrd_linear.where(ssrd_linear != 0, np.nan)

time_dim = ds["valid_time"].values
for t in range(len(time_dim)):
    data_slice = ssrd_linear.isel(valid_time=t).values

    valid_mask = ~np.isnan(data_slice)
    
    if np.sum(valid_mask) < 4:  # Skip if fewer than 4 valid points
        print(f"Skipping time step {t}, valid points: {np.sum(valid_mask)}, linear interpolation not possible.")
        continue

    valid_points = np.column_stack(np.where(valid_mask))
    valid_values = data_slice[valid_mask]

    nan_mask = np.isnan(data_slice)
    nan_points = np.column_stack(np.where(nan_mask))

    # Perform linear interpolation (fallback to nearest neighbor if it fails)
    try:
        data_slice[nan_mask] = griddata(valid_points, valid_values, nan_points, method='linear')
    except:
        print(f"Linear interpolation failed at time step {t}, switching to nearest neighbor interpolation.")
        data_slice[nan_mask] = griddata(valid_points, valid_values, nan_points, method='nearest')

    ssrd_linear[t, :, :] = data_slice

remaining_missing_values_linear = np.isnan(ssrd_linear).sum().item()
print(f"Remaining missing values: {remaining_missing_values_linear}")

output_path = "ssrd_linear_filled2.nc"
ssrd_linear.to_netcdf(output_path)
print(f"Interpolated data saved to: {output_path}")

In [None]:
# Linear Interpolation + Nearest Neighbor
import xarray as xr

filled_file_path = "ssrd_linear_filled2.nc"
filled_ds = xr.open_dataset(filled_file_path)

print("mean and standard deviation:", 
      filled_ds["ssrd"].mean(skipna=True).item(), 
      filled_ds["ssrd"].std(skipna=True).item())

In [None]:
# Linear Interpolation + Nearest Neighbor
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os

filled_file_path = "ssrd_linear_filled2.nc"
ds = xr.open_dataset(filled_file_path)

output_dir = "ssrd_li_ne"
os.makedirs(output_dir, exist_ok=True)

custom_cmap = mcolors.LinearSegmentedColormap.from_list(
    "custom_cmap", ["#3882a4", "#a8d7da", "#acc872", "#fff1d6", "#f3bc42"]
)

time_steps = [100, 1000, 2000]

for time_index in time_steps:
    if time_index < len(ds["valid_time"]):
        ssrd_slice = ds["ssrd"].isel(valid_time=time_index)

        fig, ax = plt.subplots(figsize=(8, 5))
        img = ax.imshow(ssrd_slice, cmap=custom_cmap)

        cbar = fig.colorbar(img, ax=ax, orientation="horizontal", pad=0.15)
        cbar.set_label("SSRD (W/m²)")

        ax.set_title(f"SSRD Map (Time Step {time_index})")
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")

        fig.tight_layout()
        plt.subplots_adjust(bottom=0.15)

        output_path = os.path.join(output_dir, f"ssrd_timestep_{time_index}.png")
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        print(f"Figure saved: {output_path}")

        plt.show()
    else:
        print(f"Time step {time_index} is out of range. Skipping.")

In [None]:
# Bicubic Interpolation
import numpy as np
import xarray as xr
from scipy.interpolate import griddata

file_path = "data_stream-enda_stepType-accum.nc"
ds = xr.open_dataset(file_path)

ssrd_cubic = ds["ssrd"].copy()

ssrd_cubic = ssrd_cubic.where(ssrd_cubic != 0, np.nan)

time_dim = ds["valid_time"].values
for t in range(len(time_dim)):
    data_slice = ssrd_cubic.isel(valid_time=t).values

    valid_mask = ~np.isnan(data_slice)
    
    if np.sum(valid_mask) < 16:  # At least 16 points are required for bicubic interpolation
        print(f"Skipping time step {t}, valid points: {np.sum(valid_mask)}, bicubic interpolation not possible.")
        continue

    valid_points = np.column_stack(np.where(valid_mask))
    valid_values = data_slice[valid_mask]

    nan_mask = np.isnan(data_slice)
    nan_points = np.column_stack(np.where(nan_mask))

    try:
        data_slice[nan_mask] = griddata(valid_points, valid_values, nan_points, method='cubic')
    except:
        print(f"Bicubic interpolation failed at time step {t}, switching to nearest neighbor interpolation.")
        data_slice[nan_mask] = griddata(valid_points, valid_values, nan_points, method='nearest')

    ssrd_cubic[t, :, :] = data_slice

remaining_missing_values_cubic = np.isnan(ssrd_cubic).sum().item()
print(f"Remaining missing values: {remaining_missing_values_cubic}")

output_path = "ssrd_cubic_filled.nc"
ssrd_cubic.to_netcdf(output_path)
print(f"Interpolated data saved to: {output_path}")

In [None]:
# Bicubic
import xarray as xr

filled_file_path = "ssrd_cubic_filled.nc"
filled_ds = xr.open_dataset(filled_file_path)

print("mean and standard deviation:", 
      filled_ds["ssrd"].mean(skipna=True).item(), 
      filled_ds["ssrd"].std(skipna=True).item())

In [None]:
# Bicubic
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os

filled_file_path = "ssrd_cubic_filled.nc"
ds = xr.open_dataset(filled_file_path)

output_dir = "ssrd_cubic"
os.makedirs(output_dir, exist_ok=True)

custom_cmap = mcolors.LinearSegmentedColormap.from_list(
    "custom_cmap", ["#3882a4", "#a8d7da", "#acc872", "#fff1d6", "#f3bc42"]
)

time_steps = [100, 1000, 2000]

for time_index in time_steps:
    if time_index < len(ds["valid_time"]):
        ssrd_slice = ds["ssrd"].isel(valid_time=time_index)

        fig, ax = plt.subplots(figsize=(8, 5))
        img = ax.imshow(ssrd_slice, cmap=custom_cmap)

        cbar = fig.colorbar(img, ax=ax, orientation="horizontal", pad=0.15)
        cbar.set_label("SSRD (W/m²)")

        ax.set_title(f"SSRD Map (Time Step {time_index})")
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")

        fig.tight_layout()
        plt.subplots_adjust(bottom=0.15)

        output_path = os.path.join(output_dir, f"ssrd_timestep_{time_index}.png")
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        print(f"Figure saved: {output_path}")

        plt.show()
    else:
        print(f"Time step {time_index} is out of range. Skipping.")

In [None]:
# Kriging Interpolation
import numpy as np
import xarray as xr
from pykrige.ok import OrdinaryKriging

file_path = "data_stream-enda_stepType-accum.nc"
ds = xr.open_dataset(file_path)

ssrd_kriging = ds["ssrd"].copy()

ssrd_kriging = ssrd_kriging.where(ssrd_kriging != 0, np.nan)

lat = ds["latitude"].values
lon = ds["longitude"].values
time_dim = ds["valid_time"].values

lon_2d, lat_2d = np.meshgrid(lon, lat)

for t in range(len(time_dim)):
    data_slice = ssrd_kriging.isel(valid_time=t).values

    valid_mask = ~np.isnan(data_slice)

    if np.sum(valid_mask) < 10:  # At least 10 points are required for Kriging
        print(f"Skipping time step {t}, valid points: {np.sum(valid_mask)}, Kriging not possible.")
        continue

    valid_x = lon_2d[valid_mask].astype(np.float64)
    valid_y = lat_2d[valid_mask].astype(np.float64)
    valid_values = data_slice[valid_mask].astype(np.float64)

    if np.isnan(valid_values).sum() > 0:
        print(f"Time step {t} contains NaN in input data. Skipping Kriging.")
        continue

    try:
        OK = OrdinaryKriging(valid_x, valid_y, valid_values,
                             variogram_model="spherical", verbose=False, enable_plotting=False)

        nan_mask = np.isnan(data_slice)
        nan_x = lon_2d[nan_mask].astype(np.float64)
        nan_y = lat_2d[nan_mask].astype(np.float64)

        z_kriged, _ = OK.execute("points", nan_x, nan_y)

        data_slice[nan_mask] = z_kriged

    except Exception as e:
        print(f"Kriging interpolation failed at time step {t}, error: {e}")

    ssrd_kriging[t, :, :] = data_slice

remaining_missing_values_kriging = np.isnan(ssrd_kriging).sum().item()
print(f"Remaining missing values: {remaining_missing_values_kriging}")

output_path = "ssrd_kriging_filled.nc"
ssrd_kriging.to_netcdf(output_path)
print(f"Interpolated data saved to: {output_path}")

In [None]:
# Kriging
import xarray as xr

filled_file_path = "ssrd_kriging_filled.nc"
filled_ds = xr.open_dataset(filled_file_path)

mean_value = filled_ds["ssrd"].mean(skipna=True).item()
std_value = filled_ds["ssrd"].std(skipna=True).item()

print("Mean and standard deviation:", mean_value, std_value)

In [None]:
# Kriging
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os

file_path = "ssrd_kriging_filled.nc"
ds = xr.open_dataset(file_path)

output_dir = "ssrd_Kri"
os.makedirs(output_dir, exist_ok=True)

custom_cmap = mcolors.LinearSegmentedColormap.from_list(
    "custom_cmap", ["#3882a4", "#a8d7da", "#acc872", "#fff1d6", "#f3bc42"]
)

time_steps = [100, 1000, 2000]

for time_index in time_steps:
    if time_index < len(ds["valid_time"]):
        ssrd_slice = ds["ssrd"].isel(valid_time=time_index)

        fig, ax = plt.subplots(figsize=(8, 5))
        img = ax.imshow(ssrd_slice, cmap=custom_cmap)

        cbar = fig.colorbar(img, ax=ax, orientation="horizontal", pad=0.15)
        cbar.set_label("SSRD (W/m²)")

        ax.set_title(f"SSRD Map (Time Step {time_index})")
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")

        fig.tight_layout()
        plt.subplots_adjust(bottom=0.15)

        # Save image
        output_path = os.path.join(output_dir, f"ssrd_timestep_{time_index}.png")
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        print(f"Figure saved: {output_path}")

        plt.show()
    else:
        print(f"Time step {time_index} is out of range. Skipping.")

In [None]:
# IDW: Inverse Distance Weighting Interpolation
import numpy as np
import xarray as xr
from scipy.spatial import cKDTree

file_path = "data_stream-enda_stepType-accum.nc"
ds = xr.open_dataset(file_path)

ssrd_idw = ds["ssrd"].copy()

ssrd_idw = ssrd_idw.where(ssrd_idw != 0, np.nan)

time_dim = ds["valid_time"].values
lat = ds["latitude"].values
lon = ds["longitude"].values

lon_2d, lat_2d = np.meshgrid(lon, lat)

def idw_interpolation(valid_points, valid_values, target_points, power=2):
    """
    Inverse Distance Weighting (IDW) interpolation.

    :param valid_points: Known data point coordinates (N, 2)
    :param valid_values: Known data point values (N,)
    :param target_points: Target points for interpolation (M, 2)
    :param power: Distance weighting exponent (default=2 for inverse square weighting)
    :return: Interpolated values at target points
    """
    tree = cKDTree(valid_points)  # Build KD-Tree
    distances, idx = tree.query(target_points, k=5)  # Find 5 nearest known points

    # Prevent division by zero by replacing zero distances with a small value
    distances[distances == 0] = 1e-10

    # Compute weights (inverse distance squared)
    weights = 1.0 / (distances ** power)
    weights /= weights.sum(axis=1, keepdims=True)

    interpolated_values = np.sum(weights * valid_values[idx], axis=1)
    return interpolated_values

for t in range(len(time_dim)):
    data_slice = ssrd_idw.isel(valid_time=t).values

    valid_mask = ~np.isnan(data_slice)

    if np.sum(valid_mask) < 10:  # At least 10 known points are required
        print(f"Skipping time step {t}, valid points: {np.sum(valid_mask)}, IDW interpolation not possible.")
        continue

    valid_x = lon_2d[valid_mask].astype(np.float64)
    valid_y = lat_2d[valid_mask].astype(np.float64)
    valid_values = data_slice[valid_mask].astype(np.float64)

    nan_mask = np.isnan(data_slice)
    nan_x = lon_2d[nan_mask].astype(np.float64)
    nan_y = lat_2d[nan_mask].astype(np.float64)

    try:
        nan_points = np.column_stack((nan_x, nan_y))
        valid_points = np.column_stack((valid_x, valid_y))
        interpolated_values = idw_interpolation(valid_points, valid_values, nan_points)

        data_slice[nan_mask] = interpolated_values

    except Exception as e:
        print(f"IDW interpolation failed at time step {t}, error: {e}")

    ssrd_idw[t, :, :] = data_slice

remaining_missing_values_idw = np.isnan(ssrd_idw).sum().item()
print(f"Remaining missing values after interpolation: {remaining_missing_values_idw}")

output_path = "ssrd_idw_filled.nc"
ssrd_idw.to_netcdf(output_path)
print(f"Interpolated data saved to: {output_path}")

In [None]:
# IDW
import xarray as xr

filled_file_path = "ssrd_idw_filled.nc"
filled_ds = xr.open_dataset(filled_file_path)

mean_value = filled_ds["ssrd"].mean(skipna=True).item()
std_value = filled_ds["ssrd"].std(skipna=True).item()

print("Mean and standard deviation:", mean_value, std_value)

In [None]:
# IDW
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os

file_path = "ssrd_idw_filled.nc"
ds = xr.open_dataset(file_path)

output_dir = "ssrd_IDW"
os.makedirs(output_dir, exist_ok=True)

custom_cmap = mcolors.LinearSegmentedColormap.from_list(
   "custom_cmap", ["#3882a4", "#a8d7da", "#acc872", "#fff1d6", "#f3bc42"]
)

time_steps = [100, 1000, 2000]

for time_index in time_steps:
    if time_index < len(ds["valid_time"]):
        ssrd_slice = ds["ssrd"].isel(valid_time=time_index)

        fig, ax = plt.subplots(figsize=(8, 5))
        img = ax.imshow(ssrd_slice, cmap=custom_cmap)

        cbar = fig.colorbar(img, ax=ax, orientation="horizontal", pad=0.15)
        cbar.set_label("SSRD (W/m²)")

        ax.set_title(f"SSRD Map (Time Step {time_index})")
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")

        fig.tight_layout()
        plt.subplots_adjust(bottom=0.15)

        output_path = os.path.join(output_dir, f"ssrd_timestep_{time_index}.png")
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        print(f"Figure saved: {output_path}")

        plt.show()
    else:
        print(f"Time step {time_index} is out of range. Skipping.")

In [None]:
# IDW + Nearest Neighbor
import numpy as np
import xarray as xr
from scipy.spatial import cKDTree
from scipy.ndimage import generic_filter

file_path = "data_stream-enda_stepType-accum.nc"
ds = xr.open_dataset(file_path)

ssrd_idw = ds["ssrd"].copy()

ssrd_idw = ssrd_idw.where(ssrd_idw != 0, np.nan)

time_dim = ds["valid_time"].values
lat = ds["latitude"].values
lon = ds["longitude"].values

lon_2d, lat_2d = np.meshgrid(lon, lat)

def idw_interpolation(valid_points, valid_values, target_points, power=2):
    """Perform IDW interpolation."""
    tree = cKDTree(valid_points)
    distances, idx = tree.query(target_points, k=5)

    distances[distances == 0] = 1e-10

    weights = 1.0 / (distances ** power)
    weights /= weights.sum(axis=1, keepdims=True)

    interpolated_values = np.sum(weights * valid_values[idx], axis=1)
    return interpolated_values

def nearest_neighbor(values):
    valid_values = values[~np.isnan(values)]
    return valid_values[0] if len(valid_values) > 0 else np.nan

for t in range(len(time_dim)):
    data_slice = ssrd_idw.isel(valid_time=t).values

    if np.isnan(data_slice).all():
        print(f"Time step {t} is completely missing. Skipping interpolation.")
        continue

    valid_mask = ~np.isnan(data_slice)

    if np.sum(valid_mask) < 5:
        print(f"Time step {t}, valid points: {np.sum(valid_mask)}, switching to nearest neighbor interpolation.")
        data_slice = generic_filter(data_slice, nearest_neighbor, size=3, mode='nearest')
        ssrd_idw[t, :, :] = data_slice
        continue

    valid_x = lon_2d[valid_mask].astype(np.float64)
    valid_y = lat_2d[valid_mask].astype(np.float64)
    valid_values = data_slice[valid_mask].astype(np.float64)

    nan_mask = np.isnan(data_slice)
    nan_x = lon_2d[nan_mask].astype(np.float64)
    nan_y = lat_2d[nan_mask].astype(np.float64)

    try:
        nan_points = np.column_stack((nan_x, nan_y))
        valid_points = np.column_stack((valid_x, valid_y))
        interpolated_values = idw_interpolation(valid_points, valid_values, nan_points)

        data_slice[nan_mask] = interpolated_values

    except Exception as e:
        print(f"IDW interpolation failed at time step {t}, error: {e}")

    ssrd_idw[t, :, :] = data_slice

remaining_missing_values_idw = np.isnan(ssrd_idw).sum().item()
print(f"Remaining missing values after interpolation: {remaining_missing_values_idw}")

output_path = "ssrd_idw_filled2.nc"
ssrd_idw.to_netcdf(output_path)
print(f"Interpolated data saved to: {output_path}")

In [None]:
# IDW + Nearest Neighbor
import xarray as xr

filled_file_path = "ssrd_idw_filled2.nc"
filled_ds = xr.open_dataset(filled_file_path)

mean_value = filled_ds["ssrd"].mean(skipna=True).item()
std_value = filled_ds["ssrd"].std(skipna=True).item()

print("Mean and standard deviation:", mean_value, std_value)

In [None]:
# IDW + Nearest Neighbor
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os

file_path = "ssrd_idw_filled2.nc"
ds = xr.open_dataset(file_path)

output_dir = "ssrd_IDW_ne"
os.makedirs(output_dir, exist_ok=True)

custom_cmap = mcolors.LinearSegmentedColormap.from_list(
    "custom_cmap", ["#3882a4", "#a8d7da", "#acc872", "#fff1d6", "#f3bc42"]
)

time_steps = [100, 1000, 2000]

for time_index in time_steps:
    if time_index < len(ds["valid_time"]):
        ssrd_slice = ds["ssrd"].isel(valid_time=time_index)

        fig, ax = plt.subplots(figsize=(8, 5))
        img = ax.imshow(ssrd_slice, cmap=custom_cmap)

        cbar = fig.colorbar(img, ax=ax, orientation="horizontal", pad=0.15)
        cbar.set_label("SSRD (W/m²)")

        ax.set_title(f"SSRD Map (Time Step {time_index})")
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")

        fig.tight_layout()
        plt.subplots_adjust(bottom=0.15)

        output_path = os.path.join(output_dir, f"ssrd_timestep_{time_index}.png")
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        print(f"Figure saved: {output_path}")

        plt.show()
    else:
        print(f"Time step {time_index} is out of range. Skipping.")

In [None]:
# Final Selection: Linear Interpolation + Nearest Neighbor (Mean Aggregation per Point, excluding NaN)
import xarray as xr
import numpy as np
import rasterio
from rasterio.transform import from_origin

file_path = "ssrd_linear_filled2.nc"
ds = xr.open_dataset(file_path)

if "ssrd" not in ds.variables:
    raise KeyError("'ssrd' variable not found. Check the dataset!")

ssrd_mean = ds["ssrd"].mean(dim="valid_time", skipna=True).rename("ssrd_mean")

nc_output = "E:/Download/ssrd_mean.nc"
ssrd_mean.to_netcdf(nc_output)
print(f"Mean data saved as NetCDF: {nc_output}")

ds_check = xr.open_dataset(nc_output)
if "ssrd_mean" not in ds_check.variables:
    raise KeyError("'ssrd_mean' variable not saved correctly!")

print(" 'ssrd_mean' successfully stored in NetCDF.")

lat, lon = ds["latitude"].values, ds["longitude"].values

if lat[0] > lat[-1]:  
    lat, ssrd_mean = lat[::-1], ssrd_mean[::-1, :]

ssrd_mean_data = ssrd_mean.values

output_tif = "ssrd_mean.tif"

pixel_size_x = abs(lon[1] - lon[0])
pixel_size_y = abs(lat[1] - lat[0])
transform = from_origin(lon.min(), lat.max(), pixel_size_x, pixel_size_y)

dtype = rasterio.float32 if np.issubdtype(ssrd_mean_data.dtype, np.floating) else rasterio.int32

with rasterio.open(
    output_tif, "w", driver="GTiff",
    height=ssrd_mean_data.shape[0], width=ssrd_mean_data.shape[1],
    count=1, dtype=dtype, crs="EPSG:4326", transform=transform
) as dst:
    dst.write(ssrd_mean_data.astype(dtype), 1)

print(f"Mean raster saved to: {output_tif}")

with rasterio.open(output_tif) as tif:
    print(f"GeoTIFF successfully read. Size: {tif.width} x {tif.height}, Bands: {tif.count}")

ds_tif = xr.open_dataset(output_tif, engine="rasterio")
print("Variables in GeoTIFF:", list(ds_tif.variables))
if "band_data" in ds_tif.variables:
    print("Warning: GeoTIFF variable name may be 'band_data' instead of 'ssrd_mean'!")