In [None]:
import torch
from meas_get_data import *
from meas_NN_classes import *
from meas_dataset import *
import os 
from tqdm import tqdm
import logging
from meas_test_func_fs import *
from meas_dataloader import *
from meas_train_funcs import *
from model_params import get_model_params
torch.set_default_dtype(torch.float32)


In [None]:
parameter_set = get_model_params(testing_mode=False)
model_paths = {
    "OR_LSTM": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_OR_LSTM_expnumb_640.pth",
    "OR_MLP": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_OR_MLP_expnumb_123.pth",
    "OR_TCN": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_OR_TCN_expnumb_644.pth",
    "LSTM": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_LSTM_expnumb_564.pth",
    "MLP": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_MLP_expnumb_367.pth",
    "TCN": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_TCN_expnumb_150.pth",
    "OR_RNN": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_OR_RNN_expnumb_123.pth",
    "RNN": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_RNN_expnumb_123.pth",
    "OR_GRU": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_OR_GRU_expnumb_123.pth",
    "GRU": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_GRU_expnumb_123.pth"}


if os.name == "nt":
    path_test_data=r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\messdaten\TEST-Messdaten_30traj_7times500steps.csv"
else:
    path_train_data=r"/home/rdpusr/Documents/Code_4_paper/messdaten/TEST-Messdaten_30traj_7times500steps.csv"
test_data = get_data(path_test_data,num_inits=0)

for parameters in parameter_set:

    if not os.path.exists(model_paths[parameters["model_flag"]]):
                continue    

    inp_size = 4
    output_size = 2
    nn_folder = "Trained_networks" 
        
    # Initialize the LSTM model
    # Use the flag to confirm the right model is used and to save it
    if "LSTM" in parameters["model_flag"]:
        model = OR_LSTM(input_size=inp_size, hidden_size=parameters["h_size"], out_size=output_size, 
                        layers=parameters["l_num"], window_size=parameters["window_size"], flag=parameters["model_flag"]).to(device)

    if "MLP" in parameters["model_flag"]:

        model = OR_MLP(input_size=inp_size*parameters["window_size"], hidden_size=parameters["h_size"],
                        output_size=output_size, l_num=parameters["l_num"], window_size=parameters["window_size"], flag=parameters["model_flag"]).to(device)

    if "TCN" in parameters["model_flag"]:
        num_channels = [parameters["n_hidden"]] * parameters["levels"]
        model = OR_TCN(input_size=inp_size, output_size=output_size , num_channels=num_channels,
                        kernel_size=parameters["kernel_size"], dropout=parameters["dropout"], windowsize=parameters["window_size"], flag=parameters["model_flag"]).to(device)

    if "RNN" in parameters["model_flag"]:
        model = OR_RNN(input_size=inp_size, hidden_size=parameters["h_size"], out_size=output_size, 
                        layers=parameters["l_num"], window_size=parameters["window_size"], flag=parameters["model_flag"]).to(device)
        
    if "GRU" in parameters["model_flag"]:
        model = OR_GRU(input_size=inp_size, hidden_size=parameters["h_size"], out_size=output_size, 
                        layers=parameters["l_num"], window_size=parameters["window_size"], flag=parameters["model_flag"]).to(device)

    model.load_state_dict(torch.load(model_paths[model.get_flag()], map_location=device))

    if parameters["model_flag"] != model.get_flag():
        print("Parameter list model flag does not match the model flag used!")
    print(model.get_flag())
    
    test_error = test(data=test_data, model=model, window_size=parameters["window_size"], specific_index=1, display_plots=1)


In [None]:
#ROBOT ARM

parameter_set = get_model_params(testing_mode=False, robot_mode=True)
model_paths = {
    "OR_LSTM": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks_robot\modeltype_OR_LSTM_robot.pth",
    "OR_MLP": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_OR_MLP_expnumb_123.pth",
    "OR_TCN": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_OR_TCN_expnumb_644.pth",
    "OR_GRU": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks_robot\modeltype_OR_GRU_robot_v3.pth",
    "OR_RNN": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks_robot\modeltype_OR_RNN_robot_v3.pth",
    "LSTM": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_LSTM_expnumb_564.pth",
    "MLP": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks_robot\modeltype_MLP_robot_v3.pth",
    "TCN": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks\modeltype_TCN_expnumb_150.pth",
    "RNN": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks_robot\modeltype_RNN_robot_v3.pth",
    "GRU": r"C:\Users\StrasserP\Documents\NN_Paper\Code_4_paper\Trained_networks_robot\modeltype_GRU_robot_v3.pth"}


train_data, test_data = get_data_robot()

for parameters in parameter_set:

    if not os.path.exists(model_paths[parameters["model_flag"]]):
                continue    

    inp_size = 12
    output_size = 6
    nn_folder = "Trained_networks_robot"

        
    # Initialize the LSTM model
    # Use the flag to confirm the right model is used and to save it
    if "LSTM" in parameters["model_flag"]:
        model = OR_LSTM(input_size=inp_size, hidden_size=parameters["h_size"], out_size=output_size, 
                        layers=parameters["l_num"], window_size=parameters["window_size"], flag=parameters["model_flag"]).to(device)

    if "MLP" in parameters["model_flag"]:

        model = OR_MLP(input_size=inp_size*parameters["window_size"], hidden_size=parameters["h_size"],
                        output_size=output_size, l_num=parameters["l_num"], window_size=parameters["window_size"], flag=parameters["model_flag"]).to(device)

    if "TCN" in parameters["model_flag"]:
        num_channels = [parameters["n_hidden"]] * parameters["levels"]
        model = OR_TCN(input_size=inp_size, output_size=output_size , num_channels=num_channels,
                        kernel_size=parameters["kernel_size"], dropout=parameters["dropout"], windowsize=parameters["window_size"], flag=parameters["model_flag"]).to(device)

    if "RNN" in parameters["model_flag"]:
        model = OR_RNN(input_size=inp_size, hidden_size=parameters["h_size"], out_size=output_size, 
                        layers=parameters["l_num"], window_size=parameters["window_size"], flag=parameters["model_flag"]).to(device)
        
    if "GRU" in parameters["model_flag"]:
        model = OR_GRU(input_size=inp_size, hidden_size=parameters["h_size"], out_size=output_size, 
                        layers=parameters["l_num"], window_size=parameters["window_size"], flag=parameters["model_flag"]).to(device)

    model.load_state_dict(torch.load(model_paths[model.get_flag()], map_location=device))

    if parameters["model_flag"] != model.get_flag():
        print("Parameter list model flag does not match the model flag used!")
    print(model.get_flag())
    
    test_error = test(data=test_data, model=model, window_size=parameters["window_size"], specific_index=1, display_plots=1, rescale=False)


In [None]:

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.set_default_dtype(torch.float32)


def plot_results(x, pred, rescale=False, window_size=1, sim_data_index=0, features=4):
    

    if rescale:
        x = normalize_invert(x)
        pred = normalize_invert(pred)

    if x.dim() == 3:
        x = x.view(x.size(dim=1), x.size(dim=2))
    if pred.dim() == 3:
        pred = pred.view(pred.size(dim=1), pred.size(dim=2))

    if features == 4:
        figure , axs = plt.subplots(4,1, figsize=(9,9))
        figure.tight_layout(pad=2.0)

        stepsize = 2 * 2e-5 
        time = np.linspace(0,x.size(dim=0)* stepsize, x.size(dim=0))

        axs[0].plot(time, x.detach().cpu().numpy()[:, 0], color="darkgreen", label="data")
        axs[0].set_title("NC : Input Voltage 1")
        axs[0].set_ylabel("[V]")

        axs[1].plot(time, x.detach().cpu().numpy()[:, 1], color="darkgreen", label="data")
        axs[1].set_title("NO : Input Voltage 2")
        axs[1].set_ylabel("[V]")

        axs[2].plot(time, pred.detach().cpu().numpy()[:, 2], color="red", label="pred")
        axs[2].plot(time, x.detach().cpu().numpy()[:, 2], color="blue", label="data", linestyle="dashed")
        axs[2].set_title("pressure")
        axs[2].set_ylabel("[Pa]")

        axs[3].plot(time, pred.detach().cpu().numpy()[:, 3], color="red", label="pred")
        axs[3].plot(time, x.detach().cpu().numpy()[:, 3], color="blue", label="data", linestyle="dashed")
        axs[3].set_title("position")
        axs[3].set_ylabel("[m]")
        axs[3].set_xlabel(f"time [s]")

        axs[2].axvline(x=time[window_size], color='black', linestyle='--', label='start of prediction')
        axs[3].axvline(x=time[window_size], color='black', linestyle='--', label='start of prediction')


        if rescale and features==4:
            u1_max = 200  #Spannung in [V]              ... [0, 200]
            u1_min = 0
            u2_max = 200
            u2_min = 0
            p_max = 3.5*1e5 #Druck in [bar]             ... [1, 3.5]
            p_min = 1.0*1e5 #Umgebungsdruck in [bar]
            s_max = 0.605*1e-3     #Position [m]          ... [0, 0.0006]
            s_min = 0.0

            axs[0].set_ylim(u1_min-10, u1_max+10)
            axs[1].set_ylim(u2_min-10, u2_max+10)
            axs[2].set_ylim(p_min-0.1*p_max, p_max+0.1*p_max)
            axs[3].set_ylim(s_min-+0.1*s_max, s_max+0.1*s_max)

    else:
        figure , axs = plt.subplots(12,1, figsize=(9,9))
        #figure.tight_layout(pad=2.0)

        stepsize = 0.1 # 100 ms ???
        time = np.linspace(0,x.size(dim=0)* stepsize, x.size(dim=0))

        for i in range(6):
            axs[i].plot(time, x.detach().cpu().numpy()[:, i], color="darkgreen", label="data")

            axs[i].set_title(f"tau {i}")
            axs[i].set_ylabel("[Nm]")
        for i in range(6,12):
            axs[i].plot(time, pred.detach().cpu().numpy()[:, i], color="red", label="pred")
            axs[i].plot(time, x.detach().cpu().numpy()[:, i], color="blue", label="data")

            axs[i].set_title(f"q {i}")
            axs[i].set_ylabel("[deg°]")

        #axs[2].axvline(x=time[window_size], color='black', linestyle='--', label='start of prediction')
        #axs[3].axvline(x=time[window_size], color='black', linestyle='--', label='start of prediction')

    for i in range(2*features):
        axs[i].grid(True)
        axs[i].legend()



    plt.grid(True)
    plt.legend()
    plt.show()

def plot_histogramm(error_position : list,
                    error_pressure : list,
                    error_position_simulink : list,
                    error_pressure_simulink : list)->None:
    
    SMALL_SIZE = 15
    MEDIUM_SIZE = 25
    plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=15)    # legend fontsiz

    params = {'axes.titlesize': 20}
    plt.rcParams.update(params)

    fig, (ax1, ax2) = plt.subplots(2, 1)#, figsize=(9,9))  # Create a figure with 2 subplots

    bins = 5
    ax1.hist(error_position, alpha=0.5, label='position neural net', color='red', edgecolor = "black", hatch = "//", linewidth=2)
    ax1.hist(error_position_simulink, alpha=0.3, label='position simulink', color='green', edgecolor = "black", hatch = "||", linewidth=2)
    
    ax1.axvline(np.array(error_position).mean(), color='orange', linewidth=2, label=f"Mean {np.round(np.array(error_position).mean(),5)}: MSE position")
    ax1.axvline(np.array(error_position_simulink).mean(), color='blue', linewidth=2, label=f"Mean {np.round(np.array(error_position_simulink).mean(),5)}: MSE position simulink")
    
    ax1.set_xlabel('MSE position')
    ax1.legend()
    ax1.grid(True)

    # Plot C and D in the second subplot (ax2)
    ax2.hist(error_pressure, bins=bins, alpha=0.5, label='pressure neural net', color='red', edgecolor = "black", hatch = "//", linewidth=2)
    ax2.hist(error_pressure_simulink, bins=bins, alpha=0.3, label='pressure simulink', color='green', edgecolor = "black", hatch = "||", linewidth=2)

    ax2.axvline(np.array(error_pressure).mean(), color='orange', linewidth=2, label=f"Mean {np.round(np.array(error_pressure).mean(),5)}: MSE pressure")
    ax2.axvline(np.array(error_pressure_simulink).mean(), color='blue', linewidth=2, label=f"Mean {np.round(np.array(error_pressure_simulink).mean(),5)}: MSE pressure simulink")
    
    ax2.set_xlabel('MSE pressure')
    ax2.legend()
    ax2.grid(True)

    ax1.set_ylabel(f"frequency", fontsize=15)
    ax2.set_ylabel(f"frequency", fontsize=15)

    # Display the figure
    plt.tight_layout()
    plt.show()
    
    return

def test(
    data, 
    model, 
    window_size: int = 1, 
    display_plots: bool = False, 
    numb_of_inits: int = 1, 
    fix_random: bool = True, 
    rescale: bool = False, 
    specific_index: int = -1, 
    error_histogramm: bool = False
) -> float:

    if window_size==1:
        warnings.warn("Window size 1 is not supported!", UserWarning)

    if fix_random:
     np.random.seed(1234)
    else:
     np.random.seed(seed=None)
      
    test_inits = data.size(dim=0)
    ids = np.random.choice(test_inits, min([numb_of_inits, test_inits]), replace=False)
    ids = np.unique(ids)

    

    if specific_index >= 0:
        data = data[specific_index:specific_index+1,:, :]
    else :
        if not plot_histogramm: #use all data for histogramm
         data = data[ids,:, :]   

    loss_fn = nn.MSELoss()
    timesteps = data.size(dim=1)
    features = int(data.size(dim=2) / 2)

    total_loss = 0

    error_position = []
    error_pressure = []

    error_position_simulink = []
    error_pressure_simulink = []

    for i, x in enumerate(data):

        with torch.inference_mode():
            


            x=x.to(device)        
            x = x.view(1,x.size(dim=0), x.size(dim=1))                
            pred = torch.zeros((timesteps, 2*features), device=device)

            pred[0:window_size, :] = x[0, 0:window_size, :]
            pred[:, 0] = x[0, :, 0]

            x_test = x.clone()
            x_test[:,window_size:,features:] = 0
            x_test = x_test.to(device)

            # note that both OR and regular NNs use the same function during inferece
            # which is just the forward pass of the OR model
            if model.get_flag() in ["OR_LSTM", "LSTM", "OR_RNN", "RNN", "OR_GRU", "GRU"]:
                out, _ = model(x_test)
                pred[window_size:, features:] = out
 
            if model.get_flag() in ["OR_MLP", "MLP"]:
                out = model(x_test)
                pred[window_size:, features:] = out

            if model.get_flag() in ["OR_TCN", "TCN"]:
                out = model(x_test.transpose(1,2))
                pred[window_size:, features:] = out.squeeze(0).transpose(0,1)
            
            total_loss += loss_fn(pred[window_size:, features:], x[0, window_size:, features:]).detach().cpu().numpy()

            #error_position.append(loss_fn(pred[window_size:, 2:3], x[0, window_size:, 2:3]).detach().cpu().numpy())
            #error_pressure.append(loss_fn(pred[window_size:, 3:4], x[0, window_size:, 3:4]).detach().cpu().numpy())

            if display_plots:
                if specific_index>=0:
                    plot_results(x, pred, rescale=rescale, window_size=window_size, sim_data_index=specific_index, features=features)
                else:
                    plot_results(x, pred, rescale=rescale, window_size=window_size, sim_data_index=ids[i], features=features)
    
    #if error_histogramm:
        
        #plot_histogramm(error_position, error_pressure, error_position_simulink, error_pressure_simulink)

    return total_loss/data.size(0)