In [None]:
# Imports
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from enum import Enum
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.backends.cudnn as cudnn
from sklearn.preprocessing import StandardScaler, MinMaxScaler, QuantileTransformer
from sklearn.model_selection import ParameterGrid
import mlflow
import time
import gc
import multiprocessing as mp
import optuna
from optuna.trial import TrialState

%load_ext autoreload
%autoreload 2

from forecaster.scalers import *
from forecaster.preprocessing import TimeseriesDataSet, Granularity
from forecaster.models import *
from forecaster.training import EarlyStopper, ModelTrainer, TimeseriesForecaster
from forecaster.evaluation import evaluate_series, top_n_by_metric, get_full_results_dict

In [None]:
# Environment setup
os.makedirs(f"../logs/mlruns/.trash", exist_ok=True)
torch.cuda.empty_cache()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
cudnn.benchmark = True
mp.set_start_method("spawn", force=True)

In [None]:
# Data set configuration
PATH = "../data/train_data.csv" 
GRANULARITY = Granularity.HOURLY
N_SERIES = 1 # Number of parallel time series in the dataset

# Logging configuration
USE_MLFLOW = True
SAVE_MODEL = False

# Untuned Parameters
BATCH_SIZE = 512
N_EPOCHS = 100
OUTPUT_SIZE = 24
SCALER = StandardScaler
TRAIN_VALIDATION_SPLIT = 0.7
USE_TIME_COVARIATES = True
LR = 0.001
DROPOUT = 0.1
INPUT_SIZE = 48

In [None]:
# Data Loading
pandas_df = pd.read_csv(PATH)
pandas_df.info()

In [None]:
pandas_df = pandas_df.iloc[:, :N_SERIES + 1]

In [None]:
fig = go.Figure()
for i in range(min(N_SERIES, 10)):  # Plot only first 10 series
    fig.add_trace(go.Scatter(
        x=pandas_df["deviceTimestamp"],
        y=pandas_df[f"value_{i+1}"],
        mode="lines",
        name=f"Value_{i+1}"
    ))
fig.update_layout(
    title="All Series over Time",
    width=1200,
    height=400,
    xaxis_title="deviceTimestamp",
    yaxis_title="Value"
)
fig.show()

In [None]:
MODEL_MAP = {
    "LSTMAttention": LSTMAttention, "LSTM": LSTM, "GRUAttention": GRUAttention, "GRU": GRU
}
SCALER_MAP = {
    "StandardScaler": StandardScaler, "MinMax": MinMaxScaler,
    "LogStandardScaler": LogStandardScaler
}
LOSS_MAP = {
    "MSELoss": nn.MSELoss(), "L1Loss": nn.L1Loss(), "HuberLoss": nn.HuberLoss()
}

def objective(trial):
    model_name = "GRU" #trial.suggest_categorical("model_type", list(MODEL_MAP.keys()))
    scaler_name = trial.suggest_categorical("scaler", list(SCALER_MAP.keys()))
    loss_name = trial.suggest_categorical("loss_fn", list(LOSS_MAP.keys()))
    
    time_covariates = USE_TIME_COVARIATES
    learning_rate = LR
    hidden_dim = trial.suggest_categorical("hidden_dim", [16, 32, 64]) 
    num_layers = trial.suggest_int("num_layers", 2, 8, step=2)
    input_size = trial.suggest_categorical("input_size", [24, 48, 72])
    dropout = trial.suggest_float("dropout", 0.1, 0.5, step=0.1)

    ModelClass = MODEL_MAP[model_name]
    scaler_class = SCALER_MAP[scaler_name]
    loss_fn = LOSS_MAP[loss_name]
    
    ds = TimeseriesDataSet(
        pandas_df, GRANULARITY, time_covariates,
        "deviceTimestamp", TRAIN_VALIDATION_SPLIT,
        scaler_class, input_size, OUTPUT_SIZE
    )
    N_COV = ds.get_time_feature_count()

    model = ModelClass(
        input_dim=1 + N_COV,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        output_dim=OUTPUT_SIZE,
        dropout=dropout
    ).to(DEVICE)

    opt = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=N_EPOCHS)
    stopper = EarlyStopper(patience=20, min_delta=1e-4)

    def optuna_callback(epoch, metrics):
        val_rmse = metrics['val_rmse']
        val_mae = metrics['val_mae']
        train_rmse = metrics['train_rmse']
        train_mae = metrics['train_mae']
        learning_rate = metrics['learning_rate']
        epoch_time = metrics['epoch_time']
        
        print(f"Epoch {epoch}: Train RMSE {train_rmse:.4f} | Train MAE {train_mae:.4f} | Val RMSE {val_rmse:.4f} | Val MAE {val_mae:.4f} | LR {learning_rate:.6f} | Epoch time {epoch_time:.2f}s")

    tr = ModelTrainer(
        model, DEVICE, opt, loss_fn, scheduler,
        ds.get_train_dataloader(BATCH_SIZE, True),
        ds.get_validation_dataloader(BATCH_SIZE),
        early_stopper=stopper
    )

    # Fit the model
    tr.fit(N_EPOCHS, on_epoch_end=optuna_callback)

    forecaster = TimeseriesForecaster(model, DEVICE, input_size, OUTPUT_SIZE)
    start_time = time.time()
    predictions_dict = forecaster.predict_all_series(ds, BATCH_SIZE)
    print(f"Prediction Time: {time.time() - start_time:.2f} seconds")
    
    start_time = time.time()
    unscaled_summary = evaluate_series(
        dataset=ds,
        preds=predictions_dict,
        input_size=INPUT_SIZE,
        output_size=OUTPUT_SIZE
    )
    print(f"Evaluation Time: {time.time() - start_time:.2f} seconds")

    val_mae = np.mean([m['val_mae'] for m in unscaled_summary.values()])

    del ds, model, opt, tr, scheduler, forecaster, predictions_dict, unscaled_summary
    gc.collect()
    torch.cuda.empty_cache()

    return val_mae

study = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner())

print("Starting Optimization...")
study.optimize(objective, n_trials=20) 

print("Best MAE:", study.best_value)
print("Best Params:", study.best_params)


In [None]:

print("\n" + "="*30)
print(f"Retraining best model with params: {study.best_params}")
print("="*30)

best_params = study.best_params
ModelClass = GRU #MODEL_MAP[best_params["model_type"]]
scaler_class = SCALER_MAP[best_params["scaler"]]
loss_fn = LOSS_MAP[best_params["loss_fn"]]

dataset = TimeseriesDataSet(
    pandas_df, GRANULARITY, USE_TIME_COVARIATES,
    "deviceTimestamp", TRAIN_VALIDATION_SPLIT,
    scaler_class, best_params["input_size"], OUTPUT_SIZE
)
N_COV = dataset.get_time_feature_count()

model = ModelClass(
    input_dim=1 + N_COV,
    hidden_dim=best_params["hidden_dim"],
    num_layers=best_params["num_layers"],
    output_dim=OUTPUT_SIZE,
    dropout=DROPOUT
).to(DEVICE)

opt = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=N_EPOCHS)
stopper = EarlyStopper(patience=10, min_delta=1e-4)

trainer = ModelTrainer(
    model, DEVICE, opt, loss_fn, scheduler,
    dataset.get_train_dataloader(BATCH_SIZE, True),
    dataset.get_validation_dataloader(BATCH_SIZE),
    early_stopper=stopper
)

history = trainer.fit(N_EPOCHS)


In [None]:
with open("../logs/model_architecture.txt", "w") as f:
    f.write(str(model))
torch.save(model.state_dict(), "../logs/model.pt")

In [None]:
# Attention heatmap over validation set: rows=lag, cols=windows, colors=attention weights
model.eval()
key = pandas_df.columns[1]

if isinstance(model, (GRUAttention, LSTMAttention)):
    with torch.no_grad():
        val_dataloader = dataset.get_single_series_dataloader(key, "validation", BATCH_SIZE, shuffle=False)
        seq_len = dataset.input_size
        A_list = []
        for X_batch, y_batch in val_dataloader:
            inputs = X_batch.to(DEVICE)
            _, attn_weights = model(inputs, return_attention=True)
            A_list.append(attn_weights.cpu().numpy())
        A = np.vstack(A_list)  # shape (n_windows, seq_len)
        # timestamps for validation range
        val_timestamps = dataset.get_resampled_data()[dataset.timestamp_column].iloc[dataset.n_train:].reset_index(drop=True)
        # map attention windows into heatmap matrix M (rows = window index, cols = validation timestamps)
        n_windows = 0 if A.size == 0 else A.shape[0]
        M = np.full((n_windows, len(val_timestamps)), np.nan)
        for w in range(n_windows):
            end_col = min(w + seq_len, len(val_timestamps))
            M[w, w:end_col] = A[w, : end_col - w]

    # Limit to last week of validation timestamps
    if GRANULARITY == Granularity.HOURLY:
        last_week_mask = val_timestamps >= (val_timestamps.max() - pd.Timedelta(days=7))
    elif GRANULARITY == Granularity.DAILY:
        last_week_mask = val_timestamps >= (val_timestamps.max() - pd.Timedelta(days=7))
    elif GRANULARITY == Granularity.MONTHLY:
        last_week_mask = val_timestamps >= (val_timestamps.max() - pd.DateOffset(months=1))
    else:
        last_week_mask = np.ones(len(val_timestamps), dtype=bool)
    last_week_mask = last_week_mask.values.astype(bool)
    val_timestamps = val_timestamps[last_week_mask].reset_index(drop=True)
    M = M[:, last_week_mask]
    valid_row_mask = ~np.all(np.isnan(M), axis=1)
    M = M[valid_row_mask, :]

    # get values of series for the last week
    series_values = dataset.get_scaled_data(key)["validation"]
    series_values = series_values[last_week_mask].flatten()

In [None]:
if isinstance(model, (GRUAttention, LSTMAttention)):
    fig_heat = go.Figure()

    # Heatmap
    fig_heat.add_trace(go.Heatmap(
        z=M,
        colorscale='Viridis',
        colorbar=dict(title='Attention'),
        x=val_timestamps,
        y=list(range(M.shape[0])),
        name='attention'
    ))

    # Series line on a secondary y-axis so it uses its own value scale
    fig_heat.add_trace(go.Scatter(
        x=val_timestamps,
        y=series_values,
        mode='lines',
        name=f'{key} values',
        line=dict(color='black', width=2),
        yaxis='y2'
    ))

    # Layout: add yaxis2 that overlays the heatmap y-axis
    fig_heat.update_layout(
        title='Attention heatmap for value_35 (timestamps on x-axis, rows=rolling windows)',
        xaxis=dict(title='Timestamp (validation range)'),
        yaxis=dict(title='Validation window index (rolling)'),
        yaxis2=dict(
            title=f'Scaled values for {key}',
            overlaying='y',
            side='right'
        ),
        width=1200,
        height=600,
        legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1)
    )

    fig_heat.show()
    fig_heat.write_html("../logs/attention_heatmap.html", include_plotlyjs='cdn')

In [None]:
# Train-Validation RMSE Plot (linear + log scale)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(history["train_rmse"], marker='o', label='Train RMSE')
axes[0].plot(history["val_rmse"], marker='x', label='Val RMSE')
axes[0].set_title('RMSE over Epochs (linear)')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('RMSE')
axes[0].legend()
axes[0].grid(True, which='both', alpha=0.3)

axes[1].plot(history["train_rmse"], marker='o', label='Train RMSE')
axes[1].plot(history["val_rmse"], marker='x', label='Val RMSE')
axes[1].set_yscale('log')
axes[1].set_title('RMSE over Epochs (log scale)')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('RMSE (log)')
axes[1].legend()
axes[1].grid(True, which='both', alpha=0.3)

plt.tight_layout()
plt.savefig("../logs/rmse.png")

In [None]:
fig = plt.figure(figsize=(10, 5))
plt.plot(history["epoch_time"], marker='o')
plt.axhline(y=np.mean(history["epoch_time"]), color='r', linestyle='--', label='Avg Epoch Time')
plt.legend()
plt.title('Epoch Times')
plt.xlabel('Epoch')
plt.ylabel('Time (s)')

plt.tight_layout()
fig.savefig("../logs/epoch_times.png", bbox_inches='tight', dpi=150)

In [None]:
forecaster = TimeseriesForecaster(model, DEVICE, INPUT_SIZE, OUTPUT_SIZE)

print("Generating predictions for all series...")
start_time = time.time()
predictions_dict = forecaster.predict_all_series(dataset, BATCH_SIZE)
print(f"Predictions generated in {time.time() - start_time:.2f}s")

summary = evaluate_series(
    dataset=dataset,
    preds=predictions_dict,
    input_size=INPUT_SIZE,
    output_size=OUTPUT_SIZE
)

top_5 = top_n_by_metric(summary, n=5, metric='val_mae', reverse=False)
bottom_5 = top_n_by_metric(summary, n=5, metric='val_mae', reverse=True)

In [None]:
print(np.mean([m['val_mae'] for m in summary.values()]))
print(np.mean([m['val_rmse'] for m in summary.values()]))


In [None]:

print("Top 5 series by val_mae:")
for k in top_5:
    val = summary[k].get('val_mae', np.nan)
    print(k, f"{float(val):.2f}")
print("Bottom 5 series by val_mae:")
for k in bottom_5:
    val = summary[k].get('val_mae', np.nan)
    print(k, f"{float(val):.2f}")

full_results_dict = get_full_results_dict(dataset, predictions_dict, INPUT_SIZE, OUTPUT_SIZE)

# Plot MAE for all series (validation set), sorted by MAE
mae_vals = [float(summary[k]['val_mae']) for k in summary]
series_names = list(summary.keys())

# Sort by MAE
sorted_indices = np.argsort(mae_vals)
sorted_mae = [mae_vals[i] for i in sorted_indices]
sorted_series = [series_names[i] for i in sorted_indices]

fig = px.bar(
    x=sorted_series,
    y=sorted_mae,
    labels={"x": "Series", "y": "Validation MAE"},
    title="Validation MAE for All Series (Sorted)",
    width=1200,
    height=400
)
fig.update_layout(xaxis_tickangle=90)
fig.write_html("../logs/all_series_mae_sorted.html", include_plotlyjs='cdn')
fig.show()


In [None]:
# Unscaled predictions plot
top_flop = {**top_5, **bottom_5}

for series_key in top_flop:
    full_results_df = full_results_dict[series_key]
    
    train_mask = ~pd.isna(full_results_df['train_predicted'])
    validation_mask = ~pd.isna(full_results_df['validation_predicted'])
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=full_results_df['timestamp'], 
        y=full_results_df['actual'], 
        mode='lines', 
        name='Actual consumption', 
        line=dict(color='blue', width=1)
    ))
    fig.add_trace(go.Scatter(
        x=full_results_df.loc[train_mask, 'timestamp'], 
        y=full_results_df.loc[train_mask, 'train_predicted'], 
        mode='lines', 
        name='Training predictions', 
        line=dict(color='orange', width=1)
    ))
    fig.add_trace(go.Scatter(
        x=full_results_df.loc[validation_mask, 'timestamp'], 
        y=full_results_df.loc[validation_mask, 'validation_predicted'], 
        mode='lines', 
        name='Validation predictions', 
        line=dict(color='red', width=2)
    ))
    fig.update_layout(
        title=f'LSTM Energy Consumption Forecast - Series {series_key}',
        xaxis_title='Date',
        yaxis_title='Energy Consumption',
        width=1200, 
        height=600,
        hovermode='x unified'
    )
    fig.write_html(f"../logs/train_validation_truth_{series_key}.html", include_plotlyjs='cdn')

In [None]:
# Results DataFrame for analysis
results_dict = {}

for series_key in top_flop:
    df = full_results_dict[series_key]
    validation_mask = ~pd.isna(df['validation_predicted'])
    
    validation_indices = df.index[validation_mask]
    validation_start = validation_indices.min()
    validation_end = validation_indices.max() + 1
    
    # Create results dataframe for each series
    results_df = pd.DataFrame({
        'timestamp': dataset.get_resampled_data()['deviceTimestamp'].iloc[validation_start:validation_end].reset_index(drop=True),
        'actual': df.loc[validation_start:validation_end-1, 'actual'].values,
        'predicted': df.loc[validation_start:validation_end-1, 'validation_predicted'].values
    })
    
    if GRANULARITY == Granularity.HOURLY:
        results_df['hour_of_day'] = results_df['timestamp'].dt.hour
        results_df['day_of_week'] = results_df['timestamp'].dt.day_name()
        results_df['month_of_year'] = results_df['timestamp'].dt.month
    elif GRANULARITY == Granularity.DAILY:
        results_df['day_of_week'] = results_df['timestamp'].dt.day_name()
        results_df['month_of_year'] = results_df['timestamp'].dt.month
    elif GRANULARITY == Granularity.MONTHLY:
        results_df['month_of_year'] = results_df['timestamp'].dt.month
    
    results_dict[series_key] = results_df

In [None]:
metrics = [
    ("hour_of_day", list(range(24)), "Hour of Day"),
    ("day_of_week", ['Monday','Tuesday','Wednesday','Thursday','Friday','Saturday','Sunday'], "Day of Week"),
    ("month_of_year", list(range(1, 13)), "Month of Year")
]

for col_name, labels, xlab in metrics:
    if col_name not in next(iter(results_dict.values())).columns:
        continue

    n = len(top_flop)
    cols = min(5, n)
    rows = 2 if n > 5 else 1
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 3*rows), squeeze=False)

    for idx, series_key in enumerate(top_flop):
        ax = axes[idx//cols][idx%cols]
        df = results_dict[series_key]

        vals = []
        for lab in labels:
            data = df[df[col_name] == lab]
            if not data.empty:
                rmse = np.sqrt(np.mean((data["actual"] - data["predicted"])**2))
            else:
                rmse = np.nan
            vals.append(rmse)
        ax.bar(labels, vals, alpha=0.7)
        ax.set_title(f"Series {series_key}")
        ax.set_xlabel(xlab)
        ax.set_ylabel("RMSE")
        ax.tick_params(axis="x", rotation=45)

    plt.tight_layout()
    fig.savefig(f"../logs/rmse_by_{col_name}.png", dpi=150)
    plt.close(fig)

In [None]:
def get_naive_predictions_dict(dataset, input_size, output_size):
    naive_preds = {}
    for key in dataset.split_unscaled_dict:
        naive_preds[key] = {}
        
        for split in ['train', 'validation']:
            data = dataset.get_unscaled_data(key)[split]
            preds = []
            valid_split_points = range(input_size, len(data) + 1, output_size)
            
            for i in valid_split_points:
                if i - output_size >= 0:
                    pred_window = data[i - output_size : i]
                    if len(pred_window) == output_size:
                        preds.append(pred_window)
                else:
                    pass
            
            if preds:
                naive_preds[key][split] = np.concatenate(preds).flatten()
            else:
                naive_preds[key][split] = np.array([])
                
    return naive_preds

naive_predictions_dict = get_naive_predictions_dict(dataset, INPUT_SIZE, OUTPUT_SIZE)

naive_summary = evaluate_series(
    dataset=dataset, 
    preds=naive_predictions_dict, 
    input_size=INPUT_SIZE, 
    output_size=OUTPUT_SIZE
)

metric_comparison = {}
model_vs_naive = []

for series_key in predictions_dict:
    if series_key not in naive_summary:
        continue
        
    model_metrics = summary[series_key]
    naive_metrics = naive_summary[series_key]

    metric_comparison[series_key] = {
        "naive_rmse": float(naive_metrics["val_rmse"]),
        "naive_mae": float(naive_metrics["val_mae"]),
        "model_rmse": float(model_metrics["val_rmse"]),
        "model_mae": float(model_metrics["val_mae"])
    }

    model_vs_naive.append({
        "series": series_key,
        "naive_rmse": naive_metrics["val_rmse"],
        "model_rmse": model_metrics["val_rmse"],
        "improved": model_metrics["val_rmse"] < naive_metrics["val_rmse"]
    })

naive_df = pd.DataFrame.from_dict(metric_comparison, orient="index")
naive_df = naive_df.sort_values("naive_rmse")

print("Overall averages (Validation Set):")
print(f" Naive RMSE: {naive_df['naive_rmse'].mean():.4f}, Model RMSE: {naive_df['model_rmse'].mean():.4f}")
print(f" Naive MAE:  {naive_df['naive_mae'].mean():.4f}, Model MAE:  {naive_df['model_mae'].mean():.4f}")

val_mase = naive_df['model_mae'].mean() / naive_df['naive_mae'].mean()
print(f"Model MASE: {val_mase:.2f} (Value < 1.0 indicates model is better than naive)")

better_count = (naive_df['model_mae'] < naive_df['naive_mae']).sum()
print(f"Model improved over Naive on {better_count} / {len(naive_df)} series.")

fig, ax = plt.subplots(figsize=(6,6))
ax.scatter(naive_df["naive_mae"], naive_df["model_mae"], alpha=0.6)
ax.plot([naive_df["naive_mae"].min(), naive_df["naive_mae"].max()],
        [naive_df["naive_mae"].min(), naive_df["naive_mae"].max()], 'r--', label="y=x (Parity)")
x_vals = np.array([naive_df["naive_mae"].min(), naive_df["naive_mae"].max()])
ax.plot(x_vals, val_mase * x_vals, 'g--', label=f"Avg Improvement (slope={val_mase:.2f})")

ax.set_xlabel("Naive MAE")
ax.set_ylabel("Model MAE")
ax.set_title(f"Model vs Naive (Repeat Last {OUTPUT_SIZE})")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Logging with MLFlow
if USE_MLFLOW:
    mlflow.set_tracking_uri("file:../logs/mlruns")
    try:
        mlflow.create_experiment("Energy_Consumption_Forecast")
    except:
        mlflow.set_experiment("Energy_Consumption_Forecast")

    with mlflow.start_run():
        # Model Architecture
        if SAVE_MODEL:
            mlflow.pytorch.log_model(model, artifact_path="LSTM forecaster")
        
        mlflow.log_artifact("../logs/model_architecture.txt")
        mlflow.log_param("granularity", GRANULARITY.value)
        mlflow.log_param("input_size", INPUT_SIZE)
        mlflow.log_param("output_size", OUTPUT_SIZE)
        mlflow.log_param("learning_rate", LR)
        mlflow.log_param("loss_function", best_params["loss_fn"])
        mlflow.log_param("scaler", best_params["scaler"])
        mlflow.log_param("batch_size", BATCH_SIZE)
        mlflow.log_param("n_epochs", N_EPOCHS)
        mlflow.log_param("model_type", "LSTMAttention") #best_params["model_type"])
        mlflow.log_param("N_SERIES", N_SERIES)
        mlflow.log_param("USE_TIME_COVARIATES", USE_TIME_COVARIATES)

        # Model training results
        mlflow.log_metric("train_rmse", history["train_rmse"][-1])
        mlflow.log_metric("train_mae", history["train_mae"][-1])
        mlflow.log_metric("train_mae_avg", np.mean([m['train_mae'] for m in summary.values() if not np.isnan(m['train_mae'])]))
        mlflow.log_metric("val_mase", val_mase)
        v1 = summary.get("value_1", {})
        v1_train_mae = v1.get("train_mae", np.nan)
        v1_val_mae = v1.get("val_mae", np.nan)
        if not np.isnan(v1_train_mae):
            mlflow.log_metric("value_1_train_mae", float(v1_train_mae))
        if not np.isnan(v1_val_mae):
            mlflow.log_metric("value_1_val_mae", float(v1_val_mae))
        mlflow.log_metric("val_rmse", history["val_rmse"][-1])
        mlflow.log_metric("val_mae", history["val_mae"][-1])
        mlflow.log_metric("val_mae_avg", np.mean([m['val_mae'] for m in summary.values() if not np.isnan(m['val_mae'])]))
        mlflow.log_metric("total_training_time_sec", sum(history["epoch_time"]))
        mlflow.log_artifact("../logs/rmse.png")
        mlflow.log_artifact("../logs/epoch_times.png")
        mlflow.log_artifact("../logs/attention_heatmap.html")

        # Log all series prediction results
        if 'hour_of_day' in next(iter(results_dict.values())).columns:
            mlflow.log_artifact(f"../logs/rmse_by_hour_of_day.png")
        if 'day_of_week' in next(iter(results_dict.values())).columns:
            mlflow.log_artifact(f"../logs/rmse_by_day_of_week.png")
        if 'month_of_year' in next(iter(results_dict.values())).columns:
            mlflow.log_artifact(f"../logs/rmse_by_month_of_year.png")
        
        for series_key in top_flop:
            mlflow.log_artifact(f"../logs/train_validation_truth_{series_key}.html")

        mlflow.log_artifact("../logs/all_series_mae_sorted.html")
        mlflow.log_artifact("../logs/all_series_mape_sorted.html")
        mlflow.log_artifact("../logs/naive_vs_model_mae.png")