In [None]:
import sys
import geemap
import ee
from pathlib import Path
import pandas as pd
import geopandas as gpd
import numpy as np
import os
from pathlib import Path
import time
from random import randint
import json

In [None]:
proj_dir = Path("../../../..")

In [None]:
utils = str(proj_dir / 'utils')
sys.path.insert(0, utils)
from sql import connect # utility functions for connecting to MySQL

In [None]:
river_shp = Path(
    proj_dir / "Data/GIS/shapefiles/flowlines_to_reaches/bufferedReaches.shp"
)
temperature_gauges_shp = Path(
    proj_dir / "Data/GIS/shapefiles/temperature_gauges.geojson"
)

data_dir = Path(proj_dir, "Data/LandsatTemperature")
# data_dir = Path("/Users/gdarkwah/eeDownloads")
os.makedirs(data_dir / "reaches", exist_ok=True)

In [None]:
# Create a connection object to the MySQL database
# conn = connect.Connect(str(proj_dir / "Methods/2.Data/DBManagement/mysql_config.ini"))
conn = connect.Connect(str(proj_dir / ".env/mysql_config.ini"))
connection = conn.conn

In [None]:
gdf = gpd.read_file(river_shp)
gdf = gdf.to_crs(epsg=4326)
# save shapefile
# gdf.to_file(data_dir/'rivers'/'rivers.shp')
# gdf[gdf["GNIS_Name"]=="Columbia River"].to_file(data_dir/'rivers'/'rivers.shp')

In [None]:
# Map = geemap.Map()
# Map
# ee.Authenticate()
ee.Initialize()

In [None]:
# reservoirs = geemap.shp_to_ee(data_dir/'rivers'/'rivers.shp')
# Map.addLayer(reservoirs, {}, "Reservoirs")
# Map.centerObject(reservoirs, 5)

In [None]:
# define the timeframe for the temperature data
# L9startDate = '2021-10-01'
# L8startDate = '2013-03-01'
L9startDate = '2023-09-01'
L8startDate = '2023-09-01'
# L9startDate = "2023-07-01"
# L8startDate = "2023-07-31"
L9endDate = '2023-10-31'
L8endDate = '2023-10-31'

# ndwi threshold
ndwi_threshold = 0.2

In [None]:
def divideDates(startDate, endDate):
    """
    Divide the timeframe into years

    Parameters:
    -----------
    startDate: str
        start date
    endDate: str
        end date

    Returns:
    --------
    list
        list of tuples of start and end dates
    """

    startYear = pd.to_datetime(startDate).year
    endYear = pd.to_datetime(endDate).year

    dates = []
    for year in range(startYear, endYear + 1):
        if year == startYear and year == endYear:
            dates.append((startDate, endDate))
        elif year == startYear:
            dates.append((startDate, f"{year}-12-31"))
        elif year == endYear:
            dates.append((f"{year}-01-01", endDate))
        else:
            dates.append((f"{year}-01-01", f"{year}-12-31"))

    return dates


def prepL8(image):
    """
    Prepare Landsat 8 image for analysis

    Parameters:
    -----------
    image: ee.Image
        Landsat 8 image

    Returns:
    --------
    ee.Image
        prepared Landsat 8 image
    """

    # develop masks for unwanted pixels (fill, dilated cloud, cirrus, cloud, cloud shadow, snow)
    qa_mask = image.select("QA_PIXEL").bitwiseAnd(int("111111", 2)).eq(0)
    saturation_mask = image.select("QA_RADSAT").eq(0)

    # apply scaling factors to the appropriate bands
    def getFactorImage(factorNames):
        factorList = image.toDictionary().select(factorNames).values()
        return ee.Image.constant(factorList)

    scaleImg = getFactorImage(["REFLECTANCE_MULT_BAND_.|TEMPERATURE_MULT_BAND_ST_B10"])
    offsetImg = getFactorImage(["REFLECTANCE_ADD_BAND_.|TEMPERATURE_ADD_BAND_ST_B10"])
    scaled = image.select("SR_B.|ST_B10").multiply(scaleImg).add(offsetImg)

    # replace original bands with scaled bands and apply masks
    return (
        image.addBands(scaled, overwrite=True)
        .updateMask(qa_mask)
        .updateMask(saturation_mask)
    )


def addNDWI(image):
    """
    Add NDWI band to image

    Parameters:
    -----------
    image: ee.Image
        Landsat 8 image

    Returns:
    --------
    ee.Image
        Landsat 8 image with NDWI band
    """

    ndwi = image.expression(
        "NDWI = (green - NIR)/(green + NIR)",
        {"green": image.select("SR_B3"), "NIR": image.select("SR_B5")},
    ).rename("NDWI")

    return image.addBands(ndwi)


def addNDVI(image):
    """
    Add NDVI band to image

    Parameters:
    -----------
    image: ee.Image
        Landsat 8 image

    Returns:
    --------
    ee.Image
        Landsat 8 image with NDVI band
    """

    # ndvi = image.expression(
    #     "NDVI = (NIR - red)/(NIR + red)",
    #     {"red": image.select("SR_B4"), "NIR": image.select("SR_B5")},
    # ).rename("NDVI")

    ndvi = image.normalizedDifference(["SR_B5", "SR_B4"]).rename("NDVI")

    return image.addBands(ndvi)


def addCelcius(image):
    """
    Add Celcius band to image

    Parameters:
    -----------
    image: ee.Image
        Landsat 8 image

    Returns:
    --------
    ee.Image
        Landsat 8 image with Celcius band
    """
    celcius = image.select("ST_B10").subtract(273.15).rename("Celcius")

    return image.addBands(celcius)


def extractTempSeries(
    reservoir,
    startDate,
    endDate,
    ndwi_threshold=0.2,
    imageCollection="LANDSAT/LC09/C02/T1_L2",
):
    """
    Extract temperature time series for a reservoir

    Parameters:
    -----------
    reservoir: ee.Feature
        reservoir
    startDate: str
        start date
    endDate: str
        end date

    Returns:
    --------
    ee.ImageCollection
        temperature time series
    """

    L8 = (
        ee.ImageCollection(imageCollection)
        .filterDate(startDate, endDate)
        .filterBounds(reservoir)
    )

    # def extractWaterTemp(date):
    def extractData(date):
        date = ee.Date(date)
        # prepare Landsat 8 image and add the NDWI band, and Celcius band
        processedL8 = (
            L8.filterDate(date, date.advance(1, "day"))
            .map(prepL8)
            .map(addCelcius)
            .map(addNDWI)
            .map(addNDVI)
        )

        # # get quality NDWI and use it as the water mask
        # ndwi = processedL8.qualityMosaic("NDWI").select("NDWI")
        # waterMask = ndwi.gte(ndwi_threshold)
        # nonWaterMask = ndwi.lt(ndwi_threshold)

        mosaic = processedL8.mosaic()
        waterMask = mosaic.select("QA_PIXEL").bitwiseAnd(int("10000000", 2)).neq(0)
        nonWaterMask = mosaic.select("QA_PIXEL").bitwiseAnd(int("10000000", 2)).eq(0)

        # find the mean of the images in the collection
        meanL8water = (
            processedL8.reduce(ee.Reducer.mean())
            # .addBands(ndwi, ["NDWI"], True)
            .updateMask(waterMask)
            .set("system:time_start", date)
        )
        meanL8nonwater = (
            processedL8.reduce(ee.Reducer.mean())
            # .addBands(ndwi, ["NDWI"], True)
            .updateMask(nonWaterMask)
            .set("system:time_start", date)
        )

        # get the mean temperature of the reservoir
        watertemp = meanL8water.select(["Celcius_mean"]).reduceRegion(
            reducer=ee.Reducer.mean(), geometry=reservoir.geometry(), scale=30
        )
        landtemp = meanL8nonwater.select(["Celcius_mean"]).reduceRegion(
            reducer=ee.Reducer.mean(), geometry=reservoir.geometry(), scale=30
        )
        ndvi = meanL8nonwater.select(["NDVI_mean"]).reduceRegion(
            reducer=ee.Reducer.mean(), geometry=reservoir.geometry(), scale=30
        )

        return ee.Feature(
            None,
            {
                "date": date.format("YYYY-MM-dd"),
                "watertemp(C)": watertemp,
                "landtemp(C)": landtemp,
                "NDVI": ndvi,
            },
        )

    def extractLandTemp(date):
        date = ee.Date(date)
        # prepare Landsat 8 image and add the NDWI band, and Celcius band
        processedL8 = (
            L8.filterDate(date, date.advance(1, "day"))
            .map(prepL8)
            .map(addCelcius)
            .map(addNDWI)
        )

        # get quality NDWI and use it as the water mask
        ndwi = processedL8.qualityMosaic("NDWI").select("NDWI")
        nonWaterMask = ndwi.lt(ndwi_threshold)

        # find the mean of the images in the collection
        meanL8 = (
            processedL8.reduce(ee.Reducer.mean())
            .addBands(ndwi, ["NDWI"], True)
            .updateMask(nonWaterMask)
            .set("system:time_start", date)
        )

        # get the mean temperature of the reservoir
        temp = meanL8.select(["Celcius_mean"]).reduceRegion(
            reducer=ee.Reducer.mean(), geometry=reservoir.geometry(), scale=30
        )

        return ee.Feature(None, {"date": date.format("YYYY-MM-dd"), "temp(C)": temp})

    dates = ee.List(
        L8.map(
            lambda image: ee.Feature(None, {"date": image.date().format("YYYY-MM-dd")})
        )
        .distinct("date")
        .aggregate_array("date")
    )

    # waterTempSeries = ee.FeatureCollection(dates.map(extractWaterTemp))
    # landTempSeries = ee.FeatureCollection(dates.map(extractLandTemp))

    dataSeries = ee.FeatureCollection(dates.map(extractData))

    # return waterTempSeries, landTempSeries
    return dataSeries

In [None]:
def ee_to_df(featureCollection):
    """
    Convert an ee.FeatureCollection to a pandas.DataFrame

    Parameters:
    -----------
    featureCollection: ee.FeatureCollection
        feature collection

    Returns:
    --------
    pandas.DataFrame
        dataframe
    """

    columns = featureCollection.first().propertyNames().getInfo()
    rows = (
        featureCollection.reduceColumns(ee.Reducer.toList(len(columns)), columns)
        .values()
        .get(0)
        .getInfo()
    )

    df = pd.DataFrame(rows, columns=columns)
    df.drop(columns=["system:index"], inplace=True)

    return df


def download_ee_csv(downloadUrl):
    """
    Download an ee.FeatureCollection as a csv file

    Parameters:
    -----------
    downloadUrl: str
        download url

    Returns:
    --------
    pandas.DataFrame
        dataframe
    """

    df = pd.read_csv(downloadUrl)
    df.drop(columns=["system:index", ".geo"], inplace=True)

    return df

In [None]:
def entryToDB(
    data, table_name, reach_name, connection, date_col="date", value_col="value"
):
    data = data.copy()
    data[date_col] = pd.to_datetime(data[date_col])
    data = data[[date_col, value_col]]
    data = data.dropna()
    # data = data[data[value_col] != -9999]
    data = data.sort_values(by=date_col)

    cursor = connection.cursor()

    for i, row in data.iterrows():
        query = f"""
        INSERT INTO {table_name} (Date, ReachID, Value)
        SELECT '{row[date_col]}', (SELECT ReachID FROM Reaches WHERE Name = "{reach_name}"), {row[value_col]}
        WHERE NOT EXISTS (SELECT * FROM {table_name} WHERE Date = '{row[date_col]}' AND ReachID = (SELECT ReachID FROM Reaches WHERE Name = "{reach_name}"))
        """

        cursor.execute(query)
        connection.commit()

In [None]:
def reachwiseExtraction(
    reaches,
    reach_id,
    startDate,
    endDate,
    ndwi_threshold=0.2,
    imageCollection="LANDSAT/LC09/C02/T1_L2",
    checkpoint_path=None,
    connection=None,
):
    if checkpoint_path is None:
        checkpoint = {"river_index": 0, "reach_index": 0}
    else:
        with open(checkpoint_path, "r") as f:
            checkpoint = json.load(f)

    # if reach_ids is None:
    #     ee_reach_ids = reaches.select("reach_id", retainGeometry=False).getInfo()
    #     reach_ids = [i["properties"]["reach_id"] for i in ee_reach_ids["features"]][
    #         checkpoint["reach_index"] :
    #     ]
    #     # reach_ids = gdf["reach_id"].tolist()

    # extract temperature time series for each reservoir
    # for reach_id in reach_ids:
    # print(f"Reach {reach_id} started!")
    dates = divideDates(startDate, endDate)
    waterTempSeriesList = []
    landTempSeriesList = []

    dataSeriesList = []

    # if os.path.isfile(data_dir / "reaches" / f"{reach_id}_watertemp.csv"):
    #     existing_df = pd.read_csv(
    #         data_dir / "reaches" / f"{reach_id}_watertemp.csv"
    #     )
    #     existing_df["date"] = pd.to_datetime(existing_df["date"])
    #     waterTempSeriesList.append(existing_df)
    #     # print("File exists!")

    # if os.path.isfile(data_dir / "reaches" / f"{reach_id}_landtemp.csv"):
    #     existing_df = pd.read_csv(data_dir / "reaches" / f"{reach_id}_landtemp.csv")
    #     existing_df["date"] = pd.to_datetime(existing_df["date"])
    #     landTempSeriesList.append(existing_df)
    #     # print("File exists!")

    # if os.path.isfile(data_dir / "reaches" / f"{reach_id}.csv"):
    #     existing_df = pd.read_csv(data_dir / "reaches" / f"{reach_id}.csv")
    #     existing_df["date"] = pd.to_datetime(existing_df["date"])
    #     dataSeriesList.append(existing_df)
    #     # print("File exists!")

    # if os.path.isfile(data_dir / "reaches" / f"{reach_id}.csv"):
    #     existing_df = pd.read_csv(data_dir / "reaches" / f"{reach_id}.csv")
    #     existing_df["date"] = pd.to_datetime(existing_df["date"])
    #     dataSeriesList.append(existing_df)

    for date in dates:
        startDate_ = date[0]
        endDate_ = date[1]

        reservoir = reaches.filter(ee.Filter.eq("reach_id", ee.String(reach_id)))
        # waterTempSeries, landTempSeries= extractTempSeries(
        #     reservoir, startDate_, endDate_, ndwi_threshold, imageCollection
        # )
        # waterTempSeries = geemap.ee_to_pandas(waterTempSeries)
        # landTempSeries = geemap.ee_to_pandas(landTempSeries)
        dataSeries = extractTempSeries(
            reservoir, startDate_, endDate_, ndwi_threshold, imageCollection
        )
        dataSeries = geemap.ee_to_pandas(dataSeries)

        # convert date column to datetime
        # waterTempSeries["date"] = pd.to_datetime(waterTempSeries["date"])
        # landTempSeries["date"] = pd.to_datetime(landTempSeries["date"])
        dataSeries["date"] = pd.to_datetime(dataSeries["date"])

        # waterTempSeries["temp(C)"] = (
        #     waterTempSeries["temp(C)"]
        #     .apply(lambda x: x["Celcius_mean"])
        #     .astype(float)
        # )
        # landTempSeries["temp(C)"] = (
        #     landTempSeries["temp(C)"]
        #     .apply(lambda x: x["Celcius_mean"])
        #     .astype(float)
        # )

        dataSeries["watertemp(C)"] = (
            dataSeries["watertemp(C)"]
            .apply(lambda x: x["Celcius_mean"])
            .astype(float)
        )
        dataSeries["landtemp(C)"] = (
            dataSeries["landtemp(C)"]
            .apply(lambda x: x["Celcius_mean"])
            .astype(float)
        )
        dataSeries["NDVI"] = (
            dataSeries["NDVI"].apply(lambda x: x["NDVI_mean"]).astype(float)
        )

        # append time series to list
        # waterTempSeriesList.append(waterTempSeries)
        # landTempSeriesList.append(landTempSeries)
        dataSeriesList.append(dataSeries)

        s_time = randint(5, 10)
        time.sleep(s_time)

    # concatenate all time series
    # waterTempSeries_df = pd.concat(waterTempSeriesList, ignore_index=True)
    # landTempSeries_df = pd.concat(landTempSeriesList, ignore_index=True)
    dataSeries_df = pd.concat(dataSeriesList, ignore_index=True)

    # sort by date
    # waterTempSeries_df.sort_values(by="date", inplace=True)
    # landTempSeries_df.sort_values(by="date", inplace=True)
    dataSeries_df.sort_values(by="date", inplace=True)
    # #drop null values
    # # waterTempSeries_df.dropna(inplace=True)
    # # landTempSeries_df.dropna(inplace=True)
    # dataSeries_df.dropna(inplace=True)
    # remove duplicates
    # waterTempSeries_df.drop_duplicates(subset="date", inplace=True)
    # landTempSeries_df.drop_duplicates(subset="date", inplace=True)
    dataSeries_df.drop_duplicates(subset="date", inplace=True)

    # save time series to csv
    # waterTempSeries_df.to_csv(
    #     data_dir / "reaches" / f"{reach_id}_watertemp.csv", index=False
    # )
    # landTempSeries_df.to_csv(
    #     data_dir / "reaches" / f"{reach_id}_landtemp.csv", index=False
    # )
    # dataSeries_df.to_csv(
    #     data_dir / "reaches" / f"{reach_id}.csv", index=False
    # )

    # land temp
    entryToDB(
        dataSeries_df,
        "ReachLandsatLandTemp",
        reach_id,
        connection,
        date_col="date",
        value_col="landtemp(C)",
    )
    # water temp
    entryToDB(
        dataSeries_df,
        "ReachLandsatWaterTemp",
        reach_id,
        connection,
        date_col="date",
        value_col="watertemp(C)",
    )
    # NDVI
    entryToDB(
        dataSeries_df,
        "ReachNDVI",
        reach_id,
        connection,
        date_col="date",
        value_col="NDVI",
    )

    # checkpoint["reach_index"] += 1
    # json.dump(checkpoint, open(checkpoint_path, "w"))
    # print(f"Reach {reach_id} done!")
    # s_time = randint(30, 60)
    # time.sleep(s_time)

    # print("All done!")

    # TODO: Delete this section onwards
    # # extract temperature time series for each reservoir
    # for reach_id in reach_ids:
    #     # print(f"Reach {reach_id} started!")
    #     dates = divideDates(startDate, endDate)
    #     waterTempSeriesList = []
    #     landTempSeriesList = []

    #     dataSeriesList = []

    #     # if os.path.isfile(data_dir / "reaches" / f"{reach_id}_watertemp.csv"):
    #     #     existing_df = pd.read_csv(
    #     #         data_dir / "reaches" / f"{reach_id}_watertemp.csv"
    #     #     )
    #     #     existing_df["date"] = pd.to_datetime(existing_df["date"])
    #     #     waterTempSeriesList.append(existing_df)
    #     #     # print("File exists!")

    #     # if os.path.isfile(data_dir / "reaches" / f"{reach_id}_landtemp.csv"):
    #     #     existing_df = pd.read_csv(data_dir / "reaches" / f"{reach_id}_landtemp.csv")
    #     #     existing_df["date"] = pd.to_datetime(existing_df["date"])
    #     #     landTempSeriesList.append(existing_df)
    #     #     # print("File exists!")

    #     # if os.path.isfile(data_dir / "reaches" / f"{reach_id}.csv"):
    #     #     existing_df = pd.read_csv(data_dir / "reaches" / f"{reach_id}.csv")
    #     #     existing_df["date"] = pd.to_datetime(existing_df["date"])
    #     #     dataSeriesList.append(existing_df)
    #     #     # print("File exists!")

    #     for date in dates:
    #         startDate_ = date[0]
    #         endDate_ = date[1]

    #         reservoir = reaches.filter(ee.Filter.eq("reach_id", ee.String(reach_id)))
    #         # waterTempSeries, landTempSeries= extractTempSeries(
    #         #     reservoir, startDate_, endDate_, ndwi_threshold, imageCollection
    #         # )
    #         # waterTempSeries = geemap.ee_to_pandas(waterTempSeries)
    #         # landTempSeries = geemap.ee_to_pandas(landTempSeries)
    #         dataSeries = extractTempSeries(
    #             reservoir, startDate_, endDate_, ndwi_threshold, imageCollection
    #         )
    #         dataSeries = geemap.ee_to_pandas(dataSeries)

    #         # convert date column to datetime
    #         # waterTempSeries["date"] = pd.to_datetime(waterTempSeries["date"])
    #         # landTempSeries["date"] = pd.to_datetime(landTempSeries["date"])
    #         dataSeries["date"] = pd.to_datetime(dataSeries["date"])

    #         # waterTempSeries["temp(C)"] = (
    #         #     waterTempSeries["temp(C)"]
    #         #     .apply(lambda x: x["Celcius_mean"])
    #         #     .astype(float)
    #         # )
    #         # landTempSeries["temp(C)"] = (
    #         #     landTempSeries["temp(C)"]
    #         #     .apply(lambda x: x["Celcius_mean"])
    #         #     .astype(float)
    #         # )

    #         dataSeries["watertemp(C)"] = (
    #             dataSeries["watertemp(C)"]
    #             .apply(lambda x: x["Celcius_mean"])
    #             .astype(float)
    #         )
    #         dataSeries["landtemp(C)"] = (
    #             dataSeries["landtemp(C)"]
    #             .apply(lambda x: x["Celcius_mean"])
    #             .astype(float)
    #         )
    #         dataSeries["NDVI"] = (
    #             dataSeries["NDVI"].apply(lambda x: x["NDVI_mean"]).astype(float)
    #         )

    #         # append time series to list
    #         # waterTempSeriesList.append(waterTempSeries)
    #         # landTempSeriesList.append(landTempSeries)
    #         dataSeriesList.append(dataSeries)

    #         s_time = randint(5, 10)
    #         time.sleep(s_time)

    #     # concatenate all time series
    #     # waterTempSeries_df = pd.concat(waterTempSeriesList, ignore_index=True)
    #     # landTempSeries_df = pd.concat(landTempSeriesList, ignore_index=True)
    #     dataSeries_df = pd.concat(dataSeriesList, ignore_index=True)

    #     # sort by date
    #     # waterTempSeries_df.sort_values(by="date", inplace=True)
    #     # landTempSeries_df.sort_values(by="date", inplace=True)
    #     dataSeries_df.sort_values(by="date", inplace=True)
    #     # #drop null values
    #     # # waterTempSeries_df.dropna(inplace=True)
    #     # # landTempSeries_df.dropna(inplace=True)
    #     # dataSeries_df.dropna(inplace=True)
    #     # remove duplicates
    #     # waterTempSeries_df.drop_duplicates(subset="date", inplace=True)
    #     # landTempSeries_df.drop_duplicates(subset="date", inplace=True)
    #     dataSeries_df.drop_duplicates(subset="date", inplace=True)

    #     # save time series to csv
    #     # waterTempSeries_df.to_csv(
    #     #     data_dir / "reaches" / f"{reach_id}_watertemp.csv", index=False
    #     # )
    #     # landTempSeries_df.to_csv(
    #     #     data_dir / "reaches" / f"{reach_id}_landtemp.csv", index=False
    #     # )
    #     # dataSeries_df.to_csv(
    #     #     data_dir / "reaches" / f"{reach_id}.csv", index=False
    #     # )

    #     # land temp
    #     entryToDB(
    #         dataSeries_df,
    #         "ReachLandsatLandTemp",
    #         reach_id,
    #         connection,
    #         date_col="date",
    #         value_col="landtemp(C)",
    #     )
    #     # water temp
    #     entryToDB(
    #         dataSeries_df,
    #         "ReachLandsatWaterTemp",
    #         reach_id,
    #         connection,
    #         date_col="date",
    #         value_col="watertemp(C)",
    #     )
    #     # NDVI
    #     entryToDB(
    #         dataSeries_df,
    #         "ReachNDVI",
    #         reach_id,
    #         connection,
    #         date_col="date",
    #         value_col="NDVI",
    #     )

    #     checkpoint["reach_index"] += 1
    #     json.dump(checkpoint, open(checkpoint_path, "w"))
    #     print(f"Reach {reach_id} done!")
    #     # s_time = randint(30, 60)
    #     # time.sleep(s_time)

    # # print("All done!")

In [None]:
gdf = gpd.read_file(river_shp)
gdf = gdf.to_crs(epsg=4326)

rivers = gdf["GNIS_Name"].unique()

In [None]:
def runExtraction(data_dir, checkpoint_path=None, connection=None):
    if checkpoint_path is None:
        checkpoint = {"river_index": 0, "reach_index": 0}
    else:
        with open(checkpoint_path, "r") as f:
            checkpoint = json.load(f)

    # gdf = gpd.read_file(river_shp)
    # gdf = gdf.to_crs(epsg=4326)

    # unique_rivers = gdf["GNIS_Name"].unique()
    # unique_rivers = gdf["GNIS_Name"].unique()[checkpoint["river_index"]:]
    unique_rivers = rivers[checkpoint["river_index"] :]
    # unique_rivers = redo_rivers[checkpoint["river_index"]:]

    for river in unique_rivers:
        gdf[gdf["GNIS_Name"] == river].to_file(data_dir / "reaches" / "rivers.shp")
        reach_ids = gdf[gdf["GNIS_Name"] == river]["reach_id"].tolist()
        reach_ids = reach_ids[checkpoint["reach_index"] :]

        reaches = geemap.shp_to_ee(data_dir / "reaches" / "rivers.shp")

        if reach_ids is None:
            ee_reach_ids = reaches.select("reach_id", retainGeometry=False).getInfo()
            reach_ids = [i["properties"]["reach_id"] for i in ee_reach_ids["features"]][
                checkpoint["reach_index"] :
            ]
            # reach_ids = gdf["reach_id"].tolist()

        for reach_id in reach_ids:
            # Landsat8 Data
            reachwiseExtraction(
                reaches,
                reach_id,
                L8startDate,
                L8endDate,
                ndwi_threshold,
                imageCollection="LANDSAT/LC08/C02/T1_L2",
                checkpoint_path=checkpoint_path,
                connection=connection,
            )

            # Landsat9 Data
            reachwiseExtraction(
                reaches,
                reach_id,
                L9startDate,
                L9endDate,
                ndwi_threshold,
                imageCollection="LANDSAT/LC09/C02/T1_L2",
                checkpoint_path=checkpoint_path,
                connection=connection,
            )

            checkpoint["reach_index"] += 1
            json.dump(checkpoint, open(checkpoint_path, "w"))
            print(f"Reach {reach_id} done!")

        checkpoint["reach_index"] = 0
        checkpoint["river_index"] += 1
        json.dump(checkpoint, open(checkpoint_path, "w"))

        # s_time = randint(30,120)
        # time.sleep(s_time)

        print(f"{river} done!")

In [None]:
# gdf["GNIS_Name"].unique()

In [None]:
try:
    with open(data_dir / "reaches" / "checkpoint.json", "r") as f:
        checkpoint = json.load(f)
except Exception as e:
    print(f"Error: {e}")
    print("Creating new checkpoint...")
    checkpoint = {"river_index": 0, "reach_index": 0}
    # save checkpoint
    json.dump(checkpoint, open(data_dir / "reaches" / "checkpoint.json", "w"))

repeated_tries = 0

# while checkpoint["river_index"] < len(gdf["GNIS_Name"].unique()):
while checkpoint["river_index"] < len(rivers):
    # while checkpoint["river_index"] < len(redo_rivers):
    try:
        runExtraction(data_dir, data_dir / "reaches" / "checkpoint.json", connection)
        repeated_tries = 0  # reset repeated_tries

    except Exception as e:
        print(f"Error: {e}")
        # sleep for 0.5 - 3 minutes
        s_time = randint(30, 120)
        print(f"Sleeping for {s_time} seconds...")
        time.sleep(s_time)
        print("Restarting from checkpoint...")  # restart from checkpoint

        repeated_tries += 1  # increment repeated_tries

        # if repeated_tries > 3, increment river_index and reset reach_index
        if repeated_tries > 3:
            checkpoint["reach_index"] += 1
            current_river = gdf["GNIS_Name"].unique()[checkpoint["river_index"]]
            if checkpoint["reach_index"] >= len(
                gdf[gdf["GNIS_Name"] == current_river]["reach_id"].tolist()
            ):
                checkpoint["reach_index"] = 0
                checkpoint["river_index"] += 1
            repeated_tries = 0

            # save checkpoint
            json.dump(checkpoint, open(data_dir / "reaches" / "checkpoint.json", "w"))
    finally:
        # save checkpoint
        with open(data_dir / "reaches" / "checkpoint.json", "r") as f:
            checkpoint = json.load(f)

# reset checkpoint if all rivers are done
# if checkpoint["river_index"] >= len(gdf["GNIS_Name"].unique()):
if checkpoint["river_index"] >= len(rivers):
    checkpoint["river_index"] = 0
    checkpoint["reach_index"] = 0
    json.dump(checkpoint, open(data_dir / "reaches" / "checkpoint.json", "w"))

print("All done!")