In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from mvf_bto.data_loading import load_data
from mvf_bto.constants import * 
from mvf_bto.models.baseline_lstm import BaselineLSTM
from mvf_bto.preprocessing import create_discharge_inputs
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.metrics import MeanSquaredError
from scipy.interpolate import interp1d

import numpy as np
import pandas as pd
import plotly
import plotly.graph_objects as go

## Loading Data

In [3]:
# data_path = "/Users/anoushkabhutani/PycharmProjects/10701-mvf-bto/data/2017-05-12_batchdata_updated_struct_errorcorrect.mat"
data_path = "/Users/mac/Desktop/CMU/10701MachineLearning/project/10701-mvf-bto-backup/data/2017-05-12_batchdata_updated_struct_errorcorrect.mat"

In [42]:
data = load_data(file_path=data_path, num_cells=4)

100%|██████████| 4/4 [00:32<00:00,  8.13s/it]


## Preprocessing to create model inputs and targets

In [6]:
train_split = 0.7
test_split = 0.2
# by default uses validation_split = 1 - (train_split + test_split)


In [45]:
datasets = create_discharge_inputs(data, train_split, test_split, forecast_horizon=2, history_window=4)

 Data for cell b1c3 is corrupted. Skipping cell.


100%|██████████| 1189/1189 [00:07<00:00, 162.18it/s]
100%|██████████| 1178/1178 [00:06<00:00, 173.91it/s]
100%|██████████| 1176/1176 [00:06<00:00, 169.82it/s]


In [33]:
pd.concat(datasets['original_train'])['Cell'].unique()
# datasets.keys()

array(['b1c1'], dtype=object)

## Train Model

In [13]:
window_length = datasets["X_train"].shape[1]
n_features = datasets["X_train"].shape[2]
batch_input_shape = (datasets["batch_size"], window_length, n_features)
n_outputs = datasets["y_train"].shape[-1]
nf_steps = datasets["y_train"].shape[1]
y = datasets["y_train"][:, 0, 0]
idx = y < 2.9
weights = np.ones_like(y) * 1
weights[idx] = 2

In [14]:
model = BaselineLSTM(batch_input_shape=batch_input_shape, n_outputs=n_outputs, nf_steps=nf_steps)

2022-11-05 18:55:52.296805: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [15]:
model.compile(optimizer="adam", loss="mse", metrics=[MeanSquaredError()])

es = EarlyStopping(
    monitor="val_mean_squared_error",
    min_delta=0,
    patience=10,
    verbose=1,
    mode="auto",
    restore_best_weights=True,
)

history = model.fit(
    datasets["X_train"],
    datasets["y_train"],
    validation_data=(datasets["X_val"], datasets["y_val"]),
    epochs=250,
    batch_size=datasets["batch_size"],
    shuffle=False,
    callbacks=[es],
    verbose=1,
    sample_weight=weights

)

Epoch 1/250
Epoch 2/250
Epoch 3/250
Epoch 4/250
Epoch 5/250
Epoch 6/250
Epoch 7/250
Epoch 8/250
Epoch 9/250
Epoch 10/250
Epoch 11/250
Epoch 12/250
Epoch 13/250
Epoch 14/250
Epoch 15/250
Epoch 16/250
Epoch 17/250
Epoch 18/250
Epoch 19/250
Epoch 20/250
Epoch 21/250
Epoch 22/250
Epoch 23/250
Epoch 24/250
Epoch 25/250
Epoch 26/250
Epoch 27/250
Epoch 28/250
Epoch 29/250
Epoch 30/250
Epoch 31/250
Epoch 32/250
Epoch 33/250
Epoch 34/250
Epoch 35/250
Epoch 36/250
Epoch 37/250
Epoch 38/250
Epoch 39/250
Epoch 40/250
Epoch 41/250
Epoch 42/250
Epoch 43/250
Epoch 44/250
Epoch 45/250
Epoch 46/250
Epoch 47/250
Epoch 48/250
Epoch 49/250
Epoch 50/250
Epoch 51/250
Epoch 52/250
Epoch 53/250
Epoch 54/250
Epoch 55/250
Epoch 56/250
Epoch 57/250
Epoch 58/250
Epoch 59/250
Epoch 60/250
Epoch 61/250
Epoch 62/250
Epoch 63/250
Epoch 64/250
Epoch 65/250
Epoch 65: early stopping


In [16]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=np.linspace(1, 50),
        y=history.history["loss"],
        showlegend=False,
        mode="markers+lines",
    )
)
fig.update_xaxes(title="Epochs")
fig.update_yaxes(title="Loss (MSE)")

## Parity Plot of Training Error

In [17]:
# random plotting traing error at some interval = skip to not make the plot rendering too slow
batch_size = datasets["batch_size"]
skip = 70

fig = go.Figure()
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], showlegend=False, mode="markers+lines"))
for i in range(0, len(datasets["X_train"]), batch_size * skip):
    df_pred = pd.DataFrame(model.predict(datasets["X_train"][i : i + batch_size], verbose=0)[:, :, 0])
    df_train = pd.DataFrame(datasets["y_train"][i : i + batch_size][:, :, 0])
    fig.add_trace(
        go.Scatter(
            x=df_pred[0].values,
            y=df_train[0].values,
            showlegend=False,
            mode="markers+lines",
        )
    )

fig.update_yaxes(title="Normalized Voltage Target")
fig.update_xaxes(title="Normalized Voltage Prediction")
fig.update_layout(template="simple_white")

In [18]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], showlegend=False, mode="markers+lines"))
for i in range(0, len(datasets["X_train"]), batch_size * skip):
    df_pred = pd.DataFrame(model.predict(datasets["X_train"][i : i + batch_size], verbose=0)[:, :, 0])
    df_train = pd.DataFrame(datasets["y_train"][i : i + batch_size][:, :, 0])
    fig.add_trace(
        go.Scatter(
            x=df_pred[0].values,
            y=df_train[0].values,
            showlegend=False,
            mode="markers+lines",
        )
    )

fig.update_yaxes(title="Normalized Temperature Target")
fig.update_xaxes(title="Normalized Temperature Prediction")
fig.update_layout(template="simple_white")

## Parity Plot of Test Error

In [19]:
skip = 20

fig = go.Figure()
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], showlegend=False, mode="markers+lines"))
for i in range(0, len(datasets["X_test"]), batch_size * skip):
    df_pred = pd.DataFrame(model.predict(datasets["X_test"][i : i + batch_size], verbose=0)[:, :, 0])
    df_train = pd.DataFrame(datasets["y_test"][i : i + batch_size][:, :, 0])
    fig.add_trace(
        go.Scatter(
            x=df_pred[0].values,
            y=df_train[0].values,
            showlegend=False,
            mode="markers+lines",
        )
    )

fig.update_yaxes(title="Normalized Voltage Target")
fig.update_xaxes(title="Normalized Voltage Prediction")
fig.update_layout(template="simple_white")

In [20]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], showlegend=False, mode="markers+lines"))
for i in range(0, len(datasets["X_test"]), batch_size * skip):
    df_pred = pd.DataFrame(model.predict(datasets["X_test"][i : i + batch_size], verbose=0)[:, :, 1])
    df_train = pd.DataFrame(datasets["y_test"][i : i + batch_size][:, :, 1])
    fig.add_trace(
        go.Scatter(
            x=df_pred[0].values,
            y=df_train[0].values,
            showlegend=False,
            mode="markers+lines",
        )
    )

fig.update_yaxes(title="Normalized Temperature Target")
fig.update_xaxes(title="Normalized Temperature Prediction")
fig.update_layout(template="simple_white")

## True vs Predicted Traces (Train Set)

In [35]:
print(datasets.keys())
train_cell_ids = pd.concat(datasets['original_train'])['Cell'].unique()


dict_keys(['X_train', 'X_test', 'X_val', 'y_train', 'y_test', 'y_val', 'original_train', 'original_test', 'original_val', 'batch_size'])


In [59]:
datasets['original_train']

Unnamed: 0,t,V,temp,I,Qd,Cycle,Cell
0,0.000000,2.026416,31.632496,0.000000,0.000000,1,b1c0
1,0.002417,2.039388,31.632496,0.215908,0.000000,1,b1c0
2,0.002912,2.051660,31.632496,0.359831,0.000000,1,b1c0
3,0.003212,2.063070,31.632496,0.467846,0.000000,1,b1c0
4,0.003498,2.076204,31.632496,0.575877,0.000000,1,b1c0
...,...,...,...,...,...,...,...
1029,49.942828,2.000220,32.254002,-0.020152,1.026198,1188,b1c0
1030,49.973342,2.000116,32.254002,-0.019978,1.026210,1188,b1c0
1031,49.974353,2.001144,32.254002,-0.019978,1.026210,1188,b1c0
1032,49.985110,2.011194,32.254002,-0.019978,1.026210,1188,b1c0


In [36]:
symbol_list = ["circle-open", "circle", "triangle-up"]
pallete = plotly.colors.qualitative.Dark24 + plotly.colors.qualitative.T10
pallete = pallete*70000

In [72]:
datasets["original_train"][datasets["original_train"]['Cell']==train_cell_ids[0]]

Unnamed: 0,t,V,temp,I,Qd,Cycle,Cell


In [70]:
skip = 500
train_cell_id_idx = 0
last_cycle = 1
current_cycle = 1
opacity_list = [1, 0.6, 0.3]
fig = go.Figure()
print(len(datasets["X_train"]),batch_size*skip)
for i in range(0, len(datasets["X_train"]), batch_size * skip):
    df_true = pd.DataFrame(datasets["y_train"][i : i + batch_size][:, 0, 0])
    
    last_cycle = current_cycle
    current_cycle = int(datasets["X_train"][i : i + batch_size][0][0][-1]*MAX_CYCLE)
    if current_cycle>740:
        continue
    if last_cycle> current_cycle:
        train_cell_id_idx += 1
    print(f"train_cell_id_idx {train_cell_id_idx}")
    cell_id = train_cell_ids[train_cell_id_idx]
    print(f"cell_id {cell_id}")
    original_df = datasets["original_train"]
    print(f"1 {len(original_df)}")
    original_df= original_df[original_df.Cycle==current_cycle]
    print(f"2 {original_df.Cycle==current_cycle}")
    original_df= original_df[cell_id]
    print(f"cell_id {cell_id}")
    original_df = original_df[original_df.I < MAX_DISCHARGE_CURRENT]
    print(f"4 {len(original_df)}")
    original_df = original_df[original_df.I > MIN_DISCHARGE_CURRENT]
    print(len(original_df))

    original_df['Qd'] = (original_df['Qd']-original_df['Qd'].min())/(original_df['Qd'].max()-original_df['Qd'].min())

    print(nf_steps)
    if len(original_df)< 2:
        continue
    q_new = original_df.Qd.values
    t_new = original_df['t'].values
    time_interpolator = interp1d(x = q_new,  y = t_new, fill_value="extrapolate")
    
    for j in range(nf_steps):
        print("test")
        df_pred = pd.DataFrame(model.predict(datasets["X_train"][i : i + batch_size], verbose=0)[:, j, 0])
    
        ref_capacities_wrt_nf = REFERENCE_DISCHARGE_CAPACITIES[window_length+j:-nf_steps+j-1]

        prediction_interpolator = interp1d(x = ref_capacities_wrt_nf, 
                                           y = df_pred[0].values*(VOLTAGE_MAX - VOLTAGE_MIN) + VOLTAGE_MIN,
                                           
                                          )
        
        q_j = q_new[q_new>= min(REFERENCE_DISCHARGE_CAPACITIES[window_length+j:-nf_steps+j-1])]
        q_j = q_j[q_j<= max(REFERENCE_DISCHARGE_CAPACITIES[window_length+j:-nf_steps+j-1])]
        t_interp = time_interpolator(REFERENCE_DISCHARGE_CAPACITIES[window_length:-nf_steps-1])
        if j==0:
            q_j0 = q_j
        V_pred = prediction_interpolator(q_j)
        fig.add_trace(
            go.Scatter(
                x=time_interpolator(q_j),
                y=V_pred,
                showlegend=True,
                mode="lines",
                line_dash="dash",
                name = "Interpolated Predictions",
                marker_color=pallete[i],
                opacity=opacity_list[j]
            )
        )
        
        fig.add_trace(
            go.Scatter(
                x=time_interpolator(ref_capacities_wrt_nf),
                y=df_pred[0].values*(VOLTAGE_MAX - VOLTAGE_MIN) + VOLTAGE_MIN,
                showlegend=True,
                mode="markers",
                name = f"Predictions Forecast Horizon {j+1}",
                marker_color=pallete[i],
                marker_symbol = symbol_list[j],
                marker_size=8,
#                 marker_opacity=opacity_list[j]
            )
        )
        print(df_pred[0].values)

    odf = original_df[original_df.Qd > q_j0.min()]
    odf = odf[odf.Qd < q_j.max()]
    fig.add_trace(
        go.Scatter(
            x=odf['t'],
            y=odf['V'],
            showlegend=True,
            mode="lines",
            name = f"{cell_id} Cycle {current_cycle}",
            line_color=pallete[i]
        )
    )


fig.update_yaxes(title="Voltage [V]", showgrid=True)
fig.update_xaxes(title="Times (min)", showgrid=True)
fig.update_layout(template="simple_white")

16632 7000
train_cell_id_idx 0
cell_id b1c1
1 1295874
2 0       True
1       True
2       True
3       True
4       True
        ... 
1082    True
1083    True
1084    True
1085    True
1086    True
Name: Cycle, Length: 1087, dtype: bool


KeyError: 'b1c1'

In [64]:
datasets["original_train"]

Unnamed: 0,t,V,temp,I,Qd,Cycle,Cell
0,0.000000,2.026416,31.632496,0.000000,0.000000,1,b1c0
1,0.002417,2.039388,31.632496,0.215908,0.000000,1,b1c0
2,0.002912,2.051660,31.632496,0.359831,0.000000,1,b1c0
3,0.003212,2.063070,31.632496,0.467846,0.000000,1,b1c0
4,0.003498,2.076204,31.632496,0.575877,0.000000,1,b1c0
...,...,...,...,...,...,...,...
1029,49.942828,2.000220,32.254002,-0.020152,1.026198,1188,b1c0
1030,49.973342,2.000116,32.254002,-0.019978,1.026210,1188,b1c0
1031,49.974353,2.001144,32.254002,-0.019978,1.026210,1188,b1c0
1032,49.985110,2.011194,32.254002,-0.019978,1.026210,1188,b1c0


In [None]:
train_cell_id_idx = 0
last_cycle = 1
current_cycle = 1

fig = go.Figure()
for i in range(0, len(datasets["X_train"]), batch_size * skip):
    df_true = pd.DataFrame(datasets["y_train"][i : i + batch_size][:, 0, 1])
    
    last_cycle = current_cycle
    current_cycle = int(datasets["X_train"][i : i + batch_size][0][0][-1]*MAX_CYCLE)
    if current_cycle>740:
        continue
    if last_cycle> current_cycle:
        train_cell_id_idx += 1
        
    cell_id = train_cell_ids[train_cell_id_idx]
    original_df = datasets["original_train"]

    original_df= original_df[original_df.Cycle==current_cycle]
    original_df= original_df[original_df.Cell==cell_id]
    original_df = original_df[original_df.I < MAX_DISCHARGE_CURRENT]
    original_df = original_df[original_df.I > MIN_DISCHARGE_CURRENT]

    original_df['Qd'] = (original_df['Qd']-original_df['Qd'].min())/(original_df['Qd'].max()-original_df['Qd'].min())

    if len(original_df)< 2:
        continue
    q_new = original_df.Qd.values
    t_new = original_df['t'].values
    time_interpolator = interp1d(x = q_new,  y = t_new, fill_value="extrapolate")
    
    
    for j in range(nf_steps):
        df_pred = pd.DataFrame(model.predict(datasets["X_train"][i : i + batch_size], verbose=0)[:, j, 1])
    
        ref_capacities_wrt_nf = REFERENCE_CAPACITIES[window_length+j:-nf_steps+j-1]

        prediction_interpolator = interp1d(x = ref_capacities_wrt_nf, 
                                           y = df_pred[0].values*(TEMPERATURE_MAX - TEMPERATURE_MIN) + TEMPERATURE_MIN,
                                           
                                          )
        
        q_j = q_new[q_new>= min(REFERENCE_CAPACITIES[window_length+j:-nf_steps+j-1])]
        q_j = q_j[q_j<= max(REFERENCE_CAPACITIES[window_length+j:-nf_steps+j-1])]
        t_interp = time_interpolator(REFERENCE_CAPACITIES[window_length:-nf_steps-1])
        if j==0:
            q_j0 = q_j
        V_pred = prediction_interpolator(q_j)
        fig.add_trace(
            go.Scatter(
                x=time_interpolator(q_j),
                y=V_pred,
                showlegend=True,
                mode="lines",
                line_dash="dash",
                name = "Interpolated Predictions",
                marker_color=pallete[i//skip],
                opacity=opacity_list[j]
            )
        )
        
        fig.add_trace(
            go.Scatter(
                x=time_interpolator(ref_capacities_wrt_nf),
                y=df_pred[0].values*(TEMPERATURE_MAX - TEMPERATURE_MIN) + TEMPERATURE_MIN,
                showlegend=True,
                mode="markers",
                name = f"Predictions Forecast Horizon {j+1}",
                marker_color=pallete[i//skip],
                marker_symbol = symbol_list[j],
                marker_size=8,
            )
        )


    odf = original_df[original_df.Qd > q_j0.min()]
    odf = odf[odf.Qd < q_j.max()]
    fig.add_trace(
        go.Scatter(
            x=odf['t'],
            y=odf['temp'],
            showlegend=True,
            mode="lines",
            name = f"{cell_id} Cycle {current_cycle}",
            line_color=pallete[i//skip]
        )
    )


fig.update_yaxes(title="Temperature [°C]", showgrid=True)
fig.update_xaxes(title="Times (min)", showgrid=True)
fig.update_layout(template="simple_white")

## True vs Predicted Traces (Test Set)

In [None]:
test_cell_ids = datasets['original_test']['Cell'].unique()
test_cell_ids
len(datasets["X_test"])

In [None]:
skip = 300
test_cell_id_idx = 0
last_cycle = 1
current_cycle = 1
opacity_list = [1, 0.6, 0.3]
fig = go.Figure()
for i in range(0, len(datasets["X_test"]), batch_size * skip):
    df_true = pd.DataFrame(datasets["y_test"][i : i + batch_size][:, 0, 0])
    
    last_cycle = current_cycle
    current_cycle = int(datasets["X_test"][i : i + batch_size][0][0][-1]*MAX_CYCLE)

    if last_cycle> current_cycle:
        test_cell_id_idx += 1
        
    cell_id = test_cell_ids[test_cell_id_idx]
    original_df = datasets["original_test"]

    original_df= original_df[original_df.Cycle==current_cycle]
    original_df= original_df[original_df.Cell==cell_id]
    original_df = original_df[original_df.I < MAX_DISCHARGE_CURRENT]
    original_df = original_df[original_df.I > MIN_DISCHARGE_CURRENT]

    original_df['Qd'] = (original_df['Qd']-original_df['Qd'].min())/(original_df['Qd'].max()-original_df['Qd'].min())

    if len(original_df)< 2:
        continue
    q_new = original_df.Qd.values
    t_new = original_df['t'].values
    time_interpolator = interp1d(x = q_new,  y = t_new, fill_value="extrapolate")
    
    
    for j in range(nf_steps):
        df_pred = pd.DataFrame(model.predict(datasets["X_test"][i : i + batch_size], verbose=0)[:, j, 0])
    
        ref_capacities_wrt_nf = REFERENCE_CAPACITIES[window_length+j:-nf_steps+j-1]

        prediction_interpolator = interp1d(x = ref_capacities_wrt_nf, 
                                           y = df_pred[0].values*(VOLTAGE_MAX - VOLTAGE_MIN) + VOLTAGE_MIN,
                                           
                                          )
        
        q_j = q_new[q_new>= min(REFERENCE_CAPACITIES[window_length+j:-nf_steps+j-1])]
        q_j = q_j[q_j<= max(REFERENCE_CAPACITIES[window_length+j:-nf_steps+j-1])]
        t_interp = time_interpolator(REFERENCE_CAPACITIES[window_length:-nf_steps-1])
        if j==0:
            q_j0 = q_j
        V_pred = prediction_interpolator(q_j)
        fig.add_trace(
            go.Scatter(
                x=time_interpolator(q_j),
                y=V_pred,
                showlegend=True,
                mode="lines",
                line_dash="dash",
                name = "Interpolated Predictions",
                marker_color=pallete[i//skip],
                opacity=opacity_list[j]
            )
        )
        
        fig.add_trace(
            go.Scatter(
                x=time_interpolator(ref_capacities_wrt_nf),
                y=df_pred[0].values*(VOLTAGE_MAX - VOLTAGE_MIN) + VOLTAGE_MIN,
                showlegend=True,
                mode="markers",
                name = f"Predictions Forecast Horizon {j+1}",
                marker_color=pallete[i//skip],
                marker_symbol = symbol_list[j],
                marker_size=8,
#                 marker_opacity=opacity_list[j]
            )
        )


    odf = original_df[original_df.Qd > q_j0.min()]
    odf = odf[odf.Qd < q_j.max()]
    fig.add_trace(
        go.Scatter(
            x=odf['t'],
            y=odf['V'],
            showlegend=True,
            mode="lines",
            name = f"{cell_id} Cycle {current_cycle}",
            line_color=pallete[i//skip]
        )
    )


fig.update_yaxes(title="Voltage [V]", showgrid=True)
fig.update_xaxes(title="Times (min)", showgrid=True)
fig.update_layout(template="simple_white")

In [None]:
test_cell_id_idx = 0
last_cycle = 1
current_cycle = 1

fig = go.Figure()
for i in range(0, len(datasets["X_test"]), batch_size * skip):
    
    last_cycle = current_cycle
    current_cycle = int(datasets["X_test"][i : i + batch_size][0][0][-1]*MAX_CYCLE)
    
    if last_cycle> current_cycle:
        test_cell_id_idx += 1
        
    cell_id = test_cell_ids[test_cell_id_idx]
    original_df = datasets["original_test"]

    original_df= original_df[original_df.Cycle==current_cycle]
    original_df= original_df[original_df.Cell==cell_id]
    original_df = original_df[original_df.I < MAX_DISCHARGE_CURRENT]
    original_df = original_df[original_df.I > MIN_DISCHARGE_CURRENT]

    original_df['Qd'] = (original_df['Qd']-original_df['Qd'].min())/(original_df['Qd'].max()-original_df['Qd'].min())

    if len(original_df)< 2:
        continue
    q_new = original_df.Qd.values
    t_new = original_df['t'].values
    time_interpolator = interp1d(x = q_new,  y = t_new, fill_value="extrapolate")
    
    
    for j in range(nf_steps):
        df_pred = pd.DataFrame(model.predict(datasets["X_test"][i : i + batch_size], verbose=0)[:, j, 1])
    
        ref_capacities_wrt_nf = REFERENCE_CAPACITIES[window_length+j:-nf_steps+j-1]

        prediction_interpolator = interp1d(x = ref_capacities_wrt_nf, 
                                           y = df_pred[0].values*(TEMPERATURE_MAX - TEMPERATURE_MIN) + TEMPERATURE_MIN,
                                           
                                          )
        
        q_j = q_new[q_new>= min(REFERENCE_CAPACITIES[window_length+j:-nf_steps+j-1])]
        q_j = q_j[q_j<= max(REFERENCE_CAPACITIES[window_length+j:-nf_steps+j-1])]
        t_interp = time_interpolator(REFERENCE_CAPACITIES[window_length:-nf_steps-1])
        if j==0:
            q_j0 = q_j
        V_pred = prediction_interpolator(q_j)
        fig.add_trace(
            go.Scatter(
                x=time_interpolator(q_j),
                y=V_pred,
                showlegend=True,
                mode="lines",
                line_dash="dash",
                name = "Interpolated Predictions",
                marker_color=pallete[i//skip],
                opacity=opacity_list[j]
            )
        )
        
        fig.add_trace(
            go.Scatter(
                x=time_interpolator(ref_capacities_wrt_nf),
                y=df_pred[0].values*(TEMPERATURE_MAX - TEMPERATURE_MIN) + TEMPERATURE_MIN,
                showlegend=True,
                mode="markers",
                name = f"Predictions Forecast Horizon {j+1}",
                marker_color=pallete[i//skip],
                marker_symbol = symbol_list[j],
                marker_size=8,
#                 marker_opacity=opacity_list[j]
            )
        )


    odf = original_df[original_df.Qd > q_j0.min()]
    odf = odf[odf.Qd < q_j.max()]
    fig.add_trace(
        go.Scatter(
            x=odf['t'],
            y=odf['temp'],
            showlegend=True,
            mode="lines",
            name = f"{cell_id} Cycle {current_cycle}",
            line_color=pallete[i//skip]
        )
    )


fig.update_yaxes(title="Temperature [°C]", showgrid=True)
fig.update_xaxes(title="Times (min)", showgrid=True)
fig.update_layout(template="simple_white")

## Metrics

In [None]:
def error_calculation(model, datasets, error_function):
    results = {}
    batch_size = datasets["batch_size"]
    n_outputs = datasets["y_train"].shape[-1]
    nf_steps = datasets["y_train"].shape[1]
    for dset in ["train", "test"]:
        skip = 200 if dset == "train"  else 50
        for output in range(n_outputs):
            for step in range(nf_steps):
                collector = []
                for i in range(0, len(datasets[f"X_{dset}"]), batch_size*skip):
                    true = datasets[f"y_{dset}"][i : i + batch_size][:, j, output]
                    pred = model.predict(datasets[f"X_{dset}"][i : i + batch_size], verbose=0)[:, j, output]
                    error = error_function(true, pred)
                    collector.append(error)
                results[f"{dset}_output{output}_forecasthorizon{step}"] =sum(collector)/len(collector)
                         
    return results

In [None]:
root_mean_square_error = lambda y_true, y_pred : np.sqrt(((y_true - y_pred)**2).sum()/len(y_true))
mean_absolute_error = lambda y_true, y_pred : abs(y_true - y_pred).sum()/len(y_true)
mean_absolute_percent_error = lambda y_true, y_pred : (abs(y_true - y_pred)/y_true).sum()/len(y_true)

In [None]:
error_calculation(model, datasets, error_function=mean_absolute_error)