# Create Visualizations for the Paper

## Experiment 1

In [None]:
cd ..

In [None]:
import glob
import yaml
import torch
import numpy as np

import plotly.graph_objects as go
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display

from models.MLP import MLP
from models.GRU import GRU
from models.TCN import TCN
from models.TcnAe import TcnAe
from models.Transformer import Transformer
from data.data_module import ThreeTankDataModule

### Predictions

In [None]:
dm = ThreeTankDataModule()
dm.setup()

In [None]:
def plot_sample_data(ds, index, title=None):
    x1, x2 = ds[index]
    x = np.concatenate((x1, x2))

    data = [go.Scatter(x=np.array(range(x.shape[0])), y=sig, name=name,
                      mode="lines", opacity=1) for sig, name in zip([x[:, 0], x[:, 1], x[:, 2]], ['h1', 'h2', 'h3'])]
    layout = go.Layout(title_text=title,
                       font_family="Serif", font_size=14,
                       margin_l=5, margin_t=50, margin_b=5, margin_r=5,
                       xaxis_title="time")
    layout.shapes = [dict(type='line', x0=len(x1), x1=len(x1), y0=0, y1=1, yref='paper', xref='x', line=dict(dash='dash'))]

    return data, layout

def interactive_sample_plot(datamodule, sample_index=0):
    datasets = datamodule.ds_dict
    scenarios = datamodule.scenarios
    def on_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            new_data, new_layout = plot_sample_data(datasets[dataset_dropdown.value], sample_dropdown.value, title=dataset_dropdown.value)
            for i, trace in enumerate(new_data):
                fig.data[i].update(trace)
            fig.layout.update(new_layout)

    dataset_dropdown = widgets.Dropdown(options=[
                                            "standard",
                                            "fault",
                                            "noise",
                                            "duration",
                                            "scale",
                                            "switch",
                                            "q1+v3",
                                            "q1+v3+rest",
                                            "v12+v23",
                                            "standard+",
                                            "standard++",
                                            "frequency"
                                            ],
                                            value="standard",
                                            description='Dataset:'
                                        )
    dataset_dropdown.observe(on_change)

    max_sample_index = 99

    sample_dropdown = widgets.Dropdown(options=list(range(max_sample_index + 1)),
                                       value=sample_index,
                                       description='Sample:')
    sample_dropdown.observe(on_change)

    data, layout = plot_sample_data(datasets["standard"], sample_index, title=scenarios[0])
    fig = go.FigureWidget(data=data, layout=layout)

    display(widgets.VBox([dataset_dropdown, sample_dropdown, fig]))

In [None]:
interactive_sample_plot(dm)

In [None]:
def plot_sample_forecast(sample, fcast, title=None, display=True):
    """Plots forecast of tank levels and settings for one sample.
    Args:
        sample: torch.Tensor, shape (seq_len, 3), sample from dataloader
        fcast: torch.Tensor, shape (pred_len, 3), forecast of sample
        title: str, title of plot
        display: bool, if True, plot is displayed, else returned
    """
    x1, x2 = sample
    pred_x2 = np.squeeze(fcast)
    x = np.concatenate((x1, x2))
    colors = [
        '#1f77b4',  # muted blue
        '#d62728',  # brick red
        '#2ca02c',  # cooked asparagus green
        '#17becf',  # blue-teal
        '#ff7f0e',  # safety orange
        '#bcbd22',  # curry yellow-green
        '#9467bd',  # muted purple
        '#8c564b',  # chestnut brown
        '#e377c2',  # raspberry yogurt pink
        '#7f7f7f',  # middle gray
    ]

    fig = go.Figure()
    for sig, name, c in zip([x[:, 0], x[:, 1], x[:, 2]],
                            ['h1', 'h2', 'h3'],
                            colors[:3]):
        fig.add_trace(go.Scatter(x=np.array(range(x.shape[0])), y=sig, name=name,
                      mode="lines", opacity=1, line=dict(color=c)))
    for sig, name, c in zip([pred_x2[:, 0], pred_x2[:, 1], pred_x2[:, 2]],
                            ['pred_h1', 'pred_h2', 'pred_h3'],
                            colors[3:7]):
        fig.add_trace(go.Scatter(x=np.array(range(x1.shape[0], x1.shape[0] + x2.shape[0])), y=sig, name=name,
                      mode="lines", opacity=1, line=dict(color=c, dash="dot")))

    fig.add_vline(x=len(x1), line_dash="dash")
    fig.update_xaxes(tick0=0, dtick=50)
    fig.update_xaxes(title_text=r'time')
    fig.update_layout(width=800, height=500,
                      font_family="Serif", font_size=14,
                      margin_l=5, margin_t=50, margin_b=5, margin_r=5)
    if title is not None:
        fig.update_layout(title=title)
    if display:
        fig.show()
    else:
        return fig

In [None]:
model_path = "logs/comp/MLP/MLP_dhl[256, 512, 256]_bnFalse/"
# load latest checkpoint
ckpt_path = glob.glob(f"{glob.escape(model_path)}/checkpoints/*.ckpt")
# load hparams
with open(model_path + "/hparams.yaml", "r") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)
model_name = MLP(**hparams)
model_name = MLP.load_from_checkpoint(ckpt_path[-1])


In [None]:
sample_idx = 22
sample = dm.ds_dict["fault"][sample_idx]
fcast = model_name(torch.tensor(sample[0]).unsqueeze(0)).squeeze(0).detach().numpy()
plot_sample_forecast(sample, fcast, title="MLP Forecast")

In [None]:
def plot_forecast(sample, fcasts, title=None, display=True, sample_idx_start=0, model_names=None):
    """Plots forecast of tank levels and settings for one sample.
    Args:
        sample: torch.Tensor, shape (seq_len, 3), sample from dataloader
        fcasts: dict, keys are model names and values are forecasts
        title: str, title of plot
        display: bool, if True, plot is displayed, else returned
    """
    x1, x2 = sample
    x1 = x1[sample_idx_start:]
    
    x = np.concatenate((x1, x2))
    if model_names is None:
        model_names = list(fcasts.keys())
    colors = {
        'sample': '#7f7f7f',  # muted grey
        'MLP': '#2ca02c',  # cooked asparagus green
        'GRU': '#1f77b4',  # muted blue
        'GRU-AR': '#17becf',  # blue-teal
        'TCN': '#bcbd22',  # curry yellow-green
        'TCN-AE': '#9467bd',  # muted purple
        'Transformer': '#d62728',  # brick red
        'Transf.-CE': '#ff7f0e',  # safety orange
    }

    fig = go.Figure()
    for sig, name in zip([x[:, 0], x[:, 1], x[:, 2]], ['h1', 'h2', 'h3']):
        fig.add_trace(go.Scatter(x=np.array(range(x.shape[0])), y=sig, name=name,
                      mode="lines", opacity=.7, line=dict(color=colors['sample']),
                      showlegend=False))

    for model, fcast in fcasts.items():
        if model not in model_names:
            continue
        pred_x2 = np.squeeze(fcast)
        for idx, sig in enumerate([pred_x2[:, 0], pred_x2[:, 1], pred_x2[:, 2]]):
            # fig.add_trace(go.Scatter(x=np.array(range(x1.shape[0], x1.shape[0] + x2.shape[0])), y=sig, name=f'{model}_{name}',
            #           mode="lines", opacity=1, line=dict(color=colors[model])))
            fig.add_trace(go.Scatter(x=np.array(range(x1.shape[0], x1.shape[0] + x2.shape[0])), y=sig, 
                      name=f'{model}', mode="lines", opacity=1, 
                      line=dict(color=colors[model]), 
                      legendgroup=model,
                      showlegend=True if idx == 0 else False,  # Show legend only for first trace of each model
                      hovertemplate=f'{model}'))

    fig.add_vline(x=len(x1), line_dash="dash", opacity=.7)
    # fig.update_xaxes(tick0=0, dtick=50)
    fig.update_xaxes(title_text="Time Step")
    fig.update_yaxes(title_text="Tank Level")
    fig.update_layout(
        autosize=False,
        width=800,
        height=400,
        font=dict(family="Serif", size=18),
        margin=dict(l=5, r=5, b=5, t=50, pad=4, autoexpand=True),
        legend=dict(
            title="Models",
            traceorder="normal",
            font=dict(
                family="sans-serif",
                size=14,
            ),
            # x=0,
            # y=1,
            # bgcolor="white",
            # bordercolor="Black",
            # borderwidth=2,
            # xanchor='right',
            # yanchor='top'
        ),
    )
    if title is not None:
        fig.update_layout(title=title)
    # save plot
    pio.write_image(fig, "visualizations/forecast.pdf")
    if display:
        fig.show()
    else:
        return fig


In [None]:
# load all models
model_paths = [
    "logs/comp/MLP/MLP_dhl[256, 512, 256]_bnFalse/",
    "logs/comp/GRU/GRU_dh256_nl1_aFalse/",
    "logs/comp/GRU/GRU_dh256_nl1_aTrue/",
    "logs/comp/TCN/TCN_ks9_nc(64, 128, 64)_d0/",
    "logs/comp/TcnAe/TcnAe_ld16_etid(3, 50, 40, 30)_etod(50, 40, 30, 10)_etid(10, 8, 6, 3)_etod(8, 6, 3, 1)_ks15/",
    "logs/comp/Transformer/Transformer_dm16_df256_nle4_nld4_nh4_d0_ceFalse/",
    "logs/comp/Transformer/Transformer_dm16_df256_nle4_nld4_nh4_d0_ceTrue_cks3/"
]
models = list()
for model_path in model_paths:
    # load latest checkpoint
    ckpt_path = glob.glob(f"{glob.escape(model_path)}/checkpoints/*.ckpt")
    # load hparams
    with open(model_path + "/hparams.yaml", "r") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)
    model_name = model_path.split("/")[2]
    # load model
    if model_name == "MLP":
        model = MLP(**hparams)
    elif model_name == "GRU":
        model = GRU(**hparams)
    elif model_name == "TCN":
        model = TCN(**hparams)
    elif model_name == "TcnAe":
        model = TcnAe(**hparams)
    elif model_name == "Transformer":
        model = Transformer(**hparams)
    model = model.load_from_checkpoint(ckpt_path[-1])
    models.append(model)
model_names = [
    "MLP",
    "GRU",
    "GRU-AR",
    "TCN",
    "TCN-AE",
    "Transformer",
    "Transf.-CE"
]
models = dict(zip(model_names, models))

sample_idx = 49
# 35, 37
sample = dm.ds_dict["fault"][sample_idx]
fcasts = dict()
for model_name, model in models.items():
    fcast = model(torch.tensor(sample[0]).unsqueeze(0)).squeeze(0).detach().numpy()
    fcasts[model_name] = fcast

plot_forecast(sample, fcasts, title="Forecasts on Data with Faulty Sensors", sample_idx_start=150,
              model_names=["MLP", "Transformer"]
              )

### Param Count

In [None]:
models = {
    "MLP": {"trainable": 493718},
    "GRU": {"trainable": 238998},
    "GRU-AR": {"trainable": 201219},
    "TCN": {"trainable": 776003},
    "TCN-AE": {"trainable": 420630},
    "Transformer": {"trainable": 81587},
    "Transf.-CE": {"trainable": 81683},
}

model_names = list(models.keys())
trainable = [models[name]["trainable"] for name in model_names]

fig = go.Figure(data=[
    go.Bar(name='Trainable Parameters', x=model_names, y=trainable)
])

# Change the bar mode and adjust the figure size
fig.update_layout(
    title='Model Parameter Counts', 
    # xaxis_title='Model', 
    yaxis_title='Number of Parameters',
    autosize=False,
    height=400,
    width=800,
    font=dict(family="Serif", size=18),
    margin=dict(l=5, t=50, b=5, r=5),
    )

pio.write_image(fig, "visualizations/parameter_counts.pdf")

fig.show()


## Experiment 2

### Finetuning

In [None]:

# Define the data
models = ['MLP', 'GRU', 'GRU-AR', 'TCN', 'TCN-AE', 'Transformer', 'Transf.-CE']
epochs = [0, 1, 5, 10, 20, 50]
mse_dict = {
    'scale': {
        'MLP': [0.53, 5.03, 5.03, 0.93, 0.93, 0.53],
        'GRU': [10.27, 5.17, 2.27, 1.69, 0.95, 0.53],
        'GRU-AR': [14.80, 6.27, 2.81, 1.93, 1.12, 0.53],
        'TCN': [1.52, 4.05, 3.29, 3.29, 2.60, 1.32],
        'TCN-AE': [1.28, 22.18, 1.78, 0.96, 0.96, 1.49],
        'Transformer': [13.06, 12.96, 9.85, 5.77, 4.58, 2.59],
        'Transf.-CE': [19.04, 13.3, 6.44, 4.94, 4.02, 1.66]
    },
    'q1+v3': {
        'MLP': [6.01, 6.13, 1.49, 1.19, 1.19, 0.44],
        'GRU': [4.79, 1.72, 0.83, 0.68, 0.53, 0.35],
        'GRU-AR': [2.56, 1.03, 1.03, 1.03, 1.03, 0.35],
        'TCN': [5.18, 4.02, 2.32, 2.32, 0.83, 0.57],
        'TCN-AE': [6.47, 10.36, 4.09, 2.93, 2.93, 2.93],
        'Transformer': [4.13, 8.85, 3.94, 2.21, 1.41, 0.47],
        'Transf.-CE': [4.97, 2.51, 2.51, 2.23, 1.78, 0.44]
    },
    'q1+v3+rest': {
        'MLP': [1.27, 4.30, 0.83, 0.83, 0.83, 0.43],
        'GRU': [1.28, 1.20, 0.91, 0.59, 0.40, 0.39],
        'GRU-AR': [1.32, 0.81, 0.81, 0.72, 0.56, 0.35],
        'TCN': [2.52, 8.08, 4.48, 1.68, 1.23, 0.50],
        'TCN-AE': [1.71, 1.40, 1.40, 1.40, 1.40, 1.40],
        'Transformer': [1.51, 9.64, 1.50, 1.04, 1.04, 0.46],
        'Transf.-CE': [1.57, 3.20, 1.32, 0.58, 0.55, 0.40]
    },
    'v12+v23':{
        'MLP': [6.02, 3.77, 0.92, 0.98, 0.67, 0.35],
        'GRU': [9.65, 2.67, 1.54, 0.90, 0.57, 0.37],
        'GRU-AR': [8.12, 4.62, 1.73, 1.08, 0.51, 0.36],
        'TCN': [8.28, 4.04, 2.77, 2.74, 1.36, 1.02],
        'TCN-AE': [7.21, 2.19, 2.19, 2.19, 2.19, 2.19],
        'Transformer': [9.45, 4.82, 3.81, 2.56, 1.79, 0.99],
        'Transf.-CE': [8.06, 3.50, 2.69, 2.60, 1.65, 0.79]
    }
}

scenario_titles = {
    'scale': 'Scaled',
    'q1+v3': "'Independent Phase Merging'",
    'q1+v3+rest': 'Independent-Phase-with-Rest',
    'v12+v23': 'Dependent-Phases'
}
colors = {
    'sample': '#7f7f7f',  # muted grey
    'MLP': '#2ca02c',  # cooked asparagus green
    'GRU': '#1f77b4',  # muted blue
    'GRU-AR': '#17becf',  # blue-teal
    'TCN': '#ff7f0e',  # safety orange
    'TCN-AE': '#9467bd',  # muted purple
    'Transformer': '#d62728',  # brick red
    'Transf.-CE': '#bcbd22',  # curry yellow-green
}
for scenario in ['q1+v3']:
    fig = go.Figure()

    for model_name in models:
        fig.add_trace(go.Scatter(x=epochs, 
                                 y=mse_dict[scenario][model_name], 
                                 mode='lines+markers', 
                                 name=model_name, 
                                 line=dict(color=colors[model_name])
                                 )
        )

    fig.update_layout(
        title='Performance on ' + scenario_titles[scenario] + ' Scenario after Fine-Tuning',
        xaxis_title='Finetuning Epochs',
        yaxis_title='Mean Squared Error (MSE)',
        autosize=False,
        width=800,
        height=400,
        font=dict(family="Serif", size=18),
        margin=dict(l=5, r=5, b=5, t=50, pad=4, autoexpand=True),
        legend=dict(
            title="Models",
            traceorder="normal",
            font=dict(
                family="sans-serif",
                size=14,
            ),
            # x=1,
            # y=1,
            # bgcolor="white",
            # bordercolor="Black",
            # borderwidth=2,
            # xanchor='right',
            # yanchor='top'
        ),
        # plot_bgcolor='white',
        # xaxis=dict(
        #     showgrid=True,
        #     gridwidth=0.5,
        #     gridcolor='lightgrey'
        # ),
        # yaxis=dict(
        #     showgrid=True,
        #     gridwidth=0.5,
        #     gridcolor='lightgrey'
        # ),
    )

    #fig.update_yaxes(type="log")

    # save as pdf
    pio.write_image(fig, f"visualizations/ftune_{scenario}.pdf")

    fig.show()

