In [None]:
#import mesma 
import rasterio
import matplotlib.pyplot as plt
import os
import numpy as np
import sys
import cupy as cp
import random
from cupyx.scipy.ndimage import binary_dilation, convolve
import time




In [None]:
def calculate_local_variance(image_gpu, window_size):

    mean_kernel = cp.ones((window_size, window_size), dtype=cp.float32) / (window_size * window_size)
    local_mean_gpu = convolve(image_gpu.astype(cp.float32), mean_kernel, mode='constant', cval=0.0)

    squared_image_gpu = cp.square(image_gpu.astype(cp.float32))
    mean_squared_gpu = convolve(squared_image_gpu, mean_kernel, mode='constant', cval=0.0)
    local_variance_gpu = mean_squared_gpu - cp.square(local_mean_gpu)
    
    return local_variance_gpu

In [None]:


path = r'/mnt/c/Users/attic/HLS_Kelp/imagery/rf_classified_S30/HLS.S30.T11SKU.2022006T184749.v2.0_kelp_classified.tif'
ocean_dilation = cp.ones((100, 100))  # Struct for dilation (increase to enlarge non-ocean mask) larger --> takes longer

kelp_neighborhood = 5
min_kelp_count = 4
kelp_dilation_size = 15
num_EM = 30

variance_window_size = 15
variance_threshold = .95

structuring_element = cp.ones((kelp_dilation_size, kelp_dilation_size))

time_st = time.time()

with rasterio.open(path) as imagery:
    classified_img = imagery.read(5)
    
    # Transfer data to GPU
    classified_img_gpu = cp.array(classified_img)

    kelp_mask = []
    ocean_mask = []

    time_val = time.time()
    land_dilated_gpu = cp.where(classified_img_gpu == 3, True, False)
    clouds_dilated_gpu = cp.where(classified_img_gpu == 2, True, False)
    land_dilated_gpu = binary_dilation(land_dilated_gpu, structure=ocean_dilation)
    print(f'land finished: {time.time()-time_val}')

    ocean_dilated_gpu = land_dilated_gpu | clouds_dilated_gpu 

    def count_true(window):
        return cp.sum(window)
    
    kelp_dilated_gpu = cp.where(classified_img_gpu == 0, True, False)  # This is expanding the kelp_mask so the TF is reversed
    kernel = cp.ones((kelp_neighborhood, kelp_neighborhood), dtype=cp.int32)
    
    time_val = time.time()
    kelp_count_gpu = convolve(kelp_dilated_gpu.astype(cp.int32), kernel, mode='constant', cval=0.0)
    print(f'kelp moving average finished: {time.time()-time_val}')

    kelp_dilated_gpu = cp.where(((~kelp_dilated_gpu) | (kelp_count_gpu <= min_kelp_count)), 0, 1)  # If there's no kelp, or the kelp count is <=4, set pixel == false
    time_val = time.time()
    kelp_dilated_gpu = binary_dilation(kelp_dilated_gpu, structure=structuring_element)  # I may not want to do this. we'll see
    print(f'kelp dilation finished: {time.time()-time_val}')
    time_val = time.time()
    for i in range(4):
        band_data = imagery.read(i + 1)
        band_data_gpu = cp.array(band_data)
        
        kmask_gpu = cp.where(kelp_dilated_gpu == 1, band_data_gpu, cp.nan)
        omask_gpu = cp.where((ocean_dilated_gpu == False), band_data_gpu, cp.nan)


        local_variance_gpu = calculate_local_variance(band_data_gpu, variance_window_size)
        max_local_variance = cp.percentile(local_variance_gpu, 100 * variance_threshold)  # threshold variance
        
        # Mask pixels with high variance
        variance_mask_gpu = cp.where(local_variance_gpu > max_local_variance, cp.nan, band_data_gpu)
        
        # Apply the variance mask to the ocean mask
        final_omask_gpu = cp.where((ocean_dilated_gpu == True) | cp.isnan(variance_mask_gpu) , cp.nan, band_data_gpu)
        

        kmask = cp.asnumpy(kmask_gpu)
        omask = cp.asnumpy(final_omask_gpu)
        
        kelp_mask.append(kmask)
        ocean_mask.append(omask)
    print(f'kBand masking and variance masking complete: {time.time()-time_val}')
    kelp_mask = np.array(kelp_mask)
    ocean_mask = np.array(ocean_mask)

    rgb_nor = np.stack([ocean_mask[2] / 600, ocean_mask[0] / 600, ocean_mask[1] / 600], axis=-1)
    rgb_nor_cropped = rgb_nor
    rgb_nor_cropped = np.ma.masked_where(np.isnan(rgb_nor_cropped), rgb_nor_cropped)
    
    image = kelp_mask[1]
    plt.figure(figsize=(30, 30), dpi=200)
    plt.imshow(image, alpha=1)
    plt.imshow(rgb_nor_cropped, alpha=1)
    plt.colorbar()
    plt.show()

In [None]:
ocean_EM_stack = []
kelp_EM = [459, 556, 437, 1227]
n_bands, height, width = kelp_mask.shape
ocean_EM_n = 0
ocean_data = ocean_mask.reshape(ocean_mask.shape[0], -1)
kelp_data = kelp_mask.reshape(kelp_mask.shape[0],-1)

nan_columns = np.isnan(ocean_data).any(axis=0)
# Remove columns with NaN values
filtered_ocean = ocean_data[:, ~nan_columns]
print(filtered_ocean.shape)
if(len(filtered_ocean[0,:]) < 1000):
     print("Insufficient number of ocean pixels")
     sys.exit()
for i in range(num_EM):
     index = random.randint(0,len(filtered_ocean[0])-1)
     ocean_EM_stack.append(filtered_ocean[:,index])
ocean_EM = np.stack(ocean_EM_stack, axis=1)
#print(ocean_EM_array)


average_val = np.nanmean(filtered_ocean, axis=1)
average_endmember = np.nanmean(ocean_EM, axis=1)
print(f"average EM Val: {average_endmember}")
print(f"average    Val: {average_val}")

kelp_mask = cp.asarray(kelp_mask)
ocean_EM = cp.asarray(ocean_EM)
kelp_EM = cp.asarray(kelp_EM)
kelp_data = cp.asarray(kelp_data)

frac1 = cp.full((kelp_mask.shape[1], kelp_mask.shape[2], num_EM), cp.nan)
frac2 = cp.full((kelp_mask.shape[1], kelp_mask.shape[2], num_EM), cp.nan)
rmse = cp.full((kelp_mask.shape[1], kelp_mask.shape[2], num_EM), cp.nan)
print(rmse.shape)

#print("Running MESMA")
for k in range(num_EM):
    B = cp.column_stack((ocean_EM[:, k], kelp_EM))
    U, S, Vt = cp.linalg.svd(B, full_matrices=False)
    IS = Vt.T / S
    em_inv = IS @ U.T
    F = em_inv @ kelp_data
    model = (F.T @ B.T).T
    resids = (kelp_data - model) / 10000
    rmse[:, :, k] = cp.sqrt(cp.mean(resids**2, axis=0)).reshape(height, width)
    frac1[:, :, k] = F[0, :].reshape(height, width)
    frac2[:, :, k] = F[1, :].reshape(height, width)
    #print(f'Percent MESMA {round(100/30 * (k+1))}%')

minVals = cp.nanmin(rmse, axis=2)
PageIdx = cp.nanargmin(rmse, axis=2)

rows, cols = cp.meshgrid(cp.arange(rmse.shape[0]), cp.arange(rmse.shape[1]), indexing='ij')

Zindex = cp.ravel_multi_index((rows, cols, PageIdx), dims=rmse.shape)

Mes2 = frac2.ravel()[Zindex]
Mes2 = Mes2.T
Mes2 = -0.229 * Mes2**2 + 1.449 * Mes2 - 0.018 #Landsat mesma corrections 
Mes2 = cp.clip(Mes2, 0, None)  # Ensure no negative values
Mes2 = cp.round(Mes2 * 100).astype(cp.int16)

In [None]:
#Mes2 = Mes2.astype(cp.float32)
#Mes2 = Mes2.where(Mes2 == 0, cp.nan)
kelp_img = cp.asnumpy(kelp_mask)
Mes_array = cp.asnumpy(Mes2).T
Mes_vis = np.where(Mes_array ==0, np.nan, Mes_array)
kelp_vis = np.where(kelp_img ==0 ,np.nan, kelp_img)
plt.figure(figsize=(20, 20), dpi=200)
#plt.imshow(band_data[2800:3100,800:1400], alpha=.5)
plt.imshow(kelp_vis[1,2800:3100,800:1400] , cmap='Greys')
plt.imshow(Mes_vis[2800:3100,800:1400], alpha=1)
plt.colorbar()
plt.show()