In [6]:
import os
import re
import kelp_tools as kt
import time 
import rasterio
import cupy as cp
import cudf
import numpy as np
from rasterio.errors import RasterioIOError
import matplotlib.pyplot as plt
from cuml.ensemble import RandomForestClassifier as cuRF
from cuml.model_selection import train_test_split
from scipy.stats import randint
import pickle
import csv
from cupyx.scipy.ndimage import binary_dilation, convolve
from IPython.display import clear_output
import filing_tools as ft

In [7]:
reclassify = True #Reclassify previously classified images
show_image = False
sleep = True
#classified_path = r'/mnt/c/users/attic/hls_kelp/imagery/rf_classified_cuML'
save_final_data = True
version = 4
save_to_path = rf'/mnt/h/HLS_data/imagery/Isla_vista_kelp/processed_v{version}'
endmember_path = rf'/mnt/c/Users/attic/HLS_Kelp/python_objects/EM_Dict_v{version}.pkl'
backup_frequency = 5
tile = '11SKU'
location = 'Isla_vista_kelp'
cloud_cover_threshold = .7
save_EMs = True
#save_mask = True
#save_classification = True
#remask = False
path = os.path.join(r'/mnt/h/HLS_data/imagery',location,tile)
dem_path = r'/mnt/c/Users/attic/HLS_Kelp/imagery/Socal_DEM.tiff'
rf_model = 'cu_rf9'
rf_path = os.path.join(r'/mnt/c/users/attic/hls_kelp/random_forest',rf_model)
num_iterations = 2000

#unclassified_path = r'/mnt/c/users/attic/hls_kelp/imagery/rf_prepped_v2'
#unclassified_files = os.listdir(unclassified_path)

In [8]:
all_files = os.listdir(path)
grouped_files = ft.group_by_date(all_files, max_days=2)
file_list = []
for group in grouped_files:
    files = group[1]
    for file in files:
        if file not in file_list:
            file_list.append(file)
print(len(file_list))

329


In [9]:
# Function to print current GPU memory usage
# def print_memory_usage(message=""):
#     mempool = cp.get_default_memory_pool()
#     pinned_mempool = cp.get_default_pinned_memory_pool()
#     print(f"{message} - Used memory: {mempool.used_bytes()} bytes")

for item in file_list:
    if os.path.isdir(os.path.join(path, item)):
        hls_path = os.path.join(path, item)
        granule = item
        break
    else:
        continue

pattern = re.compile(rf'.tif$')
files = os.listdir(os.path.join(path, item))
img_files = [f for f in files if re.search(pattern, f)]
geotiff_path = os.path.join(hls_path, img_files[0])

land_mask = kt.create_land_mask(geotiff_path, dem_path)
del dem_path, files, img_files, pattern

with open(rf_path, 'rb') as f:
    cu_rf = pickle.load(f)

iterations = 0

if not os.path.isdir(save_to_path):
    os.mkdir(save_to_path)


endmember_dict = {}
for item in file_list:
    start_time = time.time()
    #print_memory_usage("Before processing granule")

    if not reclassify and os.path.isfile(os.path.join(save_to_path, tile, f'{item}_processed.tif')):
        print(f'{item} already processed. Skipping.')
        continue

    print(f'Starting {item}')
    if iterations > num_iterations:
        break

    # Define image path
    if os.path.isdir(os.path.join(path, item)):
        img_path = os.path.join(path, item)
    else:
        continue
    
    try:
        sorted_files = kt.filter_and_sort_files(img_path, item)
    except:
        print(f"{item} failed to sort filenames, skipping")
        continue

    if len(sorted_files) != 6:
        print(f'Incomplete file download: {item}')
        continue

    metadata = kt.get_metadata(img_path)

    cloud_land_mask, cloud_but_not_land_mask, percent_cloud_covered = kt.create_qa_mask(land_mask, img_path)
    if percent_cloud_covered >= cloud_cover_threshold:
        print(f'{item} Percent cloud covered: {percent_cloud_covered}')
        del cloud_land_mask, cloud_but_not_land_mask
        continue

    img_bands = []
    try:
        for file in sorted_files:
            with rasterio.open(os.path.join(img_path, file)) as src:
                img_bands.append(cp.where(cloud_land_mask, 0, cp.asarray(src.read(1))))

        del cloud_land_mask
    except RasterioIOError as e:
        print(f"Error reading file {file} in granule {item}: {e}")
        continue 

    img = cp.stack(img_bands, axis=0)
    del img_bands
    n_bands, height, width = img.shape

    img_2D_normalized = kt.normalize_img(img, flatten=True)

    img_data = cudf.DataFrame(img_2D_normalized)
    img_data = img_data.astype(np.float32)
    kelp_pred = cu_rf.predict(img_data)

    del img_data, img_2D_normalized

    classified_img = kelp_pred.values_host.reshape(width, height)
    classified_img = cp.where(cloud_but_not_land_mask, 2, cp.asarray(classified_img))

    if show_image:
        plt.figure(figsize=(25, 25)) 
        plt.subplot(2, 1, 1)  
        plt.imshow(cp.asnumpy(classified_img[2700:3400, 600:2000]))
        plt.colorbar()
        plt.title(file)
        r_nor = img[2, :, :].reshape((height, width))
        g_nor = img[1, :, :].reshape((height, width))
        b_nor = img[3, :, :].reshape((height, width))
        rgb_nor_gpu = cp.stack([r_nor, g_nor, b_nor], axis=-1) 
        rgb_nor = cp.asnumpy(rgb_nor_gpu)
        rgb_cropped = cp.asnumpy(rgb_nor[2700:3400, 600:2000])
        plt.subplot(2, 1, 2) 
        plt.imshow(rgb_cropped)
        plt.title("RGB Cropped Image")
        plt.show()
        del rgb_cropped, rgb_nor_gpu, rgb_nor

    mesma_mask_params = [
        100,      # ocean_dilation_size
        6,        # kelp_neighborhood
        3,        # min_kelp_count
        15,       # kelp_dilation_size
        15,       # variance_window_size
        0.95      # variance_threshold
    ]
    kelp_mask, ocean_mask = kt.create_mesma_mask(classified_img, img, land_mask, cloud_but_not_land_mask, *mesma_mask_params)
    img = cp.asnumpy(img)
    classified_img_cpu = cp.asnumpy(classified_img)
    del cloud_but_not_land_mask, classified_img

    if show_image:
        mesma_mask_vis = cp.stack([ocean_mask[2] / 600, ocean_mask[0] / 600, ocean_mask[1] / 600], axis=-1)
        image = kelp_mask[1]
        plt.figure(figsize=(30, 30), dpi=200)
        plt.imshow(cp.asnumpy(image), alpha=1)
        plt.imshow(cp.asnumpy(mesma_mask_vis), alpha=1)
        plt.colorbar()
        plt.show()
        del mesma_mask_vis, image

    ocean_EM = kt.select_ocean_endmembers(ocean_mask, print_average=False)
    if ocean_EM is None:
        continue
    Mes2, minVals = kt.run_mesma(kelp_mask, ocean_EM)
    endmember_dict[item] = cp.asnumpy(ocean_EM)
    del ocean_EM, ocean_mask
    min_vals = cp.asnumpy(minVals).T
    del minVals
    Mes_array = cp.asnumpy(Mes2).T
    del Mes2

    if show_image:
        kelp_img = cp.asnumpy(kelp_mask).astype(np.float32)
        Mes_array_vis = np.where(Mes_array == 0, np.nan, Mes_array)
        kelp_vis = np.where(kelp_img == 0, np.nan, kelp_img)
        plt.figure(figsize=(20, 20), dpi=200)
        plt.imshow(kelp_img[1, 2800:3100, 800:1400], cmap='Greys', alpha=1)
        plt.imshow(Mes_array_vis[2800:3100, 800:1400], alpha=1)
        plt.colorbar()
        plt.show()
        del kelp_vis, kelp_mask, Mes_array_vis

    if save_final_data:
        num_bands = 6
        data_type = rasterio.int16
        profile = {
            'driver': 'GTiff',
            'width': width,
            'height': height,
            'count': 7,  # one band  B02, B03, B04, and B05, classified, mesma (Blue, Green, Red, and NIR).
            'dtype': data_type,  # assuming binary mask, adjust dtype if needed
            'crs': src.crs,
            'transform': src.transform,
            'nodata': 0,  # assuming no data is 0
            'tags': {'TIMESTAMP': metadata['SENSING_TIME'], 'CLOUD_COVERAGE': percent_cloud_covered, 'RF_MODEL': rf_model, 'VIS_LINK': metadata['data_vis_url']}
        }
        if not os.path.isdir(os.path.join(save_to_path, tile)):
            os.mkdir(os.path.join(save_to_path, tile))
        img_path = os.path.join(save_to_path, tile, f'{item}_processed.tif')

        # Write the land mask array to geotiff
        with rasterio.open(img_path, 'w', **profile) as dst:
            dst.write(img[0].astype(data_type), 1)
            dst.write(img[1].astype(data_type), 2)
            dst.write(img[2].astype(data_type), 3)
            dst.write(img[3].astype(data_type), 4)
            dst.write(classified_img_cpu.astype(data_type), 5)
            dst.write(Mes_array.astype(data_type), 6)
            dst.write(min_vals.astype(data_type), 7)
            dst.update_tags(TIMESTAMP=metadata['SENSING_TIME'], CLOUD_COVERAGE=percent_cloud_covered, RF_MODEL=rf_model, VIS_LINK=metadata['data_vis_url'])

        iterations += 1
        # 
        print(f"File complete: {item} | Iteration: {iterations} | Time: {time.time() - start_time}")

    del img, classified_img_cpu, Mes_array, min_vals
    cp.get_default_memory_pool().free_all_blocks()
    cp.get_default_pinned_memory_pool().free_all_blocks()

    #print_memory_usage("After processing granule")
    
    if save_EMs and iterations % backup_frequency == 0:
        if os.path.isfile(endmember_path):
            with open(endmember_path, 'rb') as f:
                endmember_log = pickle.load(f)
        else:
            endmember_log = []
        endmember_log.append(endmember_dict)
        
        with open(endmember_path, 'wb') as f:
            pickle.dump(endmember_log, f)
        endmember_dict = {}
        print(f'Endmembers Logged on iteration: {iterations}')


    if sleep and iterations % 20 == 0:
        print("Cooling down GPU...")
        time.sleep(240)
        clear_output(wait=True)  


Starting HLS.L30.T11SKU.2024099T183405.v2.0
(3660, 3660, 30)
Running MESMA
File complete: HLS.L30.T11SKU.2024099T183405.v2.0 | Iteration: 241 | Time: 55.932284355163574
Starting HLS.S30.T11SKU.2024101T183921.v2.0
(3660, 3660, 30)
Running MESMA
File complete: HLS.S30.T11SKU.2024101T183921.v2.0 | Iteration: 242 | Time: 42.137094020843506
Starting HLS.S30.T11SKU.2024106T183919.v2.0
(3660, 3660, 30)
Running MESMA
File complete: HLS.S30.T11SKU.2024106T183919.v2.0 | Iteration: 243 | Time: 66.17373871803284
Starting HLS.L30.T11SKU.2024107T183337.v2.0
(3660, 3660, 30)
Running MESMA
File complete: HLS.L30.T11SKU.2024107T183337.v2.0 | Iteration: 244 | Time: 45.90866017341614
Starting HLS.S30.T11SKU.2024121T183921.v2.0
(3660, 3660, 30)
Running MESMA
File complete: HLS.S30.T11SKU.2024121T183921.v2.0 | Iteration: 245 | Time: 76.76202821731567
Endmembers Logged on iteration: 245
Starting HLS.L30.T11SKU.2024123T183326.v2.0
(3660, 3660, 30)
Running MESMA
File complete: HLS.L30.T11SKU.2024123T183326.v2

In [10]:
if os.path.isfile(endmember_path):
    with open(endmember_path, 'rb') as f:
        endmember_log = pickle.load(f)
else:
    endmember_log = []
endmember_log.append(endmember_dict)

with open(endmember_path, 'wb') as f:
    pickle.dump(endmember_log, f)