# Cloud Restoration

## Modules

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os

# Get the parent directory of the current notebook
parent_dir = os.path.abspath(os.path.join(os.getcwd(), "../src"))

# Add the parent directory to sys.path
sys.path.insert(0, parent_dir)


from process_composites._cloud_restoration_utils import (
    restore_cloud_shadow_xr,
    shadow_adjustment,
)

from scripting._process_composites import load_composites
from scripting._process_dems import load_dems, calculate_aspect_from_dems, calculate_slope_from_dems

import matplotlib.pyplot as plt
import dask.distributed
import numpy as np
import xarray as xr
import rioxarray
import yaml

def read_yaml(file_path: str) -> dict:
    with open(file_path, 'r') as yaml_file: return yaml.safe_load(yaml_file)
    
def fix_paths_for_nb(input_dict, old_substring = "/home/hrlcuser/media", new_substring = "/media/datapart/lucazanolo"):
    return {
        key: (value.replace(old_substring, new_substring) if isinstance(value, str) else value)
        for key, value in input_dict.items()
    }


## Parameters

In [None]:
parameters = fix_paths_for_nb(read_yaml("/home/lucazanolo/luca-zanolo/scripts/config_files/3.cloud_restoration.yaml"))

tile_id = parameters["tile_id"]

composites_restored_path = f"{parameters['output_path']}/{parameters['composites_path'].split('/')[-2]}_restored"
composites_restored_corr_path = f"{parameters['output_path']}/{parameters['composites_path'].split('/')[-2]}_restored_adj"

os.makedirs(composites_restored_path, exist_ok=True)
os.makedirs(composites_restored_corr_path, exist_ok=True)

## Load data

In [None]:
with dask.distributed.Client(
    processes=False,
    threads_per_worker=(os.cpu_count() or 2),
) as client:
    
    print(f"Cloud Restoration for tile {tile_id}")
    print(f"Dask dashboard: {client.dashboard_link}")
    print("Loading composites.")
    composites = load_composites(parameters['composites_path'], parameters["composites_year"], tile_id)
    if parameters["verbose"]:
        print(f"Composites raw:\n{composites}\n\n")

    dems = load_dems(parameters['dems_path'], parameters["dems_year"], tile_id)
        
    # Compute slopes and aspects
    
    slope = calculate_slope_from_dems(dems.band_data)
    aspect = calculate_aspect_from_dems(dems.band_data)
    
    composites = composites.assign({
        "dems":dems.band_data,
        "slopes":slope,
        "aspects":aspect})
    
    if parameters["verbose"]:
        print(f"Dems:\n{dems}\n\n")
        print(f"Aspect:\n{aspect}\n\n")
        print(f"Slope:\n{slope}\n\n")
        print(f"Composites Dataset:\n{composites}\n\n")
    
    

## Restore composites

In [None]:
with dask.distributed.Client(
    processes=False,
    threads_per_worker=(os.cpu_count() or 2),
) as client:
        
    # Restore cloud/shadow pixels

    print(f"Restoring cloud/shadow pixels ...")

    composites["band_data_restored"] = composites.band_data.groupby("band").map(restore_cloud_shadow_xr)

    if parameters["verbose"]:
        print(f"Composites Restored:\n{composites.band_data_restored}\n\n")

    if parameters["verbose"]:
        print(f"Composites with band data restored:\n{composites}\n\n")


    slopes = composites.slopes.sel(tile=tile_id).values

    cr_paths = []
    crs_paths = []
    for time_idx in range(composites.sizes['time']):
        
        name = composites.file_name.isel(time=time_idx).values
        name = str(name).split('_')
        name_restored = f"{name[0]}Restored_{name[1]}_{name[2]}"
        name_adj = f"{name[0]}RestoredAdj_{name[1]}_{name[2]}"
        
        cr_paths.append(
            os.path.join(composites_restored_path,name_restored)
        )
        
        crs_paths.append(
            os.path.join(composites_restored_corr_path,name_adj)
        )

    for time_idx, comp_rest_path, comp_rest_corr_path in zip(range(composites.sizes['time']), cr_paths, crs_paths) :
        
        #if time_idx != 3: continue
        
        if not os.path.exists(comp_rest_path):
            
            composite_rest = composites.band_data_restored.isel(time=time_idx)
        
            print(f"Saving Restored Image at {comp_rest_path} ...")
            if parameters["verbose"]:
                print(f"Composite:\n{composite_rest}\n")
                        
            composite_rest.rio.to_raster(comp_rest_path, compress="DEFLATE", num_threads="all_cpus")
            print(f"Saved Restored Image.")    
        
        #break


## Apply shadow adjustment on restored composites

In [None]:
with dask.distributed.Client(
    processes=False,
    threads_per_worker=(os.cpu_count() or 2),
) as client:
    
    print(f"Dask dashboard: {client.dashboard_link}")

    for time_idx, comp_rest_path, comp_rest_corr_path in zip(range(composites.sizes['time']), cr_paths, crs_paths) :
        
        #if time_idx != 3: continue
        
        if not os.path.exists(comp_rest_path):
            
            composite_rest = composites.band_data_restored.isel(time=time_idx)
        
            print(f"Saving Restored Image at {comp_rest_path} ...")
            if parameters["verbose"]:
                print(f"Composite:\n{composite_rest}\n")
                        
            composite_rest.rio.to_raster(comp_rest_path, compress="DEFLATE", num_threads="all_cpus")
            print(f"Saved Restored Image.")    

        composite_rest = rioxarray.open_rasterio(comp_rest_path).squeeze()
        
        if not os.path.exists(comp_rest_corr_path):
                
            composite_rest_val = composite_rest.values
            
            print(f"Correcting restored image ...")
            
            composite_rest_corr = shadow_adjustment(composite_rest_val, slopes)
            
            composite_rest_corr_da = xr.DataArray(
                composite_rest_corr.squeeze(),
                dims=composite_rest.dims,
                coords=composite_rest.coords,
            )
        
            print(f"Saving Restored Corrected Image at {comp_rest_corr_path} ...")
            
            if parameters["verbose"]:
                print(f"Composite restored corrected:\n{composite_rest_corr_da}\n")
                
            composite_rest_corr_da.rio.to_raster(comp_rest_corr_path, compress="DEFLATE", num_threads="all_cpus")
            print(f"Saved Restored Corrected Image.")
    
        else:
            print(f"Corrected Restored Composite already exists, skipping: {comp_rest_corr_path}")

        #break
        

## Inspect cloud restoration results - Generate reports

In [None]:
def calculate_zero_percentage(array):
    """
    Calculate the percentage of zero values across all bands in an array.
    """
    total_pixels = array[0].size  # Total pixels per band
    zero_pixels = np.sum(np.all(array == 0, axis=0))  # Count pixels that are zero in all bands
    return (zero_pixels / total_pixels) * 100

def generate_composite_masks(composite):
    """
    Generate a mask where the value is 1 if the pixel is 0 across all bands, otherwise 0.
    """
    return np.all(composite == 0, axis=0).astype(np.uint8)

def create_composite_report(
    composite, composite_rest, composite_rest_corr, dem, slope, aspect, output_path, title = ""
):
    """
    Create a report for composites, DEM, slope, and aspect, including zero-value masks for composites.
    """
    # Calculate zero percentages
    composite_zero_percentage = calculate_zero_percentage(composite)
    composite_rest_zero_percentage = calculate_zero_percentage(composite_rest)
    composite_rest_corr_zero_percentage = calculate_zero_percentage(composite_rest_corr)

    # Generate masks
    composite_mask = generate_composite_masks(composite)
    composite_rest_mask = generate_composite_masks(composite_rest)
    composite_rest_corr_mask = generate_composite_masks(composite_rest_corr)

    # Create the figure and subplots
    fig, axes = plt.subplots(3, 2, figsize=(12, 18))

    # Plot composite and its mask
    axes[0, 0].imshow(np.clip(composite.transpose(1, 2, 0) / 2000, 0, 1))  # Normalize for RGB display
    axes[0, 0].set_title(f"Composite\n{composite_zero_percentage:.2f}% Zero Pixels", fontsize = 18)
    axes[0, 0].axis("off")
    axes[0, 1].imshow(composite_mask, cmap="gray")
    axes[0, 1].set_title("Composite Mask (1 where Zero in All Bands)", fontsize = 18)
    axes[0, 1].axis("off")

    # Plot composite_rest and its mask
    axes[1, 0].imshow(np.clip(composite_rest.transpose(1, 2, 0) / 2000, 0, 1))  # Normalize for RGB display
    axes[1, 0].set_title(f"Composite Restored\n{composite_rest_zero_percentage:.2f}% Zero Pixels", fontsize = 18)
    axes[1, 0].axis("off")
    axes[1, 1].imshow(composite_rest_mask, cmap="gray")
    axes[1, 1].set_title("Zero pixels mask (1 where Zero in All Bands)", fontsize = 18)
    axes[1, 1].axis("off")

    # Plot composite_rest_corr and its mask
    axes[2, 0].imshow(np.clip(composite_rest_corr.transpose(1, 2, 0) / 2000, 0, 1))  # Normalize for RGB display
    axes[2, 0].set_title(f"Composite Restored with Shadow Correction\n{composite_rest_corr_zero_percentage:.2f}% Zero Pixels", fontsize = 18)
    axes[2, 0].axis("off")
    axes[2, 1].imshow(composite_rest_corr_mask, cmap="gray")
    axes[2, 1].set_title("Zero pixels mask (1 where Zero in All Bands)", fontsize = 18)
    axes[2, 1].axis("off")

    # Plot DEM, slope, and aspect
    """
    axes[0, 2].imshow(dem, cmap="terrain")
    axes[0, 2].set_title("DEM")
    axes[0, 2].axis("off")
    axes[1, 2].imshow(slope, cmap="viridis")
    axes[1, 2].set_title("Slope")
    axes[1, 2].axis("off")
    axes[2, 2].imshow(aspect, cmap="coolwarm")
    axes[2, 2].set_title("Aspect")
    axes[2, 2].axis("off")
    """
    
    plt.suptitle(title)
    plt.tight_layout()
    fig.savefig(output_path, dpi=300)
    #plt.show()
    plt.close(fig)

    print(f"Report saved as {output_path}")

tile = "21KUQ"
report_path = f"/media/datapart/lucazanolo/S2_processed_composites/reports"
os.makedirs(report_path, exist_ok=True)
dem = rioxarray.open_rasterio("/media/datapart/lucazanolo/data/DEMs/21KUQ_COP-DEM_GLO-30-DGED-v2022-1_UTM_EGM2008_10m_Bilinear.tif")
slope = calculate_slope_from_dems(dem).squeeze().values
aspect = calculate_aspect_from_dems(dem).squeeze().values
dem = dem.squeeze().values

for time_idx in composites.time.values:
    #if not str(time_idx).startswith('2019-04'): continue

    year, month, _ = str(time_idx).split('-')
    date = f"{year}-{month}"

    print(f"Generating report for {date} ... ")

    curr_report_path = f"{report_path}/composite_{tile}_{date}.png"
    if not os.path.exists(curr_report_path):
        if len(month) > 1 and month[0] == '0':
            month = month[1]

        composite = rioxarray.open_rasterio(f"/media/datapart/lucazanolo/data/composites/median_composites/monthlyCompositeMedianMasked_T{tile}{year}_{month}.tif").isel(band = [2,1,0]).values
        composite_rest = rioxarray.open_rasterio(f"/media/datapart/lucazanolo/S2_processed_composites/median_composites_restored/monthlyCompositeMedianMaskedRestored_T{tile}{year}_{month}.tif").isel(band = [2,1,0]).values
        composite_rest_corr = rioxarray.open_rasterio(f"/media/datapart/lucazanolo/S2_processed_composites/median_composites_restored_adj/monthlyCompositeMedianMaskedRestoredAdj_T{tile}{year}_{month}.tif").isel(band = [2,1,0]).values
        title = f"Composites for tile {tile} - Date: {year}-{month}"
        create_composite_report(composite, composite_rest, composite_rest_corr, dem, slope, aspect, output_path=curr_report_path, title = title)
    else:
        print(f"Report for {date} already exists. Skipping.")
        
    #break