## Load data

In [None]:
import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split
from datetime import timedelta
import matplotlib.pyplot as plt
from IPython.display import display, Markdown
import sys
import importlib
import os
import io
import time
import datetime
from sqlalchemy import create_engine, text, DateTime
import mysql.connector

sys.path.append("..")
sys.path.append('../src')
sys.path.append('../src/utils')
from src.utils import data_loading_dwd, data_loading_wasserportal, helpers

importlib.reload(data_loading_wasserportal)
importlib.reload(data_loading_dwd)
# importlib.reload(radolan_handler)

from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()
# Inside container: DB_HOST is set via docker-compose
# On host: fallback to LOCAL_DB_HOST
DB_HOST = os.getenv("DB_HOST", os.getenv("LOCAL_DB_HOST", "localhost"))
DB_USER = os.getenv("DB_USER", "root")
DB_PASSWORD = os.getenv("DB_PASSWORD", "mysecretpassword")
DB_NAME = os.getenv("DB_NAME", "mydatabase")

# Build SQLAlchemy connection string
DATABASE_URL = f"mysql+mysqlconnector://{DB_USER}:{DB_PASSWORD}@{DB_HOST}/{DB_NAME}"

# Create engine
engine = create_engine(DATABASE_URL, pool_pre_ping=True)

def skip_cell():
    """Skip execution of the current cell when called."""
    display(Markdown("**⏭️ Skipped this cell**"))
    # raise SystemExit
    return # Exits the function, does NOT stop notebook execution
# --------------------------
# CONFIGURATION
# --------------------------

date_start = '2022-01-01'
date_end = '2025-04-30'


GROUNDWATER_PATH = '../data/wasserportal/processed/gw_data_' + date_start + '_' + date_end + '.parquet'
STATIONS_PATH = '../data/wasserportal/stations_groundwater.csv'
PRECIP_ZARR_PATH = '../data/dwd/processed/radolan_berlin_' + date_start + '_' + date_end + '.zarr'

# Change to your target station
STATION_ID = '9931'
# STATION_ID = '100'
N_GW_LAGS = 4
N_PRCP_LAGS = 4
INCLUDE_PRCP_T_PLUS_1 = True
SEASONALITY = True

In [None]:
config_groundwater = {
    'thema': 'gws',
    'exportthema': 'gw',
    'sreihe': 'ew',
    "anzeige": "d",
    "smode": "c"
}

data_loading_wasserportal.get_multi_station_data(helpers.convert_date_format(date_start),
                                                 helpers.convert_date_format(date_end),
                                                 config_groundwater,
                                                 num_stations=20)

In [None]:
# Keep raw files (recommended for production)
results = data_loading_dwd.import_radolan_recent(date_start,
                                                 date_end,
                                                 '../data/dwd/',
                                                 keep_raw=True)
# Extract only noon files (12:50)
results = data_loading_dwd.import_radolan_historical(date_start,
                                                     date_end,
                                                     '../data/dwd/',
                                                     time_to_keep=1250,
                                                     keep_raw=True)

In [None]:
# Example usage with Berlin bounds
config = {
    'bounds': {
        'min_lat': 52.4,
        'max_lat': 52.65,
        'min_lon': 13.15,
        'max_lon': 13.6
    },
    'date_range': {
        'start_date': date_start,
        'end_date': date_end
    },
    'region_name': 'berlin',
    'data_directory': '../data/dwd/extracted',
    'output_directory': '../data/dwd/processed'
}

data_processing_dwd.create_radolan_timeseries(config)

# Data Loading and Preprocessing

In this section, we wi'll load the required datasets for groundwater level prediction:
1. **Groundwater levels**: Historical measurements from monitoring stations
2. **Precipitation data**: Radar-based precipitation data from DWD (German Weather Service)
3. **Station metadata**: Geographic coordinates and station information

Our goal is to predict groundwater levels using historical groundwater data and precipitation patterns.

In [None]:
# --------------------------
# LOAD DATA
# --------------------------

# Groundwater levels
gw_df = pd.read_parquet(GROUNDWATER_PATH)
gw_series = gw_df[f"value_{STATION_ID}"].dropna()
gw_series.index = pd.to_datetime(gw_series.index)
gw_df.index = pd.to_datetime(gw_df.index)

# Timestamps
dates = gw_series.index
display(dates)

# Load precipitation from zarr
import xarray as xr

precip_ds = xr.open_zarr(PRECIP_ZARR_PATH)
precip_array = precip_ds['precipitation'].values  # shape: (time, 30, 30)

# Load coordinates
lats = precip_ds['lat'].values
lons = precip_ds['lon'].values

# Load station metadata
stations_df = pd.read_csv(STATIONS_PATH)
station = stations_df[stations_df['ID'] == int(STATION_ID)].iloc[0]


## Connect to the server , mySQL

In [None]:
# Test connection
try:
    with engine.connect() as conn:
        result = conn.execute(text("SELECT 1"))
        print("✅ Connected to MySQL, test query result:", result.scalar())
except Exception as e:
    print("❌ Connection failed:", e)

print (engine)


## Write data to database

In [None]:
# Write table : gw_df, Groundwater data
# gw_df.to_sql('groundwater_data', con=engine, if_exists='replace', index=True)
gw_df.to_sql("gw_table", engine, if_exists="replace", index=True, index_label="date",dtype={"date": DateTime()})


In [None]:
# Test
with engine.connect() as conn:
    result = conn.execute(text("SELECT * FROM gw_table LIMIT 5"))
    for row in result:
        print(row)

# Test by reading back, gw_df should be identical to gw_df
df_tmp = pd.read_sql("SELECT * FROM gw_table ", engine)
df_tmp.set_index('date', inplace=True)
display(df_tmp.head())
df_tmp.equals(gw_df) # should be True

In [None]:
# Write table : precip_data, Precipitation data

# Need to transform precip_ds to a DataFrame first

# display(precip_ds)
# Convert the dataset into a pandas DataFrame
tmp_df = precip_ds.to_dataframe().reset_index()
tmp_df.time = tmp_df.time.dt.date  # keep only date part
# tmp_df = tmp_df[['time', 'lat', 'lon', 'precipitation']]
tmp_df = tmp_df[['time', 'x', 'y', 'lat', 'lon', 'precipitation']]
# tmp_df = tmp_df.set_index('time')
# tmp_df = tmp_df.sort_index()
print(tmp_df.shape)
display(tmp_df.head())

# Always create a fresh connection
engine.dispose()
tmp_df.to_sql('precip_df', engine, if_exists='replace', index=False, chunksize=10000, method="multi" )

# write in 10k row batches # sends multiple rows per INSERT

In [None]:
# Test by reading back, tmp_df should be identical to precip_df in the database
df_tmp2 = pd.read_sql("SELECT * FROM precip_df ", engine, parse_dates=['time'])

tmp_ds = df_tmp2.set_index(["time", "x", "y"]).to_xarray()

# Attach original coordinates
tmp_ds2 = tmp_ds.assign_coords(
    lat=(("x","y"), precip_ds.lat.values),
    lon=(("x","y"), precip_ds.lon.values)
)

# this should be similar to precip_ds
print(tmp_ds2.all)

##############################################################################
# Recreate xarray Dataset from DataFrame read from SQL
# Reattach curvilinear lat/lon as coords (from stored values)
lat_grid = df_tmp2.drop_duplicates(subset=["x", "y"]).pivot(index="x", columns="y", values="lat").values
lon_grid = df_tmp2.drop_duplicates(subset=["x", "y"]).pivot(index="x", columns="y", values="lon").values

tmp_ds3 = tmp_ds.assign_coords(
    lat=(("x", "y"), lat_grid),
    lon=(("x", "y"), lon_grid)
)
# this should be similar to precip_ds
print(tmp_ds3.all)

print(xr.testing.assert_allclose(tmp_ds2, tmp_ds3))
tmp_ds3 = tmp_ds3.drop_vars(['x', 'y'])


In [None]:
# Skip this cell
should_skip = True
if should_skip:
    # raise SystemExit
    skip_cell()
else:
    pass
    # Any code below will run if should_skip is False
    print("This will run if should_skip is False")
    # Compare datasets ignoring the coordinates
    print(tmp_ds3.all)
    precip_ds["time"] = ("time", pd.to_datetime(precip_ds["time"].values).normalize()) # remove time component
    print(xr.testing.assert_allclose(precip_ds, tmp_ds3))

    # Direct difference of precipitation values
    diff = (precip_ds["precipitation"] - tmp_ds3["precipitation"])

    print("Max difference:", float(diff.max()))
    print("Min difference:", float(diff.min()))

    # Check if all values are exactly equal
    print("All equal:", bool((diff == 0).all()))


# Clean up, all read-back variables, release memory
del df_tmp, df_tmp2, tmp_ds, tmp_ds2, tmp_ds3, lat_grid, lon_grid, tmp_df



In [None]:
# Write station metadata
print(stations_df.head())

engine.dispose()
stations_df.to_sql('stations_meta', con=engine, if_exists='replace', index=False)

# no need to read back, stations_df is small and easy to verify

## load new data and append to existing tables, mainly for testing


### load DWD data

In [None]:
import load_new_data
importlib.reload(load_new_data)

from load_new_data import *

# Usage examples:
# These two variable need to be given by the user
start_date = '2025-07-01'
end_date = '2025-07-31'

load_new_data_from_dwd(start_date, end_date)

### load Groud water station data

In [None]:


# These two variable need to be given by the user
start_date = '01.06.2025'
end_date = '30.06.2025'
load_new_data_from_wasserportal(start_date, end_date)
