In [None]:
import os
import re
import kelp_tools as kt
import time 
import rasterio
import cupy as cp
import cudf
import numpy as np
from rasterio.errors import RasterioIOError
import matplotlib.pyplot as plt
from cuml.ensemble import RandomForestClassifier as cuRF
from cuml.model_selection import train_test_split
from scipy.stats import randint
import pickle
from cupyx.scipy.ndimage import binary_dilation, convolve

In [None]:

reclassify = True #Reclassify previously classified images
show_image = True

#classified_path = r'/mnt/c/users/attic/hls_kelp/imagery/rf_classified_cuML'
save_final_data = True
save_to_path = '/mnt/h/HLS_data/imagery/Isla_vista_kelp/processed'
tile = '11SKU'
location = 'Isla_Vista_Kelp'
cloud_cover_threshold = .7
#save_mask = True
#save_classification = True
#remask = False
path = os.path.join(r'/mnt/h/HLS_data/imagery',location,tile)
dem_path = r'/mnt/c/Users/attic/HLS_Kelp/imagery/Socal_DEM.tiff'
rf_path = r'/mnt/c/users/attic/hls_kelp/random_forest/cu_rf8'
num_iterations = 1000

#unclassified_path = r'/mnt/c/users/attic/hls_kelp/imagery/rf_prepped_v2'
#unclassified_files = os.listdir(unclassified_path)

In [None]:
granules = os.listdir(path)

#Choose a file to create a DEM map with 
for item in granules:
    if os.path.isdir(os.path.join(path,item)):
        hls_path = os.path.join(path,item)
        granule = item
        break
    else:
        continue
pattern = re.compile(rf'.tif$')
files = os.listdir(os.path.join(path,item))
img_files = [f for f in files if re.search(pattern, f)]
geotiff_path = os.path.join(hls_path, img_files[0])

land_mask = kt.create_land_mask(geotiff_path,dem_path)
del dem_path, files, img_files, pattern

with open(rf_path, 'rb') as f:
    cu_rf = pickle.load(f)

iterations = 0

if not os.path.isdir(save_to_path):
    os.mkdir(save_to_path)

#begin iteration
for item in granules:

    if not reclassify and os.path.isfile(os.path.join(save_to_path, tile, f'{item}_processed.tif')):
        print(f'{item} already processed. Skipping.')
        continue

    print(f'Starting {item}')
    if iterations > num_iterations:
        break

    start_time = time.time()
    # Define image path
    if os.path.isdir(os.path.join(path, item)):
        img_path = os.path.join(path, item)
    else:
        continue

    #print(f'Image path defined. Time taken: {time.time() - start_time:.2f} seconds')

    try:
        start_time = time.time()
        sorted_files = kt.filter_and_sort_files(img_path, item)
        #print(f'Files sorted. Time taken: {time.time() - start_time:.2f} seconds')
    except:
        print(f"{item} failed to sort filenames, skipping")
        continue

    if len(sorted_files) != 6:
        print(f'Incomplete file download: {item}')
        continue
    
    start_time = time.time()
    cloud_land_mask, cloud_but_not_land_mask, percent_cloud_covered = kt.create_qa_mask(land_mask, img_path)
    if percent_cloud_covered >= cloud_cover_threshold:
        print(f'{item} Percent cloud covered: {percent_cloud_covered}')
        continue
    #print(f'QA mask created. Time taken: {time.time() - start_time:.2f} seconds')

    img_bands = []
    try:
        start_time = time.time()
        for file in sorted_files:
            with rasterio.open(os.path.join(img_path, file)) as src:
                img_bands.append(cp.where(cloud_land_mask, 0, cp.asarray(src.read(1))))  # Create image with the various bands
        #print(f'Bands read. Time taken: {time.time() - start_time:.2f} seconds')
    except RasterioIOError as e:
        print(f"Error reading file {file} in granule {item}: {e}")
        continue 

    start_time = time.time()
    img = cp.stack(img_bands, axis=0)

    #cleanup space
    del img_bands
    n_bands, height, width = img.shape
    img_2D_normalized = kt.normalize_img(img, flatten=True)

    img_data = cudf.DataFrame(img_2D_normalized)
    img_data = img_data.astype(np.float32)
    #print("starting prediction")
    kelp_pred = cu_rf.predict(img_data)
    classified_img = kelp_pred.values_host.reshape(width, height)
    classified_img = cp.where(cloud_but_not_land_mask, 2, cp.asarray(classified_img))
    #print(f'Image processing and prediction completed. Time taken: {time.time() - start_time:.2f} seconds')
    
    #periodic gpu cleanup
    del img_data, img_2D_normalized

    if show_image:
        print('Displaying images...')
        plt.figure(figsize=(25, 25)) 
        plt.subplot(2, 1, 1)  
        plt.imshow(cp.asnumpy(classified_img[2700:3400, 600:2000]))
        plt.colorbar()
        plt.title(file)
        r_nor = img[2, :, :].reshape((height, width))
        g_nor = img[1, :, :].reshape((height, width))
        b_nor = img[3, :, :].reshape((height, width))
        rgb_nor_gpu = cp.stack([r_nor, g_nor, b_nor], axis=-1) 
        rgb_nor = cp.asnumpy(rgb_nor_gpu)
        rgb_cropped = cp.asnumpy(rgb_nor[2700:3400, 600:2000])
        plt.subplot(2, 1, 2) 
        plt.imshow(rgb_cropped)
        plt.title("RGB Cropped Image")
        plt.show()
        #print('Images displayed.')
        del rgb_cropped, rgb_nor_gpu

    start_time = time.time()
    mesma_mask_params = [
        100,      # ocean_dilation_size
        5,        # kelp_neighborhood
        4,        # min_kelp_count
        15,       # kelp_dilation_size
        15,       # variance_window_size
        0.95      # variance_threshold
    ]
    kelp_mask, ocean_mask = kt.create_mesma_mask(classified_img, img, land_mask, cloud_but_not_land_mask, *mesma_mask_params)
    img = cp.asnumpy(img)
    #print(f'MESMA mask created. Time taken: {time.time() - start_time:.2f} seconds')
    
    classified_img_cpu = cp.asumpy(classified_img)
    del classified_img
    if show_image:
        print('Displaying MESMA masks...')
        mesma_mask_vis = cp.stack([ocean_mask[2] / 600, ocean_mask[0] / 600, ocean_mask[1] / 600], axis=-1)
        image = kelp_mask[1]
        plt.figure(figsize=(30, 30), dpi=200)
        plt.imshow(cp.asnumpy(image), alpha=1)
        plt.imshow(cp.asnumpy(mesma_mask_vis), alpha=1)
        plt.colorbar()
        plt.show()
        #print('MESMA masks displayed.')

    start_time = time.time()
    ocean_EM = kt.select_ocean_endmembers(ocean_mask, print_average=False)
    del ocean_mask, 
    if ocean_EM is None:
        continue
    Mes2, minVals = kt.run_mesma(kelp_mask, ocean_EM)
    Mes_array = cp.asnumpy(Mes2).T
    #print(f'MESMA run completed. Time taken: {time.time() - start_time:.2f} seconds')

    if show_image:
        print('Displaying final results...')
        kelp_img = cp.asnumpy(kelp_mask).astype(np.float32)
        Mes_array_vis = np.where(Mes_array == 0, np.nan, Mes_array)
        kelp_vis = np.where(kelp_img == 0, np.nan, kelp_img)
        plt.figure(figsize=(20, 20), dpi=200)
        plt.imshow(rgb_nor[1, 2800:3100, 800:1400], alpha=1)
        plt.imshow(kelp_img[1, 2800:3100, 800:1400], cmap='Greys', alpha=1)
        plt.imshow(Mes_array_vis[2800:3100, 800:1400], alpha=1)
        plt.colorbar()
        plt.show()
        #print('Final results displayed.')
    if save_final_data:
        num_bands = 6
        data_type = rasterio.int16
        profile = {
            'driver': 'GTiff',
            'width': width,
            'height': height,
            'count': 6,  # one band  B02, B03, B04, and B05, classified, mesma (Blue, Green, Red, and NIR).
            'dtype': data_type,  # assuming binary mask, adjust dtype if needed
            'crs': src.crs,
            'transform': src.transform,
            'nodata': 0  # assuming no data is 0
        }
        if not os.path.isdir(os.path.join(save_to_path,tile)):
            os.mkdir(os.path.join(save_to_path,tile))
        img_path = os.path.join(save_to_path,tile,f'{item}_processed.tif')

        # Write the land mask array to geotiff
        with rasterio.open(img_path, 'w', **profile) as dst:
                dst.write(img[0].astype(data_type), 1)
                dst.write(img[1].astype(data_type), 2)
                dst.write(img[2].astype(data_type), 3)
                dst.write(img[3].astype(data_type), 4)
                dst.write(classified_img_cpu.astype(data_type), 5)
                dst.write(cp.asnumpy(Mes_array.astype(data_type)), 6)
                
        iterations = iterations + 1
        print(f"File complete: {item} | Iteration: {iterations} | Time: {time.time()-start_time}")

    # Free GPU memory
    del kelp_pred, classified_img_cpu, kelp_mask, ocean_mask, Mes2, Mes_array, kelp_img, Mes_array_vis
    cp.get_default_memory_pool().free_all_blocks()
    cp.get_default_pinned_memory_pool().free_all_blocks()

  
  