In [None]:
import numpy as np
import xarray as xr
import xesmf as xe

In [None]:
ds_in = xr.open_dataset("/home/sia/FjordsSim_data/oslofjord/Grid/OF160_grid_v2.nc")

In [None]:
ds_in

In [None]:
print(f"Lat min: {ds_in.lat_rho.values.min()}")
print(f"Lat max: {ds_in.lat_rho.values.max()}")
print(f"Lon min: {ds_in.lon_rho.values.min()}")
print(f"Lon max: {ds_in.lon_rho.values.max()}")

In [None]:
ds_out = xr.Dataset(
    {
        "lat": (["lat"], np.linspace(59.35, 59.98, num=200), {"units": "degrees_north"}),
        "lon": (["lon"], np.linspace(10.2, 10.85, num=100), {"units": "degrees_east"}),
    }
)

In [None]:
ds_out

In [None]:
regridder = xe.Regridder(ds_in.rename({"lon_rho": "lon", "lat_rho": "lat"}), ds_out, "bilinear", unmapped_to_nan=True)

In [None]:
da = regridder(ds_in["h"])

In [None]:
da_h = ds_in["h"]

In [None]:
da_h.plot()

In [None]:
da = da.where(da > 5, np.nan)
np_depth = da.values

In [None]:
da.plot(figsize=(7, 10))

In [None]:
def replace_surrounded_values(arr):
    # Create a copy of the array to modify
    new_arr = arr.copy()
    
    # Get the shape of the array
    rows, cols = arr.shape

    # Iterate through the array (excluding edges to avoid index errors)
    for i in range(1, rows - 1):
        for j in range(1, cols - 1):
            if not np.isnan(arr[i, j]):  # Only check non-NaN values
                # Count NaN neighbors
                neighbors = [
                    np.isnan(arr[i-1, j]) if i > 0 else False,  # Top
                    np.isnan(arr[i+1, j]) if i < rows-1 else False,  # Bottom
                    np.isnan(arr[i, j-1]) if j > 0 else False,  # Left
                    np.isnan(arr[i, j+1]) if j < cols-1 else False,  # Right
                ]
                if sum(neighbors) >= 3:
                    new_arr[i, j] = np.nan  # Replace with NaN if surrounded on 3+ sides

    return new_arr

In [None]:
def replace_surrounded_and_clusters(arr):
    new_arr = arr.copy()
    rows, cols = arr.shape

    # First pass: Replace values surrounded by NaNs on at least 3 sides
    for i in range(1, rows - 1):
        for j in range(1, cols - 1):
            if not np.isnan(arr[i, j]):  
                # Check top, bottom, left, right
                neighbors = [
                    np.isnan(arr[i-1, j]) if i > 0 else False,  # Top
                    np.isnan(arr[i+1, j]) if i < rows-1 else False,  # Bottom
                    np.isnan(arr[i, j-1]) if j > 0 else False,  # Left
                    np.isnan(arr[i, j+1]) if j < cols-1 else False,  # Right
                ]
                if sum(neighbors) >= 3:
                    new_arr[i, j] = np.nan  # Replace with NaN if surrounded on 3+ sides

    # Second pass: Replace small clusters (≤3 consecutive values) surrounded by NaNs
    def check_and_replace_clusters(arr, axis):
        """Find and replace small clusters of non-NaNs surrounded by NaNs along the given axis."""
        arr = arr.T if axis == 0 else arr  # Transpose if checking vertically

        for i in range(arr.shape[0]):  # Iterate through rows (or columns if transposed)
            row = arr[i]
            nan_mask = np.isnan(row)
            j = 0

            while j < len(row):
                # Find the start of a cluster of non-NaNs
                if not nan_mask[j]:
                    start = j
                    while j < len(row) and not nan_mask[j]:
                        j += 1
                    end = j  # End of cluster (exclusive)
                    
                    # If cluster is 3 or fewer elements and surrounded by NaNs, replace with NaNs
                    if (end - start) <= 5:
                        left_nan = (start == 0 or nan_mask[start - 1])
                        right_nan = (end == len(row) or nan_mask[end])
                        if left_nan and right_nan:
                            row[start:end] = np.nan  # Replace the cluster

                j += 1  # Move to the next element
            
            arr[i] = row  # Update the row in the array

        return arr.T if axis == 0 else arr  # Transpose back if needed

    new_arr = check_and_replace_clusters(new_arr, axis=1)  # Horizontal check
    new_arr = check_and_replace_clusters(new_arr, axis=0)  # Vertical check

    return new_arr

In [None]:
np_depth_new = replace_surrounded_and_clusters(np_depth)

In [None]:
da.values = np_depth_new

In [None]:
da.plot(figsize=(7, 10))

In [None]:
da

In [None]:
xr.Dataset({"h": da.fillna(0)}).to_netcdf("/home/sia/FjordsSim_data/oslofjord/OF_inner_100to200_bathymetry.nc")