Use SWOT elevations instead of SAR Areas for trend correction in TMS-OS algorithm.

In [1]:
import pandas as pd
import numpy as np
from scipy.interpolate import interp1d
from scipy.stats import sigmaclip, zscore
import warnings
import os
from scipy.signal import savgol_filter
from pathlib import Path
import geopandas as gpd
import hvplot.pandas
import geoviews as gv
import holoviews as hv

hv.extension('bokeh')

helper functions

# Select the reservoir

In [2]:
RESERVOIR = '0505'
ALG_VERSION = '0.1' # remove temporal resampling

RESULTS_DIR = Path(f'/tiger1/pdas47/tmsosPP/results')
DATA_DIR = Path(f'/tiger1/pdas47/tmsosPP/data')

In [3]:
# read the bounding box of the study area
### all 100 reservoirs
val_pts = gpd.read_file(Path('/tiger1/pdas47/tmsosPP/data/validation-locations/100-validation-reservoirs-grand-pts.geojson'))
val_polys = gpd.read_file(Path('/tiger1/pdas47/tmsosPP/data/validation-locations/100-validation-reservoirs-grand-polys.geojson'))

selected_reservoirs = val_pts['tmsos_id'].tolist()  # select all 100 reservoirs
res_names = val_pts[['tmsos_id', 'name']].set_index('tmsos_id').to_dict()['name'] # dictionary that can be queried to get reservoir name

RESERVOIR_NAME = res_names[RESERVOIR]

val_res_pt = val_pts.loc[val_pts['tmsos_id'].isin(selected_reservoirs)]
val_res_poly = val_polys.loc[val_polys['tmsos_id'].isin(selected_reservoirs)]

nominal_area = val_res_poly[val_res_poly['tmsos_id'] == RESERVOIR]['AREA_SKM'].values[0]
nominal_area_poly = val_res_poly[val_res_poly['tmsos_id'] == RESERVOIR]['AREA_POLY'].values[0]
max_area = val_res_poly[val_res_poly['tmsos_id'] == RESERVOIR]['AREA_MAX'].values[0]
max_area = np.nan if max_area == -99 else max_area
min_area = val_res_poly[val_res_poly['tmsos_id'] == RESERVOIR]['AREA_MIN'].values[0]
min_area = 0 if min_area == -99 else min_area
area_rep = val_res_poly[val_res_poly['tmsos_id'] == RESERVOIR]['AREA_REP'].values[0]
dam_height = float(val_res_poly[val_res_poly['tmsos_id'] == RESERVOIR]['DAM_HGT_M'].values[0])
elev_msl = float(val_res_poly[val_res_poly['tmsos_id'] == RESERVOIR]['ELEV_MASL'].values[0])
depth = float(val_res_poly[val_res_poly['tmsos_id'] == RESERVOIR]['DEPTH_M'].values[0])
capacity = float(val_res_poly[val_res_poly['tmsos_id'] == RESERVOIR]['CAP_MCM'].values[0])
db = val_res_poly[val_res_poly['tmsos_id'] == RESERVOIR]['db'].values[0]

global_map = (
    val_res_pt.hvplot(
        geo=True, tiles='OSM'
    ) * val_res_pt[val_res_pt['tmsos_id'] == RESERVOIR].hvplot(
        geo=True, color='red', size=100, 
    )
).opts(
    title=f"Locations of validation reservoirs. {RESERVOIR_NAME}, highlighted in red"
)

global_map

In [4]:
(val_res_poly[val_res_poly['tmsos_id'] == RESERVOIR].hvplot(
    geo=True, tiles='OSM', shared_axes=False
)).opts(title=f"{RESERVOIR_NAME}")

In [5]:
def clip_ts(*tss, which='left'):
    """Clips multiple time-series to align them temporally

    Args:
        which (str, optional): Defines which direction the clipping will be performed. 
                               'left' will clip the time-series only on the left side of the 
                               unaligned time-serieses, and leave the right-side untouched, and 
                               _vice versa_. Defaults to 'left'. Options can be: 'left', 'right' 
                               or 'both'

    Returns:
        lists: returns the time-series as an unpacked list in the same order that they were passed
    """
    mint = max([min(ts.index) for ts in tss])
    maxt = min([max(ts.index) for ts in tss])
    
    if mint > maxt:
        raise Exception(f'No overlapping time period between the time series. Minimum T: {mint}, Maximum T: {maxt}')

    if which == 'both':
        clipped_tss = [ts.loc[(ts.index>=mint)&(ts.index<=maxt)] for ts in tss]
    elif which == 'left':
        clipped_tss = [ts.loc[ts.index>=mint] for ts in tss]
    elif which == 'right':
        clipped_tss = [ts.loc[ts.index<=maxt] for ts in tss]
    else:
        raise Exception(f'Unknown option passed: {which}, expected "left", "right" or "both"./')

    return clipped_tss


def weighted_moving_average(data, weights, window_size):
    if window_size % 2 == 0 or window_size < 1:
        raise ValueError("Window size must be an odd positive integer.")

    data = np.array(data)
    weights = np.array(weights)

    if data.shape != weights.shape:
        raise ValueError("Data and weights must have the same shape.")

    half_window = window_size // 2
    smoothed_data = np.zeros_like(data)

    for i in range(len(data)):
        start = max(0, i - half_window)
        end = min(len(data), i + half_window + 1)

        weighted_values = data[start:end] * weights[start:end]

        # Calculate the weighted moving average
        smoothed_data[i] = np.sum(weighted_values) / np.sum(weights[start:end])

    return smoothed_data

In [6]:
from pathlib import Path

class TMS():
    def __init__(self, reservoir_name, area=None, AREA_DEVIATION_THRESHOLD_PCNT=5):
        """_summary_
        Args:
            reservoir_name (_type_): _description_
            area (_type_, optional): _description_
            AREA_DEVIATION_THRESHOLD_PCNT (float, optional): _description_. Defaults to 25% for area<10 sq. km,  10% for area<100 sq. km, and 5% otherwise.
        Raises:
            Exception: _description_
        Returns:
            _type_: _description_
        """
        self.reservoir_name = reservoir_name
        self.area = area
        if self.area < 100:
            AREA_DEVIATION_THRESHOLD_PCNT=10
        elif self.area < 10:
            AREA_DEVIATION_THRESHOLD_PCNT=25

        self.AREA_DEVIATION_THRESHOLD = self.area * AREA_DEVIATION_THRESHOLD_PCNT/100

    def tms_os(self,
            l8_dfpath: str = "", 
            s2_dfpath: str = "", 
            l9_dfpath: str = "", 
            s1_dfpath: str = "", 
            swot_dfpath: str = "",
            CLOUD_THRESHOLD: float = 90.0,
            MIN_DATE: str = '2019-01-01'
        ):
        ## TODO: add conditional, S1 required, any one of optical datasets required
        """Implements the TMS-OS methodology
        Args:
            l8_dfpath (string): Path of the surface area dataframe obtained using `sarea_cli_l8.py` - Landsat derived surface areas
            s2_dfpath (string): Path of the surface area dataframe obtained using `sarea_cli_s2.py` - Sentinel-2 derived surface areas
            s1_dfpath (string): Path of the surface area dataframe obtained using `sarea_cli_sar.py` - Sentinel-1  derived surface areas
            CLOUD_THRESHOLD (float): Threshold to use for cloud-masking in % (default: 90.0)
            MIN_DATE (str): Minimum date for which data to keep for all the datasets in YYYY-MM-DD or %Y-%m-%d format (default: 2019-01-01)
        """
        MIN_DATE = pd.to_datetime(MIN_DATE, format='%Y-%m-%d')
        S2_TEMPORAL_RESOLUTION = 5
        S1_TEMPORAL_RESOLUTION = 12
        L8_TEMPORAL_RESOLUTION = 16
        L9_TEMPORAL_RESOLUTION = 16

        TO_MERGE = []

        if os.path.isfile(l8_dfpath):
            # Read in Landsat-8
            l8df = pd.read_csv(l8_dfpath, parse_dates=['mosaic_enddate']).rename({
                'mosaic_enddate': 'date',
                'water_area_cordeiro': 'water_area_uncorrected',
                'non_water_area_cordeiro': 'non_water_area', 
                'corrected_area_cordeiro': 'water_area_corrected'
                }, axis=1).set_index('date')
            l8df = l8df[['water_area_uncorrected', 'non_water_area', 'cloud_area', 'water_area_corrected']]
            l8df['cloud_percent'] = l8df['cloud_area']*100/(l8df['water_area_uncorrected']+l8df['non_water_area']+l8df['cloud_area'])
            l8df.replace(-1, np.nan, inplace=True)

            # QUALITY_DESCRIPTION
            #   0: Good, not interpolated either due to missing data or high clouds
            #   1: Poor, interpolated either due to high clouds
            #   2: Poor, interpolated either due to missing data
            l8df.loc[:, "QUALITY_DESCRIPTION"] = 0
            l8df.loc[l8df['cloud_percent']>=CLOUD_THRESHOLD, ("water_area_uncorrected", "non_water_area", "water_area_corrected")] = np.nan
            l8df.loc[l8df['cloud_percent']>=CLOUD_THRESHOLD, "QUALITY_DESCRIPTION"] = 1

            # in some cases l8df may have duplicated rows (with same values) that have to be removed
            if l8df.index.duplicated().sum() > 0:
                print("Duplicated labels, deleting")
                l8df = l8df[~l8df.index.duplicated(keep='last')]

            # Fill in the gaps in l8df created due to high cloud cover with np.nan values
            l8df_interpolated = l8df.reindex(pd.date_range(l8df.index[0], l8df.index[-1], freq=f'{L8_TEMPORAL_RESOLUTION}D'))
            l8df_interpolated.loc[np.isnan(l8df_interpolated["QUALITY_DESCRIPTION"]), "QUALITY_DESCRIPTION"] = 2
            l8df_interpolated.loc[np.isnan(l8df_interpolated['cloud_area']), 'cloud_area'] = max(l8df['cloud_area'])
            l8df_interpolated.loc[np.isnan(l8df_interpolated['cloud_percent']), 'cloud_percent'] = 100
            l8df_interpolated.loc[np.isnan(l8df_interpolated['non_water_area']), 'non_water_area'] = 0
            l8df_interpolated.loc[np.isnan(l8df_interpolated['water_area_uncorrected']), 'water_area_uncorrected'] = 0

            # Interpolate bad data
            l8df_interpolated.loc[:, "water_area_corrected"] = l8df_interpolated.loc[:, "water_area_corrected"].interpolate(method="linear", limit_direction="forward")
            l8df_interpolated['sat'] = 'l8'

            TO_MERGE.append(l8df_interpolated)


        # Read in Landsat-9
        if os.path.isfile(l9_dfpath):
            l9df = pd.read_csv(l9_dfpath, parse_dates=['mosaic_enddate']).rename({
                'mosaic_enddate': 'date',
                'water_area_cordeiro': 'water_area_uncorrected',
                'non_water_area_cordeiro': 'non_water_area', 
                'corrected_area_cordeiro': 'water_area_corrected'
                }, axis=1).set_index('date')
            l9df = l9df[['water_area_uncorrected', 'non_water_area', 'cloud_area', 'water_area_corrected']]
            l9df['cloud_percent'] = l9df['cloud_area']*100/(l9df['water_area_uncorrected']+l9df['non_water_area']+l9df['cloud_area'])
            l9df.replace(-1, np.nan, inplace=True)

            # QUALITY_DESCRIPTION
            #   0: Good, not interpolated either due to missing data or high clouds
            #   1: Poor, interpolated either due to high clouds
            #   2: Poor, interpolated either due to missing data
            l9df.loc[:, "QUALITY_DESCRIPTION"] = 0
            l9df.loc[l9df['cloud_percent']>=CLOUD_THRESHOLD, ("water_area_uncorrected", "non_water_area", "water_area_corrected")] = np.nan
            l9df.loc[l9df['cloud_percent']>=CLOUD_THRESHOLD, "QUALITY_DESCRIPTION"] = 1

            # in some cases l9df may have duplicated rows (with same values) that have to be removed
            if l9df.index.duplicated().sum() > 0:
                print("Duplicated labels, deleting")
                l9df = l9df[~l9df.index.duplicated(keep='last')]

            # Fill in the gaps in l9df created due to high cloud cover with np.nan values
            l9df_interpolated = l9df.reindex(pd.date_range(l9df.index[0], l9df.index[-1], freq=f'{L9_TEMPORAL_RESOLUTION}D'))
            l9df_interpolated.loc[np.isnan(l9df_interpolated["QUALITY_DESCRIPTION"]), "QUALITY_DESCRIPTION"] = 2
            l9df_interpolated.loc[np.isnan(l9df_interpolated['cloud_area']), 'cloud_area'] = max(l9df['cloud_area'])
            l9df_interpolated.loc[np.isnan(l9df_interpolated['cloud_percent']), 'cloud_percent'] = 100
            l9df_interpolated.loc[np.isnan(l9df_interpolated['non_water_area']), 'non_water_area'] = 0
            l9df_interpolated.loc[np.isnan(l9df_interpolated['water_area_uncorrected']), 'water_area_uncorrected'] = 0

            # Interpolate bad data
            l9df_interpolated.loc[:, "water_area_corrected"] = l8df_interpolated.loc[:, "water_area_corrected"].interpolate(method="linear", limit_direction="forward")
            l9df_interpolated['sat'] = 'l9'
            
            TO_MERGE.append(l9df_interpolated)

        if os.path.isfile(s2_dfpath):
            # Read in Sentinel-2 data
            s2df = pd.read_csv(s2_dfpath, parse_dates=['date']).set_index('date')
            s2df = s2df[['water_area_uncorrected', 'non_water_area', 'cloud_area', 'water_area_corrected']]
            s2df['cloud_percent'] = s2df['cloud_area']*100/(s2df['water_area_uncorrected']+s2df['non_water_area']+s2df['cloud_area'])
            s2df.replace(-1, np.nan, inplace=True)
            s2df.loc[s2df['cloud_percent']>=CLOUD_THRESHOLD, ("water_area_uncorrected", "non_water_area", "water_area_corrected")] = np.nan

            # QUALITY_DESCRIPTION
            #   0: Good, not interpolated either due to missing data or high clouds
            #   1: Poor, interpolated either due to high clouds
            #   2: Poor, interpolated either due to missing data
            s2df.loc[:, "QUALITY_DESCRIPTION"] = 0
            s2df.loc[s2df['cloud_percent']>=CLOUD_THRESHOLD, "QUALITY_DESCRIPTION"] = 1

            # in some cases s2df may have duplicated rows (with same values) that have to be removed
            if s2df.index.duplicated().sum() > 0:
                print("Duplicated labels, deleting")
                s2df = s2df[~s2df.index.duplicated(keep='last')]

            # Fill in the gaps in s2df created due to high cloud cover with np.nan values
            s2df_interpolated = s2df.reindex(pd.date_range(s2df.index[0], s2df.index[-1], freq=f'{S2_TEMPORAL_RESOLUTION}D'))
            s2df_interpolated.loc[np.isnan(s2df_interpolated["QUALITY_DESCRIPTION"]), "QUALITY_DESCRIPTION"] = 2
            s2df_interpolated.loc[np.isnan(s2df_interpolated['cloud_area']), 'cloud_area'] = max(s2df['cloud_area'])
            s2df_interpolated.loc[np.isnan(s2df_interpolated['cloud_percent']), 'cloud_percent'] = 100
            s2df_interpolated.loc[np.isnan(s2df_interpolated['non_water_area']), 'non_water_area'] = 0
            s2df_interpolated.loc[np.isnan(s2df_interpolated['water_area_uncorrected']), 'water_area_uncorrected'] = 0

            # Interpolate bad data
            s2df_interpolated.loc[:, "water_area_corrected"] = s2df_interpolated.loc[:, "water_area_corrected"].interpolate(method="linear", limit_direction="forward")
            s2df_interpolated['sat'] = 's2'

            TO_MERGE.append(s2df_interpolated)

        # If SAR file exists  
        if os.path.isfile(s1_dfpath):
            # Read in Sentinel-1 data
            sar = pd.read_csv(s1_dfpath, parse_dates=['time']).rename({'time': 'date'}, axis=1)
            # If SAR has atleast 3 data points 
            if (len(sar) >=3):
                sar['date'] = sar['date'].apply(lambda d: np.datetime64(d.strftime('%Y-%m-%d')))
                sar.set_index('date', inplace=True)
                sar.sort_index(inplace=True)

                # apply weekly area change filter
                sar = sar_data_statistical_fix(sar, self.area, 15)

                std = zscore(sar['sarea'])
                SAR_ZSCORE_LIM = 3
                sar.loc[(std > SAR_ZSCORE_LIM) | (std < -SAR_ZSCORE_LIM), 'sarea'] = np.nan
                sar['sarea'] = sar['sarea'].interpolate()
                sar = sar.loc[MIN_DATE:, :]

                # in some cases s2df may have duplicated rows (with same values) that have to be removed
                if sar.index.duplicated().sum() > 0:
                    print("Duplicated labels, deleting")
                    sar = sar[~sar.index.duplicated(keep='last')]

                # extrapolate data by 12 days (S1_TEMPORAL_RESOLUTION)
                extrapolated_date = sar.index[-1] + pd.DateOffset(S1_TEMPORAL_RESOLUTION)

                from scipy.interpolate import interp1d

                in_unix_time = lambda x: (x - pd.Timestamp("1970-01-01"))//pd.Timedelta('1s')

                extrapolated_value = interp1d(in_unix_time(sar.index[-7:]), sar['sarea'][-7:], kind='linear', fill_value="extrapolate")(in_unix_time(extrapolated_date))

                sar.loc[extrapolated_date, "sarea"] = extrapolated_value

                sar = sar.rename({'sarea': 'area'}, axis=1)
            # If SAR has less than 3 points
            else:
                sar = None
                print("Sentinel-1 SAR has less than 3 data points.")
        # If SAR file does not exist
        else:
            sar = None
            print("Sentinel-1 SAR file does not exist.")
        
        if Path(swot_dfpath).exists():
            swot_df = pd.read_csv(swot_dfpath)

        # combine opticals into one dataframes
        optical = pd.concat(TO_MERGE).sort_index()
        optical = optical.loc[~optical.index.duplicated(keep='last')] # when both s2 and l8 are present, keep s2
        optical.rename({'water_area_corrected': 'area'}, axis=1, inplace=True)

        return swot_df, optical

        # # Apply the trend based corrections
        # if(sar is not None):
        #     # If Optical begins before SAR and has a difference of more than 15 days
        #     if(sar.index[0]-optical.index[0]>pd.Timedelta(days=15)):
        #         # Optical without SAR
        #         optical_with_no_sar = optical[optical.index[0]:sar.index[0]].copy()
        #         optical_with_no_sar['non-smoothened optical area'] = optical_with_no_sar['area']
        #         optical_with_no_sar.loc[:, 'days_passed'] = optical.index.to_series().diff().dt.days.fillna(0)
        #         # Calculate smoothed values with moving weighted average method if more than 7 values; weights are calculated using cloud percent.
        #         if len(optical_with_no_sar)>7:
        #             optical_with_no_sar['filled_area'] = weighted_moving_average(optical_with_no_sar['non-smoothened optical area'], weights = (101-optical_with_no_sar['cloud_percent']),window_size=3)
        #         # Drop 'area' column from optical_with_no_sar
        #         optical_with_no_sar = optical_with_no_sar.drop('area',axis=1)
        #         # Optical with SAR
        #         optical_with_sar = trend_based_correction(optical.copy(), sar.copy(), self.AREA_DEVIATION_THRESHOLD)
        #         # Merge both
        #         result = pd.concat([optical_with_no_sar,optical_with_sar],axis=0)
        #         # Smoothen the combined surface area estimates to avoid noise or peaks using savgol_filter if more than 9 values (to increase smoothness and include more points as we have both TMS-OS and Optical)
        #         if len(result)>9:    
        #             result['filled_area'] = savgol_filter(result['filled_area'], window_length=7, polyorder=3)
        #         method = 'Combine'
        #     # If SAR begins before Optical
        #     else:
        #         result = trend_based_correction(optical.copy(), sar.copy(), self.AREA_DEVIATION_THRESHOLD)
        #         method = 'TMS-OS'
        # else:
        #     result = optical.copy()
        #     result['non-smoothened optical area'] = result['area']
        #     result.loc[:, 'days_passed'] = optical.index.to_series().diff().dt.days.fillna(0)
        #     # Calculate smoothed values with Savitzky-Golay method if more than 7 values
        #     if len(result)>7:
        #         result['filled_area'] = weighted_moving_average(result['non-smoothened optical area'], weights = (101-result['cloud_percent']),window_size=3)
        #         result['filled_area'] = savgol_filter(result['filled_area'], window_length=7, polyorder=3)
        #     method = 'Optical'
        # # Returning method used for surface area estimation
        # return result,method

def area_change(df, date, n=14):
    """calculate the change in area in last n days"""
    start = date - pd.Timedelta(days=n)
    end = date

    # if start date is before the first date in the df, return nan
    if start < df.index[0]:
        return np.nan

    start_area = df.loc[start:end, "sarea"].iloc[0]
    end_area = df.loc[date, "sarea"]
    try:  # if end_area is a series which may happen in a SAR area dataframe (TODO: fix the cause of this issue, same area is returned twice), take the first value
        end_area = end_area.iloc[0]
    except AttributeError as AE:
        pass
    except Exception as E:
        raise E

    return end_area - start_area

def sar_data_statistical_fix(sar_df, nominal_area, threshold_percentage=15):
    """fix the sar data using statistical method"""
    threshold = (threshold_percentage / 100) * nominal_area

    sar_df_copy = sar_df.copy()
    sar_df_copy['date'] = sar_df_copy.index.to_series()

    sar_df_copy['area_change'] = sar_df_copy['date'].apply(lambda x: area_change(sar_df_copy, x))

    sar_df_copy.loc[(sar_df_copy['area_change'] < -threshold)|(sar_df_copy['area_change'] > threshold), 'sarea'] = np.nan
    sar_df_copy['sarea'] = sar_df_copy['sarea'].interpolate(method='time')

    return sar_df_copy.drop('area_change', axis=1).drop('date', axis=1)

def deviation_from_sar(optical_areas, sar_areas, DEVIATION_THRESHOLD = 20, LOW_STD_LIM=2, HIGH_STD_LIM=2):
    """Filter out points based on deviations from SAR reported areas after correcting for bias in SAR water areas. Remove NaNs beforehand.
    Args:
        optical_areas (pd.Series): Time-series of areas obtained using an optical sensor (S2, L8, etc) on which the filtering will be applied. Must have `pd.DatetimeIndex` and corresponding areas in a column named `area`.
        sar_areas (pd.Series): Time-series of S1 surface areas. Must have `pd.DatetimeIndex` and corresponding areas in a column named `area`.
        DEVIATION_THRESHOLD (number): (Default: 20 [sq. km.]) Theshold of deviation from bias corrected SAR reported 
        LOW_STD_LIM (number): (Default: 2) Lower limit of standard deviations to use for clipping the deviations, required for calculating the bias.
        HIGH_STD_LIM (number): (Default: 2) Upper limit of standard deviations to use for clipping the deviations, required for calculating the bias.
    """
    # convert to dataframes under the hood
    optical_areas = optical_areas.to_frame()
    sar_areas = sar_areas.to_frame()

    xs = sar_areas.index.view(np.int64)//10**9  # Convert datetime to seconds from epoch
    ys = sar_areas['area']
    sar_area_func = interp1d(xs, ys, bounds_error=False)
    
    # Interpolate and calculate the sar reported areas according to the optical sensor's observation dates
    sar_sarea_interpolated = sar_area_func(optical_areas.index.view(np.int64)//10**9)
    deviations = optical_areas['area'] - sar_sarea_interpolated

    clipped = sigmaclip(deviations.dropna(), low=LOW_STD_LIM, high=HIGH_STD_LIM)
    bias = np.median(clipped.clipped)

    optical_areas['normalized_dev'] = deviations - bias
    optical_areas['flagged'] = False
    optical_areas.loc[np.abs(optical_areas['normalized_dev']) > DEVIATION_THRESHOLD, 'flagged'] = True
    optical_areas.loc[optical_areas['flagged'], 'area'] = np.nan

    return optical_areas['area']

# helper functions
def sar_trend(d1, d2, sar):
    subset = sar['area'].resample('1D').interpolate('linear')
    subset = subset.loc[d1:d2]
    if len(subset) == 0:
        trend = np.nan
    else:
        trend = (subset.iloc[-1]-subset.iloc[0])/((np.datetime64(d2)-np.datetime64(d1))/np.timedelta64(1, 'D'))
    return trend

def backcalculate(areas, trends, who_needs_correcting):
    # identify the first reliable point
    unreliable_pts_at_the_beginning = len(who_needs_correcting[:who_needs_correcting.idxmin()])-1
    corrected_areas = [np.nan] * unreliable_pts_at_the_beginning
    corrected_areas.append(areas.iloc[unreliable_pts_at_the_beginning+1])

    # # calculate previous points
    # for area, correction_required, trend in zip(areas[unreliable_pts_at_the_beginning::-1], who_needs_correcting[unreliable_pts_at_the_beginning::-1], trends[unreliable_pts_at_the_beginning::-1]):
    #     print(area, correction_required, trend)
    
    for area, correction_required, trend in zip(areas[unreliable_pts_at_the_beginning+1:], who_needs_correcting[unreliable_pts_at_the_beginning+1:], trends[unreliable_pts_at_the_beginning+1:]):
        if not correction_required:
            corrected_areas.append(area)
        else:
            corrected_area = corrected_areas[-1] + trend
            corrected_areas.append(corrected_area)
    
    return corrected_areas

def deviation_correction(area, DEVIATION_THRESHOLD, AREA_COL_NAME='area'):
    inner_area = area.copy()

    inner_area.loc[:, 'deviation'] = np.abs(inner_area['trend']-inner_area['sar_trend'])

    inner_area.loc[:, 'erroneous'] = inner_area['deviation'] > DEVIATION_THRESHOLD

    inner_area.loc[:, 'corrected_trend'] = inner_area['trend']
    inner_area.loc[inner_area['erroneous'], 'corrected_trend'] = inner_area['sar_trend']
    if(not inner_area['erroneous'].empty):
        areas = backcalculate(inner_area[AREA_COL_NAME], inner_area['corrected_trend'], inner_area['erroneous'])
        inner_area[AREA_COL_NAME] = areas

    return inner_area

def sign_based_correction(area, AREA_COL_NAME='corrected_areas_1', TREND_COL_NAME='corrected_trend_1'):
    inner_area = area.copy()
    inner_area['sign_based_correction_reqd'] = (inner_area['trend']<0)&(inner_area['sar_trend']>0)|(inner_area['trend']>0)&(inner_area['sar_trend']<0)
    inner_area.loc[:, 'corrected_trend'] = inner_area[TREND_COL_NAME]
    inner_area.loc[inner_area['sign_based_correction_reqd'], 'corrected_trend'] = inner_area['sar_trend']

    inner_area['area'] = backcalculate(inner_area[AREA_COL_NAME], inner_area['corrected_trend'], inner_area['sign_based_correction_reqd'])

    return inner_area

def filled_by_trend(filtered_area, sar_trend, days_passed) -> pd.Series:
    """Fills in `np.nan` values of optically obtained surface area time series using SAR based time-series.
    Args:
        filtered_area (pd.Series): Optical sensor based surface areas containing `np.nan` values that will be filled in.
        sar_trend (pd.Series): SAR based surface area trends.
        days_passed (pd.Series): Days sicne last observation of optical sensor observed surface areas.
    Returns:
        pd.Series: Filled nan values
    """
    filled = [filtered_area.iloc[0]]
    for i in range(1, len(filtered_area)):
        if np.isnan(filtered_area.iloc[i]):
            a = filled[-1] + sar_trend.iloc[i] * days_passed.iloc[i]
            filled.append(a)
        else:
            filled.append(filtered_area.iloc[i])
    
    return pd.Series(filled, dtype=float, name='filled_area', index=filtered_area.index)

# Trend based correction function
def trend_based_correction(area, sar, AREA_DEVIATION_THRESHOLD=25, TREND_DEVIATION_THRESHOLD = 10):
    """Apply trend based correction on a time-series
    Args:
        area (pd.DataFrame): Pandas dataframe containing date as pd.DatetimeIndex and areas in column named `area`
        sar (pd.DataFrame): Pandas dataframe containing surface area time-series obtained from Sentinel-1 (SAR). Same format as `area`
        AREA_DEVIATION_THRESHOLD (number): (Default: 25) Threshold value of deviation of optically derived areas from SAR derived areas fro filtering.
        TREND_DEVIATION_THRESHOLD (number): (Default: 10) Threshold value of deviation in trend above which the observation is marked as erroneous and the correction step is applied
    """

    area['filtered_area'] = deviation_from_sar(area['area'], sar['area'], AREA_DEVIATION_THRESHOLD)
    area.rename({'area': 'unfiltered_area'}, axis=1, inplace=True)
    # area.rename({'filtered_area': 'area'}, axis=1, inplace=True)
    
    area_filtered = area.dropna(subset=['filtered_area'])

    area_filtered.loc[:, 'days_passed'] = area_filtered.index.to_series().diff().dt.days
    area_filtered.loc[:, 'trend'] = area_filtered['filtered_area'].diff()/area_filtered['days_passed']
    
    sar, area_filtered = clip_ts(sar, area_filtered, which='left')
    # sometimes the sar time-series has duplicate values which have to be removed
    sar = sar[~sar.index.duplicated(keep='first')]   # https://stackoverflow.com/a/34297689/4091712
    trend_generator = lambda arg: sar_trend(arg.index[0], arg.index[-1], sar)

    area_filtered.loc[:, 'sar_trend'] = area_filtered['filtered_area'].rolling(2, min_periods=0).apply(trend_generator)

    deviation_correction_results = deviation_correction(area_filtered, TREND_DEVIATION_THRESHOLD, AREA_COL_NAME='filtered_area')
    area_filtered['corrected_areas_1'] = deviation_correction_results['filtered_area']
    area_filtered['corrected_trend_1'] = deviation_correction_results['corrected_trend']
    
    area['corrected_areas_1'] = area_filtered['corrected_areas_1']
    area['corrected_trend_1'] = area_filtered['corrected_trend_1']

    area.loc[:, 'sar_trend'] = area['unfiltered_area'].rolling(2, min_periods=0).apply(trend_generator)
    area.loc[:, 'days_passed'] = area.index.to_series().diff().dt.days
    
    area, sar = clip_ts(area, sar, which="left")
    first_non_nan = area['corrected_areas_1'].first_valid_index()
    area = area.loc[first_non_nan:, :]

    # fill na based on trends
    area['filled_area'] = filled_by_trend(area['corrected_areas_1'], area['sar_trend'], area['days_passed'])

    return area


tmsos = TMS(
    RESERVOIR, 40, 5
)


l8_dir = Path('/tiger1/pdas47/tmsosPP/data/tmsos/l8')
l9_dir = Path('/tiger1/pdas47/tmsosPP/data/tmsos/l9')
s2_dir = Path('/tiger1/pdas47/tmsosPP/data/tmsos/s2')
sar_dir = Path('/tiger1/pdas47/tmsosPP/data/tmsos/sar')
swot_dir = Path('/tiger1/pdas47/tmsosPP/data/storage/swot_karin_poseidon/v0.1')

l8_fp = l8_dir / f"{RESERVOIR}.csv"
l9_fp = l9_dir / f"{RESERVOIR}.csv"
s2_fp = s2_dir / f"{RESERVOIR}.csv"
sar_fp = sar_dir / f"{RESERVOIR}_12d_sar.csv"
swot_fp = swot_dir / f"{RESERVOIR}_{RESERVOIR_NAME.split(',')[0].replace(' ', '_')}_storage.csv"

swot_df, optical_data = tmsos.tms_os(
    l8_fp, s2_fp, l9_fp, sar_fp, swot_fp
)

Duplicated labels, deleting


In [7]:
aev_dir = Path("/tiger1/pdas47/tmsosPP/data/aec/aev")
aev = pd.read_csv(aev_dir / f"{RESERVOIR}.csv", comment='#')
aev

Unnamed: 0,CumArea,Elevation,Storage,Storage (mil. m3),Elevation_Observed
0,0.000,89.739359,0.000000e+00,0.000000,
1,0.250,89.739592,2.913741e+01,0.000029,
2,0.500,89.740292,1.748244e+02,0.000175,
3,0.750,89.741457,5.536104e+02,0.000554,
4,1.000,89.743089,1.282045e+03,0.001282,
...,...,...,...,...,...
337,78.000,112.430142,5.899633e+08,589.963260,
338,78.018,112.440616,5.903718e+08,590.371788,117.0
339,78.204,112.548987,5.946043e+08,594.604300,118.0
340,78.250,112.575829,5.956542e+08,595.654161,


In [8]:
optical_data['elevation'] = np.interp(
    optical_data['area'], aev['CumArea'], aev['Elevation']
)
optical_data

Unnamed: 0,water_area_uncorrected,non_water_area,cloud_area,area,cloud_percent,QUALITY_DESCRIPTION,sat,elevation
2019-01-02,29.447958,40.264334,0.165392,29.447958,0.236687,0.0,s2,92.973626
2019-01-07,27.014625,42.697668,0.165392,27.014625,0.236687,0.0,s2,92.461184
2019-01-12,24.664847,45.047445,0.165392,24.664847,0.236687,0.0,s2,92.008321
2019-01-17,26.511816,43.200477,0.165392,26.511816,0.236687,0.0,s2,92.360805
2019-01-22,24.068360,45.643932,0.165392,24.068360,0.236687,0.0,s2,91.899901
...,...,...,...,...,...,...,...,...
2024-08-08,0.000000,0.000000,62.624415,40.056309,100.000000,0.0,s2,95.723545
2024-08-13,8.000908,14.995229,46.881547,38.829317,67.090871,0.0,s2,95.362561
2024-08-18,0.000000,0.000000,62.624415,38.829317,100.000000,0.0,s2,95.362561
2024-08-23,0.000000,0.000000,62.624415,38.829317,100.000000,0.0,s2,95.362561


In [9]:
def deviation_from_swot(optical_elevations, swot_elevations, DEVIATION_THRESHOLD = 1, LOW_STD_LIM=2, HIGH_STD_LIM=2):
    """Filter out points based on deviations from SWOT reported elevations after correcting for bias in SWOT elevations. Remove NaNs beforehand.
    Args:
        optical_elevations (pd.Series): Time-series of elevations obtained using an optical sensor (S2, L8, etc) on which the filtering will be applied. Must have `pd.DatetimeIndex` and corresponding elevations in a column named `elevation`.
        swot_elevations (pd.Series): Time-series of SWOT elevations. Must have `pd.DatetimeIndex` and corresponding elevations in a column named `elevation`.
        DEVIATION_THRESHOLD (number): (Default: 20 [m]) Threshold of deviation from bias corrected SWOT reported elevations.
        LOW_STD_LIM (number): (Default: 2) Lower limit of standard deviations to use for clipping the deviations, required for calculating the bias.
        HIGH_STD_LIM (number): (Default: 2) Upper limit of standard deviations to use for clipping the deviations, required for calculating the bias.
    """

    # convert to dataframes under the hood
    optical_elevations = optical_elevations.to_frame()
    swot_elevations = swot_elevations.to_frame()
    
    optical_elevations, swot_elevations = clip_ts(optical_elevations, swot_elevations)
    temp_original = optical_elevations.copy()

    xs = swot_elevations.index.view(np.int64)//10**9  # Convert datetime to seconds from epoch
    ys = swot_elevations['elevation']
    swot_elevation_func = interp1d(xs, ys, bounds_error=False)

    # Interpolate and calculate the swot reported elevations according to the optical sensor's observation dates
    swot_elevation_interpolated = swot_elevation_func(optical_elevations.index.view(np.int64)//10**9)
    deviations = optical_elevations['elevation'] - swot_elevation_interpolated

    clipped = sigmaclip(deviations.dropna(), low=LOW_STD_LIM, high=HIGH_STD_LIM)
    bias = np.nanmedian(clipped.clipped)

    optical_elevations['normalized_dev'] = deviations - bias
    optical_elevations['flagged'] = False
    optical_elevations.loc[np.abs(optical_elevations['normalized_dev']) > DEVIATION_THRESHOLD, 'flagged'] = True
    optical_elevations.loc[optical_elevations['flagged'], 'elevation'] = np.nan

    return optical_elevations, temp_original, bias

optical_elevations = optical_data['elevation']
swot_elevations = swot_df[['date', 'elevation']].set_index('date')['elevation']
swot_elevations.index = pd.to_datetime(swot_elevations.index)

deviations, original, bias = deviation_from_swot(
    optical_elevations, swot_elevations
)

deviations_hv = deviations.hvplot(y='elevation', kind='scatter', label='Retained Elevations')
original_hv = original.hvplot(y='elevation', kind='scatter', label='Unfiltered non-swot').opts(size=2, color='k')
swot_hv = swot_elevations.hvplot(y='elevation', kind='scatter', label='SWOT Elevations')

first_filtering_plot = (deviations_hv * swot_hv * original_hv).opts(title=f"{RESERVOIR}: {RESERVOIR_NAME}. Bias = {bias:.3f} m\nDeviation threshold: 1 m. Retained {len(deviations['elevation'].dropna())*100/len(original['elevation']):.1f}%")
first_filtering_plot

In [None]:
def trend_based_correction_elevation(elevation, swot, ELEVATION_DEVIATION_THRESHOLD=25, TREND_DEVIATION_THRESHOLD = 10):
    """Apply trend based correction on a time-series
    Args:
        elevation (pd.DataFrame): Pandas dataframe containing date as pd.DatetimeIndex and elevations in column named `elevation`
        sar (pd.DataFrame): Pandas dataframe containing surface elevation time-series obtained from Sentinel-1 (SAR). Same format as `elevation`
        ELEVATION_DEVIATION_THRESHOLD (number): (Default: 25) Threshold value of deviation of optically derived elevations from SAR derived elevations for filtering.
        TREND_DEVIATION_THRESHOLD (number): (Default: 10) Threshold value of deviation in trend above which the observation is marked as erroneous and the correction step is applied
    """
    elevation['filtered_elevation'] = deviation_from_swot(elevation['elevation'], swot['elevation'], ELEVATION_DEVIATION_THRESHOLD)
    elevation.rename({'elevation': 'unfiltered_elevation'}, axis=1, inplace=True)
    
    elevation_filtered = elevation.dropna(subset=['filtered_elevation'])

    elevation_filtered.loc[:, 'days_passed'] = elevation_filtered.index.to_series().diff().dt.days
    elevation_filtered.loc[:, 'trend'] = elevation_filtered['filtered_elevation'].diff()/elevation_filtered['days_passed']
    
    swot, elevation_filtered = clip_ts(swot, elevation_filtered, which='left')
    swot = swot[~swot.index.duplicated(keep='first')]
    trend_generator = lambda arg: sar_trend(arg.index[0], arg.index[-1], swot)

    elevation_filtered.loc[:, 'sar_trend'] = elevation_filtered['filtered_elevation'].rolling(2, min_periods=0).apply(trend_generator)

    deviation_correction_results = deviation_correction(elevation_filtered, TREND_DEVIATION_THRESHOLD, AREA_COL_NAME='filtered_elevation')
    elevation_filtered['corrected_elevations_1'] = deviation_correction_results['filtered_elevation']
    elevation_filtered['corrected_trend_1'] = deviation_correction_results['corrected_trend']
    
    elevation['corrected_elevations_1'] = elevation_filtered['corrected_elevations_1']
    elevation['corrected_trend_1'] = elevation_filtered['corrected_trend_1']

    elevation.loc[:, 'sar_trend'] = elevation['unfiltered_elevation'].rolling(2, min_periods=0).apply(trend_generator)
    elevation.loc[:, 'days_passed'] = elevation.index.to_series().diff().dt.days
    
    elevation, swot = clip_ts(elevation, swot, which="left")
    first_non_nan = elevation['corrected_elevations_1'].first_valid_index()
    elevation = elevation.loc[first_non_nan:, :]

    # fill na based on trends
    elevation['filled_elevation'] = filled_by_trend(elevation['corrected_elevations_1'], elevation['sar_trend'], elevation['days_passed'])

    return elevation

In [26]:
start_date = '2023-07-21'
end_date = '2024-10-14'

optical_data = optical_data.loc[start_date:end_date]
optical_data.hvplot.scatter(
    y='elevation'
) + swot_df.hvplot.scatter(
    x='date', y='elevation'
)

KeyError: 'elevation'