In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import glob

import pyrootutils

root = pyrootutils.setup_root(
    search_from=os.getcwd(),
    indicator=".project-root",
    pythonpath=True,
    dotenv=True,
)

In [None]:
%matplotlib inline

import hydra
import omegaconf

import src.eval
import src.utils
import src.utils.plotting

import pandas as pd

import seaborn as sns

import numpy as np
from sklearn.metrics import mean_squared_error as skmse
from darts.metrics import mse
from darts import concatenate

from src.datamodules.veas_pilot_datamodule import get_input_dataframe

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.patches as patches
colors = [(0,0.4,1),(1,0.7,0.3),(0.2,0.7,0.2),(0.8,0,0.2),(0.5,0.3,.9),(0.9,0.7,.5),(.3,0.7,1)]

# Configuration

In [None]:
model_dir = [
    "logs/train/runs/2601to2802/2024-09-02_11-00-02",
]
model_dir = [src.utils.hydra.get_absolute_project_path(md) for md in model_dir]

In [None]:
config_path = os.path.join(
    "..", "..", "configs", "eval.yaml"
)  # NB: relative to <project_root>/src/utils (must be relative path)

config_overrides_dot = [  # same notation as for cli overrides (dot notation). Useful for changing whole modules, e.g. change which datamodule file is loaded
    "++extras.disable_pytorch_lightning_output=True",
    "++eval.kwargs.show_warnings=False",
    #"datamodule=",
]

In [None]:
cfg = []
for md in model_dir:
    config_overrides_dict = dict(
        model_dir=md
    ) # Dictionary with overrides. Useful for larger changes/additions/deletions that does not exist as entire files.

    cfg.append(src.utils.initialize_hydra(
        config_path,
        config_overrides_dot,
        config_overrides_dict,
        return_hydra_config=True,
        print_config=False,
    ))  # print config to inspect if all settings are as expected

In [None]:
model_name = [None] * len(cfg)
model = [None] * len(cfg)
datamodule = [None] * len(cfg)
trainer = [None] * len(cfg)
logger = [None] * len(cfg)
for idx, c in enumerate(cfg):
    object_dict = src.utils.instantiate.instantiate_saved_objects(c)
    model[idx], datamodule[idx], trainer[idx], logger[idx] = (
        object_dict["model"],
        object_dict["datamodule"],
        object_dict.get("trainer"),
        object_dict.get("logger"),
    )
    model_name[idx] = str(object_dict["model"].__class__.__name__)

In [None]:
for c in cfg:
    with omegaconf.open_dict(c):
        c.eval.kwargs.forecast_horizon = 1
        c.eval.kwargs.stride = 1
        c.eval.plot.every_n_prediction = 1
        c.eval.predictions = {"return": {"data": True}}
        c.eval.plot.presenter = [
            "show",
            "savefig",
        ]  # set presenter to "show" to show figures in output, and "savefig" to save them to the model_dir

# Evaluate
The src.eval.run function returns the configured metrics over the evaluated split.

In [None]:
eval_object_dict = [None] * len(cfg)
for idx, c in enumerate(cfg):
    metric_dict, eval_object_dict[idx] = src.eval.run(c, datamodule[idx], model[idx], trainer[idx], logger[idx])

In [None]:
fig = plt.figure(figsize=(12,4.5))
plt.plot(eval_object_dict[0]['predictions_data']['series'].time_index, eval_object_dict[0]['predictions_data']['series'].all_values().squeeze(), color='k', label='Measured nitrate')
for idx, obj in enumerate(eval_object_dict):
    plt.plot(obj['predictions'].time_index, obj['predictions'].all_values().squeeze(), color=colors[idx], label=model_name[idx])
plt.legend(fontsize=12, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=3)
plt.show()

In [None]:
val_targets = eval_object_dict[0]['predictions_data']['series'].all_values().squeeze()
time_series_list = datamodule[0].get_data(['target'], main_split='train', transform=False)['target']
concatenated_series = concatenate(time_series_list, axis='time', ignore_time_axis=True)
train_mean = concatenated_series.all_values().squeeze().mean()
val_train_mean = train_mean*np.ones_like(val_targets)

In [None]:
errors = []
labels = []

train_error = skmse(val_targets, val_train_mean)
errors.append(train_error)
labels.append('Train set mean')
print('MSEs:')
print('Train set mean: {:.2f}'.format(skmse(val_targets, val_train_mean)))

for idx, obj in enumerate(eval_object_dict):
    val_common = eval_object_dict[0]['predictions_data']['series'].slice_intersect(obj['predictions'])
    val_targets_aligned = eval_object_dict[0]['predictions_data']['series'].slice(val_common.start_time(), val_common.end_time())
    val_predictions_aligned = obj['predictions'].slice(val_common.start_time(), val_common.end_time())
    model_error = mse(val_targets_aligned, val_predictions_aligned)
    print('{}: {:.2f}'.format(model_name[idx], model_error))
    errors.append(model_error)
    labels.append(model_name[idx])

plt.figure(figsize=(5, 3))
plt.bar(labels, errors, color='skyblue')
plt.ylabel('MSE')
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

Plotting for the paper:

In [None]:
def plot_interval(start_time, end_time, remove=[]):
    src.utils.plotting.set_matplotlib_attributes()
    sliced_measured = eval_object_dict[0]['predictions_data']['series'].slice(pd.Timestamp(start_time), pd.Timestamp(end_time))
    sliced_dict = {}
    for idx, obj in enumerate(eval_object_dict):
        sliced_dict[model_name[idx]] = obj['predictions'].slice(pd.Timestamp(start_time), pd.Timestamp(end_time))
        
    fig, ax = plt.subplots()
    plt.plot(sliced_measured.time_index, sliced_measured.all_values().squeeze(), label='Measured nitrate', linewidth=1)
    for idx, model in enumerate(model_name):
        if not model in remove:
            plt.plot(sliced_dict[model].time_index, sliced_dict[model].all_values().squeeze(), label='Nowcast w/ TCN', linewidth=1)
    
    # Add first red dashed rectangle
    rect1 = patches.Rectangle(
        (mdates.date2num(pd.Timestamp('2024-02-04 02:00:00')), 2.5),  # bottom-left corner
        mdates.date2num(pd.Timestamp('2024-02-04 14:00:00')) - mdates.date2num(pd.Timestamp('2024-02-04 02:00:00')),  # width
        11.5 - 2.5,  # height
        linewidth=1,
        edgecolor='r',
        facecolor='none',
        linestyle='--'
    )
    ax.add_patch(rect1)
    
    # Add second red dashed rectangle
    rect2 = patches.Rectangle(
        (mdates.date2num(pd.Timestamp('2024-02-25 12:00:00')), 1.0),  # bottom-left corner
        mdates.date2num(pd.Timestamp('2024-02-26 03:00:00')) - mdates.date2num(pd.Timestamp('2024-02-25 12:00:00')),  # width
        9.0 - 1.0,  # height
        linewidth=1,
        edgecolor='r',
        facecolor='none',
        linestyle='--'
    )
    ax.add_patch(rect2)
    
    ax.xaxis.set_major_locator(mdates.AutoDateLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) 
    plt.gcf().autofmt_xdate()
    plt.legend()
    ax.set_ylabel("Nitrate concentration [mg/l]")
    src.utils.plotting.set_figure_size(fig, column_span="double", height=6)
    src.utils.plotting.save_figure(fig, "../figures/pilot/analyzing_tcn")
    plt.show()

In [None]:
start_time = '2024-01-31 00:00:00'
end_time = '2024-02-29 00:00:00'
plot_interval(start_time, end_time)

In [None]:
input_df, output_df = get_input_dataframe(source='data/denit_pilot_b_240527.csv', hour_average=False, include_controlled=True)

output_df = output_df[input_df['operational']]
input_df = input_df[input_df['operational']]
input_df = input_df.drop(columns=['operational'])

input_df.index = input_df.time
input_df = input_df.drop(columns='time')
output_df['time'] = pd.to_datetime(output_df['time'], dayfirst=True)
output_df.index = output_df.time
output_df = output_df.drop(columns='time')

all_df = pd.concat([input_df, output_df], axis=1)

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot(all_df['nitrate_out'][start_time:end_time], label='nitrate_out')
plt.plot(input_df['nitrate_in'][start_time:end_time], color=colors[3], label='nitrate_in')
plt.legend(fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot((all_df['nitrate_out']/input_df['nitrate_in'])[start_time:end_time], color=colors[0], label='nitrate_out/nitrate_in')
plt.legend(fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot(all_df['nitrate_out'][start_time:end_time], label='nitrate_out')
plt.plot(.2*input_df['methanol'][start_time:end_time], color=colors[3], label='methanol')
plt.legend(fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot(input_df['nitrate_in'][start_time:end_time], color=colors[1], label='nitrate_in')
plt.plot(1.3*input_df['methanol'][start_time:end_time], color=colors[3], label='methanol')
plt.legend(fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot((input_df['methanol']/input_df['nitrate_in'])[start_time:end_time], color=colors[0], label='methanol/nitrate_in')
plt.legend(fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot(all_df['nitrate_out'][start_time:end_time], label='nitrate_out')
plt.plot(input_df['oxygen'][start_time:end_time], color=colors[3], label='oxygen')
plt.legend(fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot(all_df['nitrate_out'][start_time:end_time], label='nitrate_out')
plt.plot(input_df['ammonium'][start_time:end_time], color=colors[3], label='ammonium')
plt.legend(fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot(all_df['nitrate_out'][start_time:end_time], label='nitrate_out')
plt.plot(10*input_df['orto-p'][start_time:end_time], color=colors[3], label='orto-p')
plt.legend(fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot(all_df['nitrate_out'][start_time:end_time], label='nitrate_out')
plt.plot(input_df['temp'][start_time:end_time]-10, color=colors[3], label='temp')
plt.legend(fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot(all_df['nitrate_out'][start_time:end_time], label='nitrate_out')
plt.plot(.2*input_df['turb'][start_time:end_time], color=colors[3], label='turb')
plt.legend(fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot(all_df['nitrate_out'][start_time:end_time], label='nitrate_out')
plt.plot(.001*input_df['tunnelwater'][start_time:end_time], color=colors[3], label='tunnelwater')
plt.legend(fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot(all_df['nitrate_out'][start_time:end_time], label='nitrate_out')
plt.plot(30*input_df['filterpressure_1'][start_time:end_time]-15, color=colors[3], label='filterpressure_1')
plt.legend(fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,3))
plt.plot(all_df['nitrate_out'][start_time:end_time], label='nitrate_out')
plt.plot(30*input_df['filterpressure_8'][start_time:end_time], color=colors[3], label='filterpressure_8')
plt.legend(fontsize=12)
plt.show()

## Analyzing the input data in the interesting periods

In [None]:
def plot_interval(start_time, end_time, remove=[]):
    sliced_measured = eval_object_dict[0]['predictions_data']['series'].slice(pd.Timestamp(start_time), pd.Timestamp(end_time))
    sliced_dict = {}
    for idx, obj in enumerate(eval_object_dict):
        sliced_dict[model_name[idx]] = obj['predictions'].slice(pd.Timestamp(start_time), pd.Timestamp(end_time))
        
    fig = plt.figure(figsize=(12,4))
    plt.plot(sliced_measured.time_index, sliced_measured.all_values().squeeze(), color='k', label='Measured nitrate')
    for idx, model in enumerate(model_name):
        if not model in remove:
            plt.plot(sliced_dict[model].time_index, sliced_dict[model].all_values().squeeze(), color=colors[idx], label=model)
    plt.legend(fontsize=12, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=3)
    plt.show()

In [None]:
start_time = '2024-01-26 00:00:00'
end_time = '2024-02-29 00:00:00'
plot_interval(start_time, end_time)

In [None]:
start_time = '2024-02-20 00:00:00'
end_time = '2024-02-29 00:00:00'

In [None]:
src.utils.plotting.set_matplotlib_attributes()
plt.rcParams['legend.edgecolor'] = 'black'
plt.rcParams['legend.facecolor'] = 'white'

sliced_measured = eval_object_dict[0]['predictions_data']['series'].slice(pd.Timestamp(start_time), pd.Timestamp(end_time))

sliced_dict = {}

for idx, obj in enumerate(eval_object_dict):
    sliced_dict[model_name[idx]] = obj['predictions'].slice(pd.Timestamp(start_time), pd.Timestamp(end_time))

fig, ax = plt.subplots(figsize=(17/34*5.9488188976377945, 2.3622047244094486))
#fig, ax = plt.subplots(figsize=(18/34*5.9488188976377945, 2.3622047244094486))

plt.plot(all_df['nitrate_out'][start_time:end_time], label='Measured nitrate', linewidth=1)
plt.plot(0.2 * input_df['methanol'][start_time:end_time], color=colors[2], label='Added methanol', linewidth=1)

# Custom x-axis formatting to increase spacing between labels and remove the year
ax.xaxis.set_major_locator(mdates.DayLocator(interval=2))  # Adjust interval as necessary
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))  # Month and day only
plt.gcf().autofmt_xdate(rotation=45)  # Rotate labels for better readability

# Remove the y-axis label and ticks
ax.set_yticks([])
ax.set_ylabel('')

# Move legend to the top-left corner with a white background and black border
legend = plt.legend(
    loc='upper left',
    fancybox=False,
    framealpha=1,
    shadow=False,
    borderpad=.4
)
legend.get_frame().set_linewidth(0.5)

# Save the figure
src.utils.plotting.save_figure(fig, "../figures/pilot/methanol")
plt.show()



In [None]:
start_time = '2024-02-02 00:00:00'
end_time = '2024-02-11 00:00:00'

In [None]:
src.utils.plotting.set_matplotlib_attributes()
plt.rcParams['legend.edgecolor'] = 'black'
plt.rcParams['legend.facecolor'] = 'white'
legend.get_frame().set_linewidth(0.5)

sliced_measured = eval_object_dict[0]['predictions_data']['series'].slice(pd.Timestamp(start_time), pd.Timestamp(end_time))

sliced_dict = {}

for idx, obj in enumerate(eval_object_dict):
    sliced_dict[model_name[idx]] = obj['predictions'].slice(pd.Timestamp(start_time), pd.Timestamp(end_time))

fig, ax = plt.subplots(figsize=(17/34*5.9488188976377945, 2.3622047244094486))
#fig, ax = plt.subplots(figsize=(18/34*5.9488188976377945, 2.3622047244094486))

plt.plot(all_df['nitrate_out'][start_time:end_time], label='Measured nitrate', linewidth=1)
plt.plot(1+.2*input_df['turb'][start_time:end_time], color=colors[3], label='Measured turbidity', linewidth=1)
plt.plot(30*input_df['filterpressure_8'][start_time:end_time], color=colors[1], label='Pressure above reactor', linewidth=1)

# Custom x-axis formatting to increase spacing between labels and remove the year
ax.xaxis.set_major_locator(mdates.DayLocator(interval=2))  # Adjust interval as necessary
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))  # Month and day only
plt.gcf().autofmt_xdate(rotation=45)  # Rotate labels for better readability

# Remove the y-axis label and ticks
ax.set_yticks([])
ax.set_ylabel('')

# Move legend to the top-left corner with a white background and black border
legend = plt.legend(
    loc='upper right',
    fancybox=False,
    framealpha=1,
    shadow=False,
    borderpad=.4
)
legend.get_frame().set_linewidth(0.5)

# Save the figure
src.utils.plotting.save_figure(fig, "../figures/pilot/turb_pressure")
plt.show()