In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import pyrootutils

root = pyrootutils.setup_root(
    search_from=os.getcwd(),
    indicator=".project-root",
    pythonpath=True,
    dotenv=True,
)

In [None]:
import darts.dataprocessing.encoders
import darts.utils.statistics
import hydra.utils
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import src.utils
import src.utils.plotting

# %matplotlib inline
# %matplotlib notebook

# Configuration

In [None]:
def get_config(overrides=None):
    config_path = os.path.join(
        "..", "..", "configs", "train.yaml"
    )  # NB: relative to <project_root>/src/utils (must be relative path)

    config_overrides_dot = [  # same notation as for cli overrides (dot notation). Useful for changing whole modules, e.g. change which datamodule file is loaded
        "datamodule=veas_pilot",
    ]
    if overrides is not None:
        config_overrides_dot.extend(overrides)
    config_overrides_dict = (
        dict()
    )  # Dictionary with overrides. Useful for larger changes/additions/deletions that does not exist as entire files.

    cfg = src.utils.initialize_hydra(
        config_path,
        config_overrides_dot,
        config_overrides_dict,
        return_hydra_config=False,
        print_config=False,
    )  # print config to inspect if all settings are as expected

    return cfg

In [None]:
show_encoders = False
cfg = get_config(["experiment=veas_pilot_test"])
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
datamodule.setup("fit")

if show_encoders and cfg.model.get("add_encoders") is not None:
    encoders = darts.dataprocessing.encoders.SequentialEncoder(
        hydra.utils.instantiate(cfg.model.add_encoders),
        takes_past_covariates=True,
        takes_future_covariates=True,
    )
else:
    encoders = None

# Data Exploration

## Plot Nitrate

In [None]:
src.utils.plotting.set_matplotlib_attributes(font_size=8)
nitrate_cfg = get_config()
nitrate_datamodule = hydra.utils.instantiate(nitrate_cfg.datamodule, _convert_="partial")
nitrate_datamodule.setup("fit")
# If slice = None, plot all data. If slice is not None, only plot data within (start_time, end_time)
slice = None  # (pd.Timestamp("1966"), pd.Timestamp("1975")) # None
fig = nitrate_datamodule.plot_data(presenter=None, slice=slice)
og_legend = None

In [None]:
[p for p_i, p in enumerate(fig[0].axes[0].patches) if p_i in [0, 2, 11]]

In [None]:
data_fig_path = os.path.join("..", "figures", "pilot")
data_plot_height = 6
src.utils.plotting.set_matplotlib_attributes(font_size=8)
ax = fig[0].axes[0]
if og_legend is None:
    og_legend = ax.legend()
src.utils.plotting.set_figure_size(fig[0], "double", height=data_plot_height)
ax.set_title("Nitrate out")
ax.set_ylabel("Nitrate concentration [mg/l]")
ax.set_xlabel("")
ax.legend(
    handles=[p for p_i, p in enumerate(fig[0].axes[0].patches) if p_i in [0, 5, 11]],
    labels=["Train", "Val", "Test"],
)
src.utils.plotting.save_figure(fig[0], os.path.join(data_fig_path, "nitrate_out"))
fig[0]

## Nitrate in and removal

In [None]:
cfg_with_waterflow = get_config(["datamodule.data_variables.future_covariates=[waterflow]"])
flow_datamodule = hydra.utils.instantiate(cfg_with_waterflow.datamodule, _convert_="partial")
flow_datamodule.setup("fit")

In [None]:
f, axs = plt.subplots(3, 1, sharex=True)
ax = axs[0]

flow_rate_l_hour = 3.3 * 3600

input_load = nitrate_datamodule.data[0]["nitrate_in"] * flow_rate_l_hour * 1 / 1e6  # (kg / 1e6 mg)
input_load.plot(label="_nolegend_", ax=ax)
ax.set_xlabel("")
ax.set_title("Nitrate nitrogen loading rate")
ax.set_ylabel("Loading rate [kg N/h]")

ax = axs[1]
pd_data = nitrate_datamodule.data[0].pd_dataframe()

reduced_load = (
    (pd_data["nitrate_in"] - pd_data["nitrate_out"]) * flow_rate_l_hour * 1 / 1e6
)  # (kg / 1e6 mg)
reduced_load[reduced_load < 0] = 0
reduced_load = darts.timeseries.TimeSeries.from_dataframe(
    pd.DataFrame({"reduced_load": reduced_load}),
)
reduced_load.plot(label="_nolegend_", ax=ax)
ax.set_xlabel("")
ax.set_title("Nitrate nitrogen reduction rate")
ax.set_ylabel("Reduction rate [kg N/h]")

ax = axs[2]

nitrate_conversion = (pd_data["nitrate_in"] - pd_data["nitrate_out"]) / pd_data["nitrate_in"]
nitrate_conversion[nitrate_conversion > 1] = 1
nitrate_conversion[nitrate_conversion < 0] = 0
nitrate_conversion = darts.timeseries.TimeSeries.from_dataframe(
    pd.DataFrame({"nitrate_conversion": nitrate_conversion}),
)

nitrate_conversion.plot(label="_nolegend_", ax=ax)
ax.set_xlabel("")
ax.set_ylabel("Nitrate reduction [-]")
ax.set_title("Degree of nitrate reduction")
f.align_ylabels()
src.utils.plotting.set_figure_size(f, column_span="double", height=15)
src.utils.plotting.save_figure(f, os.path.join(data_fig_path, "nitrate_rates_and_conversion"))

In [None]:
nitrate_conversion.plot()

In [None]:
(pd_data["nitrate_in"] - pd_data["nitrate_out"]).plot()

### Temperature Histogram for paper

In [None]:
src.utils.plotting.set_matplotlib_attributes()

plot_type = "hist"
density = False

fig_folder = "../figures/pilot/eda/"

data = {}


for split in ["train", "val", "test"]:
    data[split] = datamodule.get_data(["future_covariates"], main_split=split, transform=False)

# fig, ax = plt.subplots(1, 1, figsize=(6, 4))
# fix issue with transparency
# ax.set_rasterized(True)

split_names = ["train", "val", "test"]

bins = np.linspace(6, 16, 50)
bin_centers = 0.5 * (bins[:-1] + bins[1:])

temp_data = dict()

for split in split_names:
    if split == "test":
        split_temp_data = np.concatenate(
            [
                data[split]["future_covariates"][i]["temp"].all_values().squeeze()
                for i in range(len(data[split]["future_covariates"]))
            ]
        )
    else:
        split_temp_data = data[split]["future_covariates"]["temp"].all_values().squeeze()
    temp_data[split] = split_temp_data

if plot_type == "bar":
    for split in split_names:
        temp_data[split], _ = np.histogram(split_temp_data, bins=bins)

    df = pd.DataFrame(
        {
            "Bin": np.tile(bin_centers, 3),
            "Count": np.concatenate([temp_data[split_name] for split_name in split_names]),
            "Dataset": ["train"] * len(temp_data["train"])
            + ["val"] * len(temp_data["val"])
            + ["test"] * len(temp_data["test"]),
        }
    )
    sns.barplot(x="Bin", y="Count", hue="Dataset", data=df)
    # plt.title("Distribution of temperature")
    plt.xlabel("Temperature °C")
    xticks, xticklabels = plt.xticks()
    plt.xticks(xticks[::3], [f"{tick:.1f}" for tick in bin_centers[::3]])
elif plot_type == "hist":
    df = pd.DataFrame(
        {
            "Temp": np.concatenate([temp_data[split_name] for split_name in split_names]),
            "Dataset": ["train"] * len(temp_data["train"])
            + ["val"] * len(temp_data["val"])
            + ["test"] * len(temp_data["test"]),
        }
    )
    # density and common norm for equal sized distributions
    if density:
        ax = sns.histplot(
            x="Temp", hue="Dataset", data=df, element="step", stat="density", common_norm=False
        )
    else:
        ax = sns.histplot(x="Temp", hue="Dataset", data=df, element="step")
    ax.legend_.set_title("")
    sns.move_legend(ax, loc=(0.5, 0.6))
plt.xlabel("Temperature")
fig = plt.gcf()
src.utils.plotting.set_figure_size(fig, column_span="single", height=5)
src.utils.plotting.save_figure(fig, os.path.join(fig_folder, "temperature_distribution"))
plt.show()

## Autocorrelation over sets for paper

In [None]:
import statsmodels

nlags = 144
fig = plt.figure(figsize=(8, 4))

for split in ["train", "val", "test"]:
    split_data = datamodule.get_data(["target"], main_split=split, transform=False)["target"]

    if isinstance(split_data, list):
        split_data = np.concatenate([sd.all_values().squeeze() for sd in split_data])
    else:
        split_data = split_data.all_values().squeeze()

    split_ac = statsmodels.tsa.stattools.acf(split_data, nlags=nlags)
    plt.plot(split_ac, label=split)

plt.legend()
xticks, xticklabels = plt.xticks()
xticks = list(range(0, 145, 12))
plt.xticks(xticks, [f"{tick // 6:.0f}" for tick in xticks])
plt.xlabel("Length of input (hours)")
plt.ylabel("Autocorrelation")
src.utils.plotting.save_figure(fig, os.path.join(fig_folder, "autocorrelation"))
plt.show()

In [None]:
split_ac = statsmodels.tsa.stattools.acf(split_data, nlags=144)

In [None]:
plt.plot(split_ac)

### Cross Correlation Matrix (aggregates over time)

In [None]:
src.utils.plotting.set_matplotlib_attributes(font_size=8)

covariate_names = src.utils.plotting.get_covariate_plot_names()

fig_path = "../figures/pilot/eda/feature_correlations"
plot_together = False
splits = ["train", "val", "test"]

if plot_together:
    fig, axs = plt.subplots(nrows=1, ncols=len(splits), figsize=(len(splits) * 6, 6), sharey=True)
    axs = axs.ravel()

# src.utils.plotting.set_figure_size(fig, column_span="double" if plot_together else "single", height=6)

for split_i, split in enumerate(splits):
    split_data = datamodule._get_split_data_raw(split)

    if isinstance(split_data, list):
        df = pd.concat(series.pd_dataframe() for series in split_data)
    else:
        df = split_data.pd_dataframe()

    df = df.rename(columns=covariate_names)
    df = df[sorted(df.columns)]

    if plot_together:
        ax = axs[split_i]
    else:
        fig, ax = plt.subplots(1, 1, figsize=(6, 6))

    corr = df.corr()
    # make correlation matrix upper triangular as it is symmetric
    corr_mask = np.triu(np.ones_like(corr, dtype=bool))

    sns.set_theme(style="white")
    cmap = sns.diverging_palette(230, 20, as_cmap=True)
    sns.heatmap(
        corr,
        annot=True,
        mask=corr_mask,
        cmap=cmap,
        center=0,
        square=True,
        linewidths=0.5,
        cbar_kws={"shrink": 0.5},
        vmax=1,
        vmin=-1,
        ax=ax,
        cbar=False,  # not plot_together or split_i == len(splits) - 1,
        fmt=".2f",
        annot_kws={"fontsize": 8},
    )
    src.utils.plotting.set_matplotlib_attributes(font_size=8)
    ax.tick_params(axis="x", labelrotation=45)
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.set_title(split.capitalize(), fontsize=8)

    if not plot_together or split_i == len(splits) - 1:
        src.utils.plotting.set_figure_size(fig, column_span=12, height="same")
        src.utils.plotting.save_figure(fig, fig_path + f"_{split}")
        plt.show()

## Only target variable correlation across dataset

In [None]:
src.utils.plotting.set_matplotlib_attributes(font_size=8)
covariate_names = src.utils.plotting.get_covariate_plot_names()
target_variable = "Nitrate out"
landscape = True

fig_path = "../figures/pilot/eda/feature_correlations_target"
splits = ["train", "val", "test"]

# src.utils.plotting.set_figure_size(fig, column_span="double" if plot_together else "single", height=6)
corr_matrices = {}

for split_i, split in enumerate(splits):
    split_data = datamodule._get_split_data_raw(split)

    if isinstance(split_data, list):
        df = pd.concat(series.pd_dataframe() for series in split_data)
    else:
        df = split_data.pd_dataframe()

    df = df.rename(columns=covariate_names)
    corr_matrices[split] = df.corr()[target_variable][1:]
    corr_matrices[split].index = corr_matrices[split].index.sort_values()

corr = pd.DataFrame(corr_matrices)
if landscape:
    corr = corr.transpose()
    fig_path += "_landscape"

fig, ax = plt.subplots(1, 1, figsize=(6, 6))
cmap = sns.diverging_palette(230, 20, as_cmap=True)
sns.heatmap(
    corr,
    annot=True,
    cmap=cmap,
    center=0,
    square=True,
    linewidths=0.5,
    cbar_kws={"shrink": 0.5},
    vmax=1,
    vmin=-1,
    ax=ax,
    cbar=False,  # not plot_together or split_i == len(splits) - 1,
    fmt=".2f",
    annot_kws={"fontsize": 8},
)
src.utils.plotting.set_matplotlib_attributes(font_size=8)
ax.tick_params(axis="x", labelrotation=45)
ax.tick_params(axis="y", labelrotation=0)
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_title(f"{target_variable} correlations across datasets", fontsize=8)

src.utils.plotting.set_figure_size(fig, column_span="double", height=6)
src.utils.plotting.save_figure(fig, fig_path)
plt.show()

### Cross Correlation in time

In [None]:
series_to_plot = {"train": "all"}
max_lag = 16
highlight_lag = 12


def cross_correlation_func(ts1, ts2):
    return darts.utils.statistics.plot_ccf(
        ts1, ts2, max_lag=max_lag, m=highlight_lag, alpha=0.05, fig_size=(10, 5)
    )  # cross-correlation function

In [None]:
_ = datamodule.call_function_on_pairs_of_univariate_series(
    cross_correlation_func, series_to_plot, presenter="show"
)

### Cross Correlation in time

In [None]:
series_to_plot = {"train": "all"}
max_lag = 16
highlight_lag = 12


def cross_correlation_func(ts1, ts2):
    return darts.utils.statistics.plot_ccf(
        ts1, ts2, max_lag=max_lag, m=highlight_lag, alpha=0.05, fig_size=(10, 5)
    )  # cross-correlation function

In [None]:
_ = datamodule.call_function_on_pairs_of_univariate_series(
    cross_correlation_func, series_to_plot, presenter="show"
)