In [None]:
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
import numpy as np
import matplotlib.pyplot as plt
import os
import re
#from scipy.ndimage import binary_dilation, uniform_filter
from rasterio.errors import RasterioIOError
import csv
from skimage import io
import requests
from PIL import Image
from io import BytesIO
import pandas as pd
import gc
import cudf
import cuml
from cuml.ensemble import RandomForestClassifier as cuRF
from cuml.model_selection import train_test_split
from scipy.stats import randint
from cupyx.scipy.ndimage import binary_dilation, convolve
import time

import sys
import pickle
import cupy as cp
import random

retrain = False
reclassify = False #Reclassify previously classified images
show_image = True

#classified_path = r'/mnt/c/users/attic/hls_kelp/imagery/rf_classified_cuML'
save_final_data = True
save_to_path = '/mnt/c/Users/attic/HLS_Kelp/imagery/Isla_vista_kelp_2018/processed'
tile = '11SKU'
location = 'Isla_Vista_Kelp_2018'
cloud_cover_threshold = .8
#save_mask = True
#save_classification = True
#remask = False
path = os.path.join(r'/mnt/c/Users/attic/HLS_Kelp/imagery',location,tile)
dem_path = r'/mnt/c/Users/attic/HLS_Kelp/imagery/Socal_DEM.tiff'
rf_path = r'/mnt/c/users/attic/hls_kelp/random_forest/cu_rf7_S30'
num_iterations = 1000
#unclassified_path = r'/mnt/c/users/attic/hls_kelp/imagery/rf_prepped_v2'
#unclassified_files = os.listdir(unclassified_path)


In [None]:
def load_granule_data(path):
    granules = os.listdir(path)
    for item in granules:
        if os.path.isdir(os.path.join(path, item)):
            img_path = os.path.join(path, item)
            granule = item
            break
    files = os.listdir(img_path)
    return img_path, granule, files


def get_sensor_bands(granule):
    file_data = granule.split('.')
    sensor = file_data[1]
    if sensor == 'L30':
        return ['B02', 'B03', 'B04', 'B05', 'B06', 'B07']
    else:
        return ['B02', 'B03', 'B04', 'B8A', 'B11', 'B12']
    
def load_rf_classifier(rf_path):
    with open(rf_path, 'rb') as f:
        cu_rf = pickle.load(f)
    return cu_rf

In [None]:
def valid_granule(img_files, sensor_bands, files, item):
    f_mask = [f for f in files if re.search(r'Fmask\.tif$', f)]
    if not f_mask or len(img_files) != len(sensor_bands):
        print(f"Incomplete or invalid granule: {item}")
        return False
    return True

In [None]:
def reproject_dem_to_hls(dem_path='/mnt/c/Users/attic/HLS_Kelp/imagery/Socal_DEM.tiff', hls_path):
    with rasterio.open(hls_path) as dst:
        hls = dst.read()
        dem = rasterio.open(dem_path)
        if dem.crs != dst.crs:
            reprojected_dem = np.zeros((hls.shape[1], hls.shape[2]), dtype=hls.dtype)
            reproject(
                source=dem.read(),
                destination=reprojected_dem,
                src_transform=dem.transform,
                src_crs=dem.crs,
                dst_transform=dst.transform,
                dst_crs=dst.crs,
                resampling=Resampling.bilinear)
    return reprojected_dem , hls

def generate_land_mask(reprojected_dem, land_dilation=5, show_image=False):
    if reprojected_dem.any():
        struct = cp.ones((land_dilation, land_dilation))
        reprojected_dem_gpu = cp.asarray(reprojected_dem)
        land_mask = binary_dilation(reprojected_dem_gpu > 0, structure=struct)
        land_mask = cp.asnumpy(land_mask)
        if show_image:
            plt.figure(figsize=(6, 6))
            plt.imshow(land_mask, cmap='gray')
            plt.show()
        return land_mask
    else:
        print("Something failed, you better go check...")
        sys.exit()

def create_land_mask(hls_path, dem_path='/mnt/c/Users/attic/HLS_Kelp/imagery/Socal_DEM.tiff'):
    reprojected_dem = reproject_dem_to_hls(hls_path, dem_path='/mnt/c/Users/attic/HLS_Kelp/imagery/Socal_DEM.tiff')
    return generate_land_mask(reprojected_dem)

In [None]:
def create_mesma_mask(
    classified_img,
    band_data,
    ocean_dilation_size=100,  # Struct size for dilation
    kelp_neighborhood=5,
    min_kelp_count=4,
    kelp_dilation_size=15,
    variance_window_size=15,
    variance_threshold=0.95
):
    ocean_dilation = cp.ones((ocean_dilation_size, ocean_dilation_size))
    structuring_element = cp.ones((kelp_dilation_size, kelp_dilation_size))
    
    time_st = time.time()
    classified_img_gpu = cp.array(classified_img)

    kelp_mask = []
    ocean_mask = []

    time_val = time.time()
    land_dilated_gpu = cp.where(classified_img_gpu == 3, True, False)
    clouds_dilated_gpu = cp.where(classified_img_gpu == 2, True, False)
    land_dilated_gpu = binary_dilation(land_dilated_gpu, structure=ocean_dilation)
    print(f'land finished: {time.time()-time_val}')

    ocean_dilated_gpu = land_dilated_gpu | clouds_dilated_gpu 

    def count_true(window):
        return cp.sum(window)

    kelp_dilated_gpu = cp.where(classified_img_gpu == 0, True, False)  # This is expanding the kelp_mask so the TF is reversed
    kernel = cp.ones((kelp_neighborhood, kelp_neighborhood), dtype=cp.int32)

    time_val = time.time()
    kelp_count_gpu = convolve(kelp_dilated_gpu.astype(cp.int32), kernel, mode='constant', cval=0.0)
    print(f'kelp moving average finished: {time.time()-time_val}')

    kelp_dilated_gpu = cp.where(((~kelp_dilated_gpu) | (kelp_count_gpu <= min_kelp_count)), 0, 1)  # If there's no kelp, or the kelp count is <=4, set pixel == false
    time_val = time.time()
    kelp_dilated_gpu = binary_dilation(kelp_dilated_gpu, structure=structuring_element)  # I may not want to do this. we'll see
    print(f'kelp dilation finished: {time.time()-time_val}')
    time_val = time.time()
    for i in range(4):
        band_data = img[i ]
        band_data_gpu = cp.array(band_data)
        
        kmask_gpu = cp.where(kelp_dilated_gpu == 1, band_data_gpu, cp.nan)

        local_variance_gpu = calculate_local_variance(band_data_gpu, variance_window_size)
        max_local_variance = cp.percentile(local_variance_gpu, 100 * variance_threshold)  # threshold variance
        # Mask pixels with high variance
        variance_mask_gpu = cp.where(local_variance_gpu > max_local_variance, cp.nan, band_data_gpu)
        final_omask_gpu = cp.where((ocean_dilated_gpu == True) | cp.isnan(variance_mask_gpu) , cp.nan, band_data_gpu)

        kmask = cp.asnumpy(kmask_gpu)
        omask = cp.asnumpy(final_omask_gpu)
        
        kelp_mask.append(kmask)
        ocean_mask.append(omask)
    print(f'kBand masking and variance masking complete: {time.time()-time_val}')
    kelp_mask = np.array(kelp_mask)
    ocean_mask = np.array(ocean_mask)
    return kelp_mask, ocean_mask


In [None]:
def select_ocean_endmembers(ocean_mask, print_average=False, n=30, min_pixels=500):
    ocean_EM_n = 0
    ocean_data = ocean_mask.reshape(ocean_mask.shape[0], -1)
    nan_columns = np.isnan(ocean_data).any(axis=0)  # Remove columns with nan 
   
    filtered_ocean = ocean_data[:, ~nan_columns]
    if len(filtered_ocean[0,:]) < min_pixels:
        print("Too few valid ocean end-members")
        print(f"Skipping img {file}...  :(")
        return False
    
    for i in range(n):
        index = random.randint(0,len(filtered_ocean[0])-1)
        ocean_EM_stack.append(filtered_ocean[:,index])
    ocean_EM = np.stack(ocean_EM_stack, axis=1)
    #print(ocean_EM_array)
    if(print_average):
        average_val = np.nanmean(filtered_ocean, axis=1)
        average_endmember = np.nanmean(ocean_EM, axis=1)
        print(f"average EM Val: {average_endmember}")
        print(f"average    Val: {average_val}")
    return ocean_EM

In [None]:
def run_mesma(kelp_mask, ocean_EM, n=30):
    frac1 = cp.full((kelp_mask.shape[1], kelp_mask.shape[2], n), cp.nan)
    frac2 = cp.full((kelp_mask.shape[1], kelp_mask.shape[2], n), cp.nan)
    rmse = cp.full((kelp_mask.shape[1], kelp_mask.shape[2], n), cp.nan)
    print(rmse.shape)

    print("Running MESMA")
    for k in range(n):
        B = cp.column_stack((ocean_EM[:, k], kelp_EM))
        U, S, Vt = cp.linalg.svd(B, full_matrices=False)
        IS = Vt.T / S
        em_inv = IS @ U.T
        F = em_inv @ kelp_data
        model = (F.T @ B.T).T
        resids = (kelp_data - model) / 10000
        rmse[:, :, k] = cp.sqrt(cp.mean(resids**2, axis=0)).reshape(height, width)
        frac1[:, :, k] = F[0, :].reshape(height, width)
        frac2[:, :, k] = F[1, :].reshape(height, width)

    minVals = cp.nanmin(rmse, axis=2)
    PageIdx = cp.nanargmin(rmse, axis=2)
    rows, cols = cp.meshgrid(cp.arange(rmse.shape[0]), cp.arange(rmse.shape[1]), indexing='ij')
    Zindex = cp.ravel_multi_index((rows, cols, PageIdx), dims=rmse.shape)
    Mes2 = frac2.ravel()[Zindex]
    Mes2 = Mes2.T
    Mes2 = -0.229 * Mes2**2 + 1.449 * Mes2 - 0.018 #Landsat mesma corrections 
    Mes2 = cp.clip(Mes2, 0, None)  # Ensure no negative values
    Mes2 = cp.round(Mes2 * 100).astype(cp.int16)
    return Mes2, minVals


In [None]:
   
def create_qa_mask(land_mask, files, ):
    f_mask = [f for f in files if re.search(r'Fmask\.tif$', f)]
    if not f_mask:
        print(f"Invalid granule: {item}")
        return False

##==========Fmask Cloud mask==========##
    #bitwise operations are weird. Far outside my comfort zone. Need to take CS33 first.........
    try:
        with rasterio.open(os.path.join(img_path,f_mask[0])) as fmask:
            qa_band = fmask.read(1)
        # qa_bit = (1 << 1) - 1
        # qa_cloud_mask = ((qa_band >> 1) & qa_bit) == 1  # Bit 1 for cloud
        # qa_adjacent_to_cloud_mask = ((qa_band >> 2) & qa_bit) == 1  # Bit 2 for cloud adjacent
        # qa_cloud_shadow = ((qa_band >> 3) & qa_bit) == 1 
        # qa_ice = ((qa_band >> 4) & qa_bit) == 1 
        # qa_aerosol = (((qa_band >> 6) & 1) == 1) & (((qa_band >> 7) & 1) == 1)
        # cloud_mask = qa_cloud_mask | qa_adjacent_to_cloud_mask | qa_cloud_shadow | qa_ice | qa_aerosol#Mask out Clouds and cloud-adjacent pixels 
        cloud_mask = ((qa_band >> 1) & 1) | ((qa_band >> 2) & 1) | ((qa_band >> 3) & 1) | ((qa_band >> 4) & 1) | (((qa_band >> 6) & 1) & ((qa_band >> 7) & 1))
        #cloud_mask_2D = cloud_mask.reshape(-1).T
    except RasterioIOError as e:
        print(f"Error reading file {file} in granule {item}: {e}")
        return False  # Skip to the next granule if a file cannot be read
    #may not be necessary to mask out the cloud-adjacent pixels 

##========== Determine percentage of ocean covered by clouds ==========##
    cloud_land_mask = cloud_mask | land_mask
    cloud_but_not_land_mask = cloud_mask & ~land_mask
    num_pixels_cloud_not_land = np.count_nonzero(cloud_but_not_land_mask)
    num_pixels_not_land = np.count_nonzero(~land_mask)
    percent_cloud_covered = num_pixels_cloud_not_land/num_pixels_not_land
    # print(f'{granule} Percent cloud covered: {percent_cloud_covered}')
    return cloud_land_mask, cloud_but_not_land_mask, percent_cloud_covered


In [None]:
def normalize_img(img, flatten=False):
    img_2D = img.reshape(img.shape[0], -1).T
    img_sum = img_2D.sum(axis=1)
    epsilon = 1e-10  
    img_2D_nor = np.divide(img_2D, img_sum[:, None] + epsilon, where=(img_sum[:, None] != 0))
    img_2D_nor = (img_2D_nor * 255).astype(np.uint8)
    if not flatten:
        img_data= img_2D_nor.reshape(img_2D_nor.shape[0], -1).T
        return img_data
   

In [None]:



### ==================== Create DEM mask ====================  ####
granules = os.listdir(path)
for item in granules:
    if os.path.isdir(os.path.join(path,item)):
        img_path = os.path.join(path,item)
        granule = item
        break
    else:
        continue
files = os.listdir(img_path)
file_data = granule.split('.')
sensor = file_data[1]
if(sensor == 'L30'):
    sensor_bands = ['B02','B03','B04','B05','B06','B07'] #2,3,4,5,6,7]
else:
    sensor_bands = ['B02','B03','B04','B08A','B11','B12']
    
pattern = re.compile(r'\.(' + '|'.join(sensor_bands) + r')\.tif$')
img_files = [f for f in files if re.search(pattern, f)]
geotiff_path = os.path.join(img_path, img_files[0])

with rasterio.open(geotiff_path) as dst:   
    hls = dst.read()

    dem = rasterio.open(dem_path)
    if (dem.crs != dst.crs):
        reprojected_dem = np.zeros((hls.shape[1], hls.shape[2]), dtype=hls.dtype)
        reproject(
            source=dem.read(),
            destination=reprojected_dem,
            src_transform=dem.transform,
            src_crs=dem.crs,
            dst_transform=dst.transform,
            dst_crs=dst.crs,
            resampling=Resampling.bilinear)
    hls_flat = np.squeeze(hls, axis=0)   

if reprojected_dem.any():
    struct = cp.ones((5,5))
    reprojected_dem_gpu = cp.asarray(reprojected_dem)
    land_mask = binary_dilation(reprojected_dem_gpu > 0, structure = struct)
    land_mask = cp.asnumpy(land_mask)
    # ocean_mask = binary_dilation(reprojected_dem < -60 , structure = struct)
    # full_mask = land_mask + ocean_mask
    
    if(show_image):
        plt.figure(figsize=(6, 6))
        plt.imshow(land_mask, cmap='gray')
        plt.show()    


else:
    print("Something failed, you better go check...")
    sys.exit()

### ==================== Load RF Classifier ================== ###

with open(rf_path, 'rb') as f:
    cu_rf = pickle.load(f)

def calculate_local_variance(image_gpu, window_size):

    mean_kernel = cp.ones((window_size, window_size), dtype=cp.float32) / (window_size * window_size)
    local_mean_gpu = convolve(image_gpu.astype(cp.float32), mean_kernel, mode='constant', cval=0.0)

    squared_image_gpu = cp.square(image_gpu.astype(cp.float32))
    mean_squared_gpu = convolve(squared_image_gpu, mean_kernel, mode='constant', cval=0.0)
    local_variance_gpu = mean_squared_gpu - cp.square(local_mean_gpu)
    
    return local_variance_gpu

### ====================  Begin processing each file  ====================  ####
iterations = 0
print(f"Granules to process: {len(granules)}")
print("Img process beginning...")
for item in granules:
    if os.path.isfile(os.path.join(save_to_path,tile,f'{item}_processed.tif')):
        print(f'{item} already processed. Skipping.')
        continue
    if iterations > num_iterations:
        break
    start_time = time.time()
 ##==========Select Granule and Get File Names==========##

# Check sensor type and define bands for L8 and S2
    if os.path.isdir(os.path.join(path,item)):
        img_path = os.path.join(path,item)
    else:
        continue
    files = os.listdir(img_path)
    file_data = item.split('.')
    sensor = file_data[1]
    if(sensor == 'L30'):
        sensor_bands = ['B02','B03','B04','B05','B06','B07'] #2,3,4,5,6,7]
    else:
        sensor_bands = ['B02','B03','B04','B8A','B11','B12']
    img_files = [f for f in files if re.search(pattern, f)]
    def get_band_index(filename):
        match = re.search(r'\.(B\d{2}|B8A)\.tif$', filename)
        if match:
            return sensor_bands.index(match.group(1))
        return -1
    pattern = re.compile(r'\.(' + '|'.join(sensor_bands) + r')\.tif$')
    try:
        img_files = sorted(img_files, key=get_band_index)
    except:
        print("failed to sort filenames")
        continue
    f_mask = [f for f in files if re.search(r'Fmask\.tif$', f)]
    if not f_mask:
        print(f"Invalid granule: {item}")
        continue
    if not len(img_files)  == 6:
        print(f"incomplete file download: {item}")
        continue

    img_bands = []
    metadata =[]
    metadata_file = [f for f in files if re.search(r'metadata\.csv$', f)]
    if metadata_file :
        with open(os.path.join(path,item, metadata_file[0]), mode='r') as file:
            csv_reader = csv.reader(file)
            keys = next(csv_reader)  
            values = next(csv_reader) 
        metadata = dict(zip(keys, values)) #Load metadata into dictionary 


##==========Fmask Cloud mask==========##
    #bitwise operations are weird. Far outside my comfort zone. Need to take CS33 first.........
    try:
        with rasterio.open(os.path.join(img_path,f_mask[0])) as fmask:
            qa_band = fmask.read(1)
        qa_bit = (1 << 1) - 1
        qa_cloud_mask = ((qa_band >> 1) & qa_bit) == 1  # Bit 1 for cloud
        qa_adjacent_to_cloud_mask = ((qa_band >> 2) & qa_bit) == 1  # Bit 2 for cloud adjacent
        qa_cloud_shadow = ((qa_band >> 3) & qa_bit) == 1 
        qa_ice = ((qa_band >> 4) & qa_bit) == 1 
        #qa_water = ((qa_band >> 5) & qa_bit) == 1
        qa_aerosol = (((qa_band >> 6) & 1) == 1) & (((qa_band >> 7) & 1) == 1)
        cloud_mask = qa_cloud_mask | qa_adjacent_to_cloud_mask | qa_cloud_shadow | qa_ice | qa_aerosol#Mask out Clouds and cloud-adjacent pixels 
        cloud_mask_2D = cloud_mask.reshape(-1).T
    except RasterioIOError as e:
        print(f"Error reading file {file} in granule {item}: {e}")
        continue  # Skip to the next granule if a file cannot be read
    #may not be necessary to mask out the cloud-adjacent pixels 

##========== Determine percentage of ocean covered by clouds ==========##
    cloud_land_mask = cloud_mask | land_mask
    cloud_but_not_land_mask = cloud_mask & ~land_mask
    num_pixels_cloud_not_land = np.count_nonzero(cloud_but_not_land_mask)
    num_pixels_not_land = np.count_nonzero(~land_mask)
    percent_cloud_covered = num_pixels_cloud_not_land/num_pixels_not_land
    if(percent_cloud_covered > cloud_cover_threshold):
        print(f'Granule {item} cloud cover greater than threshold. Skipping: {percent_cloud_covered}')
        continue
    print(f'{granule} Percent cloud covered: {percent_cloud_covered}')
    
 ##==========Create stacked np array, Apply landmask==========##
    try:
        for file in img_files:
            with rasterio.open(os.path.join(img_path, file)) as src:
                img_bands.append(np.where(cloud_land_mask, 0, src.read(1)))  # Create image with the various bands
    except RasterioIOError as e:
        print(f"Error reading file {file} in granule {item}: {e}")
        continue  # Skip to the next granule if a file cannot be read
    img = np.stack(img_bands, axis=0)
    n_bands, height, width = img.shape
    img_2D = img.reshape(img.shape[0], -1).T #classifier takes 2D array of band values for each pixel 


 ##========== Normalize multi-spectral data ==========##

    img_sum = img_2D.sum(axis=1)
    epsilon = 1e-10  
    img_2D_nor = np.divide(img_2D, img_sum[:, None] + epsilon, where=(img_sum[:, None] != 0))
    img_2D_nor = (img_2D_nor * 255).astype(np.uint8)

    #img_data= file_img.reshape(file_img.shape[0], -1).T
    print(img_2D_nor.shape)
    img_data = cudf.DataFrame(img_2D_nor)
    img_data = img_data.astype(np.float32)


    kelp_pred = cu_rf.predict(img_data)
    classified_img = kelp_pred.values_host.reshape(width,height)



# Kelp and ocean filtering coefficients 
    ocean_dilation = cp.ones((100, 100))  # Struct for dilation (increase to enlarge non-ocean mask) larger --> takes longer
    kelp_neighborhood = 5
    min_kelp_count = 4
    kelp_dilation_size = 15
    num_EM = 30
    variance_window_size = 15
    variance_threshold = .95
    structuring_element = cp.ones((kelp_dilation_size, kelp_dilation_size))

    time_st = time.time()
    classified_img_gpu = cp.array(classified_img)

    kelp_mask = []
    ocean_mask = []

    time_val = time.time()
    land_dilated_gpu = cp.where(classified_img_gpu == 3, True, False)
    clouds_dilated_gpu = cp.where(classified_img_gpu == 2, True, False)
    land_dilated_gpu = binary_dilation(land_dilated_gpu, structure=ocean_dilation)
    print(f'land finished: {time.time()-time_val}')

    ocean_dilated_gpu = land_dilated_gpu | clouds_dilated_gpu 

    def count_true(window):
        return cp.sum(window)

    kelp_dilated_gpu = cp.where(classified_img_gpu == 0, True, False)  # This is expanding the kelp_mask so the TF is reversed
    kernel = cp.ones((kelp_neighborhood, kelp_neighborhood), dtype=cp.int32)

    time_val = time.time()
    kelp_count_gpu = convolve(kelp_dilated_gpu.astype(cp.int32), kernel, mode='constant', cval=0.0)
    print(f'kelp moving average finished: {time.time()-time_val}')

    kelp_dilated_gpu = cp.where(((~kelp_dilated_gpu) | (kelp_count_gpu <= min_kelp_count)), 0, 1)  # If there's no kelp, or the kelp count is <=4, set pixel == false
    time_val = time.time()
    kelp_dilated_gpu = binary_dilation(kelp_dilated_gpu, structure=structuring_element)  # I may not want to do this. we'll see
    print(f'kelp dilation finished: {time.time()-time_val}')
    time_val = time.time()
    for i in range(4):
        band_data = img[i ]
        band_data_gpu = cp.array(band_data)
        
        kmask_gpu = cp.where(kelp_dilated_gpu == 1, band_data_gpu, cp.nan)
        omask_gpu = cp.where((ocean_dilated_gpu == False), band_data_gpu, cp.nan)


        local_variance_gpu = calculate_local_variance(band_data_gpu, variance_window_size)
        max_local_variance = cp.percentile(local_variance_gpu, 100 * variance_threshold)  # threshold variance
        
        # Mask pixels with high variance
        variance_mask_gpu = cp.where(local_variance_gpu > max_local_variance, cp.nan, band_data_gpu)
        
        # Apply the variance mask to the ocean mask
        final_omask_gpu = cp.where((ocean_dilated_gpu == True) | cp.isnan(variance_mask_gpu) , cp.nan, band_data_gpu)
        

        #kmask = cp.asnumpy(kmask_gpu)
        #omask = cp.asnumpy(final_omask_gpu)
        
        kelp_mask.append(kmask)
        ocean_mask.append(omask)
    print(f'kBand masking and variance masking complete: {time.time()-time_val}')
    kelp_mask = cp.array(kelp_mask)
    ocean_mask = cp.array(ocean_mask)


        #print(ocean_mask)

    rgb_nor = cp.stack([ocean_mask[2]/600,ocean_mask[0]/600,ocean_mask[1]/600], axis=-1)
    rgb_nor_cropped = rgb_nor
    #print(kelp_mask)
    rgb_nor_cropped = cp.ma.masked_where(np.isnan(rgb_nor_cropped), rgb_nor_cropped)
    rgb_cpu = cp.asnumpy(rgb_nor_cropped)
    image = kelp_mask[1]#,2500:3500,800:1800]
    
    del rgb_nor_cropped, rgb_nor

    plt.figure(figsize=(30, 30), dpi=200)
    plt.imshow(image, alpha=1)
    plt.imshow(rgb_cpu, alpha=1)
    plt.colorbar()
    plt.show()

    ocean_EM_stack = []
    kelp_EM = [459, 556, 437, 1227]

    n_bands, height, width = kelp_mask.shape
    ocean_EM_n = 0
    ocean_data = ocean_mask.reshape(ocean_mask.shape[0], -1)
    kelp_data = kelp_mask.reshape(kelp_mask.shape[0],-1)

    nan_columns = cp.isnan(ocean_data).any(axis=0)  # Remove columns with nan 
   
    filtered_ocean = ocean_data[:, ~nan_columns]
    # if len(filtered_ocean[0,:]) < 500:
    #     print("Too few valid ocean end-members")
    #     print(f"Skipping img {file}...  :(")
    #     continue
    for i in range(30):
        index = random.randint(0,len(filtered_ocean[0])-1)
        ocean_EM_stack.append(filtered_ocean[:,index])
    ocean_EM = cp.stack(ocean_EM_stack, axis=1)
    #print(ocean_EM_array)

    average_val = np.nanmean(filtered_ocean, axis=1)
    average_endmember = np.nanmean(ocean_EM, axis=1)
    print(f"average EM Val: {average_endmember}")
    print(f"average    Val: {average_val}")

    kelp_mask = cp.asarray(kelp_mask)
    ocean_EM = cp.asarray(ocean_EM)
    kelp_EM = cp.asarray(kelp_EM)
    kelp_data = cp.asarray(kelp_data)

    frac1 = cp.full((kelp_mask.shape[1], kelp_mask.shape[2], 30), cp.nan)
    frac2 = cp.full((kelp_mask.shape[1], kelp_mask.shape[2], 30), cp.nan)
    rmse = cp.full((kelp_mask.shape[1], kelp_mask.shape[2], 30), cp.nan)
    print(rmse.shape)

    print("Running MESMA")
    for k in range(30):
        B = cp.column_stack((ocean_EM[:, k], kelp_EM))
        U, S, Vt = cp.linalg.svd(B, full_matrices=False)
        IS = Vt.T / S
        em_inv = IS @ U.T
        F = em_inv @ kelp_data
        model = (F.T @ B.T).T
        resids = (kelp_data - model) / 10000
        rmse[:, :, k] = cp.sqrt(cp.mean(resids**2, axis=0)).reshape(height, width)
        frac1[:, :, k] = F[0, :].reshape(height, width)
        frac2[:, :, k] = F[1, :].reshape(height, width)


        data_type = rasterio.int16
        profile = {
            'driver': 'GTiff',
            'width': width,
            'height': height,
            'count': 5,  # one band  B02, B03, B04, and B05, classified (Blue, Green, Red, and NIR).
            'dtype': data_type,  # assuming binary mask, adjust dtype if needed
            'crs': src.crs,
            'transform': src.transform,
            'nodata': 0  # assuming no data is 0
        }



    minVals = cp.nanmin(rmse, axis=2)
    PageIdx = cp.nanargmin(rmse, axis=2)
    rows, cols = cp.meshgrid(cp.arange(rmse.shape[0]), cp.arange(rmse.shape[1]), indexing='ij')
    Zindex = cp.ravel_multi_index((rows, cols, PageIdx), dims=rmse.shape)
    Mes2 = frac2.ravel()[Zindex]
    Mes2 = Mes2.T
    Mes2 = -0.229 * Mes2**2 + 1.449 * Mes2 - 0.018 #Landsat mesma corrections 
    Mes2 = cp.clip(Mes2, 0, None)  # Ensure no negative values
    Mes2 = cp.round(Mes2 * 100).astype(cp.int16)


    #Mes2 = Mes2.astype(cp.float32)
    #Mes2 = Mes2.where(Mes2 == 0, cp.nan)
    Mes_array = cp.asnumpy(Mes2).T
    if show_image:
        kelp_img = cp.asnumpy(kelp_mask).astype(np.float32)
        Mes_array = cp.asnumpy(Mes2).T
        Mes_array_vis = np.where(Mes_array == 0, np.nan, Mes_array)
        kelp_vis = np.where(kelp_img == 0, np.nan, kelp_img)
        plt.figure(figsize=(20, 20), dpi=200)
        plt.imshow(rgb_nor[1,2800:3100,800:1400], alpha=1)
        plt.imshow(kelp_img[1,2800:3100,800:1400] , cmap='Greys', alpha=1)
        plt.imshow(Mes_array_vis[2800:3100,800:1400], alpha=1)
        plt.colorbar()
        plt.show()
    
    #  ##========== Save masked file ==========##

    if save_final_data:
        num_bands = 6
        data_type = rasterio.int16
        profile = {
            'driver': 'GTiff',
            'width': width,
            'height': height,
            'count': 6,  # one band  B02, B03, B04, and B05, classified (Blue, Green, Red, and NIR).
            'dtype': data_type,  # assuming binary mask, adjust dtype if needed
            'crs': src.crs,
            'transform': src.transform,
            'nodata': 0  # assuming no data is 0
        }
        if not os.path.isdir(os.path.join(save_to_path,tile)):
            os.mkdir(os.path.join(save_to_path,tile))
        img_path = os.path.join(save_to_path,tile,f'{item}_processed.tif')

        # Write the land mask array to GeoTIFF
        with rasterio.open(img_path, 'w', **profile) as dst:
                dst.write(img[0].astype(data_type), 1)
                dst.write(img[1].astype(data_type), 2)
                dst.write(img[2].astype(data_type), 3)
                dst.write(img[3].astype(data_type), 4)
                dst.write(classified_img.astype(data_type), 5)
                dst.write(Mes_array.astype(data_type), 6)
                
        iterations = iterations + 1
        print(f"Iteration: {iterations}")
    gc.collect()
    print(f"{item} Processing complete. Duration: {time.time()-start_time}")
