In [None]:
import numpy as np
from scipy.stats import mode
import rasterio
import os
from osgeo import gdal, osr
from shapely.geometry import box
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap
from shapely.geometry import box
from datetime import datetime
import cv2
import copy
from sklearn.decomposition import PCA
from skimage.filters import threshold_otsu
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture

%matplotlib widget 
%matplotlib inline

# Functions

In [None]:
def get_labels(labelpath):
    otsu_ims = [os.path.join(labelpath, f'otsu/{file}') for file in os.listdir(os.path.join(labelpath, f'otsu')) if file.endswith('.tif')]
    kmeans_ims = [os.path.join(labelpath, f'kmeans/{file}') for file in os.listdir(os.path.join(labelpath, f'kmeans')) if file.endswith('.tif')]
    gmm_ims = [os.path.join(labelpath, f'gmm/{file}') for file in os.listdir(os.path.join(labelpath, f'gmm')) if file.endswith('.tif')]
    majority_ims = [os.path.join(labelpath, f'majority/{file}') for file in os.listdir(os.path.join(labelpath, f'majority')) if file.endswith('.tif')]

    
    otsu_ims = sorted(otsu_ims, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))
    kmeans_ims = sorted(kmeans_ims, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))
    gmm_ims = sorted(gmm_ims, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))
    majority_ims = sorted(majority_ims, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))

    return otsu_ims, kmeans_ims, gmm_ims, majority_ims

def get_grd(grdpath):
    orig_ims = [os.path.join(grdpath, file) for file in os.listdir(grdpath) if file.endswith('.tif')]
    orig_ims = sorted(orig_ims, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))

    return orig_ims

def get_glcm(glcmpath):
    orig_glcms = [os.path.join(glcmpath, file) for file in os.listdir(glcmpath) if file.endswith('.tif')]
    orig_glcms = sorted(orig_glcms, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))

    return orig_glcms

# Clip each Sentinel-1 image
def clip_sentinel1_image(s1_path, output_path, s2_bounds):
    with rasterio.open(s1_path) as src:
        # Calculate the window corresponding to the bounding box (extent)
        window = rasterio.windows.from_bounds(*s2_bounds, transform=src.transform)
        
        # Read and clip the Sentinel-1 image
        clipped_image = src.read(window=window)
        
        # Create metadata for the clipped image
        out_meta = src.meta.copy()
        out_meta.update({
            'height': window.height,
            'width': window.width,
            'transform': src.window_transform(window)
        })
        
        # Write the clipped image to a new file
        with rasterio.open(output_path, 'w', **out_meta) as dst:
            dst.write(clipped_image)

def export_s1images(coll, s2_bounds, type):
    """
    type = str
        'original' or fitlertype used in sentinel_one_two.ipynb
    
    """

    for s1_path in enumerate(coll):
        output_path = os.path.join(s1_path[:16], f'Clipped/{type}/{s1_path[-17:]}')
        clip_sentinel1_image(s1_path, output_path, s2_bounds)

def plot_vv_vh_with_bbox(image_path, bbox):
    # Open the image using rasterio (assuming VV and VH are the first two bands)
    with rasterio.open(image_path) as src:
        # Read the VV and VH bands
        vv = src.read(1)  # VV is in the first band
        vh = src.read(2)  # VH is in the second band

        # Get the extent of the image (top-left and bottom-right coordinates)
        transform = src.transform
        height, width = vv.shape
        top_left = rasterio.transform.xy(transform, 0, 0, offset='center')
        bottom_right = rasterio.transform.xy(transform, height - 1, width - 1, offset='center')

    # Extract easting and northing from the corners
    min_easting, max_northing = top_left
    max_easting, min_northing = bottom_right

    # Prepare the bounding box as a shapely geometry
    bbox_geom = box(*bbox)  # Ensure bbox is an iterable [min_x, min_y, max_x, max_y]

    fig, ax = plt.subplots(1, 2, figsize=(14, 6))  # Two subplots for VV and VH bands

    # VV raster visualization with bounding box
    ax[0].imshow(vv, cmap='gray', extent=[min_easting, max_easting, min_northing, max_northing])
    x, y = bbox_geom.exterior.xy  # Extract coordinates for plotting the bounding box
    ax[0].plot(x, y, color='red', linewidth=2, label="Sentinel-2 Coverage")
    ax[0].set_title('VV Band with Bounding Box')
    ax[0].set_xlabel('Easting (meters)')
    ax[0].set_ylabel('Northing (meters)')
    ax[0].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks
    ax[0].yaxis.set_major_locator(mticker.MaxNLocator(5))  
    ax[0].legend(loc='lower right')

    # VH raster visualization with bounding box
    ax[1].imshow(vh, cmap='gray', extent=[min_easting, max_easting, min_northing, max_northing])
    ax[1].plot(x, y, color='red', linewidth=2, label="Sentinel-2 Coverage")
    ax[1].set_title('VH Band with Bounding Box')
    ax[1].set_xlabel('Easting (meters)')
    ax[1].set_ylabel('Northing (meters)')
    ax[1].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks
    ax[1].yaxis.set_major_locator(mticker.MaxNLocator(5))  
    ax[1].legend(loc='lower right')

    # Show the plot with layout adjustments
    plt.tight_layout()
    plt.show()

def get_EPSG(im):
    src = gdal.Open(im)
    wkt_projection = src.GetProjection()
    spatial_ref = osr.SpatialReference()
    spatial_ref.ImportFromWkt(wkt_projection)
    epsg_code = spatial_ref.GetAttrValue('AUTHORITY', 1)
    print(epsg_code)

    return epsg_code 

def reproject_raster(input_raster, output_raster, target_crs='EPSG:32615'):

    # Reproject using gdal.Warp
    warp_options = gdal.WarpOptions(dstSRS=target_crs)
    gdal.Warp(output_raster, input_raster, options=warp_options)

def clip_raster_by_bbox(input_raster, output_raster, bbox):
    # Define the output bounds (min_x, min_y, max_x, max_y)
    min_x, min_y, max_x, max_y = bbox

    # Use gdal.Translate to clip the raster by the bounding box
    options = gdal.TranslateOptions(projWin=[min_x, max_y, max_x, min_y])  # Note the order: projWin=[min_x, max_y, max_x, min_y]
    
    # Perform the clipping operation
    gdal.Translate(output_raster, input_raster, options=options)


def perform_pca(image_path, output_pca_path):
    dataset = gdal.Open(image_path)

    # Read all bands as separate arrays
    bands = [dataset.GetRasterBand(1).ReadAsArray(),  dataset.GetRasterBand(2).ReadAsArray()]

    # Convert the list of bands into a 3D NumPy array (bands, rows, cols)
    bands_array = np.stack(bands, axis=0)

    # Reshape the bands array into (pixels, bands) for PCA
    pixels, bands_count = bands_array.shape[1] * bands_array.shape[2], bands_array.shape[0]
    flattened_image = bands_array.reshape(bands_count, -1).T  # Shape: (pixels, bands)

    # Convert to float32 for OpenCV PCA
    flattened_image = flattened_image.astype(np.float32)

    # Perform PCA using OpenCV (reduce to 1 principal component)
    mean, eigenvectors = cv2.PCACompute(flattened_image, mean=None, maxComponents=1)
    pca_result = cv2.PCAProject(flattened_image, mean, eigenvectors)

    # Reshape the PCA result back to the original image dimensions
    pca_image = pca_result.reshape(bands_array.shape[1], bands_array.shape[2])
    # pca_image = np.nan_to_num(pca_image, nan=0.0, posinf=255.0, neginf=0.0)

    # Normalize the PCA image to 0-255 for OpenCV processing
    pca_image_normalized = cv2.normalize(pca_image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

    # Save the PCA-reduced image
    output = gdal.GetDriverByName('GTiff').Create(output_pca_path, dataset.RasterXSize, dataset.RasterYSize, 1, gdal.GDT_Float32)
    output.SetProjection(dataset.GetProjection())
    output.SetGeoTransform(dataset.GetGeoTransform())
    output.GetRasterBand(1).WriteArray(pca_image_normalized)
    output.FlushCache()  # Ensure data is written to disk
    output = None

def register_images(reference_image, target_image, fill_value):
    # Define warp mode: use affine transformation (can also use cv2.MOTION_EUCLIDEAN)
    warp_mode = cv2.MOTION_TRANSLATION

    # Initialize the transformation matrix (2x3 affine transformation matrix)
    warp_matrix = np.eye(2, 3, dtype=np.float32)

    # Define criteria for the ECC algorithm
    criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 5000, 1e-6)

    # Create binary masks where NaNs are located
    reference_mask = np.isnan(reference_image).astype(np.uint8)
    target_mask = np.isnan(target_image).astype(np.uint8)

    # Replace NaN values with the specified fill_value
    reference_image = np.nan_to_num(reference_image, nan=fill_value)
    target_image = np.nan_to_num(target_image, nan=fill_value)

    # Apply masks to the images to ignore NaN areas
    reference_image *= (1 - reference_mask)
    target_image *= (1 - target_mask)

    # Perform the ECC algorithm to find the transformation matrix
    try:
        cc, warp_matrix = cv2.findTransformECC(reference_image, target_image, warp_matrix, warp_mode, criteria)
    except cv2.error as e:
        print(f"Error in ECC: {e}")
        return None

    return warp_matrix

def apply_transformation_to_all_bands(target_bands, warp_matrix, image_shape, output_dtype=np.float32):
    transformed_bands = []
    
    for band in target_bands:
        # Apply the transformation to the band
        transformed_band = cv2.warpAffine(band.astype(np.float32), warp_matrix, (image_shape[1], image_shape[0]), 
                                          flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)
        
        # Handle NaN or infinite values by replacing them with valid values (e.g., 0)
        # transformed_band = np.nan_to_num(transformed_band, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Convert to the desired output data type
        transformed_band = transformed_band.astype(output_dtype)
        
        transformed_bands.append(transformed_band)
    
    return transformed_bands

def save_multiband_image_as_tiff(output_path, transformed_bands, reference_dataset, gdal_dtype=gdal.GDT_Float32):
    # Create an output GeoTIFF file with the same dimensions and the same number of bands
    driver = gdal.GetDriverByName('GTiff')
    out_dataset = driver.Create(output_path, reference_dataset.RasterXSize, reference_dataset.RasterYSize, len(transformed_bands), gdal_dtype)

    # Set the projection and geotransform from the reference dataset
    out_dataset.SetProjection(reference_dataset.GetProjection())
    out_dataset.SetGeoTransform(reference_dataset.GetGeoTransform())

    # Write each transformed band to the output file
    for i, transformed_band in enumerate(transformed_bands):
        out_dataset.GetRasterBand(i + 1).WriteArray(transformed_band)

    # Flush data to disk
    out_dataset.FlushCache()
    out_dataset = None

def sar_ratio(impathlist, outpath):
    for im in impathlist:
        testds = gdal.Open(im)

        vv_dB = testds.GetRasterBand(1).ReadAsArray()  #.astype(float)
        vh_dB = testds.GetRasterBand(2).ReadAsArray()    #.astype(float)

        # Convert from dB to linear scale
        vv_linear = 10 ** (vv_dB / 10)
        vh_linear = 10 ** (vh_dB / 10)

        with np.errstate(divide='ignore', invalid='ignore'):
            ratio1 = (4 * vh_linear)/(vh_linear + vv_linear) # sentinel-1 radar veg index
            ratio1[ratio1 == np.inf] = np.nan
            ratio2 = np.log(10 * vv_linear * vh_linear) # sentinel-1 dual-polarization water index
            ratio2[ratio2 == np.inf] = np.nan


        save_multiband_image_as_tiff(os.path.join(outpath, im[-14:]), [vv_dB, vh_dB, ratio1, ratio2], testds)

def get_grd_avg(grd_avg_path, combined_ims):
    # Open all the combined .vrt files and read their bands
    all_bands = []
    
    # Loop over each file to read its bands
    for f in combined_ims:
        ds = gdal.Open(f)
        bands = [ds.GetRasterBand(i+1).ReadAsArray() for i in range(ds.RasterCount)]
        all_bands.append(bands)
    
    # Stack the bands across all images (axis=0 for stacking across different images)
    stacked_bands = [np.stack([image_bands[i] for image_bands in all_bands], axis=0) for i in range(len(all_bands[0]))]
    
    # Compute the mean for each band across the stacked images (axis=0 is across images)
    mean_bands = [np.mean(stacked_band, axis=0) for stacked_band in stacked_bands]

    # Create a new GeoTIFF with the averaged bands
    driver = gdal.GetDriverByName('GTiff')
    
    # Use the first file for spatial reference (CRS and geotransform)
    ds = gdal.Open(combined_ims[0])
    
    # Create an output file with the same dimensions and number of bands as the input
    result = driver.Create(grd_avg_path, ds.RasterXSize, ds.RasterYSize, len(mean_bands), gdal.GDT_Float32)

    # Copy projection and geotransform from the original dataset
    result.SetProjection(ds.GetProjection())
    result.SetGeoTransform(ds.GetGeoTransform())

    # Write each averaged band to the output file
    for i, meanband in enumerate(mean_bands):
        result.GetRasterBand(i+1).WriteArray(meanband)

    # Close the result dataset to flush the data to disk
    result = None

    return grd_avg_path

def perform_pca(image_path, output_pca_path):
    # Load the Sentinel-2 multi-band image using GDAL
    dataset = gdal.Open(image_path)

    # Read all bands as separate arrays
    bands = [dataset.GetRasterBand(i + 1).ReadAsArray() for i in range(dataset.RasterCount)]

    # Convert the list of bands into a 3D NumPy array (bands, rows, cols)
    bands_array = np.stack(bands, axis=0)

    # Reshape the bands array into (pixels, bands) for PCA
    pixels, bands_count = bands_array.shape[1] * bands_array.shape[2], bands_array.shape[0]
    flattened_image = bands_array.reshape(bands_count, -1).T  # Shape: (pixels, bands)

    # Convert to float32 for OpenCV PCA
    flattened_image = flattened_image.astype(np.float32)

    # Perform PCA using OpenCV (reduce to 1 principal component)
    mean, eigenvectors = cv2.PCACompute(flattened_image, mean=None, maxComponents=1)
    pca_result = cv2.PCAProject(flattened_image, mean, eigenvectors)

    # Reshape the PCA result back to the original image dimensions
    pca_image = pca_result.reshape(bands_array.shape[1], bands_array.shape[2])
    # pca_image = np.nan_to_num(pca_image, nan=0.0, posinf=255.0, neginf=0.0)

    # Normalize the PCA image to 0-255 for OpenCV processing
    pca_image_normalized = cv2.normalize(pca_image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

    # Save the PCA-reduced image
    output = gdal.GetDriverByName('GTiff').Create(output_pca_path, dataset.RasterXSize, dataset.RasterYSize, 1, gdal.GDT_Float32)
    output.SetProjection(dataset.GetProjection())
    output.SetGeoTransform(dataset.GetGeoTransform())
    output.GetRasterBand(1).WriteArray(pca_image_normalized)
    output.FlushCache()  # Ensure data is written to disk
    output = None

def plot_class_with_grd(grd_path, otsu_class, kmeans_class, gmm_class): 

    with rasterio.open(grd_path) as rat_src:
        vv = rat_src.read(1) #vv
        vh = rat_src.read(2) #vh
        rvi = rat_src.read(3) #RVI
        sdwi = rat_src.read(4)   #SDWI

        rvi = min_max_scale(rvi, np.nanmin(rvi), np.nanmax(rvi))
        vh = min_max_scale(vh, np.nanmin(vh), np.nanmax(vh))
        vv = min_max_scale(vv, np.nanmin(vv), np.nanmax(vv))
        sdwi = min_max_scale(sdwi, np.nanmin(sdwi), np.nanmax(sdwi))


        transform = rat_src.transform
        height, width = vv.shape[:2]
        top_left = rasterio.transform.xy(transform, 0, 0, offset='center')
        bottom_right = rasterio.transform.xy(transform, height-1, width-1, offset='center')

    # Extract easting and northing from the corners
    min_easting, max_northing = top_left
    max_easting, min_northing = bottom_right

    # Define a custom colormap for the classifications
    cmap = ListedColormap(['blue', 'green', 'red'])  # Blue for class 0, Green for class 1, Red for class 2

    fig, ax = plt.subplots(1, 5, figsize=(30, 6))  # 5 subplots for RGB, Otsu, K-Means, GMM, and Majority Vote

    # VV image visualization
    ax[0].imshow(vv, extent=[min_easting, max_easting, min_northing, max_northing])
    ax[0].set_title(f'{grd_path[-14:-4]} VV')
    ax[0].set_xlabel('Easting (meters)')
    ax[0].set_ylabel('Northing (meters)')
    ax[0].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks

    # VH visualization
    ax[1].imshow(vh, extent=[min_easting, max_easting, min_northing, max_northing])
    ax[1].set_title(f'{grd_path[-14:-4]} VH')
    ax[1].set_xlabel('Easting (meters)')
    ax[1].set_ylabel('Northing (meters)')
    ax[1].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks

    # Custom legends for classification plots
    red_patch = mpatches.Patch(color='red', label='Subaerial Land')
    green_patch = mpatches.Patch(color='green', label='Subaqueous Land')
    blue_patch = mpatches.Patch(color='blue', label='Open Water')

    # Otsu classification visualization
    ax[2].imshow(otsu_class, cmap=cmap, extent=[min_easting, max_easting, min_northing, max_northing])
    ax[2].set_title('Otsu')
    ax[2].set_xlabel('Easting (meters)')
    ax[2].legend(handles=[blue_patch, green_patch, red_patch], loc='lower right', title="Classification")
    ax[2].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks

    # K-Means classification visualization
    ax[3].imshow(kmeans_class, cmap=cmap, extent=[min_easting, max_easting, min_northing, max_northing])
    ax[3].set_title('KMeans')
    ax[3].set_xlabel('Easting (meters)')
    ax[3].legend(handles=[blue_patch, green_patch, red_patch], loc='lower right', title="Classification")
    ax[3].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks

    # GMM classification visualization
    ax[4].imshow(gmm_class, cmap=cmap, extent=[min_easting, max_easting, min_northing, max_northing])
    ax[4].set_title('GMM')
    ax[4].set_xlabel('Easting (meters)')
    ax[4].legend(handles=[blue_patch, green_patch, red_patch], loc='lower right', title="Classification")
    ax[4].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks
    


    # Show the plot with layout adjustments
    plt.tight_layout()
    plt.show()

def min_max_scale(band, band_min, band_max):
    """Normalize band to [0, 1] range."""
    return (band - band_min) / (band_max - band_min)

def kmeans_land_water(image_path1, n_clusters):
    # Read the first image
    with rasterio.open(image_path1) as src1:
        img1 = src1.read()  # Read all bands
        height, width = img1.shape[1], img1.shape[2]

    img1 = min_max_scale(img1, np.nanmin(img1), np.nanmax(img1))

    # Reshape to (num_pixels, num_bands)
    img_flat = img1.reshape((img1.shape[0], height * width)).T

    # Remove rows with NaN values
    mask = ~np.isnan(img_flat).any(axis=1)
    img_no_nan = img_flat[mask]

    # Apply K-Means clustering only on non-NaN pixels
    kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(img_no_nan)

    # Initialize an output array filled with NaNs
    classified_img = np.full((height * width), np.nan)

    # Set the labels to the non-NaN pixels
    classified_img[mask] = kmeans.labels_

    # Reshape the classified image back to (height, width)
    classified_img = classified_img.reshape((height, width))
    return classified_img

def gmm_land_water(image_path1, n_components):
    # Read the first image
    with rasterio.open(image_path1) as src1:
        img1 = src1.read()  # Read all bands
        height, width = img1.shape[1], img1.shape[2]
    

    img1 = min_max_scale(img1, np.nanmin(img1), np.nanmax(img1))

    # Reshape to (num_pixels, num_bands)
    img_flat = img1.reshape((img1.shape[0], height * width)).T

    # Remove rows with NaN values
    mask = ~np.isnan(img_flat).any(axis=1)
    img_no_nan = img_flat[mask]

    # Apply Gaussian Mixture Model only on non-NaN pixels
    gmm = GaussianMixture(n_components=n_components, random_state=42).fit(img_no_nan)

    # Predict labels for the non-NaN pixels
    gmm_labels = np.full((height * width), np.nan)  # Initialize with NaN
    gmm_labels[mask] = gmm.predict(img_no_nan)

    # Reshape the labels back to (height, width)
    gmm_labels = gmm_labels.reshape((height, width))

    return gmm_labels


def process_image_pca_otsu(imagepath):
    # Step 1: Read and normalize the bands
    with rasterio.open(imagepath) as src:
        vv = src.read(1)  # VV band
        vh = src.read(2)  # VH band
        rvi = src.read(3)  # RVI
        sdwi = src.read(4)  # SDWI

    # Convert VV and VH from dB to linear scale
    vv = 10 ** (vv / 10)
    vh = 10 ** (vh / 10)

    # Normalize each band to [0, 1]
    vv = min_max_scale(vv, np.nanmin(vv), np.nanmax(vv))
    vh = min_max_scale(vh, np.nanmin(vh), np.nanmax(vh))
    rvi = min_max_scale(rvi, np.nanmin(rvi), np.nanmax(rvi))
    sdwi = min_max_scale(sdwi, np.nanmin(sdwi), np.nanmax(sdwi))

    # Stack the bands and reshape for PCA
    bands_stack = np.stack([vv, vh, rvi, sdwi], axis=-1)
    n_rows, n_cols, n_bands = bands_stack.shape
    pixels_2d = bands_stack.reshape(-1, n_bands)

    # Remove NaN or invalid values
    valid_mask = np.all(np.isfinite(pixels_2d), axis=1)
    valid_pixels = pixels_2d[valid_mask]

    # Apply PCA and reshape the first component back to 2D
    pca = PCA(n_components=1)
    pca_band = pca.fit_transform(valid_pixels)
    pca_2d = np.full((n_rows * n_cols), np.nan)
    pca_2d[valid_mask] = pca_band.flatten()
    pca_2d = pca_2d.reshape(n_rows, n_cols)

    # Step 2: Apply Bimodal Otsu Thresholding for Two Classes
    pca_valid = pca_2d[~np.isnan(pca_2d)]
    threshold = threshold_otsu(pca_valid)
    classified_image = (pca_2d > threshold).astype(int)  # 0 for subaqueous, 1 for subaerial

    return pca_2d, classified_image
    
def plot_majvote_with_grd(grd_path, majority_vote):
    
    with rasterio.open(grd_path) as rat_src:
        vv = rat_src.read(1) #vv
        vh = rat_src.read(2) #vh
        rvi = rat_src.read(3) #RVI
        sdwi = rat_src.read(4)   #SDWI

        rvi = min_max_scale(rvi, np.nanmin(rvi), np.nanmax(rvi))
        vh = min_max_scale(vh, np.nanmin(vh), np.nanmax(vh))
        vv = min_max_scale(vv, np.nanmin(vv), np.nanmax(vv))
        sdwi = min_max_scale(sdwi, np.nanmin(sdwi), np.nanmax(sdwi))


        transform = rat_src.transform
        height, width = vv.shape[:2]
        top_left = rasterio.transform.xy(transform, 0, 0, offset='center')
        bottom_right = rasterio.transform.xy(transform, height-1, width-1, offset='center')

    # Extract easting and northing from the corners
    min_easting, max_northing = top_left
    max_easting, min_northing = bottom_right

    # Define a custom colormap for the classifications
    cmap = ListedColormap(['blue', 'green', 'red'])  # Blue for class 0, Green for class 1, Red for class 2

    fig, ax = plt.subplots(1, 3, figsize=(15, 6))  # 5 subplots for grd, Otsu, K-Means, GMM, and Majority Vote

    # VV image visualization
    ax[0].imshow(vv, extent=[min_easting, max_easting, min_northing, max_northing])
    ax[0].set_title(f'{grd_path[-14:-4]} VV')
    ax[0].set_xlabel('Easting (meters)')
    ax[0].set_ylabel('Northing (meters)')
    ax[0].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks

    # VH visualization
    ax[1].imshow(vh, extent=[min_easting, max_easting, min_northing, max_northing])
    ax[1].set_title(f'{grd_path[-14:-4]} VH')
    ax[1].set_xlabel('Easting (meters)')
    ax[1].set_ylabel('Northing (meters)')
    ax[1].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks

    # Custom legends for classification plots
    red_patch = mpatches.Patch(color='red', label='Subaerial Land')
    green_patch = mpatches.Patch(color='green', label='Subaqueous Land')
    blue_patch = mpatches.Patch(color='blue', label='Open Water')
    
    # Majority vote classification visualization
    ax[2].imshow(majority_vote, cmap=cmap, extent=[min_easting, max_easting, min_northing, max_northing])
    ax[2].set_title('Majority Vote')
    ax[2].set_xlabel('Easting (meters)')
    ax[2].legend(handles=[blue_patch, green_patch, red_patch], loc='lower right', title="Classification")
    ax[2].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks

    # Show the plot with layout adjustments
    plt.tight_layout()
    plt.show()

def calculate_mean_vh_for_classes(ratio_im_path, labels_list):
    # Read the vh band image
    with rasterio.open(ratio_im_path) as src:
        vh = src.read(2)  # Assuming the vh band is in second position

    # Initialize lists to store mean vh for each class
    mean_vh_class0 = [] # subaqueous
    mean_vh_class1 = [] # subaerial 

    # Iterate through each set of labels
    for labels in labels_list:
        # Create masks for each class
        class0_mask = (labels == 0) # subaqueous
        class1_mask = (labels == 1) # subaerial

        # Calculate the mean vh for each class, ignoring NaNs
        mean_vh_0 = np.nanmean(vh[class0_mask])
        mean_vh_1 = np.nanmean(vh[class1_mask])

        # Append the mean values to the lists
        mean_vh_class0.append(mean_vh_0)
        mean_vh_class1.append(mean_vh_1)

    return mean_vh_class0, mean_vh_class1

# 1. Get original Sentienl-1 GRD VV+VH (dB) and GLCM texture data downloaded from Google Earth Engine

In [None]:
############### WSL #########################
work_dir = '/mnt/d/SabineRS'

############### linux #########################
# work_dir = '/home/wcc/Desktop/SabineRS/'

In [None]:
# set the directory for where your images are located

############### WSL #########################
orig_ims = get_grd('/mnt/d/SabineRS/GRD/0_initial/backscatter')
orig_glcms = get_glcm('/mnt/d/SabineRS/GRD/0_initial/glcm')

############### Linux #########################
# orig_ims = get_grd('/home/wcc/Desktop/SabineRS/GRD/0_initial/backscatter')
# orig_glcms = get_glcm('/home/wcc/Desktop/SabineRS/GRD/0_initial/glcm')

# 1. Clip all images to same extent of first image in time-series

In [None]:
#checking the CRS
orig_epsg = get_EPSG(orig_ims[0])

In [None]:
# getting a bbox for plotting

src = gdal.Open(orig_ims[0])
geo_transform = src.GetGeoTransform()
coords = [geo_transform[0], 
           geo_transform[0] + (src.RasterXSize * geo_transform[1]), 
           geo_transform[3] + (src.RasterYSize * geo_transform[5]), 
           geo_transform[3]
            ]
bbox = [coords[0], coords[2], coords[1], coords[3]]
bbox

In [None]:
for i , im in enumerate(orig_ims):
    clip_raster_by_bbox(im, os.path.join(work_dir, f'GRD/1_clipped/backscatter/{im[-17:]}'), bbox)

    clip_raster_by_bbox(orig_glcms[i], os.path.join(work_dir, f'GRD/1_clipped/{orig_glcms[i][-22:]}'), bbox)

In [None]:
# get the reprojected and clipped data
clip_orig = get_grd(os.path.join(work_dir, f'GRD/1_clipped/backscatter'))
clip_orig_glcms = get_glcm(os.path.join(work_dir, f'GRD/1_clipped/glcm'))

clip_orig = sorted(clip_orig, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))
clip_orig_glcms = sorted(clip_orig_glcms, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))

# 2. Generate VV/VH ratio bands

In [None]:
# vv, vh, sentinel-1 radar veg index, and sentinel-1 dual-pol water index

sar_ratio(clip_orig, f'/mnt/d/SabineRS/GRD/3_ratio')
# sar_ratio(clip_orig, f'/home/wcc/Desktop/SabineRS/GRD/3_ratio')
ratio_ims = [os.path.join('/mnt/d/SabineRS/GRD/3_ratio', file) for file in os.listdir('/mnt/d/SabineRS/GRD/3_ratio') if file.endswith('.tif')]
# ratio_ims = [os.path.join('/home/wcc/Desktop/SabineRS/GRD/3_ratio', file) for file in os.listdir('/home/wcc/Desktop/SabineRS/GRD/3_ratio') if file.endswith('.tif')]
ratio_ims = sorted(ratio_ims, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))

# 3. Apply unsupervised classification methods to classify (hopefully) classify subaerial land, subaqeuous land, and open water from SAR iamagery

In [None]:
otsuclasses = []
kmeansclasses = []
gmmclasses = []


for i, im in enumerate(ratio_ims):
    km = kmeans_land_water(im, 2)
    gmm = gmm_land_water(im, 2)
    kmeansclasses.append(km)
    gmmclasses.append(gmm)
    
    pca_2d, binary_mask = process_image_pca_otsu(im)
    otsuclasses.append(binary_mask)

In [None]:
# correct the classes based on VH backscatter values
# class 0=open water will ahve the lowest mean VH backscatter
# class 2=Subaerial land will have the highest mean VH backscatter
# class 1=Subaqueuous land will be in the middle

# Extracting the mean VH backscatter for each class
vh_means_otsu = []
vh_means_kmeans = []
vh_means_gmm = []

for i, vh_im in enumerate(ratio_ims):
    # Calculate mean vh for current vh image and K-means/GMM labels
    mean_vh_otsu_0, mean_vh_otsu_1 = calculate_mean_vh_for_classes(vh_im, [otsuclasses[i]])
    mean_vh_km_0, mean_vh_km_1 = calculate_mean_vh_for_classes(vh_im, [kmeansclasses[i]])
    mean_vh_gmm_0, mean_vh_gmm_1 = calculate_mean_vh_for_classes(vh_im, [gmmclasses[i]])
    
    # Store the results
    vh_means_otsu.append((mean_vh_otsu_0[0], mean_vh_otsu_1[0]))
    vh_means_kmeans.append((mean_vh_km_0[0], mean_vh_km_1[0]))
    vh_means_gmm.append((mean_vh_gmm_0[0], mean_vh_gmm_1[0])) 

In [None]:
# Create a list to store the relabeled images
relabeled_images = {"otsu": [], "kmeans": [], "gmm": []}

# Relabeling each image in the time series
# List of classifier methods for looping
classification_methods = ["otsu", "kmeans", "gmm"]

for i, (method, entry) in enumerate(zip(classification_methods, [[vh_means_otsu, otsuclasses], [vh_means_kmeans, kmeansclasses], [vh_means_gmm, gmmclasses]])):
    for j in range(len(entry[0])):
        labelsort = np.argsort(entry[0][j])

        # Create a relabel map to remap the classes to [0, 1]
        relabel_map = {labelsort[idx]: float(idx) for idx in range(len(labelsort))}

        # Copy the classified image to avoid overwriting the original
        image_copy = copy.deepcopy(entry[1][j]).astype(float)  # Ensure image_copy is float type to support np.nan

        # Replace None values with np.nan for consistency
        image_copy = np.where(image_copy == None, np.nan, image_copy)  # Convert None to np.nan

        # Apply the relabel map directly, preserving np.nan values
        relabeled_image = image_copy.copy()
        for original_label, new_label in relabel_map.items():
            relabeled_image = np.where(image_copy == original_label, new_label, relabeled_image)

        # Add the processed relabeled image to the dictionary
        relabeled_images[method].append(relabeled_image)

In [None]:
cleaned_ims = {"otsu": [],
               "kmeans": [], 
               "gmm": []
               }

for i, (method, entry) in enumerate(zip(classification_methods, [relabeled_images['otsu'], relabeled_images["kmeans"], relabeled_images['gmm']])):
    for j, im in enumerate(entry):
        # Define a square kernel; adjust the size as needed
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))

        # apply morphological functions to eliminate isolated pixels from each class
        subaqueous = (im == 0).astype(np.uint8)
        subaerial = (im == 1).astype(np.uint8)

        ######## KMeans
        # Apply opening to remove small isolated pixels
        subaerial_cleaned = cv2.morphologyEx(subaerial, cv2.MORPH_OPEN, kernel)
        subaqueous_cleaned = cv2.morphologyEx(subaqueous, cv2.MORPH_OPEN, kernel)

        # Apply closing to fill small holes
        subaerial_cleaned = cv2.morphologyEx(subaerial_cleaned, cv2.MORPH_CLOSE, kernel)
        subaqueous_cleaned = cv2.morphologyEx(subaqueous_cleaned, cv2.MORPH_CLOSE, kernel)

        # Reconstruct the classified image
        cleaned_classified_image = (subaqueous_cleaned * i +
                                    subaqueous_cleaned * 1)      

        # Add the processed relabeled image to the dictionary
        cleaned_ims[method].append(cleaned_classified_image)

In [None]:
for i, im in enumerate([ratio_ims[0], ratio_ims[69], ratio_ims[137]]):
    plot_class_with_grd(
        im,
        cleaned_ims['otsu'][i], 
        cleaned_ims['kmeans'][i],
        cleaned_ims['gmm'][i]
    )

# 5. Use majority voting via the mode of the three classes for each time stamp.
- Final classification will be the where the pixel is the same class in 2/3 of the images

In [None]:
# Initialize an empty list to store the majority vote result for each image
majority_vote_images = []

# Loop through each image, applying majority voting across the methods for each pixel
for j in range(len(otsuclasses)):  # Assuming all methods have the same number of images
    # Stack the three classification arrays for image `j` across the method dimension (0)
    # Shape will be (methods, height, width) -> (3, height, width)
    stacked_classes = np.stack([otsuclasses[j], kmeansclasses[j], gmmclasses[j]], axis=0)
    
    # Apply majority voting along the first axis (methods)
    majority_vote = mode(stacked_classes, axis=0, nan_policy='omit')[0].squeeze()
    
    # Store the result in the majority vote list
    majority_vote_images.append(majority_vote)

In [None]:
for i, im in enumerate([ratio_ims[0], ratio_ims[69], ratio_ims[137]]):
    plot_majvote_with_grd(
        im,
        majority_vote_images[i]
    )