In [None]:
import random
import uuid

import geopandas as gpd
import pandas as pd
import sdv
from shapely.geometry import Point

In [None]:
ms = gpd.read_file("/Users/sarvapulla/Downloads/tl_rd22_28_cousub.zip")

In [None]:
us = gpd.read_file("/Users/sarvapulla/Downloads/tl_2019_us_county.zip")

In [None]:
us = us.dissolve(by="STATEFP").reset_index()

In [None]:
us[us.STATEFP == "22"].to_file("la.shp", driver="ESRI Shapefile")

In [None]:
us[us.STATEFP == "01"].to_file("al.shp", driver="ESRI Shapefile")

In [None]:
us[us.STATEFP == "28"].to_file("ms.shp", driver="ESRI Shapefile")

In [None]:
al = gpd.read_file("al.zip")

In [None]:
def wells_output(roi, min_size=10):
    wells = pd.read_csv(f"../synthetic_wells/{roi}_1900-2023_wells.csv")
    ts = pd.read_csv(f"../synthetic_wells/{roi}_1900-2023_TS.csv")
    ts_filtered = (
        ts.groupby("Well_ID").filter(lambda x: x.shape[0] >= min_size).merge(wells)
    )
    return ts_filtered

In [None]:
region_ts = pd.concat([wells_output(roi) for roi in ["AL", "MS", "LA"]])

In [None]:
region_ts.to_csv("region_ts.csv", index=False)

In [None]:
df = region_ts.copy()

In [None]:
df["Date"] = pd.to_datetime(df["Date"], format="mixed")
df = df.sort_values(["Well_ID", "Date"])
df["date_diff"] = df.groupby("Well_ID")["Date"].diff()

In [None]:
df["is_gap"] = df["date_diff"] > pd.Timedelta(
    "31 days"
)  # Identify gaps larger than 31 days
df["chunk"] = df.groupby("Well_ID")["is_gap"].cumsum()  # Create continuous chunks

In [None]:
largest_chunk = df.groupby(["Well_ID", "chunk"]).size().reset_index(name="size")
largest_chunk = largest_chunk.loc[largest_chunk.groupby("Well_ID")["size"].idxmax()]

# Merge with the original DataFrame to filter the largest continuous chunks
df_filtered = pd.merge(df, largest_chunk[["Well_ID", "chunk"]], on=["Well_ID", "chunk"])

In [None]:
wells_with_enough_data = df_filtered.groupby("Well_ID").size() >= 10
valid_well_ids = wells_with_enough_data[wells_with_enough_data].index

df_final = df_filtered[df_filtered["Well_ID"].isin(valid_well_ids)].reset_index(
    drop=True
)

In [None]:
len(df_final)

In [None]:
# Assuming 'df' has 'well_id', 'time', and 'value' columns
for well_id in df_final["Well_ID"].unique():
    df_final.loc[df_final["Well_ID"] == well_id, "value_smoothed"] = (
        df_final.loc[df_final["Well_ID"] == well_id, "GW_measurement"]
        .ewm(alpha=0.9)
        .mean()
    )

In [None]:
# Create a list to hold the resampled data for each well
resampled_data = []

# Get the unique well IDs
well_ids = df_final["Well_ID"].unique()

# Resample and interpolate for each well
for well_id in well_ids:
    well_data = df_final[df_final["Well_ID"] == well_id][
        ["Well_ID", "value_smoothed"]
    ].copy()  # copy to prevent SettingWithCopyWarning
    well_data.index = df_final[df_final["Well_ID"] == well_id][
        "Date"
    ]  # Set the date as the index
    well_data_resampled = well_data.resample("MS").mean()
    resampled_data.append(well_data_resampled)

# Concatenate the resampled data for all wells
df_resampled = pd.concat(resampled_data).reset_index()

In [None]:
region_wells = (
    region_ts[["Well_ID", "lat_dec", "long_dec"]]
    .drop_duplicates()
    .reset_index(drop=True)
)

region_wells = gpd.GeoDataFrame(
    region_wells,
    geometry=gpd.points_from_xy(region_wells.long_dec, region_wells.lat_dec),
)

In [None]:
new_model_input = (
    df_resampled.reset_index(drop=True).merge(region_wells).drop(columns="geometry")
)

In [None]:
new_model_input.to_csv("cleaned_model_input.csv", index=False)

In [None]:
model_input = pd.read_csv("final_cleaned_data.csv")

In [None]:
model_input["Date"] = pd.to_datetime(model_input.Date)

In [None]:
model_input.Well_ID.nunique()

In [None]:
model_input.Well_ID.value_counts().describe()

In [None]:
model_input["Well_UUID"] = "ID_" + model_input["Well_ID"].astype(int).astype(str)

In [None]:
# model_input.loc[:, "Well_UUID"] = 1
# model_input.loc[:, "Well_UUID"] = model_input.groupby("Well_ID").Well_UUID.transform(lambda g: uuid.uuid4())

In [None]:
wells_uuid_mapping = dict(
    model_input[["Well_UUID", "Well_ID"]].drop_duplicates().values
)

In [None]:
list(model_input)

In [None]:
final_input = model_input[["Well_UUID", "Date", "GW_measurement_smoothed"]]

In [None]:
final_grace_input = grace_input[["Well_UUID", "Date", "GW_measurement_smoothed"]]

In [None]:
model_input.drop(columns=["Well_ID", "lat_dec", "long_dec"], inplace=True)

In [None]:
model_input.columns

In [None]:
model_input.Well_UUID.value_counts()

In [None]:
final_input[0:10].to_csv("meta_input.csv", index=False)

In [None]:
final_input.to_csv("final_input.csv", index=False)

In [None]:
def random_points_in_polygon(number, polygon):
    points = []
    min_x, min_y, max_x, max_y = polygon.bounds
    while len(points) < number:
        random_point = Point(
            [random.uniform(min_x, max_x), random.uniform(min_y, max_y)]
        )
        if random_point.within(polygon):
            points.append(random_point)
    return points

In [None]:
sample.Well_UUID.nunique()

In [None]:
all_points = []
for _, row in delta_gdf.iterrows():
    all_points.extend(
        random_points_in_polygon(sample.Well_UUID.nunique(), row["geometry"])
    )
while len(all_points) > sample.Well_UUID.nunique():
    all_points.pop()

In [None]:
point_data = pd.DataFrame(
    {
        "Well_UUID": sample["Well_UUID"].unique().tolist(),
        "geometry": all_points,
    }
)
geosynth_data = pd.merge(sample, point_data, on="Well_UUID", how="inner")
geosynth_data = gpd.GeoDataFrame(geosynth_data, geometry="geometry")
geosynth_data = geosynth_data.sort_values(by=["Well_UUID", "Date"])

In [None]:
geosynth_data["latitude"] = geosynth_data.geometry.y
geosynth_data["longitude"] = geosynth_data.geometry.x

In [None]:
geosynth_data.to_csv("ms_delta_synthetic_wells.csv", index=False)

In [None]:
from sdv.metadata import SingleTableMetadata

metadata = SingleTableMetadata()

metadata.detect_from_dataframe(final_input)

metadata.update_column(column_name="Well_UUID", sdtype="id")

metadata.set_sequence_key("Well_UUID")

metadata.set_sequence_index("Date")

In [None]:
%%time
from sdv.sequential import PARSynthesizer

# Step 1: Create the synthesizer
synthesizer = PARSynthesizer(
    metadata,
    epochs=100,
    enforce_min_max_values=True,
    enforce_rounding=False,
    verbose=True,
)

# Step 2: Train the synthesizer
synthesizer.fit(final_input)

In [None]:
sample = synthesizer.sample(num_sequences=100, sequence_length=35)

In [None]:
synthesizer.save(filepath="grace.pkl")

In [None]:
from sdv.sequential import PARSynthesizer

# synthesizer = PARSynthesizer(metadata)

In [None]:
from sdv.sequential import PARSynthesizer

synthesizer = PARSynthesizer.load(filepath="grace.pkl")

In [None]:
from sdv.evaluation.single_table import evaluate_quality

In [None]:
from sdmetrics.reports.single_table import DiagnosticReport

report = DiagnosticReport()

In [None]:
report.generate(real_data=final_input, synthetic_data=sample, metadata=metadata)

In [None]:
quality_report = evaluate_quality(
    real_data=final_input, synthetic_data=sample, metadata=metadata
)

In [None]:
round(quality_report.get_score() * 100)

In [None]:
# input_metadata = {
#     "METADATA_SPEC_VERSION": "SINGLE_TABLE_V1",
#     "sequence_key": "Well_ID",
#     "sequence_index": "Date",
#     "columns": {
#         "Well_ID": {"sdtype": "id"},
#         "Date": {"sdtype": "datetime", "datetime_format": "%m-%d-%Y"},
#         "value_smoothed": {"sdtype": "float"},
#         "lat_dec": {"sdtype": "float"},
#         "long_dec": {"sdtype": "float"},
#     },
# }