# Notebook to test intake and labeling of unseen data
- will try new surveys to continuously improve model accuracy
- will try new S2 and Planet images to test accuracy on unseen data. Need to get outlines of the NCF channels for these image searches

In [None]:
import os
import re
import rasterio
import pickle
import numpy as np
import ee
import geemap

In [None]:
ee.Initialize(project='')

In [None]:
# Function to normalize an array
def normalize(array):
    return (array - np.nanmin(array)) / (np.nanmax(array) - np.nanmin(array))

def extract_raster_data(pair_tuple):
    images_data = []

    for paths in pair_tuple:
        bathy_path = paths[1]
        s2_path = paths[0]

        # --- Step 1: Open Bathymetry Raster ---
        with rasterio.open(bathy_path) as bathy:
            bathy_data = bathy.read(1)  # Bathymetry data (band 1)
            bathy_nodata = bathy.nodata  # NoData value
            bathy_transform = bathy.transform
            bathy_shape = bathy.shape

        # --- Step 2: Open Sentinel-2 Raster ---
        with rasterio.open(s2_path) as s2:
            if s2.shape != bathy_shape or s2.transform != bathy_transform:
                raise ValueError(
                    f"Inconsistent shapes or transforms:\n"
                    f"Bathymetry Shape: {bathy_shape}, Sentinel-2 Shape: {s2.shape}.\n"
                    f"Bathymetry Transform: {bathy_transform}, Sentinel-2 Transform: {s2.transform}.\n"
                    f"Ensure rasters have identical extents and resolutions."
                )

            # Read Sentinel-2 bands
            bands = {
                "red": normalize(s2.read(3)),
                "green": normalize(s2.read(2)),
                "blue": normalize(s2.read(1)),
                "nir": normalize(s2.read(4))
            }

            s2_nodata = s2.nodata  # Sentinel-2 NoData value

        # --- Step 3: Flatten Bands ---
        flat_bathy = bathy_data.flatten()
        flat_bands = {key: band.flatten() for key, band in bands.items()}

        # --- Step 4: Mask NoData Values ---
        valid_mask = (
            ~np.isnan(flat_bathy) &  # Valid bathy pixels
            (flat_bathy != bathy_nodata)  # Exclude bathy NoData
        )

        for band in flat_bands.values():
            valid_mask &= (band != s2_nodata)  # Exclude Sentinel-2 NoData

        # Apply the mask
        valid_bathy = flat_bathy[valid_mask].reshape(-1, 1)  # Reshape bathy to (n_pixels, 1)
        valid_features = np.column_stack([band[valid_mask] for band in flat_bands.values()])

        # --- Step 5: Combine Features and Targets ---
        # combined_features = np.concatenate((valid_bathy, valid_features), axis=1)  # Combine bathy and S2
        images_data.append((valid_features, valid_bathy.flatten()))  # Flatten bathy for targets

    return images_data

def survey_name_type(surveynames):
    """
    Will take in the list of surveynames and extract the NCF channel ID and the survey type
    """
    surveytypes = ['AD', 'BD', 'CS', 'PA', 'PR', 'XA', 'XB', 'XC', 'OT', 'DS']

    extracted_parts = [re.match(r'^(.*?)_\d{8}', path).group(1) for path in surveynames if re.match(r'^(.*?)_\d{8}', path)]
    channel_ids = [re.sub(r'^.*?_DIS_', '', path) for path in extracted_parts]

    isolated_survey_types = []
    for path in surveynames:
        for type in surveytypes:
            if type in path:
                isolated_survey_types.append(type)
                break  # Stop checking after the first match

    return channel_ids, isolated_survey_types

def create_composite_bands_with_existing(flattened_s2):
    if flattened_s2.shape[1] != 4:
        raise ValueError("Input array must have 4 columns representing B, G, R, NIR bands.")

    # Split the bands
    blue = flattened_s2[:, 0]
    green = flattened_s2[:, 1]

    # Compute composite bands
    with np.errstate(divide='ignore', invalid='ignore'):
        bluegreen = np.divide(blue, green, out=np.zeros_like(blue), where=green != 0)
        greenblue = np.divide(green, blue, out=np.zeros_like(green), where=blue != 0)
        stumpf = np.divide(
            np.log(blue + 1e-6), np.log(green + 1e-6), out=np.zeros_like(blue), where=(green > 0) & (blue > 0)
        )

    # Normalize composite bands
    bluegreen = normalize(bluegreen)
    greenblue = normalize(greenblue)
    stumpf = normalize(stumpf)

    # Combine all bands
    combined_array = np.hstack((flattened_s2, bluegreen[:, None], greenblue[:, None], stumpf[:, None]))

    return combined_array

def get_pixel_positions(raster_path):
    with rasterio.open(raster_path) as src:
        # Get the affine transformation of the raster
        transform = src.transform
        
        # Read the first band to determine valid (non-NaN) pixels
        band_data = src.read(1, masked=True)  # Read the first band as a masked array
        valid_mask = ~band_data.mask          # Valid pixels where mask is False

        # Get raster dimensions
        height, width = src.height, src.width

        # Create arrays of pixel indices
        row_indices, col_indices = np.meshgrid(np.arange(height), np.arange(width), indexing="ij")

        # Compute x, y positions using the affine transform
        xs, ys = rasterio.transform.xy(transform, row_indices, col_indices, offset='center')
        xs = np.array(xs).flatten()
        ys = np.array(ys).flatten()

        # Filter x, y positions to include only valid pixels
        valid_positions = np.column_stack((xs[valid_mask.flatten()], ys[valid_mask.flatten()]))

    return valid_positions



In [1]:
# NEED TO INCLUDE PICKLE FILE PATHS FOR ENCODERS
# NEED TO INCLUDE A SIMILAR NAMING SYSTEM AS THE EHYDRO SURVEYS 
# 


def prepare_new_data(surveynames, pickle_files):
    """
    This function will take in new Sentinel-2 .tif files from GEE (R, G, B, NIR) and
    create a dataframe in the format needed for input into the bathymetry model

    surveynames: list
        list of strings containing the names of th

    pickle_files: list
        list of strings 
    """
    images = [os.path.join(NEW_DIR, f'{name}.tif') for name in surveynames]

    images_data = extract_raster_data(images)

    good_pairs = []
    goodnames = []
    for name, pairs in zip(surveynames, images_data):
        if pairs[0].shape[0] != 0:
            good_pairs.append(pairs)
            goodnames.append(name)

    ncf_channels, survey_types = survey_name_type(goodnames)
    all_bands = [create_composite_bands_with_existing(pair[0]) for pair in good_pairs]
    pixel_positions = [get_pixel_positions(pair[0]) for pair in good_pairs]

    data = {}

    # Iterate through the names and corresponding data
    for i, name in enumerate(goodnames):
        # Extract data for the current iteration
        bands = all_bands[i]               # Shape (n_pixels, 7)
        positions = pixel_positions[i]     # Shape (n_pixels, 2)
        bathymetry = good_pairs[i][1]      # Shape (n_pixels,)

        # Create a dataframe directly using a dictionary comprehension
        data[name] = pd.DataFrame({
            "Blue": bands[:, 0],
            "Green": bands[:, 1],
            "Red": bands[:, 2],
            "NIR": bands[:, 3],
            "Blue/Green": bands[:, 4],
            "Green/Blue": bands[:, 5],
            "Stumpf": bands[:, 6],
            "X": positions[:, 0],
            "Y": positions[:, 1],
            "Channel_Name": [ncf_channels[i]] * len(bands),  # Repeating value directly
            "Survey_Type": [survey_types[i]] * len(bands),   # Repeating value directly
            "Bathymetry": bathymetry
        })

        for col in ['Channel_Name', 'Survey_Type']:
            pkl_file = open(os.path.join(WORK_DIR, f'{col}_label_encoders.pkl'), 'rb')
            encoder= pickle.load(pkl_file) 
            pkl_file.close()

            data[name][col] = encoder.transform(data[name][col])

    return data

In [None]:
# path to new S2 rasters for labeling
# going

NEW_DIR = 

surveynames =
filepaths = 