In [None]:
import os
from datetime import datetime
from itertools import product
import rasterio
from rasterio import windows
from shapely.geometry import box
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.ticker as mticker
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import dask
from dask.distributed import Client
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Lambda, Conv2D, UpSampling2D, Cropping2D, MaxPooling2D, Dropout, BatchNormalization, Conv2DTranspose, concatenate, Flatten, Dense, UpSampling2D
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import Adam

# Functions

In [None]:
def get_labels(labelpath):
    otsu_ims = [os.path.join(labelpath, f'otsu/{file}') for file in os.listdir(os.path.join(labelpath, f'otsu')) if file.endswith('.tif')]
    kmeans_ims = [os.path.join(labelpath, f'kmeans/{file}') for file in os.listdir(os.path.join(labelpath, f'kmeans')) if file.endswith('.tif')]
    gmm_ims = [os.path.join(labelpath, f'gmm/{file}') for file in os.listdir(os.path.join(labelpath, f'gmm')) if file.endswith('.tif')]
    majority_ims = [os.path.join(labelpath, f'majority/{file}') for file in os.listdir(os.path.join(labelpath, f'majority')) if file.endswith('.tif')]

    
    otsu_ims = sorted(otsu_ims, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))
    kmeans_ims = sorted(kmeans_ims, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))
    gmm_ims = sorted(gmm_ims, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))
    majority_ims = sorted(majority_ims, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))

    return otsu_ims, kmeans_ims, gmm_ims, majority_ims

def get_grd(grdpath):
    orig_ims = [os.path.join(grdpath, file) for file in os.listdir(grdpath) if file.endswith('.tif')]
    orig_ims = sorted(orig_ims, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))

    return orig_ims

def get_glcm(glcmpath):
    orig_glcms = [os.path.join(glcmpath, file) for file in os.listdir(glcmpath) if file.endswith('.tif')]
    orig_glcms = sorted(orig_glcms, key=lambda x: datetime.strptime(x[-14:-4], '%Y-%m-%d'))

    return orig_glcms
def find_closest_dates(labels, backscatter_ims, glcm_ims, max_days=12):
    closest_dates = []  # To store the closest matches for each label

    # Iterate through each label
    for label in labels:
        label_date = datetime.strptime(label[-14:-4], '%Y-%m-%d')  # Extract date from label
        min_diff = max_days + 1  # Initialize minimum difference as larger than max_days
        closest_backscatter = None  # To store the closest backscatter match
        closest_glcm = None  # To store the closest GLCM match

        # Iterate through both backscatter and GLCM images
        for backscatter, glcm in zip(backscatter_ims, glcm_ims):
            backscatter_date = datetime.strptime(backscatter[-14:-4], '%Y-%m-%d')  # Extract date from backscatter
            glcm_date = datetime.strptime(glcm[-14:-4], '%Y-%m-%d')  # Extract date from GLCM

            # Calculate the absolute difference in days
            day_difference = abs((backscatter_date - label_date).days)

            # Check if the difference is within max_days and closer than the current minimum
            if day_difference <= max_days and day_difference < min_diff:
                min_diff = day_difference
                closest_backscatter = backscatter
                closest_glcm = glcm

        # Store the closest matches for the current label
        closest_dates.append((label, closest_backscatter, closest_glcm))

    return closest_dates

def stack_data(filtered_data, unlab_s1, unlab_glcm):
    X_train = []
    y_train = []
    X_unlabeled = []

    for set in filtered_data:
        with rasterio.open(set[2]) as glcm_src:
            VV_contrast = glcm_src.read(1).astype(np.float32)
            VV_asm = glcm_src.read(2).astype(np.float32)
            VV_diss = glcm_src.read(3).astype(np.float32)
            VV_idm = glcm_src.read(4).astype(np.float32)
            VV_corr = glcm_src.read(5).astype(np.float32)
            VV_var = glcm_src.read(6).astype(np.float32)
            VV_ent = glcm_src.read(7).astype(np.float32)
            VH_contrast = glcm_src.read(8).astype(np.float32)
            VH_asm = glcm_src.read(9).astype(np.float32)
            VH_diss = glcm_src.read(10).astype(np.float32)
            VH_idm = glcm_src.read(11).astype(np.float32)
            VH_corr = glcm_src.read(12).astype(np.float32)
            VH_var = glcm_src.read(13).astype(np.float32)
            VH_ent = glcm_src.read(14).astype(np.float32)


            VV_contrast = (VV_contrast - VV_contrast.min()) / (VV_contrast.max() - VV_contrast.min())
            VV_asm = (VV_asm- VV_asm.min()) / (VV_asm.max() - VV_asm.min())
            VV_diss = (VV_diss - VV_diss.min()) / (VV_diss.max() - VV_diss.min())
            VV_idm = (VV_idm - VV_idm.min()) / (VV_idm.max() - VV_idm.min())
            VV_corr = (VV_corr - VV_corr.min()) / (VV_corr.max() - VV_corr.min())
            VV_var = (VV_var - VV_var.min()) / (VV_var.max() - VV_var.min())
            VV_ent = (VV_ent - VV_ent.min()) / (VV_ent.max() - VV_ent.min())
            VH_contrast = (VH_contrast - VH_contrast.min()) / (VH_contrast.max() - VH_contrast.min())
            VH_asm = (VH_asm- VH_asm.min()) / (VH_asm.max() - VH_asm.min())
            VH_diss = (VH_diss - VH_diss.min()) / (VH_diss.max() - VH_diss.min())
            VH_idm = (VH_idm - VH_idm.min()) / (VH_idm.max() - VH_idm.min())
            VH_corr = (VH_corr - VH_corr.min()) / (VH_corr.max() - VH_corr.min())
            VH_var = (VH_var - VH_var.min()) / (VH_var.max() - VH_var.min())
            VH_ent = (VH_ent - VH_ent.min()) / (VH_ent.max() - VH_ent.min())
    
        with rasterio.open(set[1]) as src:
            vv = src.read(1).astype(np.float32)
            vh = src.read(2).astype(np.float32)
            rvi = src.read(3).astype(np.float32)
            sdwi = src.read(4).astype(np.float32)

            # Convert from dB to linear scale
            vv_linear = 10 ** (vv / 10)
            vh_linear = 10 ** (vh / 10)

            vv_lin_norm = (vv_linear - vv_linear.min()) / (vv_linear.max() - vv_linear.min())
            vh_lin_norm = (vh_linear - vh_linear.min()) / (vh_linear.max() - vh_linear.min())
            rvi_norm = (rvi - rvi.min()) / (rvi.max() - rvi.min())
            sdwi_norm = (sdwi - sdwi.min()) / (sdwi.max() - sdwi.min())

        s1_data = np.stack([vv_lin_norm, vh_lin_norm, rvi_norm, sdwi_norm, VV_contrast, VV_asm,VV_diss, VV_idm, VV_corr, VV_var, VV_ent, VH_contrast, VH_asm,VH_diss, VH_idm, VH_corr, VH_var, VH_ent], axis=-1)

        with rasterio.open(set[0]) as src:
            s2_labels = src.read(1).astype(np.int32)
    
        X_train.append(s1_data)
        y_train.append(s2_labels)


    X_unlabeled = []

    for i, im in enumerate(unlab_s1):
        with rasterio.open(im) as src:
            vv = src.read(1).astype(np.float32)
            vh = src.read(2).astype(np.float32)
            rvi = src.read(3).astype(np.float32)
            sdwi = src.read(4).astype(np.float32)

            # Convert from dB to linear scale
            vv_linear = 10 ** (vv / 10)
            vh_linear = 10 ** (vh / 10)

            vv_lin_norm = (vv_linear - vv_linear.min()) / (vv_linear.max() - vv_linear.min())
            vh_lin_norm = (vh_linear - vh_linear.min()) / (vh_linear.max() - vh_linear.min())
            rvi_norm = (rvi - rvi.min()) / (rvi.max() - rvi.min())
            sdwi_norm = (sdwi - sdwi.min()) / (sdwi.max() - sdwi.min())

        with rasterio.open(unlab_glcm[i]) as glcm_src:
            VV_contrast = glcm_src.read(1).astype(np.float32)
            VV_asm = glcm_src.read(2).astype(np.float32)
            VV_diss = glcm_src.read(3).astype(np.float32)
            VV_idm = glcm_src.read(4).astype(np.float32)
            VV_corr = glcm_src.read(5).astype(np.float32)
            VV_var = glcm_src.read(6).astype(np.float32)
            VV_ent = glcm_src.read(7).astype(np.float32)
            VH_contrast = glcm_src.read(8).astype(np.float32)
            VH_asm = glcm_src.read(9).astype(np.float32)
            VH_diss = glcm_src.read(10).astype(np.float32)
            VH_idm = glcm_src.read(11).astype(np.float32)
            VH_corr = glcm_src.read(12).astype(np.float32)
            VH_var = glcm_src.read(13).astype(np.float32)
            VH_ent = glcm_src.read(14).astype(np.float32)


            VV_contrast = (VV_contrast - VV_contrast.min()) / (VV_contrast.max() - VV_contrast.min())
            VV_asm = (VV_asm- VV_asm.min()) / (VV_asm.max() - VV_asm.min())
            VV_diss = (VV_diss - VV_diss.min()) / (VV_diss.max() - VV_diss.min())
            VV_idm = (VV_idm - VV_idm.min()) / (VV_idm.max() - VV_idm.min())
            VV_corr = (VV_corr - VV_corr.min()) / (VV_corr.max() - VV_corr.min())
            VV_var = (VV_var - VV_var.min()) / (VV_var.max() - VV_var.min())
            VV_ent = (VV_ent - VV_ent.min()) / (VV_ent.max() - VV_ent.min())
            VH_contrast = (VH_contrast - VH_contrast.min()) / (VH_contrast.max() - VH_contrast.min())
            VH_asm = (VH_asm- VH_asm.min()) / (VH_asm.max() - VH_asm.min())
            VH_diss = (VH_diss - VH_diss.min()) / (VH_diss.max() - VH_diss.min())
            VH_idm = (VH_idm - VH_idm.min()) / (VH_idm.max() - VH_idm.min())
            VH_corr = (VH_corr - VH_corr.min()) / (VH_corr.max() - VH_corr.min())
            VH_var = (VH_var - VH_var.min()) / (VH_var.max() - VH_var.min())
            VH_ent = (VH_ent - VH_ent.min()) / (VH_ent.max() - VH_ent.min())

        s1_unlab_data = np.stack([vv_lin_norm, vh_lin_norm, rvi_norm, sdwi_norm, VV_contrast, VV_asm,VV_diss, VV_idm, VV_corr, VV_var, VV_ent, VH_contrast, VH_asm,VH_diss, VH_idm, VH_corr, VH_var, VH_ent], axis=-1)

        X_unlabeled.append(s1_unlab_data)


    return X_train, y_train, X_unlabeled

# Collect Imagery for model training

In [None]:
###################### WSL #########################
# labels = get_labels('/mnt/d/SabineRS/s2classifications')
# backscatter_ims = get_grd('/mnt/d/SabineRS/GRD/3_ratio')
# glcm_ims = get_glcm('/mnt/d/SabineRS/GRD/2_registered/glcm')

###################### Linux #########################
otsu_ims, kmeans_ims, gmm_ims, majority_ims = get_labels('/home/wcc/Desktop/SabineRS/MSI/s2classifications')
backscatter_ims = get_grd('/home/wcc/Desktop/SabineRS/GRD/3_ratio')
glcm_ims = get_glcm('/home/wcc/Desktop/SabineRS/GRD/2_registered/glcm')

In [None]:
# pair the Sentinel-1 backscatter and glcm  with labels according to date
labeledPairs = find_closest_dates(majority_ims, backscatter_ims, glcm_ims)

# Filter out tuples that contain any None entries
# no close matches between S2 labels and S1 images
filtered_data = [entry for entry in labeledPairs if None not in entry]

In [None]:
s1matches = [set[1] for set in filtered_data]
glcmmatches = [set[2] for set in filtered_data]
s1_X = [i for i in backscatter_ims if i not in s1matches]   # unlabeled S1 data for model training
glcm_X = [i for i in glcm_ims if i not in glcmmatches]

# Data prep

In [None]:
X_train, y_train, X_unlabeled = stack_data(filtered_data, s1_X, glcm_X)

In [None]:
num_channels = 18 # 4 bands from GRD and 14 from GLCM 
num_classes = 3
img_height, img_width = X_train[0].shape[:2]  # Assuming all images have the same dimensions


X_train_split, X_val_split, y_train_split, y_val_split = train_test_split(
    X_train, y_train, test_size=0.2, random_state=42
)

# Custom NN
- name could be BUNet -- Beneficial Use Network given the focus on sediment enrichment in wetlands

In [None]:
# Improved model with batch normalization, dropout, and skip connections
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(798, 693, 18)),
    BatchNormalization(),
    Conv2D(32, (3, 3), activation='relu', padding='same'),
    BatchNormalization(),
    Dropout(0.3),

    Conv2D(64, (3, 3), activation='relu', padding='same'),
    BatchNormalization(),
    Conv2D(64, (3, 3), activation='relu', padding='same'),
    BatchNormalization(),
    Dropout(0.3),

    Conv2D(128, (3, 3), activation='relu', padding='same'),
    BatchNormalization(),
    Conv2D(128, (3, 3), activation='relu', padding='same'),
    BatchNormalization(),
    Dropout(0.4),

    Conv2D(64, (3, 3), activation='relu', padding='same'),
    BatchNormalization(),

    Conv2D(3, (1, 1), activation='softmax', padding='same')  # Output layer for 3 classes
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()

In [None]:
# Train the model
hist = model.fit(
    np.array(X_train_split),
    np.array(y_train_split),
    epochs=20,
    batch_size=5,
    validation_split=0.2,
    callbacks = [tf.keras.callbacks.TensorBoard(log_dir = '/home/wcc/Desktop/SabineRS/modellogs')]
)

In [None]:
fig = plt.figure()
plt.plot(hist.history['loss'], color = 'red', label = 'loss')
plt.plot(hist.history['val_loss'], color = 'blue', label = 'val_loss')
plt.suptitle('CNN Training Loss', fontsize = 20)
plt.legend(loc='upper left')
plt.show()

In [None]:
fig = plt.figure()
plt.plot(hist.history['accuracy'], color = 'red', label = 'accuracy')
plt.plot(hist.history['val_accuracy'], color = 'blue', label = 'val_accuracy')
plt.suptitle('CNN Training accuracy', fontsize = 20)
plt.legend(loc='upper left')
plt.show()

# MobileNet
- transfer learning from MobileNet

# EfficientNet
- transfer learngin from EfficientNet

# Check accuracy, precision, recall, F1 scores
- compare shallow NN with ResNet and UNet results, prove mine is better?

In [None]:
# Make predictions on the validation data
y_pred = model.predict(np.array(X_val_split))  # shape: (num_samples, height, width, num_classes)

# Convert predictions to class labels by taking the argmax along the class dimension
y_pred_labels = np.argmax(y_pred, axis=-1)  # shape: (num_samples, height, width)

# Flatten the arrays for metric calculations
y_val_flat = np.array(y_val_split).flatten()  # True labels
y_pred_flat = y_pred_labels.flatten()  # Predicted labels

# Calculate metrics for each class
print("Accuracy:", accuracy_score(y_val_flat, y_pred_flat))
print("Precision, Recall, F1 Score per class:\n", classification_report(y_val_flat, y_pred_flat, target_names=['open water', 'subaqueous land', 'subaerial land']))

In [None]:
# Select a sample index to display (for example, the first sample in validation data)
i = -1
true_labels = y_val_split[i]          # True labels for the sample
predicted_labels = y_pred_labels[i]  # Predicted labels for the sample

# Plotting
plt.figure(figsize=(12, 4))

# Display the ground truth
plt.subplot(1, 2, 1)
plt.imshow(true_labels, cmap='viridis')
plt.title('Ground Truth')
plt.axis('off')

# Display the model predictions
plt.subplot(1, 2, 2)
plt.imshow(predicted_labels, cmap='viridis')
plt.title('Predicted Labels')
plt.axis('off')

plt.show()

# Apply Morphological Filters

In [None]:
# morphological operators if needed

cleaned_ims = {"otsu": [],
               "kmeans": [], 
               "gmm": []
               }

for i, (method, entry) in enumerate(zip(classification_methods, [relabeled_images['otsu'], relabeled_images["kmeans"], relabeled_images['gmm']])):
    for j, im in enumerate(entry):
        # Define a square kernel; adjust the size as needed
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))

        # apply morphological functions to eliminate isolated pixels from each class
        subaqueous = (im == 0).astype(np.uint8)
        subaerial = (im == 1).astype(np.uint8)

        ######## KMeans
        # Apply opening to remove small isolated pixels
        subaerial_cleaned = cv2.morphologyEx(subaerial, cv2.MORPH_OPEN, kernel)
        subaqueous_cleaned = cv2.morphologyEx(subaqueous, cv2.MORPH_OPEN, kernel)

        # Apply closing to fill small holes
        subaerial_cleaned = cv2.morphologyEx(subaerial_cleaned, cv2.MORPH_CLOSE, kernel)
        subaqueous_cleaned = cv2.morphologyEx(subaqueous_cleaned, cv2.MORPH_CLOSE, kernel)

        # Reconstruct the classified image
        cleaned_classified_image = (subaqueous_cleaned * i +
                                    subaqueous_cleaned * 1)      

        # Add the processed relabeled image to the dictionary
        cleaned_ims[method].append(cleaned_classified_image)

# Check resulting classes again

# Ground truthing
- get water extent maps from various sources to serve as ground truth data for confirming the classification results

1. https://global-surface-water.appspot.com/download
2. USGS LandCover
3. Copernicus Water and Wetness Product?
4. Chesapeake Conservancy High-Resolution Land Cover Dataset
5. RAMSAR Wetlands Sites
6. MODIS Land Cover Type Product (MCD12Q1)
7. Sentinel-2 Labeled Datasets for Wetland Classification
8. OSM