## Fragen

1. Analyse Allgemein
    - Lässt sich ein Zusammenhang zwischen dem Wasserpegel von den Beobachtungsdaten und den Modelldaten erkennen?
    - Wie ist die Korrelation zwischen Wasserpegel Model und Wasserpegel Beobachtung

2. Analyse Sturmfluten
    - Wie verhält sich Wind, Windrichtung bei den unterschiedlichen Sturmfluten
    - Wie vehält sich Wassergeschwindigkeit, Richtung bei den unterschiedlichen Sturmfluten
    - Lassen sich Korrelationen zwischen den Features und dem Wasserpegel (sla) erkennen? 

## Import Libaries

In [None]:
# import all necessary libraries
import os
import warnings
from pathlib import Path

import cartopy.feature as cfeature
import geodatasets
import geopandas as gpd
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shapely.geometry
import xarray as xr
from joblib import Parallel, delayed
from mpl_toolkits.basemap import Basemap
from scipy.interpolate import griddata
from statsmodels.graphics.tsaplots import plot_acf
from tqdm import tqdm
from utils.eda_helper_functions import plot_histogram

from utils.eda_helper_functions import (
    check_missing_times,
    group_data_hourly,
    load_insitu_data,
    load_ocean_data,
    load_weather_data,
    plot_water_level_anomalies,
    process_df,
    process_flensburg_data,
    show_df,
)
from utils.config import (
    LAT_FLENSBURG,
    LON_FLENSBURG,
    SUB_BOX,
    OCEAN_DICT,
    WEATHER_DICT,
    INSITU_DICT,
    OCEAN_POINTS,
    WEATHER_POINTS,
    )

# Ignore SettingWithCopyWarning:
warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning)

# Display all columns
pd.options.display.max_columns = None

plt.rcParams.update({
    "font.size": 18,                # Grundschriftgröße (wirkt auf alles, sofern nicht überschrieben)
    "axes.titlesize": 18,           # Größe des Titels der Achse (z.B. 'Subplot Title')
    "axes.labelsize": 18,           # Achsenbeschriftung (x/y label)
    "xtick.labelsize": 18,          # X-Tick-Beschriftung
    "ytick.labelsize": 18,          # Y-Tick-Beschriftung
    "legend.fontsize": 18,          # Legendentext
    "figure.titlesize": 18,         # Gesamttitel der Abbildung (plt.suptitle)
    "figure.labelsize": 18,         # (optional, selten verwendet)
    "savefig.dpi": 300,             # DPI beim Speichern
    "figure.dpi": 72,              # DPI bei Anzeige
})


ocean_data_path = Path(f"../data/numerical_data/points{OCEAN_POINTS}")
print(ocean_data_path)
weather_data_path = Path(f"../data/numerical_data/points{WEATHER_POINTS}")
print(weather_data_path)

In [None]:

def plot_feature_distribution(df:pd.DataFrame, features:list, bins:int=50, save_png:bool=False, one_plot:bool=True, prefix:str="") -> None:
    """
    Plots the distribution of features in a DataFrame.
    
    Args:
        df (pd.DataFrame): DataFrame containing the features.
        features (list): List of feature names to plot.
        bins (int): Number of bins for the histogram.
        save_png (bool): Whether to save the plot as a PNG file.
    """
    if one_plot:
        n_cols = 2
        n_rows = (len(features) + n_cols - 1) // n_cols

        fig = plt.figure(figsize=(8 * n_cols, 5 * n_rows))
        gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)

        fig.suptitle("Feature Distribution", y=0.98)

        for idx, feature in tqdm(enumerate(features), total=len(features), desc="Plotting features"):
            #print(f"Plotting distribution for {feature}")
            ax = fig.add_subplot(gs[idx])
            plot_histogram(df, column=feature, bins=bins, ax=ax, show_stats=True)

        # Statt tight_layout --> subplots_adjust
        fig.subplots_adjust(top=0.92, hspace=0.5, wspace=0.3)  # <-- manuell fein justieren!
        if save_png:
            plt.savefig(f"../figures/distribution_{feature}.png", dpi=300, bbox_inches='tight')
    else:
        from matplotlib import rc_context

        with rc_context({
            "font.size": 20,                # Grundschriftgröße (wirkt auf alles, sofern nicht überschrieben)
            "axes.titlesize": 20,           # Größe des Titels der Achse (z.B. 'Subplot Title')
            "axes.labelsize": 20,           # Achsenbeschriftung (x/y label)
            "xtick.labelsize": 20,          # X-Tick-Beschriftung
            "ytick.labelsize": 20,          # Y-Tick-Beschriftung
            "legend.fontsize": 20,          # Legendentext
            "figure.titlesize": 20,         # Gesamttitel der Abbildung (plt.suptitle)
            "figure.labelsize": 20,         # (optional, selten verwendet)
            "savefig.dpi": 300,             # DPI beim Speichern
            "figure.dpi": 72,              # DPI bei Anzeige
        }):

            for feature in tqdm(features, desc="Plotting features"):
                fig, ax = plt.subplots(figsize=(10, 6))

                plot_histogram(df, column=feature, bins=bins, ax=ax, show_stats=True, title=" ")
                #ax.set_title(f"Distribution of {feature}")
                if save_png:
                    plt.savefig(f"../thesis_plots/{prefix}distribution_{feature}.png", dpi=300, bbox_inches='tight')
                

    plt.show()
    plt.clf()

# Load The Data & IDA

## Ocean Data

| Feature | Unit | Description | Explanation |
|:---|:---|:---|:---|
| bottomT | °C | Sea water potential temperature at sea floor | Temperature of seawater at the ocean floor, accounting for pressure effects. |
| mlotst | m | Ocean mixed layer thickness defined by sigma theta | Depth of the ocean's surface layer where temperature and salinity are relatively uniform. |
| siconc | - | Sea ice area fraction | Fractional coverage of sea ice in a given area (0 = no ice, 1 = full coverage). |
| sithick | m | Sea ice thickness | Thickness of sea ice from surface to bottom. |
| sla | m | Sea surface height above sea level | Deviation of the ocean surface from the mean sea level, can indicate currents or tides. |
| so | $1 / 10^3$ | Sea water salinity | Salinity of seawater (measured dimensionless, typically expressed in parts per thousand or PSU). |
| sob | $1 / 10^3$| Sea water salinity at sea floor | Salinity of seawater at the ocean floor, normalized (0.001 units). |
| thetao | °C | Sea water potential temperature | Potential temperature of seawater, referenced to sea surface pressure. |
| uo | m/s | Eastward sea water velocity | Velocity component of seawater flow towards the east. |
| vo | m/s | Northward sea water velocity | Velocity component of seawater flow towards the north. |
| wo | m/s | Upward sea water velocity | Vertical velocity of seawater, positive upward. |


In [None]:
df_ocean = load_ocean_data(ocean_data_path, OCEAN_POINTS, verbose=True)
df_ocean.describe()

In [None]:
plot_feature_distribution(df_ocean, df_ocean.columns, bins=50, save_png=True, one_plot=False, prefix="ocean_")

In [None]:
df_ocean = process_df(df_ocean, drop_cols=["depth"], verbose=True)

## Weather Data

| Feature | Unit | Description | Explanation |
|:---|:---|:---|:---|
| temperature_2m | °C | Temperature (2 m) | Air temperature at 2 meters above ground. |
| relative_humidity_2m | % | Relative Humidity (2 m) | Percentage of humidity at 2 meters height. |
| dew_point_2m | °C | Dewpoint (2 m) | Temperature at which air moisture condenses (dew point) at 2 meters height. |
| apparent_temperature | °C | Apparent Temperature | Perceived temperature considering wind and humidity. |
| precipitation_probability | % | Precipitation Probability | Probability of precipitation. |
| precipitation | mm | Precipitation (rain + showers + snow) | Total precipitation amount (rain, showers, snow). |
| rain | mm | Rain | Precipitation amount due to rain. |
| showers | mm | Showers | Precipitation amount due to showers. |
| snowfall | cm | Snowfall | Precipitation amount due to snow. |
| snow_depth | cm | Snow Depth | Total snow depth on the ground. |
| weather_code | - | Weather code | Classification of weather conditions by a code (e.g., sunny, cloudy). |
| pressure_msl | hPa | Sealevel Pressure | Atmospheric pressure reduced to sea level. |
| surface_pressure | hPa | Surface Pressure | Actual atmospheric pressure at the surface. |
| cloud_cover | % | Cloud cover Total | Total cloud coverage. |
| cloud_cover_low | % | Cloud cover Low | Cloud coverage by low-level clouds. |
| cloud_cover_mid | % | Cloud cover Mid | Cloud coverage by mid-level clouds. |
| cloud_cover_high | % | Cloud cover High | Cloud coverage by high-level clouds. |
| visibility | m | Visibility | Visibility distance. |
| evapotranspiration | mm | Evapotranspiration | Water loss through evaporation and plant transpiration. |
| et0_fao_evapotranspiration | mm | Reference Evapotranspiration (ET₀) | Standardized reference evapotranspiration according to FAO. |
| vapour_pressure_deficit | hPa | Vapour Pressure Deficit | Difference between saturation and actual vapor pressure. |
| wind_speed_10m | km/h | Wind Speed (10 m) | Wind speed at 10 meters above ground. |
| wind_speed_80m | km/h | Wind Speed (80 m) | Wind speed at 80 meters above ground. |
| wind_speed_120m | km/h | Wind Speed (120 m) | Wind speed at 120 meters above ground. |
| wind_speed_180m | km/h | Wind Speed (180 m) | Wind speed at 180 meters above ground. |
| wind_direction_10m | ° | Wind Direction (10 m) | Wind direction in degrees at 10 meters height (0° = North). |
| wind_direction_80m | ° | Wind Direction (80 m) | Wind direction in degrees at 80 meters height. |
| wind_direction_120m | ° | Wind Direction (120 m) | Wind direction in degrees at 120 meters height. |
| wind_direction_180m | ° | Wind Direction (180 m) | Wind direction in degrees at 180 meters height. |
| wind_gusts_10m | km/h | Wind Gusts (10 m) | Maximum gust wind speed at 10 meters height. |
| temperature_80m | °C | Temperature (80 m) | Air temperature at 80 meters above ground. |
| temperature_120m | °C | Temperature (120 m) | Air temperature at 120 meters above ground. |
| temperature_180m | °C | Temperature (180 m) | Air temperature at 180 meters above ground. |


In [None]:
df_weather = load_weather_data(weather_data_path, WEATHER_POINTS, verbose=True)
plot_feature_distribution(df_weather, df_weather.columns, bins=50, save_png=True, one_plot=False, prefix="weather_")

In [None]:
df_weather = process_df(df_weather, verbose=True)

## In Situ Data

| Feature   | Unit    | Description                     | Explanation                                                                 |
|:----------|:--------|:---------------------------------|:---------------------------------------------------------------------------|
| time      | -       | Timestamp                        | Date and time of the observation (UTC).                                    |
| depth     | m       | Measurement depth                | Depth below sea surface where the measurement was taken.                   |
| time_qc   | -       | Time quality control flag        | Quality control indicator for the timestamp (e.g., 1 = good).              |
| deph      | m       | Nominal depth                    | Nominal (intended) depth of the measurement, could differ from actual depth.|
| latitude  | degrees | Latitude                         | Geographic coordinate specifying north-south position.                     |
| longitude | degrees | Longitude                        | Geographic coordinate specifying east-west position.                       |
| slev      | m       | Sea level                        | Measured sea surface height relative to a reference Datum |
| slev_qc   | -       | Sea level quality control flag   | Quality control indicator for sea level measurement (e.g., 1 = good).       |


In [None]:
from utils.eda_helper_functions import load_insitu_data

df_insitu = load_insitu_data(verbose=True)
plot_feature_distribution(df_insitu, df_insitu.columns, bins=50, save_png=True, one_plot=False, prefix="insitu_")

In [None]:
from utils.eda_helper_functions import process_flensburg_data

df_insitu = process_flensburg_data(df_insitu, 
                                      start_time=df_ocean['time'].min(),
                                      end_time=df_ocean['time'].max(),
                                      verbose=True, order=3)

In [None]:
from utils.eda_helper_functions import group_data_hourly

df_insitu = group_data_hourly(df_insitu)
df_insitu = process_df(df_insitu, drop_cols=["deph"], verbose=True)


In [None]:
print(df_insitu.info())
print(df_insitu.describe())

In [None]:
dd

# EDA

### Plot Flensburg Observation Waterlevel Data

In [None]:
from utils.eda_helper_functions import plot_water_level_anomalies

fig, ax = plot_water_level_anomalies(df_insitu)
plt.show()

import datetime

sturm_surge_list = [datetime.datetime(2023, 2, 25, 17, 0),
                    datetime.datetime(2023, 4, 1, 12, 0),
                    datetime.datetime(2023, 10, 7, 20, 0),
                    datetime.datetime(2023, 10, 20, 0, 0),
                    datetime.datetime(2024, 1, 3, 9, 0),
                    datetime.datetime(2024, 2, 9, 18, 0),
                    datetime.datetime(2024, 12, 9, 16, 0),
                    ]

for time in sturm_surge_list:
    start_time = time - datetime.timedelta(days=3)
    end_time = time + datetime.timedelta(days=3)
    df_insitu_sturm = df_insitu[(df_insitu["time"] >= start_time) & (df_insitu["time"] <= end_time)]
    plot_water_level_anomalies(df_insitu_sturm, start_date=start_time, end_date=end_time)
    

## Display Ocean and Weather Data

In [None]:
from utils.eda_helper_functions import plot_coordinates

plot_coordinates(df_ocean, df_weather, df_insitu, save_png=False)

## Distributions of the Features

## Cluster df_ocean into K = 3

In [None]:
df_ocean = load_ocean_data(ocean_data_path, OCEAN_POINTS, verbose=False)
df_ocean = process_df(df_ocean, drop_cols=["depth"], verbose=False)

df_weather = load_weather_data(weather_data_path, WEATHER_POINTS, verbose=False)
df_weather = process_df(df_weather, verbose=False)

df_insitu = load_insitu_data(verbose=False)
df_insitu = process_flensburg_data(df_insitu, 
                                      start_time=df_ocean['time'].min(),
                                      end_time=df_ocean['time'].max(),
                                      verbose=False)

df_insitu = group_data_hourly(df_insitu)
df_insitu = process_df(df_insitu, drop_cols=["deph"], verbose=False)

In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap

def cluster_df(df:pd.DataFrame, columns:list, n_clusters:int=3, display:bool=False, start_date=None, end_date=None):
    """
    Clusters the DataFrame using KMeans clustering.
    
    Args:
        df (pd.DataFrame): DataFrame to cluster.
        columns (list): List of columns to use for clustering.
        n_clusters (int): Number of clusters.
        start_date (str): Optional start date for filtering data.
        end_date (str): Optional end date for filtering data.
        
    Returns:
        pd.DataFrame: DataFrame with cluster labels.
    """
    # === 1. Features auswählen (außer Zeit, da KMeans keine Zeit versteht)

    df = df.copy() 

    columns = df.columns.tolist()
    if start_date is not None:
        df = df[df['time'] >= start_date]
    if end_date is not None:
        df = df[df['time'] <= end_date]
    if "time" in columns:
        columns.remove("time")

    # groupby latitude and longitude
    df = df.groupby(['latitude', 'longitude']).mean().reset_index()





    X = df[columns]

    # Standardisieren
    scaler = MinMaxScaler()
    X_scaled = scaler.fit_transform(X)

    # KMeans-Clustering (k=n_clusters)
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = kmeans.fit_predict(X_scaled)

    # Cluster-Labels zurück ins DataFrame
    df_clustered = df.loc[X.index].copy()
    df_clustered['cluster'] = clusters

    if display:
        # Mittelpunkt berechnen für Basemap
        mean_lat = df_clustered['latitude'].mean()
        mean_lon = df_clustered['longitude'].mean()

        # Karte mit Basemap zeichnen
        plt.figure(figsize=(10, 8))
        m = Basemap(
            projection='lcc',
            resolution='i',
            lat_0=mean_lat,
            lon_0=mean_lon,
            width=1.2e6,
            height=1.2e6,
        )

        m.drawcoastlines()
        m.drawcountries()
        m.drawmapboundary(fill_color='lightblue')
        m.fillcontinents(color='beige', lake_color='lightblue')

        # Farben definieren
        colors = ['red', 'green', 'orange', 'purple', 'pink']

        print(df_clustered['cluster'].value_counts())
        # Punkte plotten
        for cluster_id in range(n_clusters):
            print(f"Cluster {cluster_id}: {len(df_clustered[df_clustered['cluster'] == cluster_id])} points")
            cluster_data = df_clustered[df_clustered['cluster'] == cluster_id]
            x, y = m(cluster_data['longitude'].values, cluster_data['latitude'].values)
            m.scatter(x, y, s=10, c=colors[cluster_id % len(colors)], label=f'Cluster {cluster_id}', alpha=0.6)

        plt.legend(loc='upper left')
        plt.title(f'KMeans Clustering (k={n_clusters}) der Ozeandaten')
        plt.show()


    return df_clustered


df_clustered = cluster_df(df_ocean, ['latitude', 'longitude', 'sla'], n_clusters=3, display=True)




## Comparison of SLEV and SLA



In [None]:
df_ocean = load_ocean_data(ocean_data_path, OCEAN_POINTS, verbose=False)
df_ocean = process_df(df_ocean, drop_cols=["depth"], verbose=False)

df_weather = load_weather_data(weather_data_path, WEATHER_POINTS, verbose=False)
df_weather = process_df(df_weather, verbose=False)

df_insitu = load_insitu_data(verbose=False)
df_insitu = process_flensburg_data(df_insitu, 
                                      start_time=df_ocean['time'].min(),
                                      end_time=df_ocean['time'].max(),
                                      verbose=False)

df_insitu = group_data_hourly(df_insitu)
df_insitu = process_df(df_insitu, drop_cols=["deph"], verbose=False)

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import pandas as pd


# find the closest location in df_ocean to the target location
def find_closest_location(df: pd.DataFrame, target_lat: float, target_lon: float) -> pd.Series:
    """
    Find the closest location in the DataFrame to the target latitude and longitude.

    Parameters:
        df (pd.DataFrame): DataFrame containing the data with 'latitude' and 'longitude'.
        target_lat (float): Target latitude.
        target_lon (float): Target longitude.

    Returns:
        pd.Series: The row of the closest location.
    """
    df["distance"] = np.sqrt((df["latitude"] - target_lat) ** 2 + (df["longitude"] - target_lon) ** 2)
    return df.loc[df["distance"].idxmin()]


In [None]:
import matplotlib.pyplot as plt
import pandas as pd


def plot_closest_location(df_ocean: pd.DataFrame, sub_box:dict, target_lat: float, target_lon: float, save:bool=False) -> None:
    """
    Plot the closest location in ocean data to the target latitude and longitude,
    with a basemap background using Basemap.
    """


    
    # long: 9 - 12
# lat: 54.5 - 55

    # closest_location = find_closest_location(df_ocean, target_lat, target_lon)
    df_subset = df_ocean[
        (df_ocean["latitude"] >= sub_box["lat_min"]) &
        (df_ocean["latitude"] <= sub_box["lat_max"]) &
        (df_ocean["longitude"] >= sub_box["lon_min"]) &
        (df_ocean["longitude"] <= sub_box["lon_max"])
    ]
    

    # Mittelwerte für Kartenzentrum
    mean_lat = df_ocean["latitude"].mean()
    mean_lon = df_ocean["longitude"].mean()

    plt.figure(figsize=(12, 10))

    # Erstelle Basemap
    m = Basemap(
        projection='lcc',
        resolution='i',
        lat_0=mean_lat,
        lon_0=mean_lon,
        width=1.2e6,
        height=1.2e6,
    )

    # Kartenelemente zeichnen
    m.drawcoastlines()
    m.drawcountries()
    m.drawstates()
    m.drawmapboundary(fill_color="lightblue")
    m.fillcontinents(color="lightgray", lake_color="lightblue")

    # Koordinaten konvertieren (long, lat) → Karten-Koordinaten (x, y)
    x_ocean, y_ocean = m(df_ocean["longitude"].values, df_ocean["latitude"].values)
    x_target, y_target = m(target_lon, target_lat)
    x_closest, y_closest = m(df_subset.longitude, df_subset.latitude)

    # Punkte plotten
    m.scatter(x_ocean, y_ocean, color="blue", label="Ocean Data", s=50)
    m.scatter(x_target, y_target, color="green", marker="*", label="Target Point", s=250)
    m.scatter(x_closest, y_closest, color="orange", marker="o", label="Closest Points", s=50)

    # draw rectangle around the subset
    # x_box = [sub_box["lon_min"], sub_box["lon_max"], sub_box["lon_max"], sub_box["lon_min"], sub_box["lon_min"]]
    # y_box = [sub_box["lat_min"], sub_box["lat_min"], sub_box["lat_max"], sub_box["lat_max"], sub_box["lat_min"]]
    # x_box, y_box = m(x_box, y_box)
    # m.plot(x_box, y_box, color="green", linestyle="--", linewidth=2, label="Region of Interest")
    

    #plt.title("Closest Location in Ocean Data")
    plt.legend(loc="upper right")
    m.drawparallels(np.arange(-360, 360, 2), labels=[1, 0, 0, 0])
    m.drawmeridians(np.arange(-360, 360, 2), labels=[0, 0, 0, 1])
    if save:
        plt.savefig("../thesis_plots/closest_location_ocean_data.png", dpi=300, bbox_inches='tight')
    plt.show()

sub_box = {
        "lat_min": 54.5, # 54.4
        "lat_max": 55.1, # 55.5
        "lon_min": 9.9,
        "lon_max": 10.1 # 10.5
        }
plot_closest_location(df_ocean=df_ocean, sub_box=sub_box, target_lat=LAT_FLENSBURG, target_lon=LON_FLENSBURG, save=True)



In [None]:
def comparison_slev_sla(df_ocean: pd.DataFrame, df_insitu: pd.DataFrame, sub_box: dict, target_lat: float = 54.5, target_lon: float = 10.0, save: bool = False) -> None:
    """
    Compare the SLEV and SLA data by plotting them on the same graph.
    """

    # Filter the data for the closest location

    df_ocean_target = df_ocean[
                (df_ocean["latitude"] >= sub_box["lat_min"]) &
                (df_ocean["latitude"] <= sub_box["lat_max"]) &
                (df_ocean["longitude"] >= sub_box["lon_min"]) &
                (df_ocean["longitude"] <= sub_box["lon_max"])
            ].reset_index(drop=True)
    
    # Calculate the mean SLA 
    df_ocean_target = df_ocean_target.groupby("time").mean().reset_index()

    # Calculating Pearson correlation of SLEV and SLA
    df_corr_ocean = df_ocean_target.copy()
    df_corr_insitu = df_insitu.copy()

    df_corr_ocean.index = pd.to_datetime(df_corr_ocean.index)
    df_corr_insitu.index = pd.to_datetime(df_corr_insitu.index)
    df_corr = pd.merge(df_corr_ocean[['sla']], df_corr_insitu[['slev']], left_index=True, right_index=True, how='inner')
    corr = df_corr['slev'].corr(df_corr['sla'])
    rmse = np.sqrt(np.mean((df_corr['slev'] - df_corr['sla']) ** 2))

    # Plotting SLEV and SLA
    alpha = 0.8
    window_size = 24

    rolling_mean_slev = df_insitu["slev"].rolling(window=window_size).mean()
    rolling_mean_sla = df_ocean_target["sla"].rolling(window=window_size).mean()
    plt.figure(figsize=(12, 6))
    plt.plot(df_insitu["time"], df_insitu["slev"], label="water level Flensburg (SLEV)", color="#0072B2", alpha=0.1)
    plt.plot(df_ocean["time"], df_ocean["sla"], label="water level - closests point (SLA)", color="#E69F00", alpha=0.1)
    plt.plot(df_insitu["time"], rolling_mean_slev, label="water level Flensburg (SLEV) Rolling Mean", color="#0072B2", linestyle="-", alpha=alpha)
    plt.plot(df_ocean_target["time"], rolling_mean_sla, label="water level - closest points (SLA) Rolling Mean", color="#E69F00", linestyle="-", alpha=alpha)
    # plot text with the correlation value

    plt.title(f"Comparison of SLEV and SLA \n Correlation: {corr:.2f}, RMSE: {rmse:.2f}")
    plt.xlabel("Time")
    plt.ylabel("Water Level [m]")
    plt.legend()
    plt.grid(True)
    if save:
        plt.savefig("../data/plots/comparison_slev_sla.png", dpi=300, bbox_inches='tight')
    plt.show()


    # calcualte autocorrelation for SLEV and SLA for different lags
    
    # df_corr_insitu = df_corr_insitu.loc[(df_corr_insitu['time'] >= df_corr_ocean['time'].min()) & (df_corr_insitu['time'] <= df_corr_ocean['time'].max())]
    # fig, ax = plt.subplots(2, 1, figsize=(12, 8))
    # plot_acf(df_insitu["slev"], lags=100, ax=ax[0])
    # ax[0].set_title("Autocorrelation of SLEV")
    # plot_acf(df_ocean_target["sla"], lags=100, ax=ax[1])
    # ax[1].set_title("Autocorrelation of SLA")
    # plt.tight_layout()
    # plt.show()


comparison_slev_sla(df_ocean=df_ocean, df_insitu=df_insitu, target_lat=LAT_FLENSBURG, target_lon=LON_FLENSBURG, sub_box=sub_box, save=False)

Beide Zeitreihen zeigen eine starke Autokorrelation, was typisch für ozeanografische Zeitreihen ist. Die realen Daten (SLEV) wirken jedoch etwas unregelmäßiger und zeigen vermutlich tidenbedingte Oszillationen, die das Modell (SLA) nur gedämpft oder geglättet abbildet.

## Wie verhält sich Wind, Windrichtung bei den unterschiedlichen Sturmfluten

In [None]:
df_ocean = load_ocean_data(ocean_data_path, OCEAN_POINTS, verbose=False)
df_ocean = process_df(df_ocean, drop_cols=["depth"], verbose=False)

df_weather = load_weather_data(weather_data_path, WEATHER_POINTS, verbose=False)
df_weather = process_df(df_weather, verbose=False)

df_insitu = load_insitu_data(verbose=False)
df_insitu = process_flensburg_data(df_insitu, 
                                      start_time=df_ocean['time'].min(),
                                      end_time=df_ocean['time'].max(),
                                      verbose=False)

df_insitu = group_data_hourly(df_insitu)
df_insitu = process_df(df_insitu, drop_cols=["deph"], verbose=False)


In [None]:
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shapely.geometry
import xarray as xr
from joblib import Parallel, delayed
from scipy.interpolate import griddata
from tqdm import tqdm


# Funktion zur Landprüfung mit Cartopy
def is_on_land(lon, lat):
    land = cfeature.NaturalEarthFeature("physical", "land", "10m")
    for geom in land.geometries():
        if geom.contains(shapely.geometry.Point(lon, lat)):
            return True
    return False


# Funktion zum Erstellen der Landmaske
def create_land_mask(lon_grid, lat_grid):
    coords_list = [(lon, lat) for lat in lat_grid for lon in lon_grid]
    mask_flat = Parallel(n_jobs=-1)(
        delayed(lambda p: not is_on_land(*p))(p) for p in tqdm(coords_list)
    )
    return np.array(mask_flat).reshape(len(lat_grid), len(lon_grid))




In [None]:
import matplotlib.dates as mdates
import matplotlib.pyplot as plt


def lineplot_storm_surge(df, column_name, timepoints, 
                         ax:plt.Axes=None, display=True, 
                         calculation:str='mean', show_number_lines:bool=True, 
                         title:str=None, display_sturm_surge:bool=False,
                         show_legend:bool=False,
                         save:bool=False,
                         legend_outside:bool=True
                         ):


    if ax is None:
        fig, ax = plt.subplots(figsize=(20, 5))

    timepoints_array = np.array(timepoints)
    timepoint_min = timepoints_array.min()
    timepoint_max = timepoints_array.max()
    # Eingrenzen des Datenbereichs
    df_plot = df.loc[
        (df["time"] >= timepoint_min - pd.Timedelta(hours=62)) & 
        (df["time"] <= timepoint_max + pd.Timedelta(hours=62))
    ].reset_index(drop=True)

    if column_name in WEATHER_DICT or column_name in OCEAN_DICT:
        # filter data in SUB_BOX
        df_plot = df_plot[
            (df_plot["latitude"] >= SUB_BOX["lat_min"]) &
            (df_plot["latitude"] <= SUB_BOX["lat_max"]) &
            (df_plot["longitude"] >= SUB_BOX["lon_min"]) &
            (df_plot["longitude"] <= SUB_BOX["lon_max"])
        ].reset_index(drop=True)

        if calculation == 'mean':
            df_plot = df_plot.groupby("time").mean().reset_index()
        elif calculation == 'max':
            df_plot = df_plot.groupby("time").max().reset_index()

    # Hightlight a red section where the values of df_insitu are greater than 1
    if display_sturm_surge:
        df_insitu_plot = df_insitu.loc[df_insitu["time"].between(timepoint_min, timepoint_max)].reset_index(drop=True)
        # max_slev = df_insitu_plot["slev"].max()
        # max_slev_time = df_insitu_plot.loc[df_insitu_plot["slev"] == max_slev, "time"].values[0]
        # ax.axvline(x=max_slev_time, color="red", linestyle="--", alpha=0.6, label="Sturm Surge")
        

        ax.fill_between(df_insitu_plot["time"], df_plot[column_name].min(), df_plot[column_name].max() , where=(df_insitu_plot["slev"] > 1), color='lightgrey', alpha=0.5, label='storm surge event')
        # limit the y axis to the max value of df_insitu_plot["slev"] + 1
        #ax.set_ylim(df_plot[column_name].min() - 1, df_plot[column_name].max() + 1)
        #ax.set_ylim(df_plot[column_name].min() - 0.1, df_plot[column_name].max() + 0.1)
    
    if column_name in INSITU_DICT:
        # Highlight storm surge classes with colored bands
        flood_levels = [
            (1.0, 1.25, 'yellow', 'storm surge'),
            (1.25, 1.5, 'orange', 'medium storm surge'),
            (1.5, 2.0, 'red', 'heavy storm surge'),
            (2.0, 3.5, 'darkred', 'very heavy storm surge'),
        ]

        for y0, y1, color, label in flood_levels:
            ax.axhspan(y0, y1, facecolor=color, alpha=0.3, label=label)
        
        ax.set_ylim(df_plot['slev'].min() - 0.1, df_plot['slev'].max() + 0.1)

    if column_name == 'wind_direction_10m':
        # Highlight wind direction with colored bands
        wind_directions = [
            (0, 45, '#0000FF', 'north'),
            (45, 135, '#FFFF00', 'east'),
            (135, 225, '#FF0000', 'south'),
            (225, 315, '#FFA500', 'west'),
            (315, 360, '#0000FF', 'north')
        ]

        for y0, y1, color, label in wind_directions:
            ax.axhspan(y0, y1, facecolor=color, alpha=0.3, label=label)

    if column_name == "vo":
        # Highlight water directions with colored bands >0 then #0000FF <0 then #FF0000
        ax.axhspan(0, df_plot[column_name].max(), facecolor='#0000FF', alpha=0.3, label='north direction')
        ax.axhspan(df_plot[column_name].min(), 0, facecolor='#FF0000', alpha=0.3, label='south direction')
        
    if column_name == "uo":
        # Highlight water directions with colored bands >0 then #FFFF00 <0 then #FFA500
        ax.axhspan(0, df_plot[column_name].max(), facecolor='#FFFF00', alpha=0.3, label='east direction')
        ax.axhspan(df_plot[column_name].min(), 0, facecolor='#FFA500', alpha=0.3, label='west direction')

    if column_name == "wo":
        # Highlight water directions with colored bands >0 then #00CED1 <0 then #2F4F4F
        ax.axhspan(0, df_plot[column_name].max(), facecolor='#00CED1', alpha=0.3, label='upward direction')
        ax.axhspan(df_plot[column_name].min(), 0, facecolor='#2F4F4F', alpha=0.3, label='downward direction')

    # highlight low and high pressure
    # if column_name == "pressure_msl":
    #     # Highlight low and high pressure with colored bands
    #     low_pressure = (980, 990, '#00CED1', 'Low Pressure')
    #     high_pressure = (1040, 1050, '#2F4F4F', 'High Pressure')

    #     ax.axhspan(low_pressure[0], low_pressure[1], facecolor=low_pressure[2], alpha=0.3, label=low_pressure[3])
    #     ax.axhspan(high_pressure[0], high_pressure[1], facecolor=high_pressure[2], alpha=0.3, label=high_pressure[3])

    if column_name == "wind_speed_10m" or column_name == "wind_gusts_10m":
    # Classification based on Beaufort scale (km/h) with color bands
        wind_bands = [
            (0, 5, '#00FF00', 'calm (0)'),                   # bright green
            (5, 11, '#32CD32', 'light air (1)'),             # lime green
            (11, 19, '#7FFF00', 'light breeze (2)'),         # chartreuse
            (19, 28, '#ADFF2F', 'gentle breeze (3)'),        # green-yellow
            (28, 38, '#FFFF00', 'moderate breeze (4)'),      # yellow
            (38, 49, '#FFD700', 'fresh breeze (5)'),         # gold
            (49, 61, '#FFA500', 'strong breeze (6)'),        # orange
            (61, 74, '#FF8C00', 'near gale (7)'),            # dark orange
            (74, 88, '#FF4500', 'gale (8)'),                 # orange red
            (88, 102, '#FF0000', 'strong gale (9)'),         # red
            (102, 117, '#B22222', 'storm (10)'),             # firebrick
            (117, 150, '#8B0000', 'violent storm (11–12+)')  # dark red
        ]


        for band in wind_bands:
            ax.axhspan(band[0], band[1], facecolor=band[2], alpha=0.3, label=band[3])
        
        ax.set_ylim(df_plot[column_name].min() - 0.5, df_plot[column_name].max() + 0.5)
        # set legend title
    

    # Plot SLEV
    x = df_plot["time"]
    y = df_plot[column_name]
    ax.plot(x, y, label=column_name, color="royalblue", linewidth=2)


    # VLines + Text
    if show_number_lines:
        for i, t in enumerate(timepoints):
            ax.axvline(x=t, ymin=0, ymax=1, color="crimson", linestyle="--", alpha=0.6)
            # ax.text(
            #     t,
            #     y.mean(),
            #     t.strftime("%Y-%m-%d %H:%M"),
            #     color="crimson",
            #     fontsize=9,
            #     ha="center",
            #     rotation=90,
            #     rotation_mode="anchor"
            # )
            
            ax.annotate(f"{i+1}",  # Changed from {number} to {i+1} to display the index
                    xy=(t, y.mean()),
                    xytext=(t, y.mean()),
                    xycoords="data", 
                    #fontsize=12, 
                    ha="left", 
                    va="top", 
                    color="black", 
                    bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="white"),
                    zorder=10
                    )
    # Format X-Achse
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d %H:%M"))
    ax.xaxis.set_major_locator(mdates.HourLocator(interval=24)) 
    plt.xticks(rotation=30, ha="right")

    # Labels & Title
    #ax.set_xlabel("Time", fontsize=12)
    if title is None:

        if column_name in INSITU_DICT:
            y_label = f"{column_name} [{INSITU_DICT[column_name]['unit']}]"
            title = f"{INSITU_DICT[column_name]['description']} ({column_name}) of Flensburg \n{INSITU_DICT[column_name]['explanation']}"
            # subt
        elif column_name in OCEAN_DICT:
            y_label = f"{column_name} [{OCEAN_DICT[column_name]['unit']}]"
            title = f"{calculation.capitalize()} of {OCEAN_DICT[column_name]['description']} ({column_name}) Timegraph \n{OCEAN_DICT[column_name]['explanation']}"
        elif column_name in WEATHER_DICT:
            y_label = f"{column_name} [{WEATHER_DICT[column_name]['unit']}]"
            title = f"{calculation.capitalize()} of {WEATHER_DICT[column_name]['description']} ({column_name}) Timegraph \n{WEATHER_DICT[column_name]['explanation']}"
        else:
            y_label = f"{column_name}"
            title = f"{column_name}"
    else:
        if column_name in INSITU_DICT:
            y_label = f"{column_name} [{INSITU_DICT[column_name]['unit']}]"
        elif column_name in OCEAN_DICT:
            y_label = f"{column_name} [{OCEAN_DICT[column_name]['unit']}]"
        elif column_name in WEATHER_DICT:
            y_label = f"{column_name} [{WEATHER_DICT[column_name]['unit']}]"
        else:
            y_label = f"{column_name}"


    # ax.set_title(f"{title}", 
    #              #fontsize=14, 
    #              pad=15)
    ax.set_ylabel(f"{y_label}", 
                  #fontsize=12
                  )
        

    # for t in timepoints:
    #     ax.axvline(x=t, ymin=0, ymax=1, color="crimson", linestyle="--", alpha=0.6, label="Storm Surge")

    # # Nur einmaliger Eintrag
    # handles, labels = ax.get_legend_handles_labels()
    # by_label = dict(zip(labels, handles))
    # ax.legend(by_label.values(), by_label.keys(), loc="upper left")


    # Grid, Legend, Layout
    ax.grid(True, linestyle="--", alpha=0.5)
    #ax.legend(loc="upper left")
    if show_legend and legend_outside:
        ax.legend(
                #title="Beaufort Scale",
                #loc="upper left",
                #fontsize=12,
                #title_fontsize=10,
                framealpha=0.0,
                facecolor='white',
                #edgecolor='gray',
                ncol=1,
                fancybox=True,
                shadow=False,
                borderaxespad=0.3,
                bbox_to_anchor=(1.01, 1),
            )
    if show_legend and not legend_outside:
        ax.legend()
    
    #plt.tight_layout()
    if save:
        plt.savefig(f"../data/plots/storm_surge_line_plot_{column_name}_{timepoints[0].strftime('%Y%m%d')}-{timepoints[-1].strftime('%Y%m%d')}.png", dpi=300, bbox_inches='tight')
        print(f"Saved plot as ../data/plots/storm_surge_line_plot_{column_name}_{timepoints[0].strftime('%Y%m%d')}_{timepoints[-1].strftime('%Y%m%d')}.png")

    if display:
        plt.show()
    

    return ax




In [None]:
# df_ocean_target = df_ocean[
#             (df_ocean["latitude"] >= sub_box["lat_min"]) &
#             (df_ocean["latitude"] <= sub_box["lat_max"]) &
#             (df_ocean["longitude"] >= sub_box["lon_min"]) &
#             (df_ocean["longitude"] <= sub_box["lon_max"])
#         ].reset_index(drop=True)

# # Calculate the mean SLA 
# df_ocean_target = df_ocean_target.groupby("time").mean().reset_index()

# # Calculating Pearson correlation of SLEV and SLA
# df_corr_ocean = df_ocean_target.copy()
# df_corr_insitu = df_insitu.copy()

# df_corr_ocean.index = pd.to_datetime(df_corr_ocean.index)
# df_corr_insitu.index = pd.to_datetime(df_corr_insitu.index)
# df_corr = pd.merge(df_corr_ocean[['sla']], df_corr_insitu[['slev']], left_index=True, right_index=True, how='inner')
# corr = df_corr['slev'].corr(df_corr['sla'])
# rmse = np.sqrt(np.mean((df_corr['slev'] - df_corr['sla']) ** 2))

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import scipy

def grid_to_xarray(key, grid, lon_grid, lat_grid):
    ds = xr.Dataset(
            {
                f"{key}": (("latitude", "longitude"), grid),

            },
            coords={
                "latitude": lat_grid,
                "longitude": lon_grid,
            },
        )
    return ds 


def plot_for_timepoint(timepoint, ax=None, grid_size_ocean=50, 
                       wind_grid_size=20, vmin=-1.0, vmax=1.5, 
                       plot_water_velocity_data=True, plot_wind_data=True, 
                       number=None, cluster:bool=False, timepoints:list=None,
                       show_cluster:bool=False,
                       ):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 10))
    df_weather_time = df_weather[df_weather["time"] == timepoint]
    df_ocean_time = df_ocean[df_ocean["time"] == timepoint]

    if cluster and timepoints is not None:
        # cluster the ocean data #['latitude', 'longitude', 'sla']
        timepoints = sorted(timepoints)
        print(f"Clustering for timepoints: {timepoints}")
        df_cluster = cluster_df(df_ocean, ['latitude', 'longitude', 'sla'], n_clusters=3, display=show_cluster, start_date=timepoints[0], end_date=timepoints[-1])


        # find the cluster number to closest point of Flensburg
        closest_point = find_closest_location(df_cluster, target_lat=54.8, target_lon=10)


        df_cluster = df_cluster[df_cluster["cluster"] == closest_point["cluster"]].reset_index(drop=True)
        # display(df_cluster)
        # print(df_cluster.cluster.unique())

        # filter df_weather_time to the latitude and longitude which are close to latitude and longitude of df_cluster
        df_weather_time_cluster = df_weather_time[
            (df_weather_time["latitude"] >= df_cluster["latitude"].min()) &
            (df_weather_time["latitude"] <= df_cluster["latitude"].max()) &
            (df_weather_time["longitude"] >= df_cluster["longitude"].min()) &
            (df_weather_time["longitude"] <= df_cluster["longitude"].max())
        ].reset_index(drop=True)

        mean_wind_speed = df_weather_time_cluster["wind_speed_10m"].mean()
        print(f"Mean wind speed in cluster: {mean_wind_speed}")

        df_ocean_time_cluster = df_ocean_time[
            (df_ocean_time["latitude"] >= df_cluster["latitude"].min()) &
            (df_ocean_time["latitude"] <= df_cluster["latitude"].max()) &
            (df_ocean_time["longitude"] >= df_cluster["longitude"].min()) &
            (df_ocean_time["longitude"] <= df_cluster["longitude"].max())
        ].reset_index(drop=True)

        # calculate the actual velocity with time included
        df_ocean_time_cluster['velocity'] = np.sqrt(df_ocean_time_cluster["uo"]**2 + df_ocean_time_cluster["vo"]**2)

        mean_ocean_velocity = df_ocean_time_cluster["velocity"].mean()
        print(f"Mean ocean velocity in cluster: {mean_ocean_velocity}")



    # add rmse to df_ocean_time[sla]
    #df_ocean_time["sla"] += rmse

    lon_grid = np.linspace(df_ocean_time["longitude"].min(), df_ocean_time["longitude"].max(), grid_size_ocean)
    lat_grid = np.linspace(df_ocean_time["latitude"].min(), df_ocean_time["latitude"].max(), grid_size_ocean)
    lon_mesh, lat_mesh = np.meshgrid(lon_grid, lat_grid)

    sla_grid = griddata(
        (df_ocean_time["longitude"], df_ocean_time["latitude"]),
        df_ocean_time["sla"],
        (lon_mesh, lat_mesh),
        method="linear",
    )

    m = Basemap(
        projection="cyl",
        resolution="i",
        llcrnrlon=lon_grid.min(),
        urcrnrlon=lon_grid.max(),
        llcrnrlat=lat_grid.min(),
        urcrnrlat=lat_grid.max(),
        ax=ax,
    )
    m.fillcontinents(color="grey", lake_color="white", alpha=0.5)
    m.drawcoastlines()
    m.drawcountries()

    # plot target point
    x_target, y_target = m(LON_FLENSBURG, LAT_FLENSBURG)  # Reihenfolge: (longitude, latitude)
    m.scatter(x_target, y_target, color="green", marker="*", label="Flensburg", s=200) 

    if cluster:
        x_box = [df_cluster["longitude"].min(), df_cluster["longitude"].max(), df_cluster["longitude"].max(), df_cluster["longitude"].min(), df_cluster["longitude"].min()]
        y_box = [df_cluster["latitude"].min()+0.1, df_cluster["latitude"].min()+0.1, df_cluster["latitude"].max(), df_cluster["latitude"].max(), df_cluster["latitude"].min()+0.1]
        x_box, y_box = m(x_box, y_box)
        ax.plot(x_box, y_box, color="green", linestyle="--", linewidth=2, label="box of interest ")
    else:
        # plot rectangle around target point
        x_box = [SUB_BOX["lon_min"], SUB_BOX["lon_max"], SUB_BOX["lon_max"], SUB_BOX["lon_min"], SUB_BOX["lon_min"]]
        y_box = [SUB_BOX["lat_min"], SUB_BOX["lat_min"], SUB_BOX["lat_max"], SUB_BOX["lat_max"], SUB_BOX["lat_min"]]
        x_box, y_box = m(x_box, y_box)
        ax.plot(x_box, y_box, color="green", linestyle="--", linewidth=2, label="box of interest ")
        #ax.fill(x_box, y_box, color="white", alpha=0.2)


    # Create ocean grid
    mask = create_land_mask(lon_grid, lat_grid)
    sla_grid[~mask] = np.nan

    x_mesh, y_mesh = m(lon_mesh, lat_mesh)
    heatmap = m.pcolormesh(x_mesh, y_mesh, sla_grid, cmap="magma", shading="auto", vmin=vmin, vmax=vmax)


    # Wasser Geschwindigkeitsdaten
    # eastward and northward velocity
    if plot_water_velocity_data:
        water_uo = griddata(
            (df_ocean_time["longitude"], df_ocean_time["latitude"]),
            df_ocean_time["uo"],
            (lon_mesh, lat_mesh),
            method="linear",
        )
        water_vo = griddata(
            (df_ocean_time["longitude"], df_ocean_time["latitude"]),
            df_ocean_time["vo"],
            (lon_mesh, lat_mesh),
            method="linear",
        )

        # stride depends on the grid size
        if grid_size_ocean <= 100:
            stride = 1
        elif grid_size_ocean <= 200:
            stride = 4
        elif grid_size_ocean <= 300:
            stride = 6
        elif grid_size_ocean <= 400:
            stride = 8
        elif grid_size_ocean <= 500:
            stride = 12
        elif grid_size_ocean <= 600:
            stride = 16
        elif grid_size_ocean <= 700:
            stride = 20
        elif grid_size_ocean <= 800:
            stride = 24

        water_uo[~mask] = np.nan
        water_vo[~mask] = np.nan
        #stride = 2  
        x_current = x_mesh[::stride, ::stride]
        y_current = y_mesh[::stride, ::stride]
        u_current = water_uo[::stride, ::stride]
        v_current = water_vo[::stride, ::stride]

        quiv_current = m.quiver(
            x_current,
            y_current,
            u_current,
            v_current,
            scale=20,        # je nach Einheiten der uo/vo anpassen
            color='grey',   # z.B. andere Farbe als Wind
            width=0.002,    # dünner Pfeil
            alpha=0.99,
            label="water direction and velocity",
        )
        # ax.quiverkey(quiv_current, 0.92, 0.04, 20, '1 m/s Current', labelpos='E', 
        #             coordinates='axes', 
        #             color='grey')

        ds_ocean_uo = grid_to_xarray("vo", water_uo, lon_grid, lat_grid)
        ds_ocean_vo = grid_to_xarray("uo", water_vo, lon_grid, lat_grid)

        # merge the datasets
        ds_ocean_uo_vo= xr.merge([ds_ocean_uo, ds_ocean_vo])

        # calculate the actual velocity with time included
        ds_ocean_uo_vo['velocity'] = np.sqrt(ds_ocean_uo_vo["uo"]**2 + ds_ocean_uo_vo["vo"]**2)

        # select the ocean current at the bbox
        ds_ocean_uo_vo_bbox = ds_ocean_uo_vo.sel(
            latitude=slice(SUB_BOX["lat_min"], SUB_BOX["lat_max"]),
            longitude=slice(SUB_BOX["lon_min"], SUB_BOX["lon_max"]),
        )

        if not cluster:
            # get the mean ocean current in the ds_ocean_uo_vo_bbox
            mean_ocean_velocity = ds_ocean_uo_vo_bbox["velocity"].mean().values

     # Winddaten   
    if plot_wind_data:
        
        lon_grid_wind = np.linspace(df_weather_time["longitude"].min(), df_weather_time["longitude"].max(), wind_grid_size)
        lat_grid_wind = np.linspace(df_weather_time["latitude"].min(), df_weather_time["latitude"].max(), wind_grid_size)
        lon_mesh_wind, lat_mesh_wind = np.meshgrid(lon_grid_wind, lat_grid_wind)

        wind_speed_grid = griddata(
            (df_weather_time["longitude"], df_weather_time["latitude"]),
            df_weather_time["wind_speed_10m"],
            (lon_mesh_wind, lat_mesh_wind),
            method="linear",
        )
        wind_dir_grid = griddata(
            (df_weather_time["longitude"], df_weather_time["latitude"]),
            df_weather_time["wind_direction_10m"],
            (lon_mesh_wind, lat_mesh_wind),
            method="linear",
        )

        u = wind_speed_grid * -np.cos(np.deg2rad(wind_dir_grid))
        v = wind_speed_grid * -np.sin(np.deg2rad(wind_dir_grid))
        x_wind, y_wind = m(lon_mesh_wind, lat_mesh_wind)
        quiv_wind = m.quiver(x_wind, y_wind, u, v, scale=1500, color="black")
        #ax.quiverkey(quiv_wind, 0.92, 0.08, 10, '10 m/s Wind', labelpos='E', coordinates='axes', color='black')


        # make xarray from wind_speed_grid and wind_dir_grid
        ds_wind_speed = grid_to_xarray("wind_speed", wind_speed_grid, lon_grid_wind, lat_grid_wind)
        
        # select the wind speed and direction at the bbox
        ds_wind_speed_bbox = ds_wind_speed.sel(
            latitude=slice(SUB_BOX["lat_min"], SUB_BOX["lat_max"]),
            longitude=slice(SUB_BOX["lon_min"], SUB_BOX["lon_max"]),
        )

        if not cluster:
            # get the mean wind speed in the ds_wind_bbox
            mean_wind_speed = ds_wind_speed_bbox["wind_speed"].mean().values
        
        # Add contour lines for pressure_msl
    if "pressure_msl" in df_weather_time.columns:
        pressure_msl_grid = griddata(
            (df_weather_time["longitude"], df_weather_time["latitude"]),
            df_weather_time["pressure_msl"],
            (lon_mesh_wind, lat_mesh_wind),
            method="linear",
        )
        if pressure_msl_grid is not None:
            x_pressure, y_pressure = m(lon_mesh_wind, lat_mesh_wind)
            # Konturstufen definieren (alle 2 hPa z.B.)
            

            # Konturstufen und Konturen zeichnen
            levels = np.arange(np.nanmin(pressure_msl_grid), np.nanmax(pressure_msl_grid), 2)
            cmap = plt.get_cmap('coolwarm')

            cs = m.contour(x_pressure, y_pressure, pressure_msl_grid, 
                        levels=levels, 
                        cmap=cmap, 
                        linewidths=1.0)

            plt.clabel(cs, inline=True, 
                       #fontsize=12, 
                       fmt="%.0f hPa")


            #cbar = plt.colorbar(cs, orientation='horizontal', pad=0.05)
            #cbar.set_label('Luftdruck (hPa)')


            # Tiefdruck: Lokale Minima, Hochdruck: Lokale Maxima
            pressure_msl_masked = np.ma.masked_invalid(pressure_msl_grid)

            # Lokale Minima und Maxima suchen
            minima = (scipy.ndimage.minimum_filter(pressure_msl_masked, size=20, mode='nearest') == pressure_msl_masked)
            maxima = (scipy.ndimage.maximum_filter(pressure_msl_masked, size=20, mode='nearest') == pressure_msl_masked)

            # 4. Positionen der Minima und Maxima bestimmen
            min_locs = np.where(minima)
            max_locs = np.where(maxima)

            # # "L" für Tiefdruck, "H" für Hochdruck auf Karte plotten
            # for y_idx, x_idx in zip(*min_locs):
            #     x, y = m(x_pressure[y_idx, x_idx], y_pressure[y_idx, x_idx])
            #     plt.text(x, y, 'L', fontsize=15, fontweight='bold', ha='center', va='center', color='blue')

            # for y_idx, x_idx in zip(*max_locs):
            #     x, y = m(x_pressure[y_idx, x_idx], y_pressure[y_idx, x_idx])
            #     plt.text(x, y, 'H', fontsize=15, fontweight='bold', ha='center', va='center', color='red')

            # Nur starke Hochs und Tiefs markieren   
            # for y_idx, x_idx in zip(*min_locs):
            #     if pressure_msl_grid[y_idx, x_idx] < 1005:
            #         x, y = m(x_pressure[y_idx, x_idx], y_pressure[y_idx, x_idx])
            #         plt.text(x, y, 'L', fontsize=15, fontweight='bold', ha='center', va='center', color='blue')

            # for y_idx, x_idx in zip(*max_locs):
            #     if pressure_msl_grid[y_idx, x_idx] > 1015:
            #         x, y = m(x_pressure[y_idx, x_idx], y_pressure[y_idx, x_idx])
            #         plt.text(x, y, 'H', fontsize=15, fontweight='bold', ha='center', va='center', color='red')

    ax.set_title(f"Time: {pd.to_datetime(timepoint).strftime('%Y-%m-%d %H:%M')}\n Mean wind speed in box of interest  {mean_wind_speed:.3f} km/h \n Mean water velocity in box of interest {mean_ocean_velocity:.3f} m/s", pad=15)
    #ax.set_title(f"Time: {pd.to_datetime(timepoint).strftime('%Y-%m-%d %H:%M')}", fontsize=14, pad=15)

    m.drawparallels(np.arange(0, 360, 2), labels=[1, 0, 0, 0])
    m.drawmeridians(np.arange(0, 350, 2), labels=[0, 0, 0, 1])
    
    # anotate the number of the plot on the top left corner
    if number is not None:
        
        number = str(int(number) + 1)

        #x_anotate, y_annotate = m(10, 59)
        plt.annotate(f"{number}", 
                     xy=(0.01, 0.98),
                     xytext=(0.01, 0.98),
                     xycoords="axes fraction", 
                     #fontsize=12, 
                     ha="left", 
                     va="top", 
                     color="black", 
                     bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="white"),
                     zorder=10
                     )
    return heatmap


def analyse_storm_surges(title: str = "Storm Surges Analysis", timepoints: list = None, grid_size_ocean=50, wind_grid_size=20, plot_water_velocity_data=True, plot_wind_data=True, save=False):
    
    fig = plt.figure(figsize=(20, 12))

    # Neue Anordnung: 2 Zeilen, 2 Spalten → Zeitreihe oben, darunter 2x2 Heatmaps
    n_cols = 2  # Anzahl der Spalten
    n_rows = (len(timepoints) + n_cols - 1) // n_cols  # Berechne die benötigten Zeilen
    gs = gridspec.GridSpec(n_rows, n_cols, 
                           #height_ratios=[0.4] + [0.8] * (n_rows - 1), 
                           hspace=0.1, wspace=0.1)

    heatmaps = []
    axes_heatmaps = []

    timepoints_array = np.array(timepoints)
    timepoint_min = timepoints_array.min()
    timepoint_max = timepoints_array.max()

    # Eingrenzen des Datenbereichs
    df_plot = df_ocean.loc[
        (df_ocean["time"] >= timepoint_min - pd.Timedelta(hours=62)) & 
        (df_ocean["time"] <= timepoint_max + pd.Timedelta(hours=62))
    ].reset_index(drop=True)



    # select the smallest sla and maximum sla
    min_sla = df_plot["sla"].min()
    max_sla = df_plot["sla"].max()
    print(f"Min SLA: {min_sla}, Max SLA: {max_sla}")

    # # Zeitverlauf ganz oben über beide Spalten
    # ax_line = fig.add_subplot(gs[0, :])  # oberste Zeile, beide Spalten
    # lineplot_storm_surge(df_insitu, timepoints, ax=ax_line)
    # ax_line.set_title("Time Graph of Water Level Elevation in Flensburg")
    # ax_line.set_position([0.1, 0.75, 0.8, 0.2])  # Adjust size and position

    
    # 2x2 Heatmaps darunter
    min_sla_round = np.round(min_sla, 1)
    max_sla_round = np.round(max_sla, 1)


    

    for i, timepoint in enumerate(timepoints):
        ax = fig.add_subplot(gs[i // n_cols, i % n_cols])  # Zeile und Spalte dynamisch berechnen
        
        #ax.set_position([0.1 + (i % n_cols) * 0.4, 0.1 + (i // n_cols) * 0.1, 0.35, 0.35])  # Adjust size and position [left, bottom, width, height]
        axes_heatmaps.append(ax)
        heatmap = plot_for_timepoint(timepoint=timepoint, ax=ax, 
                                     grid_size_ocean=grid_size_ocean, wind_grid_size=wind_grid_size, 
                                     vmin=min_sla_round, vmax=max_sla_round, 
                                     plot_water_velocity_data=plot_water_velocity_data, plot_wind_data=plot_wind_data, 
                                     number=i, 
                                     cluster=False, timepoints=timepoints,
                                     )

        #heatmap = plot_for_timepoint(timepoint, ax, grid_size_ocean, wind_grid_size, min_sla_round, max_sla_round, plot_water_velocity_data=plot_water_velocity_data, plot_wind_data=plot_wind_data, number=i, cluster=True, timepoints=timepoints)
        heatmaps.append(heatmap)

    # Gemeinsame Farbleiste rechts neben den Heatmaps
    cbar_ax = fig.add_axes([0.32, 0.05, 0.55, 0.03])  # [left, bottom, width, height]
    cbar = fig.colorbar(heatmaps[0], cax=cbar_ax, orientation="horizontal", use_gridspec=True)
    
    ticks = np.linspace(min_sla_round, max_sla_round, num=5)  # Anzahl der Ticks anpassen
    cbar.set_ticks(ticks)
    cbar.set_label("Water Level (m)")

    legend_elements = []
    legend_elements.append(Line2D([0], [0], color='green', lw=2, label='box of interest ', linestyle='--'))
    
    legend_elements.append(Line2D([0], [0], color='green', marker='*', markersize=15, linestyle='', label='Flensburg'))

    
    if plot_water_velocity_data:
        legend_elements.append(Line2D([0], [0], color='grey', lw=4, marker=r'$\rightarrow$', label='ocean current', linestyle=''))
    if plot_wind_data:
        legend_elements.append(Line2D([0], [0], color='black', lw=4, marker=r'$\rightarrow$', label='wind direction', linestyle=''))


    # Legende außerhalb der Plots, neben der Colorbar
    fig.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(0.13, 0.05), frameon=True, title="")



    # fig.suptitle(title, 
    #     #fontsize=16, 
    #     y=0.99)
    plt.tight_layout()  # Adjust layout to prevent overlap
        # Save the figure
    if save:
        fig_name = f"{title.replace(' ', '_').replace(':', '')}.png"
        fig.savefig(f'../data/plots/{fig_name}', bbox_inches='tight', dpi=300)
        print(f"Figure saved as {fig_name}")
        
    plt.show()


In [None]:
plot_for_timepoint(timepoint="2023-01-01 00:00:00", ax=None, grid_size_ocean=50, wind_grid_size=20, vmin=-1.0, vmax=1.5, plot_water_velocity_data=True, plot_wind_data=True, cluster=False)

In [None]:
grid_size_ocean = 200
wind_grid_size = 25

In [None]:
df_ocean.columns

In [None]:
df_weather.columns

In [None]:
df_insitu['time'].min(), df_insitu['time'].max()

### Plots

In [None]:
sturm_surge_dict = {
    1: {
        
        "timepoints": sorted([
                pd.Timestamp("2023-02-24 05:00"),
                pd.Timestamp("2023-02-24 15:00"),
                pd.Timestamp("2023-02-25 16:00"),
                pd.Timestamp("2023-02-26 15:00"),
        ])
    },    
    2: {
        "timepoints": sorted([
                pd.Timestamp("2023-03-30 21:00"),
                pd.Timestamp("2023-03-31 12:00"),
                pd.Timestamp("2023-04-01 12:00"),
                pd.Timestamp("2023-04-03 00:00"),
        ])
    },
    3: {
        "timepoints": sorted([
                pd.Timestamp("2023-10-06 00:00"),
                pd.Timestamp("2023-10-06 22:00"),
                pd.Timestamp("2023-10-07 19:00"),
                pd.Timestamp("2023-10-09 07:00"),
        ])
    },

    4: {
        "timepoints": sorted([
                pd.Timestamp("2023-10-17 22:00"), 
                pd.Timestamp("2023-10-19 21:00"), 
                pd.Timestamp("2023-10-20 21:00"),
                pd.Timestamp("2023-10-21 17:00"),   
        ])
    },
    5: {
        "timepoints": sorted([
                pd.Timestamp("2024-01-02 05:00"), 
                pd.Timestamp("2024-01-03 01:00"),             
                pd.Timestamp("2024-01-04 12:00"), 
                pd.Timestamp("2024-01-05 05:00"), 
                
                
        ])
    },
    6: {
        "timepoints": sorted([
                pd.Timestamp("2024-02-08 18:00"), 
                pd.Timestamp("2024-02-09 15:00"),         
                pd.Timestamp("2024-02-11 13:00"), 
                pd.Timestamp("2024-02-13 11:00"), 
                
                 
        ])
    },


    7: {
        "timepoints": sorted([
                pd.Timestamp("2024-12-07 17:00"),
                pd.Timestamp("2024-12-09 05:00"),
                pd.Timestamp("2024-12-09 12:00"),
                pd.Timestamp("2024-12-10 17:00"),
        ])
    },
}

In [None]:
# Analyse von 2023-10-19 bis 2023-10-21
# timepoints = sorted([
#     pd.Timestamp("2023-10-17 22:00"), # links oben
#     pd.Timestamp("2023-10-19 21:00"), # links unten
#     pd.Timestamp("2023-10-20 21:00"), # rechts oben
#     pd.Timestamp("2023-10-21 17:00"), # rechts unten    
# ])



# sub_box = {
#         "lat_min": closest_point["latitude"] - 0.1, # 54.4
#         "lat_max": closest_point["latitude"] + 0.1, # 55.5
#         "lon_min": closest_point["longitude"] - 0.1,
#         "lon_max": closest_point["longitude"] + 0.1 # 10.5
#         }
for number in sturm_surge_dict.keys():
    
    print(f"\n ################# Analyse Sturmflut {number} ################# \n")
    timepoints = sturm_surge_dict[number]["timepoints"]
    # plot_closest_location(df_ocean, sub_box=sub_box, target_lat=LAT_FLENSBURG, target_lon=LON_FLENSBURG)
    
    lineplot_storm_surge(df=df_insitu, column_name='slev', timepoints=timepoints, ax=None, show_legend=True, save=True, legend_outside=False)

    analyse_storm_surges(
        title=f"Storm surge analysis from {timepoints[0].strftime('%Y-%m-%d')} to {timepoints[-1].strftime('%Y-%m-%d')}", 
        timepoints=timepoints, 
        grid_size_ocean=grid_size_ocean, 
        wind_grid_size=wind_grid_size,
        plot_water_velocity_data=True,
        plot_wind_data=True,
        save=True
        )

In [None]:
def compare_features(column_name:str, sturm_surge_dict:dict, calculation:str):
    """
    Compare features of the ocean data and weather data for a given timepoint.
    """
    keys = list(sturm_surge_dict.keys())


    # Neue Anordnung: 3 Zeilen, 3 Spalten → Zeitreihe oben, darunter 3x3 Heatmaps
    n_cols = 3  # Anzahl der Spalten
    n_rows = (len(keys) + n_cols - 1) // n_cols  # Berechne die benötigten Zeilen
    fig = plt.figure(figsize=(8*n_cols, 5*n_rows), constrained_layout=True)
    gs = gridspec.GridSpec(n_rows, n_cols, 
                           #height_ratios=[0.4] + [0.8] * (n_rows - 1), 
                           hspace=0.7, wspace=0.2)


    if column_name in INSITU_DICT:
        df_plot = df_insitu.copy()
    elif column_name in OCEAN_DICT:
        df_plot = df_ocean.copy()
    elif column_name in WEATHER_DICT:
        df_plot = df_weather.copy()
    else:
        return None



    timepoints = sturm_surge_dict[keys[0]]["timepoints"]
    letter_dict = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H', 8: 'I'}

    ax_lines = []
    for i, key in enumerate(keys):
        ax = fig.add_subplot(gs[i // n_cols, i % n_cols])  # Zeile und Spalte dynamisch berechnen

        timepoints = sturm_surge_dict[key]["timepoints"]

        ax_line = lineplot_storm_surge(df=df_plot, column_name=column_name, timepoints=timepoints,
                                       ax=ax, calculation=calculation, show_number_lines=False, display=False,
                                       title=f'Figure {letter_dict[i]}: {timepoints[0].strftime("%Y-%m-%d")} - {timepoints[-1].strftime("%Y-%m-%d")}',
                                       display_sturm_surge=True)
        ax_lines.append(ax_line)
    
    # set legend
    ax_lines[2].legend(
            #title="Beaufort Scale",
            loc="upper right",
            #fontsize=12,
            #title_fontsize=10,
            framealpha=0.0,
            facecolor='white',
            #edgecolor='gray',
            ncol=1,
            fancybox=True,
            shadow=False,
            #borderaxespad=0.3,
            bbox_to_anchor=(1.5, 1),
        )
        
    if column_name in INSITU_DICT:
        title = f"Observation of {INSITU_DICT[column_name]['description'].lower()} ({column_name}) for different storm surges"
    elif column_name in WEATHER_DICT:
        title = f"Mean value of {WEATHER_DICT[column_name]['description'].lower()} ({column_name}) in target box for different storm surges\nDescription: {WEATHER_DICT[column_name]['explanation']}"
    elif column_name in OCEAN_DICT:
        title = f"Mean value of {OCEAN_DICT[column_name]['description'].lower()} ({column_name}) in target box for different storm surges\nDescription: {OCEAN_DICT[column_name]['explanation']}"
        if column_name == 'sla':
            title = f"Mean value of {OCEAN_DICT[column_name]['description'].lower()} ({column_name}) in target box for different storm surges"
    else:
        title = f"Mean value of {column_name.lower()} in target box for different storm surges"

    fig.suptitle(f'{title}', 
        #fontsize=24, 
        y=0.99)
    plt.tight_layout()  # Adjust layout to prevent overlap
    # Save the figure
    fig_name = f"../data/plots/datacompare_{column_name}.png"
    fig.savefig(fig_name, bbox_inches='tight', dpi=300)
    print(f"Figure saved as {fig_name}")
    plt.show()

### All Line Plots

In [None]:
features_all = df_ocean.columns.tolist() + df_weather.columns.tolist() + df_insitu.columns.tolist()
features_all = ['slev', 'sla', 'uo', 'vo', 'wo', 'wind_speed_10m', 'wind_direction_10m', 'pressure_msl']

for feature in features_all:
    compare_features(column_name=feature, sturm_surge_dict=sturm_surge_dict, calculation='mean')

## Korrelation zwischen SLEV und features on map

In [None]:
df_ocean = load_ocean_data(ocean_data_path, OCEAN_POINTS, verbose=False)
df_ocean = process_df(df_ocean, drop_cols=["depth"], verbose=False)

df_weather = load_weather_data(weather_data_path, WEATHER_POINTS, verbose=False)
df_weather = process_df(df_weather, verbose=False)

df_insitu = load_insitu_data(verbose=False)
df_insitu = process_flensburg_data(df_insitu, 
                                      start_time=df_ocean['time'].min(),
                                      end_time=df_ocean['time'].max(),
                                      verbose=False)

df_insitu = group_data_hourly(df_insitu)
df_insitu = process_df(df_insitu, drop_cols=["deph"], verbose=False)


In [None]:


resolution = 0.25 # 0.25 degrees

def interpolate_xarray(ds, resolution=0.25, make_fine_grid=True, interpolate_nan=True):
    """
    Interpolates the xarray dataset to a finer grid.

    Parameters:
        ds (xarray.Dataset): The input dataset to interpolate.
        resolution (float): The desired resolution for the interpolation.

    Returns:
        xarray.Dataset: The interpolated dataset.
    """

    if interpolate_nan:
        # Interpolation von NaN-Werten
        ds = ds.interpolate_na(dim="time", method="linear")
        ds = ds.interpolate_na(dim="latitude", method="linear")
        ds = ds.interpolate_na(dim="longitude", method="linear")
    
    if make_fine_grid:
        # Neues feineres Gitter erzeugen
        new_lats = np.arange(ds.latitude.min(), ds.latitude.max(), resolution)
        new_lons = np.arange(ds.longitude.min(), ds.longitude.max(), resolution)

        # Interpolation
        ds = ds.interp(latitude=new_lats, longitude=new_lons, method="linear")
  
    return ds

# Create xarray datasets from DataFrames
# Interpolate the xarray to a higher resolution
ds_ocean = df_ocean.set_index(["time", "latitude", "longitude"]).to_xarray()
ds_ocean_interp = interpolate_xarray(ds_ocean, resolution=resolution, make_fine_grid=True, interpolate_nan=True)

ds_weather = df_weather.set_index(["time", "latitude", "longitude"]).to_xarray()
ds_weather_interp = interpolate_xarray(ds_weather, resolution=resolution, make_fine_grid=True, interpolate_nan=True)

# Aline the time axes
common_time = np.intersect1d(ds_ocean_interp.time.values, ds_weather_interp.time.values) # Finde common time points
ds_ocean_interp = ds_ocean_interp.sel(time=common_time)
ds_weather_interp = ds_weather_interp.sel(time=common_time)

# Aline the lat and lon axes
ds_weather_interp = ds_weather_interp.interp(
    latitude=ds_ocean_interp.latitude,
    longitude=ds_ocean_interp.longitude
)

ds_weather_interp = interpolate_xarray(ds_weather_interp, resolution=resolution, make_fine_grid=False, interpolate_nan=True)
ds_ocean_interp = interpolate_xarray(ds_ocean_interp, resolution=resolution, make_fine_grid=False, interpolate_nan=True)




def create_ocean_mask(ds, land):

    lon, lat = np.meshgrid(ds.longitude.values, ds.latitude.values)
    points = [shapely.geometry.Point(x, y) for x, y in zip(lon.flatten(), lat.flatten())]

    points_gdf = gpd.GeoDataFrame(geometry=points, crs=land.crs)
    joined = gpd.sjoin(points_gdf, land, predicate="within", how="left")
    on_land = ~joined.index_right.isna()

    mask_land = np.array(on_land).reshape(lat.shape)
    mask_ocean = ~mask_land

    ocean_mask_xr = xr.DataArray(
        mask_ocean,
        coords={"latitude": ds.latitude, "longitude": ds.longitude},
        dims=["latitude", "longitude"]
    )
    return ocean_mask_xr

land = gpd.read_file(geodatasets.get_path("naturalearth.land"))
ocean_mask = create_ocean_mask(ds_weather_interp, land)
ds_weather_ocean_only = ds_weather_interp.where(ocean_mask)
ds_ocean_ocean_only = ds_ocean_interp.where(ocean_mask)

ds_ocean_weather_interp = xr.merge([ds_ocean_ocean_only, ds_weather_interp])


In [None]:
ds_ocean_weather_interp

In [None]:


def calculate_correlation_temporal_spatial(ds, df_ref, variable='sla', start_date=None, end_date=None, ax=None, title=None):
    """
    Berechnet die Korrelation zwischen einer xarray-Zeitreihe und einer Pandas-Zeitreihe.
    
    Args:
        ds_ocean (xarray.Dataset): Das xarray-Dataset mit den Ozeandaten.
        df_ref (pandas.DataFrame): Das DataFrame mit den In-situ-Daten.
        variable (str): Der Name der Variablen im xarray-Dataset, die korreliert werden soll.
        
    Returns:
        correlations (numpy.ndarray): Ein Array mit den Korrelationswerten.
    """
    

    fig = None

    ocean_variables = ['bottomT', 'mlotst', 'siconc', 'sithick', 
                       'sla', 'so', 'sob', 'thetao', 'uo', 'vo', 'wo']
    
    air_variables = ['temperature_2m',
                    'relative_humidity_2m', 'dew_point_2m', 'apparent_temperature',
                    'precipitation', 'rain', 'showers', 'snowfall', 'weather_code',
                    'pressure_msl', 'surface_pressure', 'cloud_cover', 'cloud_cover_low',
                    'cloud_cover_mid', 'cloud_cover_high', 'et0_fao_evapotranspiration',
                    'vapour_pressure_deficit', 'wind_speed_10m', 'wind_direction_10m',
                    'wind_gusts_10m']
    

    df_ref_filtered = df_ref.copy()
    ds_filtered = ds.copy()

    if start_date is not None and end_date is not None:
        # Filtere die Daten nach dem angegebenen Zeitraum
        df_ref_filtered = df_ref_filtered[(df_ref_filtered['time'] >= start_date) & (df_ref_filtered['time'] <= end_date)]
        # Filtere die xarray-Daten nach dem angegebenen Zeitraum
        ds_filtered = ds_filtered.sel(time=slice(start_date, end_date))
    else:
        # Wenn kein Zeitraum angegeben ist, verwende die gesamte Zeitreihe
        start_date = ds_filtered['time'].min().values
        end_date = ds_filtered['time'].max().values

        # Filtere die xarray-Daten nach dem gesamten Zeitraum
    
    # slev-Zeitreihe (von df_ref_filtered)
    slev_times = pd.to_datetime(df_ref_filtered['time'])
    slev_values = df_ref_filtered['slev'].values

    # Sicherstellen, dass die Zeiten übereinstimmen
    # Zeitstempel vom xarray
    ocean_times = pd.to_datetime(ds_filtered['time'].values)

    # Indexe der gemeinsamen Zeiten
    common_times, idx_slev, idx_ocean = np.intersect1d(slev_times, ocean_times, return_indices=True)

    # neue Zeitreihen
    slev_values_common = slev_values[idx_slev]
    ocean_times_common = ocean_times[idx_ocean]

    # Leeres Array für Korrelationen
    correlations = np.full((len(ds_filtered.latitude), len(ds_filtered.longitude)), np.nan)

    # Schleife über alle Punkte
    for i, lat in enumerate(ds_filtered.latitude.values):
        for j, lon in enumerate(ds_filtered.longitude.values):
            # Zeitreihe an diesem Punkt
            ts = ds_filtered[variable].isel(latitude=i, longitude=j).values
            
            if np.all(np.isnan(ts)):  # Wenn nur NaNs -> überspringen
                continue
            
            # nur gemeinsame Zeiten auswählen
            ts_common = ts[idx_ocean]
            
            # Wenn zu viele NaNs, überspringen
            if np.isnan(ts_common).mean() > 0.3:  # z.B. >30% fehlende Werte
                continue
            
            # NaNs behandeln
            mask = ~np.isnan(ts_common) & ~np.isnan(slev_values_common)
            if np.sum(mask) < 10:  # Weniger als 10 gültige Werte
                continue
            
            # Korrelation berechnen
            corr = np.corrcoef(ts_common[mask], slev_values_common[mask])[0, 1]
            correlations[i, j] = corr

    # Plotten der Korrelationen

    if ax is None:
        fig, ax = plt.subplots(figsize=(14, 12))
        created_fig = True
    else:
        created_fig = False


    m = Basemap(projection='cyl', resolution='i',
                llcrnrlon=ds_filtered.longitude.min(), urcrnrlon=ds_filtered.longitude.max(),
                llcrnrlat=ds_filtered.latitude.min(), urcrnrlat=ds_filtered.latitude.max(), ax=ax)

    if variable in air_variables:
        m.drawcoastlines(ax=ax)
        m.drawcountries(ax=ax)
        m.drawmapboundary(fill_color='white', ax=ax)
        m.fillcontinents(color='lightgrey', lake_color='white', ax=ax)
        heatmap = ax.pcolormesh(ds_filtered.longitude, ds_filtered.latitude, correlations, cmap='coolwarm', vmin=-1, vmax=1)
        plt.colorbar(heatmap, label='Pearson correlation coefficient', orientation='horizontal', pad=0.1, ax=ax, shrink=0.9)
        ax.scatter(df_ref_filtered['longitude'], df_ref_filtered['latitude'], c='green', s=200, label='Flensburg', marker='*')

    m.drawparallels(np.arange(-360., 360, 2.), labels=[1, 0, 0, 0])
    m.drawmeridians(np.arange(-360., 360, 2.), labels=[0, 0, 0, 1])    

    if variable in ocean_variables:
        heatmap = ax.pcolormesh(ds_filtered.longitude, ds_filtered.latitude, correlations, cmap='coolwarm', vmin=-1, vmax=1)
        plt.colorbar(heatmap, label='Pearson correlation coefficient', orientation='horizontal', pad=0.1, ax=ax, shrink=0.9)
        m.drawcoastlines(ax=ax)
        m.drawcountries(ax=ax)
        m.drawmapboundary(fill_color='white', ax=ax)
        m.fillcontinents(color='lightgrey', lake_color='white', ax=ax)
        ax.scatter(df_ref_filtered['longitude'], df_ref_filtered['latitude'], c='green', s=250, label='Flensburg', marker='*') 
    
    # Set title to describe the variable
    if variable in WEATHER_DICT.keys():
        var_name = WEATHER_DICT[variable]['description'].lower()
    if variable in OCEAN_DICT.keys():
        var_name = OCEAN_DICT[variable]['description'].lower()
    else:
        var_name = variable
    
    start_date_str = pd.to_datetime(start_date).strftime('%Y-%m-%d %H:%M')
    end_date_str = pd.to_datetime(end_date).strftime('%Y-%m-%d %H:%M')
    if title is not None:

        plt.title(title, pad=15)
    else:
        plt.title(f'{start_date_str} to {end_date_str}', pad=15)
        
    plt.legend(loc='upper right')
    plt.tight_layout()
    
    if created_fig:
        plt.show()


    return heatmap





In [None]:
# variables = ['sla', 'uo', 'vo', 'wind_speed_10m', 'wind_direction_10m', 'surface_pressure']
# #variables = ds_ocean_weather_interp.data_vars.keys()
# #variables = ['wind_speed_10m']

# for variable in variables:
#     print(f"Calculating correlation for {variable}")
#     calculate_correlation_temporal_spatial(ds_ocean_weather_interp, df_insitu, variable=variable, 
#                                            #start_date="2023-10-19", end_date="2023-10-21"
#                                            )
#     break


In [None]:


# start_date = ds_ocean_weather_interp['time'].min().values
# end_date = ds_ocean_weather_interp['time'].max().values

# # round datetime to YYYY-MM-DD HH:MM
# start_date = pd.to_datetime(start_date).strftime('%Y-%m-%d %H:%M')
# end_date = pd.to_datetime(end_date).strftime('%Y-%m-%d %H:%M')

# variables = ['sla', 'uo', 'vo', 'wind_speed_10m', 'wind_direction_10m', 'surface_pressure']
# variables

def plot_correlations(variables, start_date, end_date, title:str = None, subplot_title:bool = False,save:bool = False, one_plot:bool = False):
    """
    Plots the correlations between ocean variables and Flensburg SLEV.
    
    Args:
        variables (list): List of variable names to plot.
        start_date (str): Start date for the analysis.
        end_date (str): End date for the analysis.
    """

    if one_plot:
        n_cols = 3
        n_rows = (len(variables) + n_cols - 1) // n_cols

        fig = plt.figure(figsize=(8 * n_cols, 5 * n_rows))
        gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)

        if title is not None:
            fig.suptitle(title, 
                        #fontsize=20, 
                        y=0.98
                        )
        else:
            title = f"Correlation between Ocean Variables and Flensburg SLEV from {start_date} to {end_date}"
            fig.suptitle(title,
                            #fontsize=20,
                            y=0.98
                            )
            
        


        letter_dict = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H', 8: 'I'}
        

        for idx, variable in enumerate(variables):

            print(f"Calculating correlation for {variable}")
            ax = fig.add_subplot(gs[idx])
            calculate_correlation_temporal_spatial(
                ds_ocean_weather_interp,
                df_insitu,
                variable=variable,
                ax=ax,
                start_date=start_date,
                end_date=end_date,
                title=f'Figure {letter_dict[idx]}: {variable}',
            )

        # if subplot_title:
        #     for ax, variable in zip(fig.axes, variables):
        #         if variable in WEATHER_DICT.keys():
        #             var_name = WEATHER_DICT[variable]['description'].lower()
        #         elif variable in OCEAN_DICT.keys():
        #             var_name = OCEAN_DICT[variable]['description'].lower()
        #         else:
        #             var_name = variable
        #         ax.set_title(f'Variable {variable} and slev', pad=15)

        # Statt tight_layout --> subplots_adjust
        fig.subplots_adjust(top=0.90, hspace=0.2, wspace=0.1)  # <-- manuell fein justieren!
        if save:
            fig_name = '../data/plots/correlation_all.png'
            fig.savefig(fig_name, bbox_inches='tight', dpi=300)
        plt.show()

    else:
        for idx, variable in enumerate(variables):
            print(f"Calculating correlation for {variable}")
            fig, ax = plt.subplots(figsize=(14, 12))
            calculate_correlation_temporal_spatial(
                ds_ocean_weather_interp,
                df_insitu,
                variable=variable,
                ax=ax,
                start_date=start_date,
                end_date=end_date,
            )
            if title is not None:
                plt.title(title, pad=15)
            else:
                plt.title(f'Correlation between {variable} and Flensburg SLEV from {start_date} to {end_date}', pad=15)
            plt.tight_layout()
            if save:
                fig_name = f'../thesis_plots/all_correlation_{variable}_map.png'
                fig.savefig(fig_name, bbox_inches='tight', dpi=300)
            plt.show()


def compare_correlations(timepoints:list, variable:str,save:bool = False):
    
    n_cols = 3
    n_rows = (len(timepoints) + n_cols -1) // n_cols
    fig = plt.figure(figsize=(8 * n_cols, 5 * n_rows), dpi=300)

    #fig = plt.figure(figsize=(13.33, 7.5))
    gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)

    if variable in WEATHER_DICT.keys():
        var_name = WEATHER_DICT[variable]['description'].lower()
        var_explaination = WEATHER_DICT[variable]['explanation']
        
        #var_unit = WEATHER_DICT[variable]['unit']
    elif variable in OCEAN_DICT.keys():
        var_name = OCEAN_DICT[variable]['description'].lower()
        var_explaination = OCEAN_DICT[variable]['explanation']
        #var_unit = WEATHER_DICT[variable]['unit']
    else:
        var_name = variable
        var_explaination = ""
        
        # return None
        #var_unit = ""
    fig.suptitle(f"Correlation between water level in Flensburg (slev) and {var_name} ({variable})", 
                 #fontsize=20, 
                 y=0.98)
    

    for idx, timepoint in tqdm(enumerate(timepoints), total=len(timepoints)):

        letter_dict = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H', 8: 'I'}

        start_date = timepoint - pd.Timedelta(days=3)
        end_date = timepoint + pd.Timedelta(days=3)
        ax = fig.add_subplot(gs[idx])
        start_date_str = pd.to_datetime(start_date).strftime('%Y-%m-%d %H:%M')
        end_date_str = pd.to_datetime(end_date).strftime('%Y-%m-%d %H:%M')
        calculate_correlation_temporal_spatial(
            ds_ocean_weather_interp,
            df_insitu,
            variable=variable,
            ax=ax,
            start_date=start_date,
            end_date=end_date,
            title=f"Figure {letter_dict[idx]}: {start_date_str} - {end_date_str}",
        
        )
    fig.subplots_adjust(top=0.92, hspace=0.25, wspace=0.1)  # <-- manuell fein justieren!

    
    if save:
        fig_name = f'../data/plots/correlation_{variable}.png'
        fig.savefig(fig_name, bbox_inches='tight', dpi=300)

    plt.show()



### Plots

In [None]:
import datetime
timepoints_all = sorted(set(df_ocean["time"]) & set(df_weather["time"]))

sturm_surge_list = sorted([
                datetime.datetime(2023, 2, 25, 17, 0),
                datetime.datetime(2023, 4, 1, 12, 0),
                datetime.datetime(2023, 10, 7, 20, 0),
                datetime.datetime(2023, 10, 20, 0, 0),
                datetime.datetime(2024, 1, 3, 9, 0),
                datetime.datetime(2024, 2, 9, 18, 0),
                datetime.datetime(2024, 12, 9, 16, 0),
                ])
variables = ds_ocean_weather_interp.data_vars.keys()
variables = ["sla", "uo", "vo", "wind_speed_10m", "wind_direction_10m", "pressure_msl"]
for variable in variables:
    print(f"Calculating correlation for {variable}")
    compare_correlations(timepoints=sturm_surge_list, variable=variable, save=False, )
    

In [None]:
ds_ocean_weather_interp

In [None]:
dd

In [None]:
start_date = ds_ocean_weather_interp['time'].min().values
end_date = ds_ocean_weather_interp['time'].max().values

# round datetime to YYYY-MM-DD HH:MM
start_date = pd.to_datetime(start_date).strftime('%Y-%m-%d %H:%M')
end_date = pd.to_datetime(end_date).strftime('%Y-%m-%d %H:%M')

print(f"Start date: {start_date}, End date: {end_date}")

variables = ds_ocean_weather_interp.data_vars.keys()
variables = ["sla", "uo", "vo", "wind_speed_10m", "wind_direction_10m", "pressure_msl"]

plot_correlations(variables, start_date, end_date,
                  title=" ",
                  subplot_title=True,
                  save=True,
                  one_plot=False)

## Archive

In [None]:
import datetime

start_date = datetime.datetime(2023, 10, 17, 0, 0).strftime("%Y-%m-%d %H:%M")
end_date = datetime.datetime(2023, 10, 23, 0, 0).strftime("%Y-%m-%d %H:%M")

variables = ds_ocean_weather_interp.data_vars.keys()

plot_water_level_anomalies(df_insitu, start_date=start_date, end_date=end_date)
plot_correlations(
    variables,
    start_date=start_date,
    end_date=end_date
)

In [None]:
import datetime

sturm_surge_list = [datetime.datetime(2023, 2, 25, 17, 0),
                    datetime.datetime(2023, 4, 1, 12, 0),
                    datetime.datetime(2023, 10, 7, 20, 0),
                    datetime.datetime(2023, 10, 20, 0, 0),
                    datetime.datetime(2024, 2, 9, 18, 0),
                    datetime.datetime(2024, 12, 9, 16, 0),
                    ]

for time in sturm_surge_list:
    start_time = time - datetime.timedelta(days=1)
    end_time = time + datetime.timedelta(days=1)
    df_insitu_sturm = df_insitu[(df_insitu["time"] >= start_time) & (df_insitu["time"] <= end_time)]
    plot_water_level_anomalies(df_insitu_sturm, start_date=start_date, end_date=end_date)
    plot_correlations(
        variables,
        start_date=start_time,
        end_date=end_time
    )
    

In [None]:
## Correlaion just at sturm surges 

df_surge = df_insitu.copy()
df_surge = df_surge.loc[df_surge['slev'] >= 0.5]
start_date = df_surge['time'].min()
end_date = df_surge['time'].max()
# round datetime to YYYY-MM-DD HH:MM
start_date = pd.to_datetime(start_date).strftime('%Y-%m-%d %H:%M')
end_date = pd.to_datetime(end_date).strftime('%Y-%m-%d %H:%M')
variables = ds_ocean_weather_interp.data_vars.keys()
plot_water_level_anomalies(df_surge, start_date=start_date, end_date=end_date)
plot_correlations(
    variables,
    start_date=start_date,
    end_date=end_date
)


In [None]:
df_weather.showers

### Lassen sich Korrelationen zwischen den Features und dem Wasserpegel (sla) erkennen?

In [None]:
df_ocean = load_ocean_data(ocean_data_path, OCEAN_POINTS, verbose=False)
df_ocean = process_df(df_ocean, drop_cols=["depth"], verbose=False)

df_weather = load_weather_data(weather_data_path, WEATHER_POINTS, verbose=False)
df_weather = process_df(df_weather, verbose=False)

df_insitu = load_insitu_data(verbose=False)
df_insitu = process_flensburg_data(df_insitu, 
                                      start_time=df_ocean['time'].min(),
                                      end_time=df_ocean['time'].max(),
                                      verbose=False)

df_insitu = group_data_hourly(df_insitu)
df_insitu = process_df(df_insitu, drop_cols=["deph"], verbose=False)


In [None]:
import geodatasets
import geopandas as gpd
import numpy as np
import xarray as xr

resolution = 0.25 # 0.25 degrees

def interpolate_xarray(ds, resolution=0.25, make_fine_grid=True, interpolate_nan=True):
    """
    Interpolates the xarray dataset to a finer grid.

    Parameters:
        ds (xarray.Dataset): The input dataset to interpolate.
        resolution (float): The desired resolution for the interpolation.

    Returns:
        xarray.Dataset: The interpolated dataset.
    """

    if interpolate_nan:
        # Interpolation von NaN-Werten
        ds = ds.interpolate_na(dim="time", method="linear")
        ds = ds.interpolate_na(dim="latitude", method="linear")
        ds = ds.interpolate_na(dim="longitude", method="linear")
    
    if make_fine_grid:
        # Neues feineres Gitter erzeugen
        new_lats = np.arange(ds.latitude.min(), ds.latitude.max(), resolution)
        new_lons = np.arange(ds.longitude.min(), ds.longitude.max(), resolution)

        # Interpolation
        ds = ds.interp(latitude=new_lats, longitude=new_lons, method="linear")
  
    return ds

# Create xarray datasets from DataFrames
# Interpolate the xarray to a higher resolution
ds_ocean = df_ocean.set_index(["time", "latitude", "longitude"]).to_xarray()
ds_ocean_interp = interpolate_xarray(ds_ocean, resolution=resolution, make_fine_grid=True, interpolate_nan=True)

ds_weather = df_weather.set_index(["time", "latitude", "longitude"]).to_xarray()
ds_weather_interp = interpolate_xarray(ds_weather, resolution=resolution, make_fine_grid=True, interpolate_nan=True)

# Aline the time axes
common_time = np.intersect1d(ds_ocean_interp.time.values, ds_weather_interp.time.values) # Finde common time points
ds_ocean_interp = ds_ocean_interp.sel(time=common_time)
ds_weather_interp = ds_weather_interp.sel(time=common_time)

# Aline the lat and lon axes
ds_weather_interp = ds_weather_interp.interp(
    latitude=ds_ocean_interp.latitude,
    longitude=ds_ocean_interp.longitude
)

ds_weather_interp = interpolate_xarray(ds_weather_interp, resolution=resolution, make_fine_grid=False, interpolate_nan=True)
ds_ocean_interp = interpolate_xarray(ds_ocean_interp, resolution=resolution, make_fine_grid=False, interpolate_nan=True)

def create_ocean_mask(ds, land):

    lon, lat = np.meshgrid(ds.longitude.values, ds.latitude.values)
    points = [shapely.geometry.Point(x, y) for x, y in zip(lon.flatten(), lat.flatten())]

    points_gdf = gpd.GeoDataFrame(geometry=points, crs=land.crs)
    joined = gpd.sjoin(points_gdf, land, predicate="within", how="left")
    on_land = ~joined.index_right.isna()

    mask_land = np.array(on_land).reshape(lat.shape)
    mask_ocean = ~mask_land

    ocean_mask_xr = xr.DataArray(
        mask_ocean,
        coords={"latitude": ds.latitude, "longitude": ds.longitude},
        dims=["latitude", "longitude"]
    )
    return ocean_mask_xr

land = gpd.read_file(geodatasets.get_path("naturalearth.land"))
ocean_mask = create_ocean_mask(ds_weather_interp, land)
ds_weather_ocean_only = ds_weather_interp.where(ocean_mask)
ds_ocean_ocean_only = ds_ocean_interp.where(ocean_mask)

ds_ocean_weather_interp = xr.merge([ds_ocean_ocean_only, ds_weather_interp])


---

#### Calculating Bivariate Moran’s I 

**Erklärung:** 

**Abgrenzung zu Korrelation und zeitlicher Autokorrelation**
Bei der Berechnung der statistischen Korrelation werden zwei Variablen (x,y) bei zwei oder mehr Beobachtungen betrachtet; bei der räumlichen Autokorrelation hingegen eine Variable x an zwei oder mehr Orten.[3]

Während die zeitliche Autokorrelation die Beziehungen der Ausprägungen einer Variablen mit sich selbst über die Zeit beschreibt, beschreibt die räumliche Autokorrelation die Ausprägungen einer Variablen mit sich selbst im Raum. 

**Berechnung**
*Positive räumliche Autokorrelation *liegt dann vor, wenn nahe beieinander liegende Orte einander mit höherer Wahrscheinlichkeit ähnlich sind als weiter voneinander entfernte Orte. Das heißt: Positive räumliche Autokorrelation liegt vor, wenn Orte dazu tendieren, im Hinblick auf eine Eigenschaft Cluster zu bilden. Positive räumliche Autokorrelation ist eine empirische Manifestation von [Toblers](https://de.wikipedia.org/wiki/Erstes_Gesetz_der_Geographie) „Erstem Gesetz der Geographie“.

*Negative räumliche Autokorrelation* liegt dann vor, wenn benachbarte Orte im Vergleich zu zufälliger Anordnung[5] unterschiedliche Eigenschaftswerte aufweisen. Bei Phänomenen, die mit Lebewesen (Tieren, Pflanzen) verbunden sind, wird negative Autokorrelation häufig durch Wettbewerb und Verdrängung verursacht.
d
*Keine räumliche Autokorrelation* liegt vor, wenn die Orte im Hinblick auf eine Eigenschaft zufällig angeordnet sind, also keine ausgeprägten Cluster aufweisen.

[Wikipedia](https://de.wikipedia.org/wiki/Räumliche_Autokorrelation?utm_source=chatgpt.com) 

**Was ist räumliche Autokorrelation (spatial autocorrelatio**n)?

Räumliche Autokorrelation beschreibt, wie stark der Wert einer Variablen an einem Ort mit den Werten derselben Variablen an anderen Orten zusammenhängt.

**Vergleich mit klassischer Korrelation:**
Normale (statistische) Korrelation:
→ Fragt: "Wie hängen zwei verschiedene Variablen miteinander zusammen?"
Beispiel: Wenn der Luftdruck steigt, sinkt vielleicht die Regenwahrscheinlichkeit.
Räumliche Autokorrelation:
→ Fragt: "Wie hängt der Wert einer einzigen Variable an einem Ort mit den Werten derselben Variable an benachbarten Orten zusammen?"
Beispiel: Ist die Wasserhöhe an Punkt A ähnlich wie an den umliegenden Punkten?

**Und wie unterscheidet sich das von zeitlicher Autokorrelation?**
Zeitliche Autokorrelation:
→ Fragt: "Wie hängt ein Wert heute mit den Werten derselben Variablen in der Vergangenheit oder Zukunft zusammen?"
Beispiel: Die Temperatur heute hängt oft mit der Temperatur gestern zusammen.
Räumliche Autokorrelation:
→ Fragt: "Wie hängt der Wert an einem Ort mit Werten an anderen Orten zusammen?"

**Achtung!!**
Verwende den bivariaten Moran’s I, wenn du den räumlichen Zusammenhang zwischen zwei Variablen in geografischen Daten untersuchen möchtest. Dieser berücksichtigt, wie benachbarte Gebiete miteinander korrelieren und wie sich räumliche Muster bilden.
Beispiel: Du analysierst, ob hohe Wasserhöhen in einer Region mit hohen Wasserhöhen in den benachbarten Regionen zusammenhängen und ob der Luftdruck in benachbarten Regionen ebenfalls ähnliche Werte aufweist.
Verwende den Pearson’s R, wenn du den linearen Zusammenhang zwischen zwei Variablen messen möchtest, ohne die räumliche Abhängigkeit zu berücksichtigen. Dies ist besonders nützlich, wenn du den direkten Zusammenhang zwischen zwei Variablen untersuchen möchtest, ohne auf ihre geografische Lage zu achten.
Beispiel: Du möchtest wissen, ob es einen linearen Zusammenhang zwischen der Wasserhöhe und dem Luftdruck gibt, unabhängig von den geografischen Standorten der Daten.





In [None]:
import geopandas as gpd

#from splot.esda import moran_bv_plot
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from esda.moran import Moran_BV
from libpysal.weights import KNN
from shapely.geometry import Point
from tqdm import tqdm  # Fortschrittsbalken für test_k


def calculating_morans_I(ds, var1: str, var2: str, time: str, k=3, plot=False, test_k=False):
    # 1. Zeitschritt auswählen
    ds = ds.sel(time=time)

    # 2. Koordinaten extrahieren
    lat, lon = np.meshgrid(ds.latitude.values, ds.longitude.values, indexing='ij')
    coords = np.column_stack([lat.ravel(), lon.ravel()])
    gdf = gpd.GeoDataFrame({
        var1: ds[var1].values.ravel(),
        var2: ds[var2].values.ravel()
    }, geometry=[Point(xy) for xy in coords])

    # 3. NaNs entfernen
    gdf = gdf.dropna(subset=[var1, var2]).reset_index(drop=True)

    # 4. KNN Spatial Weights
    if len(gdf) <= k:
        raise ValueError(f"k={k} is too large for the dataset size of {len(gdf)}")
    w = KNN.from_dataframe(gdf, k=k)
    w.transform = "r"

    # 5. Moran's I berechnen
    x = gdf[var1]
    y = gdf[var2]
    moran_bv = Moran_BV(x, y, w)

    # 6. Optional: K-Test
    if test_k:
        list_k = np.arange(3, min(50, len(gdf)), 3)
        results = []

        for k_ in tqdm(list_k, desc="Testing k values"):
            if len(gdf) <= k_:
                continue
            try:
                w_ = KNN.from_dataframe(gdf, k=k_)
                w_.transform = "r"
                m_bv = Moran_BV(x, y, w_)
                results.append({"k": k_, "Moran's I": m_bv.I, "p-value": m_bv.p_sim})
            except Exception as e:
                print(f"Fehler bei k={k_}: {e}")

        df_results = pd.DataFrame(results)
        print(df_results)




    # 7. Optional: Plot
    if plot:
        print(f"Bivariate Moran’s I: {moran_bv.I:.4f}")
        print(f"P-Wert (Monte Carlo): {moran_bv.p_sim:.4f}")

        z_x = (x - x.mean()) / x.std()
        z_y = (y - y.mean()) / y.std()
        wz_y = w.sparse @ z_y

        plt.figure(figsize=(8, 6))
        plt.scatter(z_x, wz_y, alpha=0.5)
        plt.axhline(0, color='red', linestyle='--')
        plt.axvline(0, color='red', linestyle='--')
        plt.title("Bivariate Moran's I Scatter Plot")
        plt.xlabel("Variable X (Standardized)")
        plt.ylabel("Spatial Lag of Y (Standardized)")
        plt.grid(True, linestyle='--', alpha=0.3)
        plt.tight_layout()
        plt.show()

    return moran_bv.I, moran_bv.p_sim


# Usage
time = "2024-10-20 19:00"
var1 = "sla"
var2 = "pressure_msl"

morans_I, morans_p = calculating_morans_I(ds_ocean_weather_interp, var1=var1, var2=var2, time=time, k=8, plot=False, test_k=True)





In [None]:
# make list of times between 2024-01-02 18:00 and 2024-01-05 05:00
start = df_ocean["time"].min()
end = df_ocean["time"].max()

start = "2023-10-01 00:00"
end = "2023-11-01 00:00"
timepoints = pd.date_range(start=start, end=end, freq="1h")



var1 = "sla"
var2 = "wind_speed_10m"

list_morans_I = []
list_morans_p = []
for time in timepoints:
    #print(f"Calculating Moran's I for time: {time}")
    # Calculate Moran's I for each timepoint
    morans_I, morans_p = calculating_morans_I(ds_ocean_weather_interp, var1=var1, var2=var2, time=time, k=5, plot=False)
    list_morans_I.append(morans_I)
    list_morans_p.append(morans_p)

array_morans_I = np.array(list_morans_I)
array_morans_p = np.array(list_morans_p)

# Plotting the results
plt.figure(figsize=(12, 6))
plt.plot(timepoints, array_morans_I, label="Moran's I", color="blue")
plt.axhline(0, color="red", linestyle="--")
plt.title("Moran's I over Time")
plt.xlabel("Time")
plt.ylabel("Moran's I")
plt.xticks(rotation=45)
plt.grid()
plt.legend()
plt.tight_layout()
plt.show()


Ziel: Berechnet die Pearson-Korrelation zwischen zwei Xarray-Daten (x und y), wobei dim='time' spezifiziert, dass die Korrelation entlang der Zeitdimension durchgeführt wird.

xr.corr: Dies ist eine Funktion von Xarray, die den Pearson-Korrelationskoeffizienten zwischen zwei Variablen berechnet, indem sie ihre Werte entlang einer bestimmten Dimension (in diesem Fall time) vergleicht.

x und y sind die beiden Xarray-Datenarrays (z. B. Zeitserien von sla und pressure_msl).
dim='time' bedeutet, dass die Korrelation für jede räumliche Position über die Zeit hinweg berechnet wird. Zum Beispiel, wenn du Zeitserien für verschiedene geografische Punkte hast, wird der Korrelationswert für jeden Punkt berechnet.

In [None]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt

# Select subset of data from start_time to end_time
# start_time = "2023-01-19"
# end_time = "2023-10-21"

# Select all times in df_ocean where df_insitu[slev] larger then 1.0
df_stormsurge = df_insitu[df_insitu["slev"] > 1.0]
# Select the timepoints from df_stormsurge
timepoints = df_stormsurge["time"].values

ds_ocean_weather_interp_sub = ds_ocean_weather_interp.sel(time=timepoints)


#time_string = f"{start_time.strftime('%Y-%m-%d')} to {end_time.strftime('%Y-%m-%d')}"
time_string = f'Correlation of all timpoints if sturm surges'

#ds_ocean_weather_interp_sub = ds_ocean_weather_interp.sel(time=slice(start_time, end_time))

def pearson_r(x, y):
    return xr.corr(x, y, dim='time')

def plot_correlation(x, y, title="Correlation Map (Contour)"):
    """
    Plots the correlation map as a contour plot using Cartopy.
    
    Parameters:
        x, y (xarray.DataArray): Variables to compute correlation from.
        title (str): Title of the plot.
    """
    correlation_map = pearson_r(x, y)

    fig = plt.figure(figsize=(10, 6))
    ax = plt.axes(projection=ccrs.PlateCarree())

    # Konturplot
    correlation_map.plot.contourf(
        ax=ax,
        transform=ccrs.PlateCarree(),
        cmap='coolwarm', # 
        levels=50,  # Optional: Anzahl der Konturlinien
        vmin=-1,
        vmax=1,
        cbar_kwargs={'label': 'Pearson Correlation Coefficient','orientation': 'horizontal', 'shrink': 0.8, 'pad': 0.05},

    )

    ax.coastlines()
    ax.add_feature(cfeature.BORDERS)
    # colors to land
    ax.add_feature(cfeature.LAND, facecolor='lightgrey', alpha=0.9)
    # add lat and lon gridlines
    ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    ax.set_title(title)
    plt.show()

# Ignore RuntimeWarning
warnings.filterwarnings("ignore", category=RuntimeWarning)

plot_correlation(ds_ocean_weather_interp_sub['sla'], ds_ocean_weather_interp_sub['pressure_msl'], title=f"Correlation between SLA and Pressure from {time_string}")


In [None]:
def plot_correlation_with_currents(x, y, u, v, title="Correlation Map (Contour + Currents)"):
    """
    Plots the correlation map as a contour plot with ocean current arrows using Cartopy.

    Parameters:
        x, y (xarray.DataArray): Variables to compute correlation from.
        u, v (xarray.DataArray): Zonal (uo) and meridional (vo) current components.
        title (str): Title of the plot.
    """
    correlation_map = pearson_r(x, y)

    fig = plt.figure(figsize=(12, 6))
    ax = plt.axes(projection=ccrs.PlateCarree())

    # Konturplot
    correlation_map.plot.contourf(
        ax=ax,
        transform=ccrs.PlateCarree(),
        cmap='coolwarm',
        levels=21,
        vmin=-1,
        vmax=1,
        cbar_kwargs={'label': 'Pearson Correlation Coefficient',
                     'orientation': 'horizontal', 'shrink': 0.8, 'pad': 0.05},
    )

    # Subsampling für bessere Übersicht
    step = 1
    u_sub = u.isel(latitude=slice(None, None, step), longitude=slice(None, None, step))
    v_sub = v.isel(latitude=slice(None, None, step), longitude=slice(None, None, step))

    # Gitterkoordinaten extrahieren
    lat_sub = u_sub.latitude.values
    lon_sub = u_sub.longitude.values
    lon2d, lat2d = np.meshgrid(lon_sub, lat_sub)

    # Quiver-Pfeile plotten (auf 2D-Arrays achten!)
    ax.quiver(
        lon2d,
        lat2d,
        u_sub.values,
        v_sub.values,
        transform=ccrs.PlateCarree(),
        scale=10,  # Anpassen nach Daten
        width=0.002,
        color='black'
    )

    # Zusätzliche Kartenfeatures
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS)
    ax.add_feature(cfeature.LAND, facecolor='lightgrey', alpha=0.9)
    ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    ax.set_title(title)
    plt.show()

u_mean = ds_ocean_weather_interp_sub['uo'].mean(dim='time')
v_mean = ds_ocean_weather_interp_sub['vo'].mean(dim='time')


plot_correlation_with_currents(
    ds_ocean_weather_interp_sub['sla'],
    ds_ocean_weather_interp_sub['pressure_msl'],
    u_mean,
    v_mean,
    title=f"Correlation between SLA and Pressure with Currents from {time_string}"
)

In [None]:
import numpy as np
import xarray as xr


def select_nearest_valid_point(ds, variable_name, target_lat, target_lon):
    """
    Wähle den nächstgelegenen gültigen Punkt (nicht-NaN) im Dataset für eine bestimmte Variable.
    
    Parameters:
        ds (xr.Dataset): Das Eingabe-Dataset mit latitude, longitude, und time-Dimensionen.
        variable_name (str): Name der Variable zur Prüfung auf Gültigkeit (z.B. 'wind_speed_10m').
        target_lat (float): Ziel-Breitengrad.
        target_lon (float): Ziel-Längengrad.

    Returns:
        xr.Dataset: Subset des ursprünglichen Datasets an der nächsten gültigen Position.
        float: Breitengrad der gültigen Position.
        float: Längengrad der gültigen Position.
    """
    # Maske gültiger Punkte entlang der Zeitachse
    valid_mask = ds[variable_name].notnull().any(dim='time')

    # 2D-Gitter der Koordinaten
    lat2d, lon2d = np.meshgrid(ds.latitude.values, ds.longitude.values, indexing='ij')

    # Nur gültige Koordinaten extrahieren
    valid_lat_points = lat2d[valid_mask.values]
    valid_lon_points = lon2d[valid_mask.values]

    # Distanzberechnung (euklidisch)
    distances = np.sqrt((valid_lat_points - target_lat)**2 + (valid_lon_points - target_lon)**2)

    # Index des nächsten gültigen Punkts
    min_idx = np.argmin(distances)
    nearest_lat = valid_lat_points[min_idx]
    nearest_lon = valid_lon_points[min_idx]

    print(f"Nächstgelegener gültiger Punkt: lat={nearest_lat:.6f}, lon={nearest_lon:.6f}")
    
    return ds.sel(latitude=nearest_lat, longitude=nearest_lon), nearest_lat, nearest_lon



In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr
from mpl_toolkits.basemap import Basemap
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import MinMaxScaler

# === 1. In-situ Daten laden ===
# def load_insitu_data(filepath):
#     ds = xr.open_dataset(filepath)
#     df = ds.to_dataframe().reset_index()
#     return df.rename(columns={
#         "TIME": "time",
#         "SLEV": "slev",
#         "LATITUDE": "latitude",
#         "LONGITUDE": "longitude"
#     })

df_insitu = load_insitu_data()
df_insitu = flensburg_data_processing(df_insitu)
df_insitu = interpolate_missing_times(df_insitu)

insitu_location = (df_insitu.latitude.iloc[0], df_insitu.longitude.iloc[0])

# Get the nearest valid point in xarray to Flensburg
target_lat = 54.5
target_lon = 10.0
flensburg_ds, nearest_lat, nearest_lon = select_nearest_valid_point(ds_ocean_weather_interp_sub, 'wind_speed_10m', target_lat, target_lon)



# === 2. Wetterdaten vorbereiten ===
# Stelle sicher, dass flensburg_ds vorher korrekt definiert ist
wetter_df = flensburg_ds.to_dataframe().reset_index().set_index("time")
wetter_location = (wetter_df.latitude.iloc[0], wetter_df.longitude.iloc[0])


# === 3. Karte mit Basemap zeichnen ===
def plot_locations(insitu_loc, weather_loc):
    center_lat = (insitu_loc[0] + weather_loc[0]) / 2
    center_lon = (insitu_loc[1] + weather_loc[1]) / 2

    plt.figure(figsize=(10, 8))
    m = Basemap(projection='merc',
                llcrnrlat=center_lat - 2, urcrnrlat=center_lat + 2,
                llcrnrlon=center_lon - 2, urcrnrlon=center_lon + 2,
                resolution='i')

    m.drawcoastlines()
    m.drawcountries()
    m.drawmapboundary(fill_color='white')
    m.fillcontinents(color='lightgray', lake_color='white')
    m.drawparallels(np.arange(0., 90., 0.5), labels=[1,0,0,0])
    m.drawmeridians(np.arange(0., 180., 0.5), labels=[0,0,0,1])

    x_insitu, y_insitu = m(insitu_loc[1], insitu_loc[0])
    x_weather, y_weather = m(weather_loc[1], weather_loc[0])

    m.plot(x_insitu, y_insitu, 'bo', markersize=8, label='In-situ Location')
    m.plot(x_weather, y_weather, 'ro', markersize=8, label='Weather Location')

    plt.text(x_insitu+10000, y_insitu+10000, 'Flensburg In-Situ', 
             #fontsize=12, 
             color='black')
    plt.text(x_weather+10000, y_weather+10000, 'Closest Point', 
             #fontsize=12, 
             color='red')

    plt.legend(loc='upper left')
    plt.title('In-situ vs. Wetterpunkt')
    plt.show()

plot_locations(insitu_location, wetter_location)

# === 4. Zeitliche Synchronisierung ===
df_insitu = df_insitu.set_index("time").resample("h").mean(numeric_only=True)
wetter_df = wetter_df.resample("h").mean()

# === 5. Daten zusammenführen ===
merged_df = pd.merge(df_insitu, wetter_df, left_index=True, right_index=True, how="inner")
display(merged_df.head(2))
display(merged_df.tail(2))

# === 6. Explorative Analyse ===
features = ['sla', 'slev', 'wind_speed_10m', 'surface_pressure', 'precipitation', 'wind_direction_10m', 'vo']
#features = merged_df.columns.tolist()


sns.pairplot(merged_df[features], diag_kind='kde')
plt.show()

# === 7. Korrelationen anzeigen ===
correlation = merged_df.corr(numeric_only=True)
print("Korrelationsmatrix:")
display(correlation["slev"].sort_values(ascending=False))

# 10 grö0te Korrelationen
features = correlation["slev"].sort_values(ascending=False).nlargest(10).index.tolist()
features += correlation["slev"].sort_values(ascending=False).nsmallest(10).index.tolist()
print(features)

# === 8. Zeitreihenvisualisierung ===
merged_df[features].plot(
    subplots=True, figsize=(10, 17), title=f"Zeitreihen from {time_string}")
plt.tight_layout()
plt.legend(loc='upper left')
plt.xlabel("Time")
plt.show()

# === 9. Lineare Regression ===
def run_regression(df, target_col, feature_cols):
    X = df[feature_cols]
    y = df[target_col]

    scaler = MinMaxScaler()
    X_scaled = scaler.fit_transform(X)
    #X_scaled = X

    model = LinearRegression()
    model.fit(X_scaled, y)

    print("\nRegressionskoeffizienten:")
    for feat, coef in zip(feature_cols, model.coef_):
        print(f"{feat}: {coef:.4f}")
    print(f"R² Score: {model.score(X_scaled, y):.4f}")

run_regression(merged_df, "slev", features[1:])


In [None]:
features

In [None]:
ds_merged_sub_sub

In [None]:
dd

In [None]:
def plot_scatter(ds_merged, x_col, y_col, c, title, xlabel, ylabel, dim=['latitude', 'longitude']):
    plt.figure(figsize=(10, 6))
    plt.scatter(ds_merged[x_col].mean(dim=dim), ds_merged[y_col].mean(dim=dim), c=ds_merged[c].mean(dim=dim), cmap='viridis', alpha=0.5)
    plt.colorbar(label=c)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.show()

plot_scatter(ds_merged=ds_merged_sub, 
             x_col='wind_speed_10m', 
             y_col='sla',
             c='sla',
            title='SLA vs Pressure MSL',
            xlabel='SLA [m]',
            ylabel='sla',
            )


# col1 = ds_merged['sla'].mean(dim=['latitude', 'longitude'])
# col2 = ds_merged['wind_speed_10m'].mean(dim=['latitude', 'longitude'])

# plt.scatter(col1, col2, alpha=0.5, marker='o', cmap='viridis', c=ds_merged['wind_direction_10m'].mean(dim=['latitude', 'longitude']))
# plt.xlabel('mean SLA [m]')
# plt.ylabel('mean Wind Speed [m/s]')
# plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr


def plot_contour(ds_merged, x_col, y_col, c, title, xlabel, ylabel, bins=500, dim=['latitude', 'longitude']):
    """
    Erzeugt einen Contourplot von Mittelwerten über angegebene Dimensionen (z. B. lat/lon).
    """
    # Mittelwerte über Raumdimensionen
    x = ds_merged[x_col].mean(dim=dim).values.flatten()
    y = ds_merged[y_col].mean(dim=dim).values.flatten()
    z = ds_merged[c].mean(dim=dim).values.flatten()

    # Entferne NaNs
    mask = ~np.isnan(x) & ~np.isnan(y) & ~np.isnan(z)
    x, y, z = x[mask], y[mask], z[mask]

    # Erzeuge ein 2D-Gitter durch Histogramm-Binning
    xi = np.linspace(np.min(x), np.max(x), bins)
    yi = np.linspace(np.min(y), np.max(y), bins)
    Xi, Yi = np.meshgrid(xi, yi)

    # Interpolation der Z-Werte auf das Gitter
    from scipy.interpolate import griddata
    Zi = griddata((x, y), z, (Xi, Yi), method='linear')

    # Plot
    plt.figure(figsize=(10, 6))
    contour = plt.contourf(Xi, Yi, Zi, levels=50, cmap='viridis')
    cbar = plt.colorbar(contour)
    cbar.set_label(c)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.tight_layout()
    plt.show()

plot_contour(ds_merged=ds_merged, 
             x_col='wind_speed_10m', 
             y_col='wind_direction_10m',
             c='sla',
             title='Contourplot: SLA vs Pressure MSL',
             xlabel='SLA [m]',
             ylabel='Pressure MSL [hPa]',
             bins=100,
            )


In [None]:
dd

In [None]:
ds_interp = ds_weather.interpolate_na(dim="latitude", method="linear")
ds_interp = ds_interp.interpolate_na(dim="longitude", method="linear")
ds_interp

In [None]:
import numpy as np

# Altes Gitter
old_lats = ds_interp.latitude
old_lons = ds_interp.longitude

# Neues feineres Gitter erzeugen (z. B. 0.25° statt 1.0° Auflösung)
new_lats = np.arange(old_lats.min(), old_lats.max(), 0.25)
new_lons = np.arange(old_lons.min(), old_lons.max(), 0.25)

# Interpolation
ds_interp = ds_interp.interp(latitude=new_lats, longitude=new_lons, method='linear')


In [None]:
ds_interp

In [None]:
import matplotlib.pyplot as plt

# Beispiel: Temperatur auswählen
temperature = ds_interp['pressure_msl'].sel(time='2023-10-18T12:00:00')

# Plot
fig = plt.figure(figsize=(10, 6))
ax = plt.axes(projection=ccrs.PlateCarree())  # oder z. B. ccrs.Mercator()

temperature.plot(ax=ax, transform=ccrs.PlateCarree(), cmap='coolwarm', cbar_kwargs={'label': '°C'})

# Extras: Küstenlinien etc.
ax.coastlines()
ax.add_feature(cfeature.BORDERS)
ax.set_title("Temperaturkarte")

plt.show()


In [None]:
temperature = ds_interp['pressure_msl'].sel(time='2023-10-18T12:00:00')

fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.PlateCarree()})
ax.coastlines()
ax.add_feature(cfeature.BORDERS)

# Konturlinien
cs = ax.contourf(temperature.longitude, temperature.latitude, temperature, 
                 levels=20, cmap='coolwarm',  # <- Korrektur hier
                 linewidths=1, transform=ccrs.PlateCarree())


ax.set_title("Konturlinien des Luftdrucks")
plt.show()


In [None]:
ds_interp

In [None]:
u = ds_interp['u10'].sel(time='2023-10-18T12:00:00')  # Ost-Komponente
v = ds_interp['v10'].sel(time='2023-10-18T12:00:00')  # Nord-Komponente

fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.PlateCarree()})
ax.coastlines()
ax.add_feature(cfeature.BORDERS)

# Vektorpfeile (Wind)
q = ax.quiver(u.lon[::5], u.lat[::5], u[::5, ::5], v[::5, ::5], transform=ccrs.PlateCarree(), scale=700)

ax.set_title("Windvektoren")
plt.show()


In [None]:
fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.PlateCarree()})
ax.coastlines()

ax.streamplot(u.lon, u.lat, u.values, v.values, transform=ccrs.PlateCarree(), color='blue', density=1.5)
ax.set_title("Stromlinien der Windrichtung")
plt.show()


In [None]:
temp = ds_interp['temperature_2m'].mean(dim=['latitude', 'longitude'])
humidity = ds_interp['relative_humidity_2m'].mean(dim=['latitude', 'longitude'])

plt.scatter(temp, humidity, alpha=0.5)
plt.xlabel('Temperatur [°C]')
plt.ylabel('relative Luftfeuchtigkeit [%]')
plt.show()


In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

# Beispielzeitpunkt auswählen
time_sel = '2023-10-20T12:00'

# Daten selektieren
temp = ds_interp['relative_humidity_2m'].sel(time=time_sel)
wind_speed = ds_interp['wind_speed_10m'].sel(time=time_sel)
wind_dir = ds_interp['wind_direction_10m'].sel(time=time_sel)

# Windrichtung und -geschwindigkeit → u, v-Komponenten umrechnen
wind_u = wind_speed * -np.sin(np.deg2rad(wind_dir))
wind_v = wind_speed * -np.cos(np.deg2rad(wind_dir))

# Plot erstellen
fig, ax = plt.subplots(figsize=(12, 8), subplot_kw={'projection': ccrs.PlateCarree()})

# Temperaturkarte (colormap)
temp.plot(ax=ax, transform=ccrs.PlateCarree(), cmap='coolwarm', cbar_kwargs={'label': 'Temperatur [°C]'})

# Küstenlinie, Ländergrenzen etc.
ax.coastlines()
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.3)

# Windvektoren
# Downsamplen für bessere Übersicht (z. B. jeden 3. Punkt)
step = 2
lat = ds_interp.latitude[::step]
lon = ds_interp.longitude[::step]
u = wind_u[::step, ::step]
v = wind_v[::step, ::step]

ax.quiver(lon, lat, u, v, transform=ccrs.PlateCarree(), color='black', scale=700)

# Titel
ax.set_title(f"Temperatur und Wind am {time_sel}", 
             #fontsize=14
             )
plt.tight_layout()
plt.show()


In [None]:
dd

### Merge Dataframes

In [None]:
import pandas as pd


def process_coord(df: pd.DataFrame, coord: tuple) -> pd.DataFrame:
    """
    Processes a subset of the DataFrame for a given coordinate.
    """
    df_sub = df[df['position'] == coord].copy()
    df_sub.drop(columns=['latitude', 'longitude', 'position'], inplace=True)

    value_columns = df_sub.columns.difference(['time'])
    df_sub.rename(
        columns={col: f"{col}_{coord}" for col in value_columns},
        inplace=True
    )

    return df_sub

def convert_df_joblib(df: pd.DataFrame, n_jobs: int = -1) -> pd.DataFrame:
    """
    Converts the DataFrame by pivoting values for unique coordinates
    using parallel processing via joblib.
    """
    df['position'] = list(zip(df['latitude'], df['longitude']))
    unique_coords = df['position'].unique()
    print(f"Number of unique coordinates: {len(unique_coords)}")

    df_merged = pd.DataFrame({'time': df['time'].unique()})

    # Parallel processing
    results = Parallel(n_jobs=n_jobs)(
        delayed(process_coord)(df, coord) for coord in tqdm(unique_coords)
    )

    # Merge all partial DataFrames
    for df_sub in results:
        df_merged = df_merged.merge(df_sub, on='time', how='left')

    return df_merged

# Usage
df_ocean_converted = convert_df_joblib(df_ocean)
df_weather_converted = convert_df_joblib(df_weather)



In [None]:
# Merge df_ocean and df_weather and df_insitu
df_merged = df_ocean_converted.merge(df_weather_converted, on='time', how='inner')
df_merged = df_merged.merge(df_insitu, on='time', how='inner')

In [None]:
df_merged.shape

In [None]:
df_merged.columns

In [None]:
# correlation matrix
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

corr = df_merged.corr()
# mask = np.triu(np.ones_like(corr, dtype=bool))
# sns.heatmap(corr, mask=mask, cmap="coolwarm", annot=False, fmt=".2f", square=True, cbar_kws={"shrink": .8})
# plt.title("Correlation Matrix")
# plt.show()





In [None]:
# Find the highest correlation between the columns and slev
corr_slev = corr["slev"].nlargest(100)


# Display the correlation values in a bar plot
plt.figure(figsize=(12, 8))
sns.barplot(x=corr_slev.index, y=corr_slev.values)
plt.xticks(rotation=90)
plt.title("Correlation with slev")
plt.xlabel("Features")
plt.ylabel("Correlation")
plt.tight_layout()
plt.show()

In [None]:
corr_slev.index 

In [None]:
# Find the highest correlation between the columns and slev
corr_slev = corr["slev"].nsmallest(100)


# Display the correlation values in a bar plot
plt.figure(figsize=(12, 8))
sns.barplot(x=corr_slev.index, y=corr_slev.values)
plt.xticks(rotation=90)
plt.title("Correlation with slev")
plt.xlabel("Features")
plt.ylabel("Correlation")
plt.tight_layout()
plt.show()

In [None]:
def get_index_name(corr_slev: pd.Series, name: str) -> str:
    """
    Get the first index name that contains the specified substring.
    """
    for col in corr_slev.index:
        if name in col:
            return col
    return None

# Get the index name of the column that contains 'wind'
sla_col = get_index_name(corr_slev, "sla")
wo_col = get_index_name(corr_slev, "wo")
pressure_msl_col = get_index_name(corr_slev, "pressure_msl")
surface_pressure_col = get_index_name(corr_slev, "surface_pressure")
print(f"Wind column: {sla_col}")


In [None]:
import pandas as pd


def scale_df(df: pd.DataFrame) -> pd.DataFrame:
    """
    Scale the DataFrame using MinMaxScaler, excluding the 'time' column.
    """
    scaler = MinMaxScaler()
    
    # Save time column and drop it from the data to be scaled
    time = df["time"]
    data_to_scale = df.drop(columns=["time"])

    # Fit and transform the data (excluding 'time')
    scaled_values = scaler.fit_transform(data_to_scale)

    # Create scaled DataFrame
    df_scaled = pd.DataFrame(scaled_values, columns=data_to_scale.columns)
    df_scaled["time"] = time

    return df_scaled

# Beispiel-Nutzung
df_scaled = scale_df(df_merged)


In [None]:
def get_position_of_column(df: pd.DataFrame, col_name: str) -> tuple:
    """
    Get the position (latitude, longitude) of a specific column in the DataFrame.
    """
    # Split the column name to extract latitude and longitude
    parts = col_name.split("_")[-1]
    parts = eval(parts)
    lat = float(parts[-2])
    lon = float(parts[-1])
    return lat, lon

In [None]:
pressure_msl_col

In [None]:
# Plot slev

# normalize the like MinMaxScaler


fig, ax = plt.subplots(figsize=(12, 6))

# Plot the water level
ax.plot(df_scaled['time'], df_scaled['slev'], label='Wasserstand', color='blue')

# plot the sla_col 
#ax.plot(df_scaled['time'], df_scaled[sla_col], label=sla_col, color='red')

ax.plot(df_scaled['time'], df_scaled[wo_col], label=wo_col, color='green')

#ax.plot(df_scaled['time'], df_scaled[pressure_msl_col], label=pressure_msl_col, color='orange')


# plot the position of the column on map
lat, lon = get_position_of_column(df_scaled, wo_col)
print(f"Position of {wo_col}: {lat}, {lon}")
# Create a Basemap
fig_map, ax_map = plt.subplots(figsize=(12, 10))
m = Basemap(
    projection="cyl",
    resolution="i",
    llcrnrlon=lon_grid.min(),
    urcrnrlon=lon_grid.max(),
    llcrnrlat=lat_grid.min(),
    urcrnrlat=lat_grid.max(),
    ax=ax_map,
)
# Draw map features
m.drawcoastlines()
m.drawcountries()
m.fillcontinents(color="0.8")
m.drawstates()
m.drawmapboundary(fill_color="aqua")
m.fillcontinents(color="coral", lake_color="aqua", alpha=0.2)
# Scatterplot for ocean data
x, y = m(lon, lat)
m.scatter(x, y, color="blue", label="Ocean Data", zorder=5)
# Add a title and legend
plt.title(f"Position of {pressure_msl_col} on map")
plt.legend(loc="upper left")
plt.show()

