In [32]:
import os
from datetime import datetime
from itertools import product
import rasterio
from rasterio import windows
from shapely.geometry import box
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.ticker as mticker

# Functions

In [2]:
def get_labels(labelpath):
    combinedlabels = [os.path.join(labelpath, f'combined/{file}') for file in os.listdir(os.path.join(labelpath, f'combined')) if file.endswith('.vrt')]

    return combinedlabels

def get_grd(grdpath):
    orig_ims = [os.path.join(grdpath, file) for file in os.listdir(grdpath) if file.endswith('.tif')]

    return orig_ims

def get_glcm(glcmpath):
    orig_glcms = [os.path.join(glcmpath, file) for file in os.listdir(glcmpath) if file.endswith('.tif')]

    return orig_glcms

def find_closest_dates(labels, backscatter_ims, glcm_ims, max_days=12):
    closest_dates = []  # To store the closest matches for each label

    # Iterate through each label
    for label in labels:
        label_date = datetime.strptime(label[-14:-4], '%Y-%m-%d')  # Extract date from label
        min_diff = max_days + 1  # Initialize minimum difference as larger than max_days
        closest_backscatter = None  # To store the closest backscatter match
        closest_glcm = None  # To store the closest GLCM match

        # Iterate through both backscatter and GLCM images
        for backscatter, glcm in zip(backscatter_ims, glcm_ims):
            backscatter_date = datetime.strptime(backscatter[-14:-4], '%Y-%m-%d')  # Extract date from backscatter
            glcm_date = datetime.strptime(glcm[-14:-4], '%Y-%m-%d')  # Extract date from GLCM

            # Calculate the absolute difference in days
            day_difference = abs((backscatter_date - label_date).days)

            # Check if the difference is within max_days and closer than the current minimum
            if day_difference <= max_days and day_difference < min_diff:
                min_diff = day_difference
                closest_backscatter = backscatter
                closest_glcm = glcm

        # Store the closest matches for the current label
        closest_dates.append((label, closest_backscatter, closest_glcm))

    return closest_dates

def plot_unsup_labels(filtered_pairs):
    for image_path in filtered_pairs:
        with rasterio.open(image_path[0]) as src:
            kmeans_class = src.read(3)
            gmm_class = src.read(4)
            transform = src.transform
            height, width = kmeans_class.shape
            top_left = rasterio.transform.xy(transform, 0, 0, offset='center')
            bottom_right = rasterio.transform.xy(transform, height-1, width-1, offset='center')

        min_easting, max_northing = top_left
        max_easting, min_northing = bottom_right

        fig, ax = plt.subplots(1, 2, figsize=(12, 6))

        # Manual classification visualization
        ax[0].imshow(kmeans_class, cmap = 'coolwarm', extent=[min_easting, max_easting, min_northing, max_northing])
        ax[0].set_title(f'{image_path[0][-14:-4]} (KMeans)')
        ax[0].set_xlabel('Easting (meters)')
        ax[0].set_ylabel('Northing (meters)')
        ax[0].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks
        # Otsu classification visualization
        ax[1].imshow(gmm_class, cmap = 'coolwarm', extent=[min_easting, max_easting, min_northing, max_northing])
        ax[1].set_title(f'{image_path[0][-14:-4]} (GMM)')
        ax[1].set_xlabel('Easting (meters)')
        ax[1].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks

        # Show the plot with layout adjustments
        plt.tight_layout()
        plt.show()

def plot_sup_labels(filtered_pairs):
    for image_path in filtered_pairs:
        with rasterio.open(image_path[0]) as src:
            man_class = src.read(1)
            otsu_class = src.read(2)
            transform = src.transform
            height, width = man_class.shape
            top_left = rasterio.transform.xy(transform, 0, 0, offset='center')
            bottom_right = rasterio.transform.xy(transform, height-1, width-1, offset='center')

        min_easting, max_northing = top_left
        max_easting, min_northing = bottom_right

        fig, ax = plt.subplots(1, 2, figsize=(12, 6))

        # Manual classification visualization
        ax[0].imshow(man_class, cmap = 'coolwarm', extent=[min_easting, max_easting, min_northing, max_northing])
        ax[0].set_title(f'{image_path[0][-14:-4]} (Manual)')
        ax[0].set_xlabel('Easting (meters)')
        ax[0].set_ylabel('Northing (meters)')
        ax[0].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks
        # Otsu classification visualization
        ax[1].imshow(otsu_class, cmap = 'coolwarm', extent=[min_easting, max_easting, min_northing, max_northing])
        ax[1].set_title(f'{image_path[0][-14:-4]} (Otsu)')
        ax[1].set_xlabel('Easting (meters)')
        ax[1].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks

        # Show the plot with layout adjustments
        plt.tight_layout()
        plt.show()

def plot_class_with_sar(filtered_pairs):
    for image_path in filtered_pairs:
        with rasterio.open(image_path[0]) as label_src:
            manual = label_src.read(1)  
            otsu = label_src.read(2)
            kmeans = label_src.read(3)
            gmm = label_src.read(4)

        with rasterio.open(image_path[1]) as sar_src:
            vv = sar_src.read(1)  
            vh = sar_src.read(2)

            transform = sar_src.transform
            height, width = vh.shape[:2]
            top_left = rasterio.transform.xy(transform, 0, 0, offset='center')
            bottom_right = rasterio.transform.xy(transform, height-1, width-1, offset='center')

        # Extract easting and northing from the corners
        min_easting, max_northing = top_left
        max_easting, min_northing = bottom_right
    
        fig, ax = plt.subplots(1, 6, figsize=(50, 10))  # 5 subplots for RGB, manual, Otsu, K-Means, and GMM
    
        ax[0].imshow(vv,cmap = 'gray', extent=[min_easting, max_easting, min_northing, max_northing])
        ax[0].set_title(f'{image_path[0][-14:-4]} VV Backscatter')
        ax[0].set_xlabel('Easting (meters)')
        ax[0].set_ylabel('Northing (meters)')
        ax[0].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks
    
        ax[1].imshow(vh,cmap = 'gray', extent=[min_easting, max_easting, min_northing, max_northing])
        ax[1].set_title(f'{image_path[0][-14:-4]} VH Backscatter')
        ax[1].set_xlabel('Easting (meters)')
        ax[1].set_ylabel('Northing (meters)')
        ax[1].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks
    
        # Custom legends for classification plots
        red_patch = mpatches.Patch(color='red', label='Land')
        blue_patch = mpatches.Patch(color='blue', label='Water')
    
        # Manual classification visualization
        ax[2].imshow(manual, cmap = 'coolwarm', extent=[min_easting, max_easting, min_northing, max_northing])
        ax[2].set_title('Manual')
        ax[2].set_xlabel('Easting (meters)')
        ax[2].legend(handles=[red_patch, blue_patch], loc='lower right', title="Classification")
        ax[2].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks
    
        # Otsu classification visualization
        ax[3].imshow(otsu, cmap = 'coolwarm', extent=[min_easting, max_easting, min_northing, max_northing])
        ax[3].set_title('Otsu')
        ax[3].set_xlabel('Easting (meters)')
        ax[3].legend(handles=[red_patch, blue_patch], loc='lower right', title="Classification")
        ax[3].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks
    
        # K-Means classification visualization
        ax[4].imshow(kmeans, cmap = 'coolwarm', extent=[min_easting, max_easting, min_northing, max_northing])
        ax[4].set_title('KMeans')
        ax[4].set_xlabel('Easting (meters)')
        ax[4].legend(handles=[red_patch, blue_patch], loc='lower right', title="Classification")
        ax[4].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks
    
        # GMM classification visualization
        ax[5].imshow(gmm, cmap = 'coolwarm', extent=[min_easting, max_easting, min_northing, max_northing])
        ax[5].set_title('GMM')
        ax[5].set_xlabel('Easting (meters)')
        ax[5].legend(handles=[red_patch, blue_patch], loc='lower right', title="Classification")
        ax[5].xaxis.set_major_locator(mticker.MaxNLocator(5))  # Reduce x-axis ticks
    
        # Show the plot with layout adjustments
        plt.tight_layout()
        plt.show()

# Collect Imagery for model training

In [3]:
labels = get_labels('/mnt/d/SabineRS/s2classifications')
backscatter_ims = get_grd('/mnt/d/SabineRS/GRD/2_registered/backscatter')
glcm_ims = get_glcm('/mnt/d/SabineRS/GRD/2_registered/glcm')

In [4]:
# pair the Sentinel-1 backscatter and glcm  with labels according to date
labeledPairs = find_closest_dates(labels, backscatter_ims, glcm_ims)

# Filter out tuples that contain any None entries
# no close matches between S2 labels and S1 images
filtered_data = [entry for entry in labeledPairs if None not in entry]
filtered_data

[('/mnt/d/SabineRS/s2classifications/combined/labels_2019-09-06.vrt',
  '/mnt/d/SabineRS/GRD/2_registered/backscatter/s1_2019-09-07.tif',
  '/mnt/d/SabineRS/GRD/2_registered/glcm/s1_2019-09-07.tif'),
 ('/mnt/d/SabineRS/s2classifications/combined/labels_2019-11-15.vrt',
  '/mnt/d/SabineRS/GRD/2_registered/backscatter/s1_2019-11-18.tif',
  '/mnt/d/SabineRS/GRD/2_registered/glcm/s1_2019-11-18.tif'),
 ('/mnt/d/SabineRS/s2classifications/combined/labels_2020-01-24.vrt',
  '/mnt/d/SabineRS/GRD/2_registered/backscatter/s1_2020-01-29.tif',
  '/mnt/d/SabineRS/GRD/2_registered/glcm/s1_2020-01-29.tif'),
 ('/mnt/d/SabineRS/s2classifications/combined/labels_2020-09-30.vrt',
  '/mnt/d/SabineRS/GRD/2_registered/backscatter/s1_2020-09-25.tif',
  '/mnt/d/SabineRS/GRD/2_registered/glcm/s1_2020-09-25.tif'),
 ('/mnt/d/SabineRS/s2classifications/combined/labels_2020-10-10.vrt',
  '/mnt/d/SabineRS/GRD/2_registered/backscatter/s1_2020-10-07.tif',
  '/mnt/d/SabineRS/GRD/2_registered/glcm/s1_2020-10-07.tif'),


In [33]:
# plot_sup_labels(filtered_data)
# plot_unsup_labels(filtered_data)

# Data augmentation for training
- segmented training (2x2 grid each image, each individual cell from grids used as training data)
- rotation, flipping, affine
- multiple sites?

In [33]:
out_path = '/mnt/d'
output_filename = 'tile_{}-{}.tif'

In [35]:
def get_tiles(ds, width, height):
    nols, nrows = ds.meta['width'], ds.meta['height']
    offsets = product(range(0, nols, width), range(0, nrows, height))
    big_window = windows.Window(col_off=0, row_off=0, width=nols, height=nrows)
    for col_off, row_off in  offsets:
        window =windows.Window(col_off=col_off, row_off=row_off, width=width, height=height).intersection(big_window)
        transform = windows.transform(window, ds.transform)
        yield window, transform


with rasterio.open(filtered_data[0][0]) as inds:
    combined_classes = inds.read(1)
    height, width = combined_classes.shape
    tile_width, tile_height = width // 2, height //2


    meta = inds.meta.copy()

    for window, transform in get_tiles(inds, tile_width, tile_height):
        print(window)
        meta['transform'] = transform
        meta['width'], meta['height'] = window.width, window.height
        outpath = os.path.join(out_path,output_filename.format(int(window.col_off), int(window.row_off)))
        with rasterio.open(outpath, 'w', **meta) as outds:
            outds.write(inds.read(window=window))

Window(col_off=0, row_off=0, width=346, height=399)


RasterioIOError: Write failed. See previous exception for details.

In [24]:
with rasterio.open(filtered_data[0][1]) as testsrc:
    combined_classes = testsrc.read(1)
    h, w = combined_classes.shape
    print(h, w)

798 693


In [31]:
pairs = []
grids = {}

for i, im in enumerate(filtered_data):
    with rasterio.open(im[0]) as testsrc:
        combined_classes = testsrc.read()
        kmeans = testsrc.read(1)
        h, w = kmeans.shape
        classgridh, classgridw = h // 2, w // 2 
    
    with rasterio.open(im[1]) as testsrc:
        grd = testsrc.read()
        vv = testsrc.read(1)
        h, w = vv.shape
        grdgridh , grdgridw = h // 2, w // 2 

    with rasterio.open(im[2]) as testsrc:
        glcm = testsrc.read()
        band1 = testsrc.read(1)
        h, w = band1.shape
        glcmgridh , glcmgridw = h // 2, w // 2 

    for i in range(2):
        for j in range(2):
            classgrid = combined_classes[i * classgridh:(i + 1) * classgridh, j * classgridw:(j + 1) * classgridw, :]
            grdgrid = grd[i * grdgridh:(i + 1) * grdgridh, j * grdgridw:(j + 1) * grdgridw, :]
            glcmgrid = glcm[i * glcmgridh:(i + 1) * glcmgridh, j * glcmgridw:(j + 1) * glcmgridw, :]

            grids[f'{im[0][-14:-4]}'] = (classgrid, grdgrid, glcmgrid)
            pairs.append(grids)


pairs

[{'2019-09-06': (array([], shape=(0, 346, 693), dtype=uint8),
   array([], shape=(0, 346, 693), dtype=float32),
   array([], shape=(0, 346, 693), dtype=float32)),
  '2019-11-15': (array([], shape=(0, 346, 693), dtype=uint8),
   array([], shape=(0, 346, 693), dtype=float32),
   array([], shape=(0, 346, 693), dtype=float32)),
  '2020-01-24': (array([], shape=(0, 346, 693), dtype=uint8),
   array([], shape=(0, 346, 693), dtype=float32),
   array([], shape=(0, 346, 693), dtype=float32)),
  '2020-09-30': (array([], shape=(0, 346, 693), dtype=uint8),
   array([], shape=(0, 346, 693), dtype=float32),
   array([], shape=(0, 346, 693), dtype=float32)),
  '2020-10-10': (array([], shape=(0, 346, 693), dtype=uint8),
   array([], shape=(0, 346, 693), dtype=float32),
   array([], shape=(0, 346, 693), dtype=float32)),
  '2020-10-30': (array([], shape=(0, 346, 693), dtype=uint8),
   array([], shape=(0, 346, 693), dtype=float32),
   array([], shape=(0, 346, 693), dtype=float32)),
  '2020-11-04': (array

In [None]:
# Assuming image is a numpy array with shape (height, width, channels)
def split_image_into_grids(image, grid_size=(2, 2)):
    h, w, c = image.shape
    grid_h, grid_w = h // grid_size[0], w // grid_size[1]
    grids = []
    for i in range(grid_size[0]):
        for j in range(grid_size[1]):


# Example usage
grids = split_image_into_grids(image, grid_size=(2, 2))


In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Create an image data generator with augmentation
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.2,
    fill_mode='nearest'
)

# Example: apply augmentation to a single image (expand_dims adds a batch dimension)
augmented_image = datagen.flow(np.expand_dims(image, axis=0), batch_size=1)


# 5. Train NN

In [None]:


# Define the model
model = models.Sequential()

# Input layer (shape depends on the number of features: VV, VH, GLCM features)
model.add(layers.InputLayer(input_shape=(10,)))  # Assuming 10 features

# Hidden layers
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(32, activation='relu'))

# Output layer for binary classification (water vs land)
model.add(layers.Dense(1, activation='sigmoid'))

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(X_train, y_train, epochs=20, batch_size=32, validation_split=0.2)


# 6. Evaluate the accuracy of the NN