In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from mvf_bto.data_loading import load_data
from mvf_bto.constants import * 
from mvf_bto.models.one_shot import OneShot
from mvf_bto.preprocessing.one_shot import create_discharge_inputs, REFERENCE_DISCHARGE_CAPACITIES

from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.metrics import MeanSquaredError
import tensorflow as tf
from scipy.interpolate import interp1d

import numpy as np
import pandas as pd
import plotly
import plotly.graph_objects as go

## Loading Data

In [None]:
data_path = "/Users/anoushkabhutani/PycharmProjects/10701-mvf-bto/data/2017-05-12_batchdata_updated_struct_errorcorrect.mat"
# data_path = "/Users/mac/Desktop/CMU/10701MachineLearning/project/10701-mvf-bto-backup/data/2017-05-12_batchdata_updated_struct_errorcorrect.mat"

In [None]:
data = load_data(file_path=data_path, num_cells=40)

## Preprocess Datasets

In [None]:
train_split = 0.7
test_split = 0.2
history_window = 16

datasets = create_discharge_inputs(data, train_split, test_split, history_window=history_window)

## Model Training

In [None]:
n_features = datasets['X_train'].shape[-1]
output_dimension = datasets['y_train'].shape[-1]
input_shape=(history_window, n_features)

In [None]:
def custom_loss_function(y_true, y_pred):   
    split = y_true.shape[-1]//2
    return tf.reduce_mean((y_true[:,:split]-y_pred[:,:split])**2 + 10*(y_true[:,split:]-y_pred[:,split:])**2, axis=-1)


model = OneShot(input_shape=input_shape, n_outputs=output_dimension)

es = EarlyStopping(
    monitor="val_mean_squared_error",
    min_delta=0,
    patience=80,
    verbose=1,
    mode="min",
    restore_best_weights=True,
)

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.01,
    decay_steps=1000,
    decay_rate=0.96)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

model.compile(optimizer=optimizer,
              loss=custom_loss_function, 
              metrics=[MeanSquaredError()])

history = model.fit(
    datasets["X_train"],
    datasets["y_train"],
    validation_data=(datasets["X_val"], datasets["y_val"]),
    epochs=500,
    callbacks=[es],
    batch_size=128,
    shuffle=False,
    verbose=1,

)

## Parity Plots

In [None]:
skip = 1000
yhat = model.predict(datasets['X_test'])
y_test = datasets['y_test']
fig = go.Figure()
fig.add_trace(go.Scatter(x = y_test[:, :y_test.shape[-1]//2].flatten()[::skip],
                         y= yhat[:, :y_test.shape[-1]//2].flatten()[::skip], mode="markers", showlegend=False))
fig.add_trace(go.Scatter(x = [0,1], y=[0,1], mode="lines", showlegend=False))
fig.update_xaxes(title='Normalize Voltage Target')
fig.update_yaxes(title='Normalize Voltage Prediction')

In [None]:
yhat = model.predict(datasets['X_test'])
y_test = datasets['y_test']
fig = go.Figure()
fig.add_trace(go.Scatter(x = y_test[:, y_test.shape[-1]//2:].flatten()[::skip],
                         y= yhat[:, y_test.shape[-1]//2:].flatten()[::skip], mode="markers", showlegend=False))
fig.add_trace(go.Scatter(x = [0,1], y=[0,1], mode="lines", showlegend=False))
fig.update_xaxes(title='Normalize Temperature Target')
fig.update_yaxes(title='Normalize Temperature Prediction')

## Predicted vs Actual Traces

In [None]:
pallete = plotly.colors.qualitative.Dark24 + plotly.colors.qualitative.T10
pallete = pallete*70000

In [None]:

labels = list(datasets['original_test'].groupby(["Cycle", "Cell"]).groups.keys())
fig = go.Figure()
for i in range(10,len(yhat), 1000):
    fig.add_trace(go.Scatter(x = datasets['q_eval_test'][i], marker_color=pallete[i], 
                             name = f"Cell {labels[i][1]} Cycle {labels[i][0]}",
                             y=y_test[i][:y_test.shape[-1]//2], mode="lines"))
    fig.add_trace(go.Scatter(x = datasets['q_eval_test'][i], line_color=pallete[i], name = "Prediction",
                             y=yhat[i][:yhat.shape[-1]//2], mode="markers"))

In [None]:
fig.update_xaxes(title='Capacity [A]')
fig.update_yaxes(title='Voltage [V]')

In [None]:
yhat = model.predict(datasets['X_test'])
y_test = datasets['y_test']
labels = list(datasets['original_test'].groupby(["Cycle", "Cell"]).groups.keys())
fig = go.Figure()
for i in range(10,len(yhat), 1000):
    fig.add_trace(go.Scatter(x = datasets['q_eval_test'][i], marker_color=pallete[i], 
                             name = f"Cell {labels[i][1]} Cycle {labels[i][0]}",
                             y=y_test[i][y_test.shape[-1]//2:], mode="lines"))
    fig.add_trace(go.Scatter(x = datasets['q_eval_test'][i], line_color=pallete[i], name = "Prediction",
                             y=yhat[i][yhat.shape[-1]//2:], mode="markers"))

In [None]:
fig.update_xaxes(title='Capacity [A]')
fig.update_yaxes(title='Temperature [degC]')