In [None]:
import io
import os
import json
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
import folium
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
from rasterio.plot import show
from shapely import Point

In [None]:
path_store_dai = "/media/store-dai/"

In [None]:
path_parcelles_S1 = os.path.join(path_store_dai, "projets/pac/3str/EXP_2/Data_Raster/radar_dataset_parcelles")
path_parcelles_gpkg = os.path.join(path_store_dai, "projets/pac/3str/EXP_2/Data_Vecteur/CIRCA+_parcels_652K.gpkg")
path_circa_tiles_gpkg = os.path.join(path_store_dai, "projets/pac/3str/datasets/vecteurs/CIRCA/CIRCA_grid_mgrs22.gpkg")
path_parcelles_S2 = os.path.join(path_store_dai, "projets/pac/3str/EXP_2/Data_Raster/optique_dataset_parcelles")
path_dates_S2 = os.path.join(path_store_dai, "projets/pac/3str/EXP_1/auxiliary/CIRCA__mgrs_dates_crs.json")

assert os.path.isdir(path_parcelles_S1), f"Le répertoire {path_parcelles_S1} n'existe pas"
assert os.path.isfile(path_parcelles_gpkg), f"Le répertoire {path_parcelles_gpkg} n'existe pas"
assert os.path.isfile(path_circa_tiles_gpkg), f"Le répertoire {path_circa_tiles_gpkg} n'existe pas"
assert os.path.isdir(path_parcelles_S2), f"Le répertoire {path_parcelles_S2} n'existe pas"
assert os.path.isfile(path_dates_S2), f"Le répertoire {path_dates_S2} n'existe pas"

In [None]:
parcelles_gdf = gpd.read_file(path_parcelles_gpkg).to_crs(epsg=4326)

In [None]:
tiles_gdf = gpd.read_file(path_circa_tiles_gpkg)
tiles_names = set(parcelles_gdf["Grid"])
print(tiles_names)

In [None]:
class MapMGRS:
    def __init__(self, path_circa_tiles_gpkg):
        self.path_circa_tiles_gpkg = path_circa_tiles_gpkg
        self.tiles_gdf = gpd.read_file(path_circa_tiles_gpkg).to_crs(epsg=4326)

    def create_map(self):
        centroid:Point = self.tiles_gdf.dissolve().centroid
        location = [centroid.y, centroid.x]
        aFoliumMap = folium.Map(location=location, zoom_start=6)

        popup = folium.GeoJsonPopup(fields=["Name"])
        tiles_json = folium.GeoJson(
            data=self.tiles_gdf.to_json(), 
            style_function=lambda x: {"fillColor": "orange"},
            popup=popup,
        )
        
        tiles_json.add_to(aFoliumMap)
        aHTMLbalise = HTML("<div style='width: 75%'>" + aFoliumMap._repr_html_() + "</div>")
        return aHTMLbalise
    

class MapParcelles:
    
    def __init__(self, parcelles_gdf):
        self.parcelles_gdf = parcelles_gdf

    def createFoliumMap(self):
        aFoliumMap = folium.Map(location = [47.5, 0], zoom_start=6)
        aHTMLbalise = HTML("<div style='width: 75%'>" + aFoliumMap._repr_html_() + "</div>")
        with self.out_folium_map:
            display(aHTMLbalise)
    
    def display_parcelle(self, parcelle_gdf):
        grid_selected = parcelle_gdf.iloc[0]["Grid"]
        bounds = parcelle_gdf.iloc[0]["geometry"].bounds
        location = [(bounds[1]+bounds[3])/2, (bounds[0]+bounds[2])/2]
        self.display(grid_selected, location)
    
    
    def display_tile_event(self, e):
        grid_selected = self.input_tile_name.value
        self.display(grid_selected, None)
    

    def display(self, grid_selected, location):
        extract_gdf = self.parcelles_gdf[self.parcelles_gdf["Grid"]==grid_selected]
        zoom_start = 17
        if location is None:
            bounds = extract_gdf.total_bounds
            location = [(bounds[1]+bounds[3])/2, (bounds[0]+bounds[2])/2]
            zoom_start = 11
        aFoliumMap = folium.Map(location=location, zoom_start=zoom_start)
        popup = folium.GeoJsonPopup(fields=["ID-CIRCA"])
        tiles_json = folium.GeoJson(
            data=extract_gdf.to_json(), 
            style_function=lambda x: {"fillColor": "orange"},
            popup=popup,
        )
        tiles_json.add_to(aFoliumMap)
        aHTMLbalise = HTML("<div style='width: 75%'>" + aFoliumMap._repr_html_() + "</div>")
        with self.out_folium_map:
            clear_output()
            display(aHTMLbalise)


    def create_widget(self):
        self.input_tile_name = widgets.Text(description="Tile name :")
        self.out_folium_map = widgets.Output()
        self.createFoliumMap()
        validate_button_parcel = widgets.Button(description='Valider')
        validate_button_parcel.on_click(self.display_tile_event)
        aWidget = widgets.VBox([widgets.HBox([self.input_tile_name, validate_button_parcel]), self.out_folium_map])
        return aWidget


class SerieTemporelleRadar:

    def __init__(self, path_parcelles_S1, parcelles_gdf,):
        self.path_parcelles_S1 = path_parcelles_S1
        self.parcelles_gdf = parcelles_gdf
        self.output_img = widgets.Output(layout={'width': '100%'})
        
    def normalize(self, data):
        mask = np.where(data[0,0,:,:]==0,0,1)
        for i in range(data.shape[1]):
            band = data[:,i, :, :]
            band_modif = np.where(band==0,np.mean(band), band) # Bidouille pour que lors du réétalement, on ne prenne pas en compte les valeurs du masque
            data[:,i, :, :] = (band-np.min(band_modif))/(np.max(band_modif)-np.min(band_modif))
        mask = np.expand_dims(mask, (0,1))
        return data * mask

    def is_null(self, data):
        band_0 = data[0,:,:]
        if np.min(band_0)==np.max(band_0):
            return True
        return False
    


    def remove_null_date(self, data, dates):
        indices_keep = []
        dates_keep = []
        for i in range(data.shape[0]):
            date = data[i,:,:,:]
            band_0 = date[0,:,:]
            if np.min(band_0)!=np.max(band_0):
                indices_keep.append(i)
                dates_keep.append(dates[i])
        return data[indices_keep,:,:,:], dates_keep


    def display_radar_une_orbite(self, data, dates):
        data_imgs_IO = []
        plt.rcParams.update({'figure.max_open_warning': 100})
        data_normalized = self.normalize(data)
        for i in range(data_normalized.shape[0]):
            if self.is_null(data_normalized[i,:3,:,:]):
                continue
            fig = plt.figure(figsize=(2, 4))
            ax = fig.add_subplot(1, 1, 1, xticklabels=[], yticklabels=[])
            ax.set_title(dates[i])
            show(data_normalized[i,:3,:,:], ax=ax)
            buf = io.BytesIO()
            fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
            data_imgs_IO.append(widgets.Image(value=buf.getvalue(), layout=widgets.Layout(padding='0 0 0 0', margin='0 3 0 3')) )
            plt.close()
        return data_imgs_IO
    

    def compute_mean(self, data, dates):
        data, dates = self.remove_null_date(data, dates)
        mask = np.where(data!=0, True, False)
        means = np.mean(data, axis=(2,3), where=mask)
        abscisses = np.arange(means.shape[0])


        fig, ax1 = plt.subplots(figsize=(20, 1))
        main_color = (0.0, 0.0, 1.0)
        second_color = (0.3, 0.0, 0.7)
        ax1.plot(abscisses, means[:,0], color=main_color, label="sigma vv")
        ax1.plot(abscisses, means[:,1], color=second_color, label="sigma vh")
        ax1.set_ylabel("sigma", color=main_color)
        ax1.tick_params(axis='y', labelcolor=main_color)
        ax1.legend()

        ax2 = ax1.twinx()
        main_color = (1.0, 0.0, 0.0)
        second_color = (0.7, 0.3, 0.0)
        ax2.plot(abscisses, means[:,2], color=main_color, label="cohérence vv")
        ax2.plot(abscisses, means[:,3], color=second_color, label="cohérence vh")
        ax2.set_ylabel("cohérence", color=main_color)
        ax2.tick_params(axis='y', labelcolor=main_color)
        ax2.legend()

        ax1.set_xticks(abscisses, dates, rotation="vertical")

        buf = io.BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
        aWidget = widgets.Image(value=buf.getvalue(), layout=widgets.Layout(padding='0 0 0 0', margin='0 3 0 3'))
        plt.close()

        return aWidget


    def read_dates_radar(self, path_dir_grid:str):
        with open(os.path.join(path_dir_grid, "dates.json"), "r") as f:
            dates = json.load(f)
        return dates


    def display_radar(self):

        widgets_asc = self.display_radar_une_orbite(self.radar_data_asc, self.dates["dates_asc"])
        widgets_desc = self.display_radar_une_orbite(self.radar_data_desc, self.dates["dates_desc"])

        mean_asc = self.compute_mean(self.radar_data_asc, self.dates["dates_asc"])
        mean_desc = self.compute_mean(self.radar_data_desc, self.dates["dates_desc"])
        
        widgets_asc_hbox = widgets.HBox(widgets_asc, layout=widgets.Layout(overflow_x='auto', height='auto', margin='0 0 8px 0', padding='0 0 0 0'))
        widgets_desc_hbox = widgets.HBox(widgets_desc, layout=widgets.Layout(overflow_x='auto', height='auto', margin='0 0 8px 0', padding='0 0 0 0'))
        widgets_vbox = widgets.VBox([widgets_asc_hbox, mean_asc, widgets_desc_hbox, mean_desc], layout=widgets.Layout(overflow_x='auto', height='auto', margin='0 0 8px 0', padding='0 0 0 0'))

        return widgets_vbox
    

    def load_parcelle(self, parcelle_gdf):
        with self.output_img:
            clear_output()
        grid = parcelle_gdf.iloc[0].Grid
        grid_group = grid.split("_")[0]
        id_circa_full = parcelle_gdf.iloc[0]["ID-CIRCA-FULL"]
        npy_filename_asc = f"{grid}__{id_circa_full}__SEN1-ASC-data-parcel.npy"
        npy_filename_desc = f"{grid}__{id_circa_full}__SEN1-DESC-data-parcel.npy"
        npy_path_asc = os.path.join(self.path_parcelles_S1, grid_group, grid, npy_filename_asc)
        if os.path.isfile(npy_path_asc):
            self.radar_data_asc = np.load(npy_path_asc)
        npy_path_desc = os.path.join(self.path_parcelles_S1, grid_group, grid, npy_filename_desc)
        if os.path.isfile(npy_path_desc):
            self.radar_data_desc = np.load(npy_path_desc)
        self.dates = self.read_dates_radar(os.path.join(self.path_parcelles_S1, grid_group, grid))
        widget = self.display_radar()
        with self.output_img:
            display(widget)


    def create_widget(self):
        return self.output_img
    

class SerieTemporelleOptique:

    def __init__(self, parcelles_gdf, path_parcelles_S2, path_dates_S2):
        self.path_parcelles_S2 = path_parcelles_S2
        self.parcelles_gdf = parcelles_gdf
        self.path_dates_S2 = path_dates_S2
        with open(path_dates_S2, "r") as f:
            self.dates_S2_json = json.load(f)
        self.output_img = widgets.Output(layout={'width': '100%'})
        self.clouds_removed = False
        self.optical_data = None

    def is_cloudy(self, data):
        max_cloud_value=10
        max_snow_value=10
        max_fraction_covered=0.05
        data[data==65535] = 0
        _, n, m = data.shape
        cloud_mask = data[0,:,:]
        snow_mask = data[1,:,:]
        image_mask = np.ones((n,m))
        image_mask[cloud_mask==65535] = 0

        select = (cloud_mask <= max_cloud_value) & (snow_mask <= max_snow_value)
        num_pix = image_mask.sum()
        threshold = (1 - max_fraction_covered) * num_pix
        image_masked = select*image_mask
        sum = np.sum(image_masked)

        if sum>=threshold:
            return False
        return True


    def display_optique_images(self):
        data_imgs_IO = []
        plt.rcParams.update({'figure.max_open_warning': 100})
        optical_data = np.clip(self.optical_data / 3000, 0, 1)
        for i in range(self.optical_data.shape[0]):
            if self.clouds_removed and self.is_cloudy(self.optical_data[i, 10:12,:,:]):
                continue
            if np.max(self.optical_data[i,0,:,:])==0:
                continue 
            fig = plt.figure(figsize=(2, 4))
            ax = fig.add_subplot(1, 1, 1, xticklabels=[], yticklabels=[])
            ax.set_title(self.dates_S2[i])
            show(optical_data[i,[2,1,0],:,:], ax=ax)
            buf = io.BytesIO()
            fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
            data_imgs_IO.append(widgets.Image(value=buf.getvalue(), layout=widgets.Layout(padding='0 0 0 0', margin='0 3 0 3')) )
            plt.close()
        return data_imgs_IO

    
    def display_ndvi(self):
        idx_nir = 6
        idx_red = 2
        data = self.optical_data.astype(np.int32)
        dates = self.dates_S2      

        mask = np.where(data!=0, True, False)
        NDVI = (data[:, idx_nir, :, :] - data[:, idx_red, :, :]) / (data[:, idx_nir, :, :] + data[:, idx_red, :, :] + np.finfo(float).eps)
        NDVI_mean = np.mean(NDVI, axis=(1,2), where=mask[:,idx_red,:,:])
        NDVI_mean = np.clip(NDVI_mean, 0, 1)

        dates_nc = []
        dates_c = []
        ndvi_nc = []
        ndvi_c = []
        abscisses_nc = []
        abscisses_c = []
        for i in range(NDVI.shape[0]):
            if np.max(self.optical_data[i,0,:,:])==0:
                continue
            if not self.is_cloudy(data[i, 10:12,:,:]):
                dates_nc.append(dates[i])
                ndvi_nc.append(NDVI_mean[i])
                abscisses_nc.append(i)
            else:
                dates_c.append(dates[i])
                ndvi_c.append(NDVI_mean[i])
                abscisses_c.append(i)

        fig = plt.figure(figsize=(20, 1))
        plt.scatter(abscisses_nc, ndvi_nc, color="g", marker="x")
        plt.scatter(abscisses_c, ndvi_c, color="r", marker="x")
        plt.xticks(np.arange(len(dates)), dates, rotation="vertical")
        buf = io.BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
        aNDVIWidget = widgets.Image(value=buf.getvalue(), layout=widgets.Layout(padding='0 0 0 0', margin='0 3 0 3'))
        plt.close()
        return aNDVIWidget
    

    def remove_clouds(self, removeClouds):
        self.clouds_removed = removeClouds
        self.display_optique()

    
    def display_optique(self):
        with self.output_img:
            clear_output()
        if self.optical_data is None:
            return 
        widget_optique = self.display_optique_images()
        widget_ndvi = self.display_ndvi()
        widgets_optique_hbox = widgets.HBox(widget_optique, layout=widgets.Layout(overflow_x='auto', height='auto', margin='0 0 8px 0', padding='0 0 0 0'))
        widget_optique_ndvi = widgets.VBox([widgets_optique_hbox, widget_ndvi], layout=widgets.Layout(overflow_x='auto', height='auto', margin='0 0 8px 0', padding='0 0 0 0'))
        with self.output_img:
            display(widget_optique_ndvi)
    
    def create_widget(self):
        return self.output_img

    def load_parcelle(self, parcelle_gdf):
        grid = parcelle_gdf.iloc[0].Grid
        grid_group = grid.split("_")[0]
        id_circa_full = parcelle_gdf.iloc[0]["ID-CIRCA-FULL"]
        npy_filename = f"{grid}__{id_circa_full}__SEN2-data-parcel.npy"
        npy_path = os.path.join(self.path_parcelles_S2, grid_group, grid, npy_filename)
        json_path = os.path.join(self.path_parcelles_S2, grid_group, grid, "dates.json")
        if os.path.isfile(npy_path):
            self.optical_data:np.array = np.load(npy_path)
            with open(json_path, "r") as f:
                self.dates_S2 = json.load(f)
            self.display_optique()
        

class Parameters:

    def __init__(self, aMapParcelles:MapParcelles, aSerieTemporelleRadar:SerieTemporelleRadar, aSerieTemporelleOptique:SerieTemporelleOptique, parcelles_gdf):
        self.aMapParcelles = aMapParcelles
        self.aSerieTemporelleRadar = aSerieTemporelleRadar
        self.aSerieTemporelleOptique = aSerieTemporelleOptique
        self.parcelles_gdf = parcelles_gdf
        self.input_tile_name = widgets.Text(description="Parcelle name :", value="058021503-5-3")
        self.validate_button_parcel = widgets.Button(description='Valider')
        self.validate_button_parcel.on_click(self.display_parcelle)
        self.cb_s2_cloud = widgets.Checkbox(value=False, description='Remove clouds', disabled=False)
        self.cb_s2_cloud.observe(self.remove_s2_clouds)



    def display_parcelle(self, e):
        parcelle_selected = self.input_tile_name.value
        parcelle_gdf = self.parcelles_gdf[self.parcelles_gdf["ID-CIRCA"]==parcelle_selected]
        if parcelle_gdf.shape[0]==1:
            self.aSerieTemporelleRadar.load_parcelle(parcelle_gdf)
            self.aSerieTemporelleOptique.load_parcelle(parcelle_gdf)
            self.aMapParcelles.display_parcelle(parcelle_gdf)

    def remove_s2_clouds(self, e):
        if e["name"]=="value":
            self.aSerieTemporelleOptique.remove_clouds(e["new"])
            

    def create_widgets(self):
        a_input_tile_name_widget = widgets.HBox([self.input_tile_name, self.validate_button_parcel])
        aWidget = widgets.VBox([a_input_tile_name_widget, self.cb_s2_cloud])
        return aWidget


aMapMGRS = MapMGRS(path_circa_tiles_gpkg)
aMapMGRSHTMLbalise = aMapMGRS.create_map()
display(aMapMGRSHTMLbalise)

aMapParcellesHTMLbalise = MapParcelles(parcelles_gdf)
display(aMapParcellesHTMLbalise.create_widget())
aSerieTemporelleRadar = SerieTemporelleRadar(path_parcelles_S1, parcelles_gdf)
aSerieTemporelleOptique = SerieTemporelleOptique(parcelles_gdf, path_parcelles_S2, path_dates_S2)
aParameters = Parameters(aMapParcellesHTMLbalise, aSerieTemporelleRadar, aSerieTemporelleOptique, parcelles_gdf)

display(aParameters.create_widgets())
display(aSerieTemporelleRadar.create_widget())
display(aSerieTemporelleOptique.create_widget())
