In [None]:
# Beta version status - this Code was used to create land masked flat grid.nc for SMOS IC for v.105 ASC
# it can be used as a blue print to create land masked grid file for future SMOS IC versions

# CORRECT GRID CODE 1/2

import xarray as xr
import numpy as np

# 1. Open both datasets
smos_data = xr.open_dataset('/path_to/grid_ic_Geo2D.nc') # produced when reshuffling image to ts
mask_data = xr.open_dataset('/path_to/Grid_Point_Mask_USGS.nc') # can be found in src/smos

# lat values in mask_data are flipped, so flip again!
mask_data = mask_data.isel(lat=slice(None, None, -1))

# 2. Interpolate the mask to match SMOS grid
# This handles any floating point precision issues by creating a new mask on the exact SMOS grid
regridded_mask = mask_data.interp(lat=smos_data.lat, lon=smos_data.lon, method='nearest')

# 3. Get the land mask from the regridded data
land_mask = regridded_mask['USGS_Land_Flag']

# 4. Create a new dataset
masked_ds = xr.Dataset()

# 5. Copy variables while preserving original data types and units
for var_name, var in smos_data.data_vars.items():
    # Get the original data type
    dtype = var.dtype

    # Create masked variable
    if 'lat' in var.dims and 'lon' in var.dims:  # 2D variables with lat/lon dimensions
        if np.issubdtype(dtype, np.integer):
            # For integer types, use a proper integer fill value
            fill_value = -9999  # A conventional fill value for integers

            # Create a masked array with the fill value
            masked_data = xr.where(land_mask == 1, var, fill_value)

            # Convert back to the original data type
            masked_ds[var_name] = masked_data.astype(dtype)

            # Copy all attributes except _FillValue (will handle in encoding)
            attrs_copy = var.attrs.copy()
            if '_FillValue' in attrs_copy:
                del attrs_copy['_FillValue']  # Remove to avoid conflicts
            masked_ds[var_name].attrs = attrs_copy
        else:
            # For float types, use NaN
            masked_ds[var_name] = var.where(land_mask == 1)

            # Copy all attributes except _FillValue (will handle in encoding)
            attrs_copy = var.attrs.copy()
            if '_FillValue' in attrs_copy:
                del attrs_copy['_FillValue']  # Remove to avoid conflicts
            masked_ds[var_name].attrs = attrs_copy
    else:
        # Copy other variables as-is
        masked_ds[var_name] = var.copy()

        # Copy all attributes except _FillValue (will handle in encoding)
        attrs_copy = var.attrs.copy()
        if '_FillValue' in attrs_copy:
            del attrs_copy['_FillValue']  # Remove to avoid conflicts
        masked_ds[var_name].attrs = attrs_copy

# 6. Copy coordinates
for coord_name, coord in smos_data.coords.items():
    if coord_name not in masked_ds.coords:
        masked_ds[coord_name] = coord.copy()

        # Copy all attributes except _FillValue (will handle in encoding)
        attrs_copy = coord.attrs.copy()
        if '_FillValue' in attrs_copy:
            del attrs_copy['_FillValue']  # Remove to avoid conflicts
        masked_ds[coord_name].attrs = attrs_copy

# Explicitly set critical lat/lon attributes (hardcoded)
if 'lat' in masked_ds.coords:
    masked_ds['lat'].attrs['long_name'] = "Latitude"
    masked_ds['lat'].attrs['units'] = "degree_north"
    masked_ds['lat'].attrs['standard_name'] = "latitude"

if 'lon' in masked_ds.coords:
    masked_ds['lon'].attrs['long_name'] = "Longitude"
    masked_ds['lon'].attrs['units'] = "degree_east"
    masked_ds['lon'].attrs['standard_name'] = "longitude"

# 7. Copy attributes
masked_ds.attrs.update(smos_data.attrs)
masked_ds.attrs['masking_applied'] = 'USGS_Land_Flag from Grid_Point_Mask_USGS.nc (interpolated to match grid)'
masked_ds.attrs['masking_description'] = 'Only land grid points (USGS_Land_Flag=1) are included'
masked_ds.attrs['interpolation_method'] = 'Nearest-neighbor interpolation used to regrid mask to SMOS coordinates'

# 8. Check data types and units before saving
print("Variable data types and units before saving:")
for var_name in masked_ds.data_vars:
    unit_info = f"units: {masked_ds[var_name].attrs.get('units', 'Not specified')}" if 'units' in masked_ds[var_name].attrs else "No units attribute"
    print(f"{var_name}: {masked_ds[var_name].dtype}, {unit_info}")

# 9. Save with compression
# Create encoding dict with compression and explicit dtypes
encoding = {}
for var in masked_ds.data_vars:
    encoding[var] = {
        'zlib': True, 
        'complevel': 4,
        'dtype': masked_ds[var].dtype  # Explicitly set dtype to preserve it
    }

    # For integer types that we masked, set the fill value in encoding
    if np.issubdtype(masked_ds[var].dtype, np.integer) and 'lat' in masked_ds[var].dims and 'lon' in masked_ds[var].dims:
        encoding[var]['_FillValue'] = -9999

# Also handle coordinates encoding
for coord in masked_ds.coords:
    encoding[coord] = {
        'dtype': masked_ds[coord].dtype,
        'zlib': True,
        'complevel': 4
    }

masked_ds.to_netcdf('./smos_ic_land_Geo2D.nc',
                   encoding=encoding)

# 10. Statistics
total_points = land_mask.size
land_points = (land_mask == 1).sum().item()
percentage_land = (land_points / total_points) * 100

print("Successfully created land-masked NetCDF file: smos_ic_land_Geo2D.nc")
print(f"Total grid points: {total_points}")
print(f"Land points (USGS_Land_Flag=1): {land_points}")
print(f"Percentage of land points: {percentage_land:.2f}%")


In [None]:
# CORRECT GRID CODE 2/2 - flatten it!!

import xarray as xr
import numpy as np

# Open the 2D NetCDF file
ds = xr.open_dataset('/path_to/smos_ic_land_Geo2D.nc')

# Extract 2D arrays
lat_2d = ds.lat.values
lon_2d = ds.lon.values
gpi_2d = ds.gpi.values
cell_2d = ds.cell.values

# Create meshgrid of lat/lon values
lon_mesh, lat_mesh = np.meshgrid(lon_2d, lat_2d)

# Find valid grid points (where both gpi and cell are not -9999 and not NaN)
gpi_valid = (gpi_2d != -9999) & (~np.isnan(gpi_2d))
cell_valid = (cell_2d != -9999) & (~np.isnan(cell_2d))
valid_mask = gpi_valid & cell_valid

# Check overlap between gpi and cell valid/invalid points
gpi_invalid = ~gpi_valid
cell_invalid = ~cell_valid
overlap_invalid = np.logical_and(gpi_invalid, cell_invalid)
overlap_percentage = np.sum(overlap_invalid) / np.sum(gpi_invalid) * 100
print(f"Invalid value overlap: {overlap_percentage:.2f}%")

# Count points where only one is invalid but not the other
gpi_only_invalid = np.logical_and(gpi_invalid, ~cell_invalid)
cell_only_invalid = np.logical_and(~gpi_invalid, cell_invalid)
print(f"Points where only GPI is invalid: {np.sum(gpi_only_invalid)}")
print(f"Points where only CELL is invalid: {np.sum(cell_only_invalid)}")

# Extract 1D arrays for valid points only
lat_1d = lat_mesh[valid_mask]
lon_1d = lon_mesh[valid_mask]
gpi_1d = gpi_2d[valid_mask]
cell_1d = cell_2d[valid_mask]

# Convert gpi and cell to integers
gpi_1d = gpi_1d.astype(np.int32)
cell_1d = cell_1d.astype(np.int32)

# Count valid grid points
num_valid_points = len(lat_1d)
print(f"Number of valid grid points: {num_valid_points}")
print(f"Original number of grid points: {np.prod(gpi_2d.shape)}")
print(f"Percentage of valid points: {(num_valid_points / np.prod(gpi_2d.shape)) * 100:.2f}%")

# Create new 1D dataset with 'gp' as a dimension only, not a coordinate
ds_1d = xr.Dataset(
    data_vars={
        'lat': (['gp'], lat_1d),
        'lon': (['gp'], lon_1d),
        'gpi': (['gp'], gpi_1d),
        'cell': (['gp'], cell_1d),
        'crs': ds.crs
    }
)

# Copy variable attributes from original dataset
ds_1d.lat.attrs = ds.lat.attrs.copy()
ds_1d.lon.attrs = ds.lon.attrs.copy()
ds_1d.gpi.attrs = ds.gpi.attrs.copy()
ds_1d.cell.attrs = ds.cell.attrs.copy()
ds_1d.crs.attrs = ds.crs.attrs.copy()

# Update attributes for 1D format
ds_1d.lat.attrs.pop('_ChunkSizes', None)
ds_1d.lon.attrs.pop('_ChunkSizes', None)
ds_1d.lat.attrs['valid_range'] = [-90.0, 90.0]
ds_1d.lon.attrs['valid_range'] = [-180.0, 180.0]

# Update chunking for 1D arrays
ds_1d.lat.attrs['_ChunkSizes'] = num_valid_points
ds_1d.lon.attrs['_ChunkSizes'] = num_valid_points
ds_1d.gpi.attrs['_ChunkSizes'] = num_valid_points
ds_1d.cell.attrs['_ChunkSizes'] = num_valid_points

# Copy all global attributes from the original dataset
for attr_name, attr_value in ds.attrs.items():
    ds_1d.attrs[attr_name] = attr_value

# Update the shape global attribute to be a long integer instead of a string
ds_1d.attrs['shape'] = np.int64(num_valid_points)  # Set as a long integer



# Save the new 1D NetCDF file
output_path = '/output_path/smos_ic_land_flat.nc'
encoding = {
    'lat': {'zlib': True, 'complevel': 4},
    'lon': {'zlib': True, 'complevel': 4},
    'gpi': {'zlib': True, 'complevel': 4, 'dtype': 'int32'},
    'cell': {'zlib': True, 'complevel': 4, 'dtype': 'int32'},
    'crs': {'zlib': True, 'complevel': 4}
}
ds_1d.to_netcdf(output_path, encoding=encoding)

print(f"1D grid file saved to: {output_path}")
print(f"Global attributes set with shape as long integer: {ds_1d.attrs.get('shape')} (type: {type(ds_1d.attrs.get('shape'))})")
print(f"Dimensions: {list(ds_1d.dims)}")
print(f"Coordinates: {list(ds_1d.coords)}")
print(f"GPI data type: {ds_1d.gpi.dtype}")
print(f"CELL data type: {ds_1d.cell.dtype}")
