# Use this notebook to extract some of the needed data for training the model. Assuming the bathy data is downloaded from 01_get_data.ipynb, and the imagery is downloaded from 01b_get_s2_SAFE.ipynb

In [None]:
import os
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
import numpy as np
import re
import shutil
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt

# Functions

In [None]:
def mask_acolite_images(rhow_paths, rrs_paths, spm_tur_paths):
    def mask_image(image_path):
        with rasterio.open(image_path) as src:
            meta = src.meta.copy()
            meta['descriptions'] = src.descriptions
            bands = [src.read(i) for i in range(1, src.count + 1)]
            max_val = bands[0].max()
            masked_bands = [np.where(band == max_val, np.nan, band) for band in bands]
        return masked_bands, meta

    def process_and_write(paths):
        masked_files = []
        for path in paths:
            masked_bands, meta = mask_image(path)
            meta['count'] = len(masked_bands)
            mask_path = path.replace('merged', 'masked')
            os.makedirs(os.path.dirname(mask_path), exist_ok=True)
            meta['dtype'] = 'float32'
            with rasterio.open(mask_path, 'w', **meta) as dst:
                for i, band in enumerate(masked_bands, start=1):
                    dst.write(band.astype(np.float32), i)
                    # Reassign original band description if available
                    if meta.get('descriptions') and len(meta['descriptions']) >= i:
                        dst.set_band_description(i, meta['descriptions'][i-1])
            masked_files.append(mask_path)
        return masked_files

    masked_rhow_files = process_and_write(rhow_paths)
    masked_rrs_files = process_and_write(rrs_paths)
    masked_spm_files = process_and_write(spm_tur_paths)
    
    return masked_rhow_files, masked_rrs_files, masked_spm_files

def reproject_acolite(bathy_raster, rhow_results, rrs_results, spm_tur_results):
    # Open bathy_raster to get the target CRS, transform, and grid dimensions
    with rasterio.open(bathy_raster) as bathy_src:
        bathy_crs = bathy_src.crs
        dst_transform = bathy_src.transform
        dst_width = bathy_src.width
        dst_height = bathy_src.height

    def reproject_and_save(paths):
        projected_files = []
        for path in paths:
            # Open the masked file to obtain its data and metadata
            with rasterio.open(path) as src:
                meta = src.meta.copy()
                meta['descriptions'] = src.descriptions
                bands = [src.read(i) for i in range(1, src.count + 1)]
                src_transform = meta['transform']
                src_crs = meta.get('crs')

            # Prepare new metadata using target parameters from the bathymetry raster
            new_meta = meta.copy()
            new_meta.update({
                "crs": bathy_crs,
                "transform": dst_transform,
                "width": dst_width,
                "height": dst_height,
                "dtype": "float32"
            })
            new_bands = []
            # Reproject each band
            for band in bands:
                dest = np.empty((dst_height, dst_width), dtype=band.dtype)
                reproject(
                    source=band,
                    destination=dest,
                    src_transform=src_transform,
                    src_crs=src_crs,
                    dst_transform=dst_transform,
                    dst_crs=bathy_crs,
                    resampling=Resampling.bilinear
                )
                new_bands.append(dest)
            new_meta['count'] = len(new_bands)
            
            proj_path = path.replace('masked', 'projected')
            os.makedirs(os.path.dirname(proj_path), exist_ok=True)
            
            # Write the reprojected bands to the new file
            with rasterio.open(proj_path, 'w', **new_meta) as dst:
                for i, band in enumerate(new_bands, start=1):
                    dst.write(band.astype(np.float32), i)
                    if new_meta.get('descriptions') and len(new_meta['descriptions']) >= i:
                        dst.set_band_description(i, new_meta['descriptions'][i-1])
            projected_files.append(proj_path)
        return projected_files

    projected_rhow = reproject_and_save(rhow_results)
    projected_rrs = reproject_and_save(rrs_results)
    projected_spm_tur = reproject_and_save(spm_tur_results)

    return projected_rhow, projected_rrs, projected_spm_tur

########################################################################################################################################################################

def resample_bathy_to_sentinel2(bathy_raster, sentinel_raster, resampled_bathy):

    with rasterio.open(sentinel_raster) as sentinel_src:
        sentinel_transform = sentinel_src.transform
        sentinel_crs = sentinel_src.crs
        sentinel_width = sentinel_src.width
        sentinel_height = sentinel_src.height

    with rasterio.open(bathy_raster) as bathy_src:
        resampled_bathymetry = np.empty((sentinel_height, sentinel_width), dtype=bathy_src.dtypes[0])

        # Perform the reprojection and resampling
        reproject(
            source=rasterio.band(bathy_src, 1),  # Use the bathymetry data as the source
            destination=resampled_bathymetry,   # Target array for the resampled bathymetry
            src_transform=bathy_src.transform,
            src_crs=bathy_src.crs,
            dst_transform=sentinel_transform,  # Match Sentinel-2 raster's transform
            dst_crs=sentinel_crs,              # Match Sentinel-2 raster's CRS
            dst_width=sentinel_width,
            dst_height=sentinel_height,
            resampling=Resampling.bilinear     
        )

        bathy_meta = bathy_src.meta.copy()
        bathy_meta.update({
            "transform": sentinel_transform,
            "crs": sentinel_crs,
            "width": sentinel_width,
            "height": sentinel_height,
            "dtype": bathy_src.dtypes[0]  # Use the bathymetry data type
        })

    # Save the resampled bathymetry raster
    with rasterio.open(resampled_bathy, "w", **bathy_meta) as dst:
        dst.write(resampled_bathymetry, 1)

    print(f"Resampled bathymetry raster saved to: {resampled_bathy}")

def clip_sentinel_by_bathy(bathy_raster, sentinel_raster, output_sentinel):
    # Open the bathymetry raster
    with rasterio.open(bathy_raster) as bathy_src:
        bathy_data = bathy_src.read(1)  # Read the first band (assumes single-band data)
        valid_bathy_mask = ~np.isnan(bathy_data)  # Non-NaN bathymetry pixels are valid
        bathy_transform = bathy_src.transform
        bathy_crs = bathy_src.crs

    # Open the Sentinel-2 raster
    with rasterio.open(sentinel_raster) as sentinel_src:
        # Ensure CRS and transform match
        if sentinel_src.crs != bathy_crs or sentinel_src.transform != bathy_transform:
            raise ValueError("Sentinel-2 raster must already be aligned with bathymetry raster.")

        sentinel_data = sentinel_src.read()  # Read all bands of Sentinel-2 raster

        # Create a mask for Sentinel-2 valid pixels (exclude zero values)
        valid_sentinel_mask = sentinel_data[0, :, :] != 0  # Assuming first band is representative

        # Combine masks (valid bathymetry AND valid Sentinel-2)
        combined_mask = valid_bathy_mask & valid_sentinel_mask

        # Apply the combined mask to Sentinel-2 data
        clipped_sentinel_data = np.where(combined_mask, sentinel_data, np.nan)

        # Update metadata
        sentinel_meta = sentinel_src.meta.copy()
        sentinel_meta.update({
            "dtype": "float32",
            "nodata": np.nan
        })

    # Save the clipped Sentinel-2 raster
    with rasterio.open(output_sentinel, "w", **sentinel_meta) as dst:
        dst.write(clipped_sentinel_data)

    print(f"Clipped Sentinel-2 raster saved to: {output_sentinel}")

def clip_bathy_by_sentinel(bathy_raster, clipped_sentinel_raster, output_bathy):
    # Open the clipped Sentinel-2 raster to create a valid mask
    with rasterio.open(clipped_sentinel_raster) as sentinel_src:
        sentinel_data = sentinel_src.read(1)  # Read the first band (assumes single-band data)
        valid_sentinel_mask = ~np.isnan(sentinel_data)  # Non-NaN pixels are valid
        sentinel_transform = sentinel_src.transform
        sentinel_crs = sentinel_src.crs

    # Open the bathymetry raster
    with rasterio.open(bathy_raster) as bathy_src:
        # Ensure CRS and transform match
        if bathy_src.crs != sentinel_crs or bathy_src.transform != sentinel_transform:
            raise ValueError("Bathymetry raster must already be aligned with Sentinel-2 raster.")

        bathy_data = bathy_src.read(1)  # Assuming single-band bathymetry data
        clipped_bathy_data = np.where(valid_sentinel_mask, bathy_data, np.nan)  # Mask bathymetry data

        # Update metadata
        bathy_meta = bathy_src.meta.copy()
        bathy_meta.update({
            "dtype": "float32",
            "nodata": np.nan
        })

    # Save the clipped bathymetry raster
    with rasterio.open(output_bathy, "w", **bathy_meta) as dst:
        dst.write(clipped_bathy_data, 1)

    print(f"Clipped bathymetry raster saved to: {output_bathy}")

# Establish working directories

In [None]:
S2_PATH = '/media/clay/Crucial/SDB/CESWG/merged_acolite'
BATHY_PATH = '/media/clay/Crucial/SDB/CESWG/bathy_rasters'
S2_MASK = S2_PATH.replace('merged', 'masked')
S2_PROJ = S2_PATH.replace('merged', 'projected')
BATHY_PROJ = BATHY_PATH.replace('rasters', 'proj')

FINAL_PATH = '/home/clay/Documents/SDB/CESWG/processed'
S2_FINAL = os.path.join(FINAL_PATH, 'S2')
BATHY_FINAL = os.path.join(FINAL_PATH, 'Bathy')

os.makedirs(S2_MASK, exist_ok=True)
os.makedirs(S2_PROJ, exist_ok=True)
os.makedirs(BATHY_PROJ, exist_ok=True)
os.makedirs(FINAL_PATH, exist_ok=True)
os.makedirs(S2_FINAL, exist_ok=True)
os.makedirs(BATHY_FINAL, exist_ok=True)

In [None]:
surveynames = [f[:-4] for f in os.listdir(BATHY_PATH) if f.endswith('.tif')]
surveyinfo = {}
for f in surveynames:
    hydro_tif = os.path.join(BATHY_PATH, f"{f}.tif")
    acolite_path = os.path.join(S2_PATH, f)
    
    if os.path.exists(acolite_path):
        rhow_paths = [os.path.join(acolite_path, file) for file in os.listdir(acolite_path) if 'rhow' in file if file.endswith('.tif')]
        Rrs_paths = [os.path.join(acolite_path, file) for file in os.listdir(acolite_path) if 'Rrs' in file if file.endswith('.tif')]
        spm_tur_paths = [os.path.join(acolite_path, file) for file in os.listdir(acolite_path) if 'TUR_SPM' in file if file.endswith('.tif')]

        surveyinfo[f] = [hydro_tif, rhow_paths, Rrs_paths, spm_tur_paths]

# Adjust the acolite tif Nan values

In [None]:
masked_surveyinfo = {}

for f, paths in surveyinfo.items():
    hydro_tif, rhow_paths, Rrs_paths, spm_tur_paths = paths
    masked_surveyinfo[f] = mask_acolite_images(rhow_paths, Rrs_paths, spm_tur_paths)

In [None]:
# visualize the masked results alongside the ehydro survey
i = 300

rhow_paths, rrs_paths, spm_tur_paths = masked_surveyinfo[surveynames[i]]
hydro_tif = surveyinfo[surveynames[i]][0]

# Get the list of masked bands and metadata from test_Rrs
masked_bands, meta = rhow_paths[0]
band_descriptions = meta.get('descriptions', None)

# Create subplots: using 3 rows and 4 columns to display 12 visualizations
fig, axes = plt.subplots(3, 4, figsize=(20, 15))
axes = axes.flat

# Plot the 11 masked bands
for idx, band in enumerate(masked_bands):
    im = axes[idx].imshow(band, cmap='viridis')
    if band_descriptions is not None and idx < len(band_descriptions) and band_descriptions[idx]:
        axes[idx].set_title(band_descriptions[idx])
    else:
        axes[idx].set_title(f"Band {idx + 1}")
    plt.colorbar(im, ax=axes[idx])
    
# Plot the hydro_tif in the 12th subplot
with rasterio.open(hydro_tif) as src:
    hydro_band = src.read(1)
im_hydro = axes[11].imshow(hydro_band, cmap='viridis')
axes[11].set_title("Bathy")
plt.colorbar(im_hydro, ax=axes[11])

plt.tight_layout()
plt.show()

# Remove the ACOLITE tifs that have no data (only Nans)

# Reproject ACOLITE tif products to the eHydro CRS

In [None]:
reprojected_surveyinfo = {}

for f, results in masked_surveyinfo.items():
    hydro_tif = surveyinfo[f][0]
    rhow_results, Rrs_results, spm_tur_results = results
    reprojected_surveyinfo[f] = reproject_acolite(hydro_tif, rhow_results, Rrs_results, spm_tur_results)

# Maybe get median or mosaic raster here from the surveys that have multiple S2 images
- need to also mask out new Nans introduced during the reprojection

# Resample bathy rasters from 10 ft resolution to same resolution as S2 rasters
- I think I resampled these to 32.8084 ft (10-meter) already, but can check just in case
- will just get the resolution from the metadata of the resampled acolite images

In [None]:
for name, rasters in surveyinfo.items():
    reprojected_bathy = os.path.join(BATHY_PROJ, f"{name}.tif")
    reprojected_s2 = os.path.join(S2_PROJ, f"{name}.tif")

    resample_bathy_to_sentinel2(rasters[0], reprojected_s2, reprojected_bathy)

In [None]:
# new dictionary with rasters of matching CRS and spatial resolution (~10 meters)

reprojected_rasters = {}
for name, rasters in surveyinfo.items():
    for raster in rasters:
        if 'bathy_rasters' in raster:
            bathypath = raster.replace('bathy_rasters', 'bathy_proj')
        elif 's2_rasters' in raster:
            s2path = raster.replace('s2_rasters', 's2_proj')
    reprojected_rasters[name] = [bathypath, s2path]

# Clip the S2 rasters by the valid bathymetry pixels
- need bounds of non np.nan pixels for clipping

In [None]:
for name, rasters in reprojected_rasters.items():
    reprojected_s2 = os.path.join(S2_FINAL, f"{name}.tif")

    clip_sentinel_by_bathy(rasters[0], rasters[1], reprojected_s2)

# Clip the bathymetry rasters by the valid cloud-masked S2 pixels
- need bounds of valid pixels, seems like these will be values above 0.0 since no nan-value is applied in GEE

In [None]:
for name, rasters in reprojected_rasters.items():
    final_bathy = os.path.join(BATHY_FINAL, f"{name}.tif")
    final_s2 = os.path.join(S2_FINAL, f"{name}.tif")

    clip_bathy_by_sentinel(rasters[0], final_s2, final_bathy)

Should be done with needed preprocessing, can now move to training the model on the data
- will try traditional ML regression models, as well as CNN
- May try majority voting of multiple training set models like in Tan et al. 2022

# Thinking that exporting the data to a .csv or parquet can alleviate some of my memory issues. Going to try it

In [None]:

def prepare_train_data(surveynames):
    pairs = [(os.path.join(S2_PATH, f'{name}.tif'), os.path.join(BATHY_PATH, f'{name}.tif')) for name in surveynames]
    
    good_pairs = []
    goodnames = []
    for name, pair in zip(surveynames, pairs):
        with rasterio.open(pair[0]) as src:
            band = src.read(1)
            if band.shape[0] != 0:
                good_pairs.append(pair)
                goodnames.append(name)

    images_data = [extract_raster_data(pair) for pair in good_pairs]
    ncf_channels, survey_types = survey_name_type(goodnames)
    # all_bands = [create_composite_bands_with_existing(pair[0]) for pair in images_data]
    pixel_positions = [get_pixel_positions(os.path.join(S2_PATH, f'{name}.tif')) for name in goodnames]
    
    data = {}
    for i, name in enumerate(goodnames):
        # Extract data for the current iteration
        bands = images_data[i]               # Shape (n_pixels, 7)
        positions = pixel_positions[i]     # Shape (n_pixels, 2)
    
        data[name] = pd.DataFrame({
                "Blue": bands[0][0][:, 0],
                "Green": bands[0][0][:, 1],
                "Red": bands[0][0][:, 2],
                "NIR": bands[0][0][:, 3],
                "Blue/Green": bands[0][0][:, 4],
                "Green/Blue": bands[0][0][:, 5],
                "Stumpf": bands[0][0][:, 6],
                "NSMI": bands[0][0][:, 7],
                "TI": bands[0][0][:, 8],
                "X": positions[:, 0],
                "Y": positions[:, 1],
                "Channel_Name": [ncf_channels[i]] * len(bands[0][0]),  # Repeating value directly
                "Bathymetry": bands[0][1]
            })

    combined_df = pd.concat(data.values(), ignore_index=True)

    encoder = LabelEncoder()
    combined_df['Channel_Name_Encoded'] = encoder.fit_transform(combined_df['Channel_Name'])

    output = open(os.path.join(WORK_DIR, 'Channel_Name_label_encoders.pkl'), 'wb')
    pickle.dump(encoder, output)
    output.close()
    
    # Drop original categorical columns
    combined_df.drop(columns=['Channel_Name'], inplace=True)

    X = combined_df.drop(columns=['Bathymetry'])
    y = combined_df['Bathymetry']

    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)

    # Split temp into validation (15%) and test (15%)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

    return combined_df, X_train, y_train, X_test, y_test, X_val, y_val