# helper functions

In [1]:
def get_valid_data(eopatch):
    """ filter data with cloud mask
    
    Returns :
        - valid_data : times series with NaN values for cloudy pixels
        - dates : array of dates
    """
    
    # number of superpixel
    n_superpixels = np.unique(eopatch.mask_timeless['SUPER_PIXELS']).size

    # superpixel to which belong the corresponding pixel 
    superpixels = eopatch.mask_timeless['SUPER_PIXELS']

    # number of dates
    n_timestamps = len(eopatch.timestamp)

    # temporal superpixels
    temporal_superpixels = np.array([idx*n_superpixels + superpixels for idx in range(n_timestamps)])
    
    # mean of cloud coverage of each superpixel
    mean_clm_superpixels = ndimage.mean(eopatch.mask['CLM'], labels=temporal_superpixels, index=np.unique(temporal_superpixels)).reshape((n_timestamps, n_superpixels))

    # mean of ndvi of each superpixel
    mean_ndvi_superpixels = ndimage.mean(eopatch.data['NDVI_STANDARD'], labels=temporal_superpixels, index=np.unique(temporal_superpixels)).reshape((n_timestamps, n_superpixels))
    
    # Apply filter : set to NaN every data that has too much cloud coverage
    valid_data = np.where(mean_clm_superpixels<0.2, mean_ndvi_superpixels, np.nan)

    # Get dates from time line
    dates = np.array([date.strftime('%Y-%m-%d') for date in eopatch.timestamp])
    
    return valid_data, dates

In [2]:
def get_forest_time_series(eopatch, shapefile, filter_percentage=0.8):
    
    # read shapefile
    forets = gpd.read_file(shapefile)
    
    # polygon of aoi
    poly = shp.box(*eopatch.bbox)

    # Projection en 4326
    forets['geometry'] = forets['geometry'].to_crs(eopatch.bbox.crs.epsg)

    # Intersection forêts et polygone superpixels
    forets['geometry'] = forets['geometry'].intersection(poly)

    # Polygones des forêts de l'aoi
    aoi_forest = forets[~forets['geometry'].is_empty]

    # dissolve polygons into a single one
    single_poly_forest = aoi_forest.dissolve()
    
    # Get time series polygons
    time_series_polygons = eopatch.vector_timeless['SUPER_PIXELS']

    # Get time series polygons of region of interest (forest)
    intersection = time_series_polygons.intersection(list(single_poly_forest.geometry)[0])

    # only keep superpixel that contain more than @filter_percentage of forest
    time_series_forest = time_series_polygons[(intersection.area / time_series_polygons.area) > filter_percentage]
    
    # indices of the time series preserved
    ts_indices_preserved = np.unique(time_series_forest['VALUE']).astype('int64')

    # number of time series preserved : print of the percentage
    n_superpixels = np.unique(eopatch.mask_timeless['SUPER_PIXELS']).size
    n_time_series_preserved = len(ts_indices_preserved)
    print(round(len(ts_indices_preserved)/n_superpixels * 100, 2), "% of time series preserved")
    
    return time_series_forest, ts_indices_preserved

In [4]:
def get_town_time_series(time_series, town, filter_percentage=0.8):
    
    intersection = time_series.intersection(list(town)[0])

    new_time_series = time_series[(intersection.area / time_series.area) > filter_percentage]

    ts_indices_preserved = np.unique(new_time_series['VALUE']).astype('int64')
    
    return new_time_series, ts_indices_preserved

In [5]:
def set_bfast_params(valid_data, dates, ts_indices_preserved, end_training, start_monitor, end_monitor, k=3, freq=365, trend=False, hfrac=0.25, level=0.05):
    # change date format to datetime
    dates = np.array([datetime.fromisoformat(str(date)) for date in dates])

    # list of dates
    dates = list(dates)

    # set NaN values to 0
    valid_data[np.isnan(valid_data)] = 0

    # fit BFASTMontiro model
    model = BFASTMonitor(
                start_monitor,
                freq=freq,
                k=k,
                hfrac=hfrac,
                trend=trend,
                level=level,
                backend='python',
                verbose=1,
                device_id=0,
            )

    # preparing change of type
    valid_data_int = valid_data * (32768/valid_data.max())

    # change of type
    valid_data_int = valid_data_int.astype(np.int16)

    # add third dimension to make it look like an image
    valid_data_int = valid_data_int[..., np.newaxis]
    
    # first date
    start_hist = dates[0]

    # crop data from start to end date of monitoring
    valid_data_f, dates_f = crop_data_dates(valid_data_int, dates, start=start_hist - timedelta(days=1), end=end_monitor)

    # filter of data
    valid_data_f2 = valid_data_f[:, ts_indices_preserved, :]
    
    # dates indices
    ind_end_train = 0
    while dates[ind_end_train] < end_training:
        ind_end_train+=1
    
    ind_start_monitor = ind_end_train
    while dates[ind_start_monitor] < start_monitor:
        ind_start_monitor+=1
    
    # filter dates
    valid_data_f2 = np.delete(valid_data_f2, list(range(ind_end_train, ind_start_monitor)), 0)
    del dates_f[ind_end_train:ind_start_monitor]
    
    return model, valid_data_f2, dates_f

In [6]:
def execute_bfast(model, data, dates, n_chunks=5, nan_values=0):
    
    old_stdout = sys.stdout # backup current stdout
    sys.stdout = open(os.devnull, "w")
    model.fit(data, dates, n_chunks=5, nan_value=0)
    sys.stdout = old_stdout # reset old stdout

    breaks = model.breaks
    means = model.means
    valids = model.valids
    magnitudes = model.magnitudes
    
    return breaks, magnitudes, means, valids

In [7]:
def organise_results(time_series, dates, start_monitor, breaks, magnitudes):
    # get the index of monitoring start
    start_monitor_index = 0
    while dates[start_monitor_index] < start_monitor:
        start_monitor_index+=1


    # datetime format
    def to_date(breakpoint):
        if breakpoint <= 0 :
            return np.datetime64("NaT")
        bp_index = breakpoint[0]
        return dates[start_monitor_index+bp_index]

    # IF NO FOREST FILTERING 
    # superpixel dataframe
    # super_pixels_df = eopatch.vector_timeless['SUPER_PIXELS']

    # output_df = pd.DataFrame({
    #     'VALUE':range(0,n_superpixels), 
    #     'breakpoint': [to_date(b) for b in breaks],
    #     'magnitude': np.squeeze(magnitudes, axis=1)
    # })

    super_pixels_df = time_series

    output_df = pd.DataFrame({
        'VALUE': np.unique(time_series['VALUE']).astype('int64'), 
        'breakpoint': [to_date(b) for b in breaks],
        'magnitude': np.squeeze(magnitudes, axis=1)
    })

    results = super_pixels_df.merge(output_df, on='VALUE')

    # mag_norm = ((results['magnitude'] - min_mag)/(max_mag - min_mag))*2 - 1
    results['norm_mag'] = np.zeros(len(results.index))
    min_mag = results['magnitude'].min()
    max_mag = results['magnitude'].max()

    results['norm_mag'][results['magnitude'] > 0] = results['magnitude'][results['magnitude'] > 0]/max_mag
    results['norm_mag'][results['magnitude'] < 0] = results['magnitude'][results['magnitude'] < 0]/abs(min_mag)
    
    return results

In [8]:
def group_by_breakpoints(results):
    
    # group sp by breakpoint date
    by_breakpoints = results.groupby(by='breakpoint', as_index=False).agg({'VALUE' : 'count', 'magnitude': ['min', 'max', 'mean', 'median']})
    by_breakpoints.columns = [f"{x}_{y}" if y else x for x, y in by_breakpoints.columns.to_flat_index()]
    
    return by_breakpoints

In [9]:
def bfast_dynamic(valid_data, dates, ts_indices_preserved, end_train, first_window):
    
    first_date_window = first_window[0]
    last_date_window = first_window[1]
    last_date_index = len(dates)-1

    first_date_window_index = 0
    while dates[first_date_window_index] < first_date_window.strftime('%Y-%m-%d'):
        first_date_window_index+=1
    
    last_date_window_index = last_date_index
    while dates[last_date_window_index] > last_date_window.strftime('%Y-%m-%d'):
        last_date_window_index-=1
    
    window_dates = dates[last_date_window_index:]
    
    breaks_list = []
    magnitudes_list = []
    results = []
    with tqdm(total=last_date_index-last_date_window_index) as pbar:
        while last_date_window_index <= last_date_index: 
            start_monitor = datetime.fromisoformat(str(dates[first_date_window_index]))
            end_monitor = datetime.fromisoformat(str(dates[last_date_window_index]))
            bfast_model, valid_data_f, dates_f = set_bfast_params(valid_data, dates, ts_indices_preserved, end_train, start_monitor, end_monitor)
            breaks, magnitudes, means, valids = execute_bfast(bfast_model, valid_data_f, dates_f);
            #print(start_monitor, end_monitor, "breaks =",(breaks>0).sum())
            breaks_list.append(breaks)
            magnitudes_list.append(magnitudes)
            results.append(organise_results(time_series_forest, dates_f, start_monitor, breaks, magnitudes))
            first_date_window_index+=1
            last_date_window_index+=1
            #print('==================================================================')
            pbar.update(1)
    
    return breaks_list, magnitudes_list, results, window_dates

### Affichage

In [10]:
def plot_dep(aoi, nom_dep, basemap='OSM', shapefile=''):
    
    fig, ax = plt.subplots(figsize=(10,10))
    departements = gpd.read_file(shapefile)
    departements.geometry = departements.geometry.to_crs(aoi.crs.epsg)
    departements[departements.nom==nom_dep].iloc[[0]].plot(ax=ax, alpha=0.3, color=None, edgecolor='k', linewidth = 1)
    gpd.GeoDataFrame(geometry=[aoi.geometry], crs=aoi.crs.pyproj_crs()).plot(ax=ax, alpha=0.3, color='red', edgecolor='r', linewidth=3)
    if basemap=='GP':
        cx.add_basemap(ax=ax, crs=aoi.crs.epsg, source=cx.providers.GeoportailFrance.orthos)
    elif basemap=='OSM':
        cx.add_basemap(ax=ax, crs=aoi.crs.epsg, source=cx.providers.OpenStreetMap.Mapnik)

In [1]:
def plot_forest_sp(sp):
    
    fig, ax = plt.subplots(figsize=(15, 10))
    sp.geometry.plot(ax=ax, alpha=0.5, color='red', edgecolor='k', linewidth=1)
    cx.add_basemap(ax=ax, crs=sp.crs, source=cx.providers.GeoportailFrance.orthos)

In [5]:
def plot_magnitudes(results, time_series):
    
    fig, ax = plt.subplots(ncols=1, figsize=(15,10))
    divider = make_axes_locatable(ax)
    cax = make_axes_locatable(ax).append_axes('right', size='2%', pad=0.1)
    cbar = results.plot(ax=ax, column='norm_mag', cmap=cmaps.RdYlGn, legend=True, cax=cax)
    time_series.geometry.boundary.plot(ax=ax, color=None, edgecolor='grey', linewidth=0.2)
    cx.add_basemap(ax=ax, crs=time_series.crs, source=cx.providers.GeoportailFrance.orthos)

In [4]:
def plot_breakpoints(breakpoint_df, func='median'):
    # prepare colormap
    min_ = breakpoint_df['magnitude_'+func].min()
    max_ = breakpoint_df['magnitude_'+func].max()
    colormap = (breakpoint_df['magnitude_'+func] - min_)/(max_ - min_)

    # plot graph
    fig, ax = plt.subplots(figsize=(12, 8))
    sn.barplot(ax=ax, x=breakpoint_df['breakpoint'], y=breakpoint_df['VALUE_count'], palette=plt.cm.Blues_r(colormap))
    plt.xticks(rotation=90)
    plt.tight_layout()
    ax.set(xlabel='Dates', ylabel='Number of breakpoints')
    #plt.legend()

In [1]:
def plot_high_changing_sectors(eopatch, breakpoint_df, start_date, end_date, filenames, i, path):
    
    palette_size = len(breakpoint_df.breakpoint)
    red_palette = [clr.rgb2hex(plt.cm.Reds(i)) for i in range(0, plt.cm.Reds.N, round(plt.cm.Reds.N/(palette_size-1)) - 1)]
    
    start_date = start_date.strftime('%Y-%m-%d')
    end_date = end_date.strftime('%Y-%m-%d')
    breakpoints = breakpoint_df.query("@start_date <= breakpoint <= @end_date")
    fig, ax = plt.subplots(figsize=(12, 8))
    eopatch.vector_timeless['SUPER_PIXELS'].geometry.boundary.plot(ax=ax, color=None, edgecolor='black', linewidth=0.1)
    list_bp = list(breakpoints['breakpoint'])
    cmap = clr.ListedColormap([red_palette[b] for b in range(len(list_bp))])
    results.query('breakpoint in @list_bp').plot(ax=ax, column='breakpoint', categorical=True, cmap=cmap, legend=True)
    cx.add_basemap(ax=ax, crs=eopatch.bbox.crs.epsg, source=cx.providers.GeoportailFrance.orthos)
    
    filename = path + f'{i}.png'
    filenames.append(filename)
    
    # save frame
    plt.savefig(filename)
    plt.close()
    
    return filenames

In [2]:
def plot_sectors_by_bp_date(time_series_forest, results, date):
    
    date = date.strftime('%Y-%m-%d')
    fig, ax = plt.subplots(figsize=(20, 12))
    time_series_forest.geometry.boundary.plot(ax=ax, color=None, edgecolor='red', linewidth=0)
    results[results.breakpoint == date].geometry.boundary.plot(ax=ax, color=None, edgecolor='red', linewidth=1)
    cx.add_basemap(ax=ax, crs=time_series_forest.crs, source=cx.providers.GeoportailFrance.orthos)

In [None]:
def live_breaks(data, date, filenames, i, path, figsize=(7,5), title='Breaks detected over time'):
    clear_output(wait=True)
    filename = path + f'{i}.png'
    filenames.append(filename)
    
    plt.figure(figsize=figsize)
    plt.plot(data)
    plt.title(title + ' (' + date + ')')
    plt.grid(True)
    plt.ylim(top=ceil(max_bp/1000)*1000)
    plt.xlabel('Time')
    plt.ylabel('Number of breaks detected')
    plt.savefig(filename)
    plt.close()
    
    return filenames

In [None]:
def live_mag(data, date, filenames, i, path, figsize=(7,5), title='', ):
    clear_output(wait=True)
    fig, ax = plt.subplots(figsize=figsize)
    cax = make_axes_locatable(ax).append_axes('right', size='2%', pad=0.1)
    time_series_forest.geometry.boundary.plot(ax=ax, color=None, edgecolor='grey', linewidth=0.2)
    cx.add_basemap(ax=ax, crs=time_series_forest.crs, source=cx.providers.GeoportailFrance.orthos)
    cbar = data.plot(ax=ax, column='norm_mag', cmap=cmaps.RdYlGn, legend=True, cax=cax)
    
    filename = path + f'{i}.png'
    filenames.append(filename)
    
    # save frame
    plt.title('New date =' + date)
    plt.savefig(filename)
    plt.close()
    #plt.legend(loc='center left') # the plot evolves to the right
    
    return filenames