In [None]:
import sys
import geemap
import ee
from pathlib import Path
import pandas as pd
import geopandas as gpd
import numpy as np
import os
from pathlib import Path
import time
from random import randint
import json

In [None]:
proj_dir = Path("../../../..")

In [None]:
utils = str(proj_dir / 'utils')
sys.path.insert(0, utils)
from sql import connect # utility functions for connecting to MySQL

In [None]:
reservoirs_shp = Path(proj_dir, "Data/GIS/raw/gee_basin_params/basin_reservoirs.shp")

# Data/GIS/shapefiles/CRBReservoirs.shp
temperature_gauges_shp = Path(proj_dir, "Data/GIS/shapefiles/temperature_gauges.geojson")

data_dir = Path(proj_dir, "Data/LandsatTemperature")
os.makedirs(data_dir/'reservoirs', exist_ok=True)

In [None]:
# Create a connection object to the MySQL database
# conn = connect.Connect(str(proj_dir / "Methods/2.Data/DBManagement/mysql_config.ini"))
conn = connect.Connect(str(proj_dir / ".env/mysql_config.ini"))
connection = conn.conn

In [None]:
# Map = geemap.Map()
# Map
# ee.Authenticate()
ee.Initialize()

In [None]:
reservoirs = geemap.shp_to_ee(reservoirs_shp)

In [None]:
# reservoir = reservoirs.filter(ee.Filter.eq("DAM_NAME", "Dworshak Dam")) 

In [None]:
# define the timeframe for the temperature data
L9startDate = '2021-10-01'
L8startDate = '2013-03-01'
# L9endDate = '2023-05-31'
# L8endDate = '2023-05-31'
# L9startDate = "2023-05-31"
# L8startDate = "2023-05-31"
L9endDate = "2023-11-30"
L8endDate = "2023-11-30"

startDate = L9startDate
endDate = L9endDate

# ndwi threshold
ndwi_threshold = 0.2

In [None]:
def divideDates(startDate, endDate):
    """
    Divide the timeframe into years

    Parameters:
    -----------
    startDate: str
        start date
    endDate: str
        end date

    Returns:
    --------
    list
        list of tuples of start and end dates
    """

    startYear = pd.to_datetime(startDate).year
    endYear = pd.to_datetime(endDate).year

    dates = []
    for year in range(startYear, endYear+1):
        if year == startYear and year == endYear:
            dates.append((startDate, endDate))
        elif year == startYear:
            dates.append((startDate, f"{year}-12-31"))
        elif year == endYear:
            dates.append((f"{year}-01-01", endDate))
        else:
            dates.append((f"{year}-01-01", f"{year}-12-31"))

    return dates

def prepL8(image):
    """
    Prepare Landsat 8 image for analysis

    Parameters:
    -----------
    image: ee.Image
        Landsat 8 image

    Returns:
    --------
    ee.Image
        prepared Landsat 8 image
    """

    # develop masks for unwanted pixels (fill, cloud, shadow)
    qa_mask = image.select("QA_PIXEL").bitwiseAnd(int("11111", 2)).eq(0)
    saturation_mask = image.select("QA_RADSAT").eq(0)

    # apply scaling factors to the appropriate bands
    def getFactorImage(factorNames):
        factorList = image.toDictionary().select(factorNames).values()
        return ee.Image.constant(factorList)

    scaleImg = getFactorImage(["REFLECTANCE_MULT_BAND_.|TEMPERATURE_MULT_BAND_ST_B10"])
    offsetImg = getFactorImage(["REFLECTANCE_ADD_BAND_.|TEMPERATURE_ADD_BAND_ST_B10"])
    scaled = image.select("SR_B.|ST_B10").multiply(scaleImg).add(offsetImg)

    # replace original bands with scaled bands and apply masks
    return (
        image.addBands(scaled, overwrite=True)
        .updateMask(qa_mask)
        .updateMask(saturation_mask)
    )

def addNDWI(image):
    """
    Add NDWI band to image

    Parameters:
    -----------
    image: ee.Image
        Landsat 8 image

    Returns:
    --------
    ee.Image
        Landsat 8 image with NDWI band
    """

    ndwi = image.expression(
        "NDWI = (green - NIR)/(green + NIR)",
        {"green": image.select("SR_B3"), "NIR": image.select("SR_B5")},
    ).rename("NDWI")

    return image.addBands(ndwi)

def addCelcius(image):
    """
    Add Celcius band to image

    Parameters:
    -----------
    image: ee.Image
        Landsat 8 image

    Returns:
    --------
    ee.Image
        Landsat 8 image with Celcius band
    """
    celcius = image.select("ST_B10").subtract(273.15).rename("Celcius")

    return image.addBands(celcius)

def extractTempSeries(
    reservoir,
    startDate,
    endDate,
    ndwi_threshold=0.2,
    imageCollection="LANDSAT/LC08/C02/T1_L2",
):
    """
    Extract temperature time series for a reservoir

    Parameters:
    -----------
    reservoir: ee.Feature
        reservoir
    startDate: str
        start date
    endDate: str
        end date

    Returns:
    --------
    ee.ImageCollection
        temperature time series
    """

    L8 = (
        ee.ImageCollection(imageCollection)
        .filterDate(startDate, endDate)
        .filterBounds(reservoir)
    )

    def extractTemp(date):
        date = ee.Date(date)
        # prepare Landsat 8 image and add the NDWI band, and Celcius band
        processedL8 = (
            L8.filterDate(date, date.advance(1, "day"))
            .map(prepL8)
            .map(addCelcius)
            .map(addNDWI)
        )

        # get quality NDWI and use it as the water mask
        ndwi = processedL8.qualityMosaic("NDWI").select("NDWI")
        waterMaskNdwi = ndwi.gte(ndwi_threshold)
        # nonWaterMask = ndwi.lt(ndwi_threshold)

        mosaic = processedL8.mosaic()
        waterMask = mosaic.select("QA_PIXEL").bitwiseAnd(int("10000000", 2)).neq(0)
        nonWaterMask = mosaic.select("QA_PIXEL").bitwiseAnd(int("10000000", 2)).eq(0)

        # find the mean of the images in the collection
        meanL8 = (
            processedL8.reduce(ee.Reducer.mean()).addBands(ndwi, ["NDWI"], True).updateMask(waterMask).set("system:time_start", date)
        )

        # get the mean temperature of the reservoir
        temp = meanL8.select(["Celcius_mean"]).reduceRegion(
            reducer=ee.Reducer.mean(), geometry=reservoir.geometry(), scale=30
        )

        return ee.Feature(None, {"date": date.format("YYYY-MM-dd"), "temp(C)": temp})

    dates = ee.List(L8.map(
        lambda image: ee.Feature(None, {"date": image.date().format("YYYY-MM-dd")})
    ).distinct("date").aggregate_array("date").getInfo())

    tempSeries = ee.FeatureCollection(dates.map(extractTemp))

    return tempSeries

In [None]:
ee_dam_names = reservoirs.select("DAM_NAME", retainGeometry=False).getInfo()
ee_uniq_ids = reservoirs.select("uniq_id", retainGeometry=False).getInfo()
dam_names = [i["properties"]["DAM_NAME"] for i in ee_dam_names["features"]]
uniq_ids = [i["properties"]["uniq_id"] for i in ee_uniq_ids["features"]]


In [None]:
uniq_ids = uniq_ids[49:]
dam_names = dam_names[49:]

In [None]:
# extract temperature time series for each reservoir
for dam_name, uniq_id in zip(dam_names, uniq_ids):
    dates = divideDates(startDate, endDate)
    tempSeriesList = []
    if os.path.isfile(data_dir / "reservoirs" / f"{uniq_id}.csv"):
        existing_df = pd.read_csv(data_dir / "reservoirs" / f"{uniq_id}.csv")
        existing_df["date"] = pd.to_datetime(existing_df["date"])
        tempSeriesList.append(existing_df)
        # print("File exists!")

    for date in dates:
        startDate_ = date[0]
        endDate_ = date[1]

        reservoir = reservoirs.filter(ee.Filter.eq("DAM_NAME", ee.String(dam_name)))
        tempSeries = extractTempSeries(reservoir, startDate_, endDate_, ndwi_threshold, imageCollection="LANDSAT/LC09/C02/T1_L2")
        tempSeries = geemap.ee_to_pandas(tempSeries)

        # convert date column to datetime
        tempSeries["date"] = pd.to_datetime(tempSeries["date"])
        tempSeries["temp(C)"] = (
            tempSeries["temp(C)"].apply(lambda x: x["Celcius_mean"]).astype(float)
        )

        # append time series to list
        tempSeriesList.append(tempSeries)

    # concatenate all time series
    tempSeries_df = pd.concat(tempSeriesList, ignore_index=True)

    # sort by date
    tempSeries_df.sort_values(by="date", inplace=True)
    # remove duplicates
    tempSeries_df.drop_duplicates(subset="date", inplace=True)

    # save time series to csv
    tempSeries_df.to_csv(data_dir / "reservoirs" / f"{uniq_id}.csv", index=False)

    cursor = connection.cursor()

    data = tempSeries_df.dropna().copy()
    # convert the date column to datetime YYYY-MM-DD
    data['date'] = pd.to_datetime(data['date'])
    data['date'] = data['date'].dt.date

    for i, row in data.iterrows():
        query = f"""
        INSERT INTO DamLandsatWaterTemp (Date, DamID, Value)
        SELECT '{row['date']}', (SELECT DamID FROM Dams WHERE Name = "{dam_name}"), {row['temp(C)']}
        WHERE NOT EXISTS (SELECT * FROM DamLandsatWaterTemp WHERE Date = '{row['date']}' AND DamID = (SELECT DamID FROM Dams WHERE Name = "{dam_name}"))
        """
        try:
            cursor.execute(query)
            connection.commit()
        except:
            print(query)
            raise Exception("Error!")
        
        # print(i, dam_name, row["date"], row["temp(C)"])

        # cursor.execute(query)
        # connection.commit()

        # print(dam_name, row["date"], row["temp(C)"])

    print(dam_name, "Done!")
    s_time = randint(60, 120)
    # print(f"Sleeping for {s_time} seconds...")
    time.sleep(s_time)
    # print("Restarting from checkpoint...")  # restart from checkpoint
print("Done!")