In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from mvf_bto.data_loading import load_data
from mvf_bto.constants import * 
from mvf_bto.models.baseline_lstm import BaselineLSTM
from mvf_bto.preprocessing import create_discharge_inputs, create_charge_inputs

from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.metrics import MeanSquaredError

import numpy as np
import pandas as pd
import plotly
import plotly.graph_objects as go

## Loading Data

In [None]:
data_path = "/Users/anoushkabhutani/PycharmProjects/10701-mvf-bto/data/2017-05-12_batchdata_updated_struct_errorcorrect.mat"
# data_path = "/Users/mac/Desktop/CMU/10701MachineLearning/project/10701-mvf-bto-backup/data/2017-05-12_batchdata_updated_struct_errorcorrect.mat"

In [None]:
data = load_data(file_path=data_path, num_cells=4)

## Preprocessing to create model inputs and targets

In [None]:
train_split = 0.7
test_split = 0.2
# by default uses validation_split = 1 - (train_split + test_split)

In [None]:
datasets = create_charge_inputs(data, train_split, test_split)

## Train Model

In [None]:
window_length = datasets["X_train"].shape[1]
n_features = datasets["X_train"].shape[2]
batch_size=datasets["batch_size"]
batch_input_shape = (datasets["batch_size"], window_length, n_features)
n_outputs = datasets["y_train"].shape[-1]
print(window_length, n_features, batch_input_shape, batch_size, n_outputs,datasets["X_train"].shape[0]//batch_size)

In [None]:
model = BaselineLSTM(batch_input_shape=batch_input_shape, n_outputs=n_outputs)

In [None]:
model.compile(optimizer="adam", loss="mse", metrics=[MeanSquaredError()])

In [None]:
skip=100
fig = go.Figure()
pallete = plotly.colors.qualitative.Dark24*(len(datasets["X_train"])//skip)

for i in range(0, len(datasets["X_train"]), batch_size * skip):

    df_true = pd.DataFrame(datasets["y_train"][i : i + batch_size, 0])
    
    fig.add_trace(
        go.Scatter(
            x=REFERENCE_CHARGE_CAPACITIES[-batch_size:],
            y=df_true[0].values*(VOLTAGE_MAX - VOLTAGE_MIN) + VOLTAGE_MIN,
            showlegend=True,
            mode="lines+markers",
            name = f"Curve {i//batch_size+1}",
            line_color=pallete[i//skip]
        )
    )

fig.update_yaxes(title="Voltage [V]")
fig.update_xaxes(title="State of Charge (Normalized Capacity)")

In [None]:
model = BaselineLSTM(batch_input_shape=batch_input_shape, n_outputs=n_outputs)
model.compile(optimizer="adam", loss="mse", metrics=[MeanSquaredError()])

es = EarlyStopping(
    monitor="val_mean_squared_error",
    min_delta=0,
    patience=30,
    verbose=1,
    mode="auto",
    restore_best_weights=True,
)

history = model.fit(
    datasets["X_train"],
    datasets["y_train"],
    validation_data=(datasets["X_val"], datasets["y_val"]),
    epochs=150,
    batch_size=datasets["batch_size"],
    shuffle=False,
    callbacks=[es],
    verbose=1,
)

In [None]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=np.linspace(1, 50),
        y=history.history["loss"],
        showlegend=False,
        mode="markers+lines",
    )
)
fig.update_xaxes(title="Epochs")
fig.update_yaxes(title="Loss (MSE)")

## Parity Plot of Training Error

In [None]:
pd.DataFrame(datasets["y_train"][:,0,:])

In [None]:
# random plotting traing error at some interval = skip to not make the plot rendering too slow
batch_size = datasets["batch_size"]
skip = 600

fig = go.Figure()
fig.add_trace(go.Scatter(x=[0, 1.5], y=[0, 1.5], showlegend=False, mode="markers+lines"))
for i in range(0, len(datasets["X_train"]), batch_size * skip):
    df_pred = pd.DataFrame(model.predict(datasets["X_train"][i : i + batch_size], verbose=0, batch_size=batch_size)[:,0,:])
    df_train = pd.DataFrame(datasets["y_train"][:,0,:])
    fig.add_trace(
        go.Scatter(
            x=df_pred[0].values,
            y=df_train[0].values,
            showlegend=False,
            mode="markers+lines",
        )
    )

fig.update_yaxes(title="Normalized Voltage Target")
fig.update_xaxes(title="Normalized Voltage Prediction")

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=[0, 1.5], y=[0, 1.5], showlegend=False, mode="markers+lines"))
for i in range(0, len(datasets["X_train"]), batch_size * skip):
    df_pred = pd.DataFrame(model.predict(datasets["X_train"][i : i + batch_size], verbose=0, batch_size=batch_size)[:,0,:])
    df_train = pd.DataFrame(datasets["y_train"][:,0,:])
    fig.add_trace(
        go.Scatter(
            x=df_pred[1].values,
            y=df_train[1].values,
            showlegend=False,
            mode="markers+lines",
        )
    )

fig.update_yaxes(title="Normalized Temperature Target")
fig.update_xaxes(title="Normalized Temperature Prediction")

## Parity Plot of Test Error

In [None]:
skip = 50

fig = go.Figure()
fig.add_trace(go.Scatter(x=[0, 1.5], y=[0, 1.5], showlegend=False, mode="markers+lines"))
for i in range(0, len(datasets["X_test"]), batch_size * skip):
    df_pred = pd.DataFrame(model.predict(datasets["X_test"][i : i + batch_size], verbose=0, batch_size=batch_size)[:,0,:])
    df_train = pd.DataFrame(datasets["y_test"][:,0,:])
    fig.add_trace(
        go.Scatter(
            x=df_pred[0].values,
            y=df_train[0].values,
            showlegend=False,
            mode="markers+lines",
        )
    )

fig.update_yaxes(title="Normalized Voltage Target")
fig.update_xaxes(title="Normalized Voltage Prediction")

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=[0, 1.5], y=[0, 1.5], showlegend=False, mode="markers+lines"))
for i in range(0, len(datasets["X_test"]), batch_size * skip):
    df_pred = pd.DataFrame(model.predict(datasets["X_test"][i : i + batch_size], verbose=0, batch_size=batch_size)[:,0,:])
    df_train = pd.DataFrame(datasets["y_test"][:,0,:])
    fig.add_trace(
        go.Scatter(
            x=df_pred[1].values,
            y=df_train[1].values,
            showlegend=False,
            mode="markers+lines",
        ))
fig.update_yaxes(title="Normalized Voltage Target")
fig.update_xaxes(title="Normalized Voltage Prediction")

## True vs Predicted Traces (Test Set)

In [None]:
skip = 20

pallete = plotly.colors.qualitative.Dark24*(len(datasets["X_test"])//skip)

fig = go.Figure()
for i in range(0, len(datasets["X_test"]), batch_size * skip):
    df_pred = pd.DataFrame(model.predict(datasets["X_test"][i : i + batch_size], verbose=0, batch_size=batch_size)[:,0,:])
    df_true = pd.DataFrame(datasets["y_test"][i : i + batch_size][:,0,:])
    fig.add_trace(
        go.Scatter(
            x=REFERENCE_CHARGE_CAPACITIES[-batch_size:],
            y=df_pred[0].values*(VOLTAGE_MAX - VOLTAGE_MIN) + VOLTAGE_MIN,
            showlegend=True,
            mode="markers",
            name = f"Predicted Curve {i+1}",
            marker_color=pallete[i//skip]
        )
    )
    
    fig.add_trace(
        go.Scatter(
            x=REFERENCE_CHARGE_CAPACITIES[-batch_size:],
            y=df_true[0].values*(VOLTAGE_MAX - VOLTAGE_MIN) + VOLTAGE_MIN,
            showlegend=True,
            mode="lines",
            name = f"True Curve {i+1}",
            line_color=pallete[i//skip]
        )
    )

fig.update_yaxes(title="Voltage [V]")
fig.update_xaxes(title="State of Charge (Normalized Capacity)")

In [None]:
fig = go.Figure()
for i in range(0, len(datasets["X_test"]), batch_size * skip):
    df_pred = pd.DataFrame(model.predict(datasets["X_test"][i : i + batch_size], batch_size=batch_size, verbose=0)[:,0,:])
    df_true = pd.DataFrame(datasets["y_test"][i : i + batch_size][:,0,:])
    fig.add_trace(
        go.Scatter(
            x=REFERENCE_CHARGE_CAPACITIES[-batch_size:],
            y=df_pred[1].values*(TEMPERATURE_MAX - TEMPERATURE_MIN) + TEMPERATURE_MIN,
            showlegend=True,
            mode="markers",
            name = f"Predicted Curve {i+1}",
            marker_color=pallete[i//skip]
        )
    )
    
    fig.add_trace(
        go.Scatter(
            x=REFERENCE_CHARGE_CAPACITIES[-batch_size:],
            y=df_true[1].values*(TEMPERATURE_MAX - TEMPERATURE_MIN) + TEMPERATURE_MIN,
            showlegend=True,
            mode="lines",
            name = f"True Curve {i+1}",
            line_color=pallete[i//skip]
        )
    )

fig.update_yaxes(title="Temperature [°C]")
fig.update_xaxes(title="State of Charge (Normalized Capacity)")