In [1]:
from get_data import *
from dataloader import *
from NN_classes import *
from experiments import *

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.set_default_dtype(torch.float64)

In [2]:

# Importing necessary libraries
import matplotlib.pyplot as plt
import torch
from torch import nn
import numpy as np
import torchcde
import time


def plot_results(x, pred, pred_next_step=None, physics_rescaling=None, additional_data=None):

    if x.dim() == 3:
        x = x.view(x.size(dim=1), x.size(dim=2))
    if pred.dim() == 3:
        pred = pred.view(pred.size(dim=1), pred.size(dim=2))
    if pred_next_step != None:
        if pred_next_step.dim() == 3:
            pred_next_step = pred_next_step.view(pred_next_step.size(dim=1), pred_next_step.size(dim=2))

        #scale back:    
    if physics_rescaling != None:

        # we invert:
        # x = (x - xmin)/(xmax - xmin)
        # x * (xmax - xmin) + xmin

        pred[:,0] = pred[:,0]*(physics_rescaling[0] - physics_rescaling[3]) + physics_rescaling[3]
       # pred[:,0] = pred[:,0]/1e5
        pred[:,1] = pred[:,1]*(physics_rescaling[1] - physics_rescaling[4]) + physics_rescaling[4]
        pred[:,2] = pred[:,2]*(physics_rescaling[2] - physics_rescaling[5]) + physics_rescaling[5]
        x[:,0] = x[:,0]*(physics_rescaling[0] - physics_rescaling[3]) + physics_rescaling[3]
       # x[:,0] = x[:,0]/1e5
        x[:,1] = x[:,1]*(physics_rescaling[1] - physics_rescaling[4]) + physics_rescaling[4]
        x[:,2] = x[:,2]*(physics_rescaling[2] - physics_rescaling[5]) + physics_rescaling[5]

        if additional_data != None:
            for i in range(additional_data.size(dim=0)):
                additional_data[i,:,0] = additional_data[i,:,0]*(physics_rescaling[0] - physics_rescaling[3]) + physics_rescaling[3]
                additional_data[i,:,1] = additional_data[i,:,1]*(physics_rescaling[1] - physics_rescaling[4]) + physics_rescaling[4]
                additional_data[i,:,2] = additional_data[i,:,2]*(physics_rescaling[2] - physics_rescaling[5]) + physics_rescaling[5]

    #figure , axs = plt.subplots(1,3,figsize=(20,8))
    figure , axs = plt.subplots(3,1, figsize=(16,9))
    figure.tight_layout(pad=5.0)

    greek_letterz=[chr(code) for code in range(945,970)]
    mu = greek_letterz[11]

    stepsize = 2e-5
    time = np.linspace(0,x.size(dim=0)* stepsize, x.size(dim=0))

    if pred_next_step != None:
        axs[0].plot(time, pred_next_step.detach().cpu().numpy()[:, 1], color="green", label="next step from data")
        axs[1].plot(time, pred_next_step.detach().cpu().numpy()[:, 2], color="green", label="next step from data")

    axs[0].plot(time, pred.detach().cpu().numpy()[:, 1], color="red", label="pred")
    axs[0].plot(time, x.detach().cpu().numpy()[:, 1], color="blue", label="true", linestyle="dashed")

    if additional_data != None:
        for i in range(additional_data.size(dim=0)):
           names = ["simulink", "Hub im Regler"]
           axs[0].plot(time, additional_data[i, :, 1], label=names[i])

    axs[0].set_title("position")
    axs[0].set_ylabel("[m]")
    axs[0].set_xlabel(f"time [s]")
    axs[0].grid()
    axs[0].legend()


    axs[1].plot(time, pred.detach().cpu().numpy()[:, 2], color="red", label="pred")
    axs[1].plot(time, x.detach().cpu().numpy()[:, 2], color="blue", label="true", linestyle="dashed")
    if additional_data != None:
        for i in range(additional_data.size(dim=0)):
           names = ["simulink", "Hub im Regler"]
           axs[1].plot(time, additional_data[i, :, 2], label=names[i])
    axs[1].set_title("speed")
    axs[1].set_ylabel("[m/s]")
    axs[1].set_xlabel(f"time [s]")
    axs[1].grid()
    axs[1].legend()

    axs[2].plot(time, x.detach().cpu().numpy()[:,0], label="pressure")
    if additional_data != None:
       for i in range(additional_data.size(dim=0)):
           names = ["simulink", "Hub im Regler"]
           axs[2].plot(time, additional_data[i, :, 0], label=names[i])
    axs[2].set_title("pressure")
    axs[2].set_ylabel("[Pa]")
    axs[2].set_xlabel(f"time [s]")
    axs[2].grid()
    axs[2].legend()


   


    plt.grid(True)
    plt.legend()
    plt.show()


def test(data, model, model_type = "or_lstm", window_size=10, display_plots=False, num_of_inits = 5, set_rand_seed=True, physics_rescaling = 0, additional_data=None):

    if model_type not in ["or_lstm", "lstm", "mlp", "gru", "tcn", "or_tcn", "neural_cde", "or_mlp"]:
        print("Error: model_type = ", model_type, "available options are: [or_lstm, lstm, mlp, gru, tcm]")
        return 0

    
    device = "cpu" if data.get_device() == -1 else "cuda:0"
    
    if data.dim() != 3:
        print("data tensor has unexpected dimension", data.dim(), "expected", 3 )
        return 0
    
    timesteps = data.size(dim=1)

    model.eval()
    loss_fn = nn.MSELoss()
    test_loss = 0
    test_loss_deriv = 0
    total_loss = 0
    total_firsthalf = 0
    total_secondhalf = 0
   
    if set_rand_seed:
     np.random.seed(1234)

    test_inits = data.size(dim=0)
    ids = np.random.choice(test_inits, min([num_of_inits, test_inits]), replace=False)
    ids = np.unique(ids)
    

    if model_type in ["or_lstm", "gru"]:
        for i, x in enumerate(data):

            x=x.to(device)        
            x = x.view(1,x.size(dim=0), x.size(dim=1))

            if i not in ids:
                continue
    
            with torch.inference_mode():
    
                pred = torch.zeros((timesteps, 3), device=device)
    
                if window_size > 1:
                    pred[0:window_size, :] = x[0, 0:window_size, :]
                    pred[:, 0] = x[0, :, 0]
    
                else:
                    pred[0, :] = x[0, 0, :]
                    pred[:, 0] = x[0, :, 0]
    

                out, _ = model(x)
                pred[window_size:,1:] = out

                print(x.size(), pred.size())
                test_loss += loss_fn(pred[window_size:, 1], x[0, window_size:, 1]).detach().cpu().numpy()
                test_loss_deriv += loss_fn(pred[window_size:, 2], x[0, window_size:, 2]).detach().cpu().numpy()
                total_loss += loss_fn(pred[window_size:, 1:], x[0, window_size:, 1:]).detach().cpu().numpy()

                total_firsthalf += loss_fn(pred[window_size:int((timesteps-window_size)/2), 1:], 
                        x[0, window_size:int((timesteps-window_size)/2), 1:]).detach().cpu().numpy()  
                total_secondhalf += loss_fn(pred[int((timesteps-window_size)/2):, 1:],
                            x[0, int((timesteps-window_size)/2):, 1:]).detach().cpu().numpy()

                if display_plots:
                    plot_results(x, pred, pred_next_step=None, physics_rescaling=physics_rescaling, additional_data=additional_data)

    if model_type == "mlp":
        for i, x in enumerate(data):

                x=x.to(device)
                
                if i not in ids:
                    continue
        
                with torch.inference_mode():
        
                    pred = torch.zeros((timesteps, 3), device=device)
        
                    if window_size > 1:
                        pred[0:window_size, :] = x[0:window_size, :]
                        pred[:, 0] = x[ :, 0]
        
                    else:
                        pred[0, :] = x[0, :]
                        pred[:, 0] = x[:, 0]

                    inp = torch.cat((x[:window_size,0], x[:window_size,1], x[:window_size,2]))

                    for t in range(1,timesteps - window_size + 1 ): 

                        out = model(inp)
                        pred[window_size+(t-1):window_size+t,1:] =  pred[window_size+(t-2):window_size+(t-1):,1:] + out
                        new_p = pred[t:t+window_size,0]
                        new_s = pred[t:t+window_size,1]
                        new_v = pred[t:t+window_size,2]
                        
                        inp = torch.cat((new_p, new_s, new_v))

                    test_loss += loss_fn(pred[window_size:, 1], x[window_size:, 1]).detach().cpu().numpy()
                    test_loss_deriv += loss_fn(pred[window_size:, 2], x[window_size:, 2]).detach().cpu().numpy()
                    total_loss += loss_fn(pred[window_size:, 1:], x[window_size:, 1:]).detach().cpu().numpy()

                    
                    total_firsthalf += loss_fn(pred[window_size:int((pred.size(dim=0)-window_size)/2), 1:], 
                                            x[window_size:int((pred.size(dim=0)-window_size)/2), 1:]).detach().cpu().numpy()  
                    total_secondhalf += loss_fn(pred[int((pred.size(dim=0)-window_size)/2):, 1:],
                                                x[int((pred.size(dim=0)-window_size)/2):, 1:]).detach().cpu().numpy()  
                    #print("Error first half: ", total_firsthalf)
                    #print("Error second half: ", total_secondhalf)

                    if display_plots:
                        plot_results(x, pred, pred_next_step=None, physics_rescaling=physics_rescaling, additional_data=additional_data)

    if model_type == "lstm":
        for i, x in enumerate(data):
            x=x.to(device)
            if i not in ids:
                continue

            with torch.inference_mode():

                pred = torch.zeros((timesteps, 3), device=device)
                pred_next_step = torch.zeros((timesteps, 3), device=device)

                if window_size > 1:
                    pred[0:window_size, :] = x[0:window_size, :]
                    pred[:, 0] = x[:, 0]
                    pred_next_step[0:window_size, :] = x[0:window_size, :]
                    pred_next_step[:, 0] = x[:, 0]
                else:
                    pred[0, :] = x[0, :]
                    pred[:, 0] = x[:, 0]
                    pred_next_step[0, :] = x[0, :]
                    pred_next_step[:, 0] = x[:, 0]

                for i in range(len(x) - window_size):

                    out, _ = model(pred[i:i+window_size, :])
                    pred[i+window_size, 1:] = pred[i+window_size-1, 1:] + out[-1, :]
                    pred_next_step[i+window_size, 1:] = x[i+window_size-1, 1:] + out[-1, :]
                
                test_loss += loss_fn(pred[:, 1], x[:, 1]).detach().cpu().numpy()
                test_loss_deriv += loss_fn(pred[:, 2], x[:, 2]).detach().cpu().numpy()

                total_loss += loss_fn(pred[:, 1:], x[:, 1:]).detach().cpu().numpy()

                if display_plots:
                    plot_results(x, pred, pred_next_step=None, physics_rescaling=physics_rescaling , additional_data=additional_data)

    if model_type == "tcn" :
         for i, x in enumerate(data):
            
            if i not in ids:
                continue

            with torch.inference_mode():

                x=x.to(device)        
                x = x.view(1,x.size(dim=0), x.size(dim=1))

                pred = torch.zeros_like(x, device=device)  
                pred_next_step = torch.zeros_like(x, device=device)               

                pred[:, 0:window_size, :] = x[0, 0:window_size, :]
                pred[:, :, 0] = x[0, :, 0]

                for i in range(1,x.size(1) - window_size + 1):

                    pred[:, window_size+(i-1):window_size+i,1:] =  pred[:, window_size+(i-2):window_size+(i-1):,1:] + model(pred[:,i:window_size+(i-1),:].transpose(1,2))    

                test_loss += loss_fn(pred[0, :, 1], x[0, :, 1]).detach().cpu().numpy()
                test_loss_deriv += loss_fn(pred[0, :, 2], x[0, :, 2]).detach().cpu().numpy()

                total_loss += loss_fn(pred[0, :, 1:], x[0, :, 1:]).detach().cpu().numpy()

                if display_plots:
                    plot_results(x, pred, pred_next_step=None, physics_rescaling=physics_rescaling , additional_data=additional_data)

    if model_type == "or_tcn" :
         for i, x in enumerate(data):
            
            if i not in ids:
                continue

            with torch.inference_mode():
                x=x.to(device)        
                x = x.view(1,x.size(dim=0), x.size(dim=1))                
                pred = torch.zeros((timesteps, 3), device=device)
    
                if window_size > 1:
                    pred[0:window_size, :] = x[0, 0:window_size, :]
                    pred[:, 0] = x[0, :, 0]
    
                else:
                    pred[0, :] = x[0, 0, :]
                    pred[:, 0] = x[0, :, 0]
    
                x_test = x.clone()
                x_test[:,window_size:,1:] = 0
                x_test = x_test.to(device)
                #print("Data passed to the model, all 0 after the initial window to prove that the forward pass is correct and doesnt access information it shouldnt.",x_test[:,0:10,:])

                out = model(x_test.transpose(1,2))
                
                pred[window_size:,1:] = out.squeeze(0).transpose(0,1)

                test_loss += loss_fn(pred[window_size:, 1], x[0, window_size:, 1]).detach().cpu().numpy()
                test_loss_deriv += loss_fn(pred[window_size:, 2], x[0, window_size:, 2]).detach().cpu().numpy()
                total_loss += loss_fn(pred[window_size:, 1:], x[0, window_size:, 1:]).detach().cpu().numpy()

                if display_plots:
                    plot_results(x, pred, pred_next_step=None, physics_rescaling=physics_rescaling , additional_data=additional_data)

    if model_type == "neural_cde" :
         for i, x in enumerate(data):
            
            if i not in ids:
                continue

            with torch.inference_mode():

                x=x.to(device)        
                x = x.view(1,x.size(dim=0), x.size(dim=1))

                pred = torch.zeros_like(x, device=device)
        
                pred[:, 0:window_size, :] = x[0:1, 0:window_size, :]
                pred[:, :, 0:2] = x[0:1, :, 0:2] # time, pressure

                #start_total=time.time()

                for i in range(x.size(1) - window_size):
                    
                    #start_coeffs=time.time()
                    train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(pred[0:1, i:i+window_size, :]) 
                    #train_coeffs = torchcde.linear_interpolation_coeffs(pred[0:1, i:i+window_size, :])
   
                    #stop_coeffs=time.time()
                    #print(stop_coeffs-start_coeffs, "time: coeff calc one step")
                    if (i+1)%100==0:
                     print(i, " timessteps done")
                    #start=time.time()

                    out = model(train_coeffs)
                    pred[0:1, i+window_size, 2:] = pred[0:1, i+window_size-1, 2:] + out.unsqueeze(1)

                    #pred[0:1, i+window_size, 2:] = out
                    #stop=time.time()
                    #print(stop-start, "time: model calc step")

                test_loss += loss_fn(pred[0, window_size:, 2], x[0, window_size:, 2]).detach().cpu().numpy()
                test_loss_deriv += loss_fn(pred[0, window_size:, 3], x[0, window_size:, 3]).detach().cpu().numpy()
                total_loss += loss_fn(pred[0, window_size:, 2:], x[0, window_size:, 2:]).detach().cpu().numpy()

                total_firsthalf += loss_fn(pred[0, window_size:int((pred.size(dim=1)-window_size)/2), 2:], 
                                           x[0, window_size:int((pred.size(dim=1)-window_size)/2), 2:]).detach().cpu().numpy()  
                total_secondhalf += loss_fn(pred[0, int((pred.size(dim=1)-window_size)/2):, 2:],
                                             x[0, int((pred.size(dim=1)-window_size)/2):, 2:]).detach().cpu().numpy()  
                #print("Error first half: ", total_firsthalf)
                #print("Error second half: ", total_secondhalf)

                #stop_total=time.time()
               # print(stop_total-start_total, "time: model calc step")

                if display_plots:
                    plot_results(x[:,:,1:], pred[:,:,1:], pred_next_step=None, physics_rescaling=physics_rescaling , additional_data=additional_data)

    if model_type == "or_mlp" :
         for i, x in enumerate(data):
            
            if i not in ids:
                continue

            with torch.inference_mode():
                x=x.to(device)        
                x = x.view(1,x.size(dim=0), x.size(dim=1))                
                pred = torch.zeros((timesteps, 3), device=device)
    
                if window_size > 1:
                    pred[0:window_size, :] = x[0, 0:window_size, :]
                    pred[:, 0] = x[0, :, 0]
    
                else:
                    pred[0, :] = x[0, 0, :]
                    pred[:, 0] = x[0, :, 0]
    
                x_test = x.clone()
                x_test[:,window_size:,1:] = 0
                x_test = x_test.to(device)
                #print("Data passed to the model, all 0 after the initial window to prove that the forward pass is correct and doesnt access information it shouldnt.",x_test[:,0:10,:])

                out = model(x_test)
                
                pred[window_size:,1:] = out

                test_loss += loss_fn(pred[window_size:, 1], x[0, window_size:, 1]).detach().cpu().numpy()
                test_loss_deriv += loss_fn(pred[window_size:, 2], x[0, window_size:, 2]).detach().cpu().numpy()
                total_loss += loss_fn(pred[window_size:, 1:], x[0, window_size:, 1:]).detach().cpu().numpy()

                total_firsthalf += loss_fn(pred[window_size:int((pred.size(dim=0)-window_size)/2), 1:], 
                                            x[0, window_size:int((pred.size(dim=0)-window_size)/2), 1:]).detach().cpu().numpy()  
                total_secondhalf += loss_fn(pred[int((pred.size(dim=0)-window_size)/2):, 1:],
                                                x[0, int((pred.size(dim=0)-window_size)/2):, 1:]).detach().cpu().numpy()  

                if display_plots:
                    plot_results(x, pred, pred_next_step=None, physics_rescaling=physics_rescaling , additional_data=additional_data)


    print("Error first half: ", np.mean(total_firsthalf))
    print("Error second half: ", np.mean(total_secondhalf))
    print("total loss full traj: ", np.mean(total_loss))

    return np.mean(test_loss), np.mean(test_loss_deriv), np.mean(total_loss)

In [3]:
params_lstm =   {
                        "window_size" : 16,
                        "h_size" : 8,
                        "l_num" : 3,
                        "learning_rate" : 0.0008,
                        "batch_size" : 20,
                }

params_mlp =    {
                        "window_size" : 5,
                        "h_size" : 8,
                        "l_num" : 1,
                        "learning_rate" : 0.001,
                        "batch_size" : 20,
                        "act_fn" : "relu",
                        "nonlin_at_out" : None #None if no nonlinearity at the end
                }

params_tcn =    {
                    "window_size" : 30,
                    "learning_rate" : 0.001,
                    "batch_size" : 20,
                    "n_hidden" : 5,
                    "levels" : 4,
                    "kernel_size" : 7,
                    "dropout" : 0,
                    "input_channels" : 3,
                    "output" : 2,
                    "drop_half_timesteps" : True,
                    "cut_off_timesteps" : 100,
                    "part_of_data" : 0,
                    "percentage_of_data" : 0.8

                }
    
parameter_configs  = [params_lstm, params_mlp, params_tcn] 

path = "data\save_data_test_5xlonger_dyndyn.csv"
#path = "data\save_data_test_revised.csv"
input_data, PSW_max = get_data(path = path, 
                        timesteps_from_data=0, 
                        skip_steps_start = 0,
                        skip_steps_end = 0, 
                        drop_half_timesteps = params_tcn["drop_half_timesteps"],
                        normalise_s_w="minmax",
                        rescale_p=False,
                        num_inits=params_tcn["part_of_data"])

print(input_data.size())

np.random.seed(1234)
print("input_data size", input_data.size())
num_of_inits_train = int(len(input_data)*params_tcn["percentage_of_data"])
train_inits = np.random.choice(np.arange(len(input_data)),num_of_inits_train,replace=False)
test_inits = np.array([x for x in range(len(input_data)) if x not in train_inits])
np.random.shuffle(train_inits)
np.random.shuffle(test_inits)
test_data = input_data[test_inits,:,:]
np.random.seed()

#load models
# Initialize the LSTM model
model_lstm = LSTMmodel(input_size=3, hidden_size=params_lstm["h_size"], out_size=2, layers=params_lstm["l_num"], window_size=params_lstm["window_size"]).to(device)
path_lstm = "Ventil_trained_NNs\OR_LSTM0.pth"#
path_lstm = "First_experiment_run_22_07_2024\LSTM_or_nextstep_exp0.pth"
#path_lstm = "working_networks\OR_lstm_16_8_3_best_V2.pth"
model_lstm.load_state_dict(torch.load(path_lstm, map_location=torch.device(device)))
# Initialize the MLP model
model_or_mlp = OR_MLP(input_size=3*params_mlp["window_size"], hidden_size = params_mlp["h_size"], l_num=params_mlp["l_num"],
                output_size=2, act_fn = params_mlp["act_fn"], act_at_end = params_mlp["nonlin_at_out"], timesteps=params_mlp["window_size"]).to(device)
path_or_mlp = "Ventil_trained_NNs\OR_MLP1.pth"
#path_mlp = "working_networks\MLP_5_8_1.pth"
model_or_mlp.load_state_dict(torch.load(path_or_mlp, map_location=torch.device(device)))

model_mlp = MLP(input_size=3*params_mlp["window_size"], hidden_size = params_mlp["h_size"], l_num=params_mlp["l_num"], output_size=2, act_fn = "relu", act_at_end = None).to(device)
path_mlp = "working_networks\MLP_5_8_1.pth"
#path_mlp = "working_networks\MLP_5_8_1.pth"
model_mlp.load_state_dict(torch.load(path_mlp, map_location=torch.device(device)))

# Initialize the TCN model
input_channels = params_tcn["input_channels"]
output = params_tcn["output"]
num_channels = [params_tcn["n_hidden"]] * params_tcn["levels"]
kernel_size = params_tcn["kernel_size"]
dropout = params_tcn["dropout"]
model_tcn = OR_TCN(input_channels, output, num_channels, kernel_size=kernel_size, dropout=dropout, windowsize=params_tcn["window_size"]).to(device)
#path_tcn = "Ventil_trained_NNs\OR_TCN2.pth"
path_tcn = "working_networks\or_tcn_547.pth"
model_tcn.load_state_dict(torch.load(path_tcn, map_location=torch.device(device)))

# TODO Neural CDE  
#model = NeuralCDE(input_channels=4, hidden_channels=params["h_size"], hidden_width = params["h_width"], output_channels=2).to(device)
#model.load_state_dict(torch.load(path, map_location=torch.device(device)))

models = {"or_lstm" : model_lstm,
          #"or_mlp" : model_or_mlp,
          "mlp" : model_mlp,
          "or_tcn" : model_tcn
          }

window_sizes = {"or_lstm" : params_lstm["window_size"],
                "or_mlp" : params_mlp["window_size"],
                "mlp" : params_mlp["window_size"],
                "or_tcn" : params_tcn["window_size"]
                }
    


exp1(models, input_data[:,:,:], window_sizes)


torch.Size([50, 2750, 3])
input_data size torch.Size([50, 2750, 3])


FileNotFoundError: [Errno 2] No such file or directory: 'Ventil_trained_NNs\\OR_MLP1.pth'

In [None]:
%matplotlib qt

In [None]:
# This file contains all the experiments to compare the trained models 
import torch
import torchcde
import numpy as np
import matplotlib.pyplot as plt

def exp1(models: dict, data, window_sizes, plot_errs=False, set_random=False):
    
    if set_random:
         np.seed(1234)
    
    if plot_errs==False:
        index = np.random.randint(0, data.size(dim=0),1)[0]
        data = data[index:index+1,:,:]

    print(data.size())
    
    test_loss = 0
    test_loss_deriv = 0
    total_loss = 0
    total_firsthalf = 0
    total_secondhalf = 0

    device = "cpu" if data.get_device() == -1 else "cuda:0"
    timesteps = data.size(dim=1)
    loss_fn = torch.nn.MSELoss()

    predictionary = {}
    error_dict = {}
    with torch.inference_mode():
        for model_type, model in models.items():

            window_size = window_sizes[model_type]
            model.eval()
                
            if model_type == "or_lstm":
                    for i, x in enumerate(data):

                        x=x.to(device)        
                        x = x.view(1,x.size(dim=0), x.size(dim=1))
                
                        with torch.inference_mode():
                
                            pred = torch.zeros((timesteps, 3), device=device)
                
                            if window_size > 1:
                                pred[0:window_size, :] = x[0, 0:window_size, :]
                                pred[:, 0] = x[0, :, 0]
                
                            else:
                                pred[0, :] = x[0, 0, :]
                                pred[:, 0] = x[0, :, 0]
                

                            out, _ = model(x)
                            pred[window_size:,1:] = out

                            print(x.size(), pred.size())
                            test_loss += loss_fn(pred[window_size:, 1], x[0, window_size:, 1]).detach().cpu().numpy()
                            test_loss_deriv += loss_fn(pred[window_size:, 2], x[0, window_size:, 2]).detach().cpu().numpy()
                            total_loss += loss_fn(pred[window_size:, 1:], x[0, window_size:, 1:]).detach().cpu().numpy()

                            total_firsthalf += loss_fn(pred[window_size:int((timesteps-window_size)/2), 1:], 
                                    x[0, window_size:int((timesteps-window_size)/2), 1:]).detach().cpu().numpy()  
                            total_secondhalf += loss_fn(pred[int((timesteps-window_size)/2):, 1:],
                                        x[0, int((timesteps-window_size)/2):, 1:]).detach().cpu().numpy()
                            
                    predictionary[model_type] = pred
                    error_dict[model_type] = total_loss
            
            if model_type == "mlp":
                for i, x in enumerate(data):

                        x=x.to(device)
                        
                        with torch.inference_mode():
                
                            pred = torch.zeros((timesteps, 3), device=device)
                
                            if window_size > 1:
                                pred[0:window_size, :] = x[0:window_size, :]
                                pred[:, 0] = x[ :, 0]
                
                            else:
                                pred[0, :] = x[0, :]
                                pred[:, 0] = x[:, 0]

                            inp = torch.cat((x[:window_size,0], x[:window_size,1], x[:window_size,2]))

                            for t in range(1,timesteps - window_size + 1 ): 

                                out = model(inp)
                                pred[window_size+(t-1):window_size+t,1:] =  pred[window_size+(t-2):window_size+(t-1):,1:] + out
                                new_p = pred[t:t+window_size,0]
                                new_s = pred[t:t+window_size,1]
                                new_v = pred[t:t+window_size,2]
                                
                                inp = torch.cat((new_p, new_s, new_v))

                            test_loss += loss_fn(pred[window_size:, 1], x[window_size:, 1]).detach().cpu().numpy()
                            test_loss_deriv += loss_fn(pred[window_size:, 2], x[window_size:, 2]).detach().cpu().numpy()
                            total_loss += loss_fn(pred[window_size:, 1:], x[window_size:, 1:]).detach().cpu().numpy()

                            
                            total_firsthalf += loss_fn(pred[window_size:int((pred.size(dim=0)-window_size)/2), 1:], 
                                                    x[window_size:int((pred.size(dim=0)-window_size)/2), 1:]).detach().cpu().numpy()  
                            total_secondhalf += loss_fn(pred[int((pred.size(dim=0)-window_size)/2):, 1:],
                                                        x[int((pred.size(dim=0)-window_size)/2):, 1:]).detach().cpu().numpy()  
                            
                predictionary[model_type] = pred
                error_dict[model_type] = total_loss

            if model_type == "lstm":
                for i, x in enumerate(data):
                    x=x.to(device)
                
                    with torch.inference_mode():

                        pred = torch.zeros((timesteps, 3), device=device)
                        pred_next_step = torch.zeros((timesteps, 3), device=device)

                        if window_size > 1:
                            pred[0:window_size, :] = x[0:window_size, :]
                            pred[:, 0] = x[:, 0]
                            pred_next_step[0:window_size, :] = x[0:window_size, :]
                            pred_next_step[:, 0] = x[:, 0]
                        else:
                            pred[0, :] = x[0, :]
                            pred[:, 0] = x[:, 0]
                            pred_next_step[0, :] = x[0, :]
                            pred_next_step[:, 0] = x[:, 0]

                        for i in range(len(x) - window_size):

                            out, _ = model(pred[i:i+window_size, :])
                            pred[i+window_size, 1:] = pred[i+window_size-1, 1:] + out[-1, :]
                            pred_next_step[i+window_size, 1:] = x[i+window_size-1, 1:] + out[-1, :]
                        
                        test_loss += loss_fn(pred[:, 1], x[:, 1]).detach().cpu().numpy()
                        test_loss_deriv += loss_fn(pred[:, 2], x[:, 2]).detach().cpu().numpy()

                        total_loss += loss_fn(pred[:, 1:], x[:, 1:]).detach().cpu().numpy()
                    
                predictionary[model_type] = pred
                error_dict[model_type] = total_loss

            if model_type == "tcn" :
                    for i, x in enumerate(data):

                        with torch.inference_mode():

                            x=x.to(device)        
                            x = x.view(1,x.size(dim=0), x.size(dim=1))

                            pred = torch.zeros_like(x, device=device)  
                            pred_next_step = torch.zeros_like(x, device=device)               

                            pred[:, 0:window_size, :] = x[0, 0:window_size, :]
                            pred[:, :, 0] = x[0, :, 0]

                            for i in range(1,x.size(1) - window_size + 1):

                                pred[:, window_size+(i-1):window_size+i,1:] =  pred[:, window_size+(i-2):window_size+(i-1):,1:] + model(pred[:,i:window_size+(i-1),:].transpose(1,2))    

                            test_loss += loss_fn(pred[0, :, 1], x[0, :, 1]).detach().cpu().numpy()
                            test_loss_deriv += loss_fn(pred[0, :, 2], x[0, :, 2]).detach().cpu().numpy()

                            total_loss += loss_fn(pred[0, :, 1:], x[0, :, 1:]).detach().cpu().numpy()

                        predictionary[model_type] = pred
                        error_dict[model_type] = total_loss

            if model_type == "or_tcn" :
                    for i, x in enumerate(data):
                    

                        with torch.inference_mode():
                            x=x.to(device)        
                            x = x.view(1,x.size(dim=0), x.size(dim=1))                
                            pred = torch.zeros((timesteps, 3), device=device)

                            if window_size > 1:
                                pred[0:window_size, :] = x[0, 0:window_size, :]
                                pred[:, 0] = x[0, :, 0]

                            else:
                                pred[0, :] = x[0, 0, :]
                                pred[:, 0] = x[0, :, 0]

                            x_test = x.clone()
                            x_test[:,window_size:,1:] = 0
                            x_test = x_test.to(device)
                            #print("Data passed to the model, all 0 after the initial window to prove that the forward pass is correct and doesnt access information it shouldnt.",x_test[:,0:10,:])

                            out = model(x_test.transpose(1,2))
                            
                            pred[window_size:,1:] = out.squeeze(0).transpose(0,1)

                            test_loss += loss_fn(pred[window_size:, 1], x[0, window_size:, 1]).detach().cpu().numpy()
                            test_loss_deriv += loss_fn(pred[window_size:, 2], x[0, window_size:, 2]).detach().cpu().numpy()
                            total_loss += loss_fn(pred[window_size:, 1:], x[0, window_size:, 1:]).detach().cpu().numpy()

                        predictionary[model_type] = pred
                        error_dict[model_type] = total_loss

            if model_type == "neural_cde" :
                    for i, x in enumerate(data):

                        with torch.inference_mode():

                            x=x.to(device)        
                            x = x.view(1,x.size(dim=0), x.size(dim=1))

                            pred = torch.zeros_like(x, device=device)
                    
                            pred[:, 0:window_size, :] = x[0:1, 0:window_size, :]
                            pred[:, :, 0:2] = x[0:1, :, 0:2] # time, pressure

                            #start_total=time.time()

                            for i in range(x.size(1) - window_size):
                                
                                #start_coeffs=time.time()
                                train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(pred[0:1, i:i+window_size, :]) 
                                #train_coeffs = torchcde.linear_interpolation_coeffs(pred[0:1, i:i+window_size, :])

                                #stop_coeffs=time.time()
                                #print(stop_coeffs-start_coeffs, "time: coeff calc one step")
                                if (i+1)%100==0:
                                    print(i, " timessteps done")
                                #start=time.time()

                                out = model(train_coeffs)
                                pred[0:1, i+window_size, 2:] = pred[0:1, i+window_size-1, 2:] + out.unsqueeze(1)

                                #pred[0:1, i+window_size, 2:] = out
                                #stop=time.time()
                                #print(stop-start, "time: model calc step")

                            test_loss += loss_fn(pred[0, window_size:, 2], x[0, window_size:, 2]).detach().cpu().numpy()
                            test_loss_deriv += loss_fn(pred[0, window_size:, 3], x[0, window_size:, 3]).detach().cpu().numpy()
                            total_loss += loss_fn(pred[0, window_size:, 2:], x[0, window_size:, 2:]).detach().cpu().numpy()

                            total_firsthalf += loss_fn(pred[0, window_size:int((pred.size(dim=1)-window_size)/2), 2:], 
                                                        x[0, window_size:int((pred.size(dim=1)-window_size)/2), 2:]).detach().cpu().numpy()  
                            total_secondhalf += loss_fn(pred[0, int((pred.size(dim=1)-window_size)/2):, 2:],
                                                            x[0, int((pred.size(dim=1)-window_size)/2):, 2:]).detach().cpu().numpy()  
                    
                        predictionary[model_type] = pred
                        error_dict[model_type] = total_loss

            if model_type == "or_mlp" :
                    
                    for i, x in enumerate(data):
                    

                        with torch.inference_mode():
                            x=x.to(device)        
                            x = x.view(1,x.size(dim=0), x.size(dim=1))                
                            pred = torch.zeros((timesteps, 3), device=device)

                            if window_size > 1:
                                pred[0:window_size, :] = x[0, 0:window_size, :]
                                pred[:, 0] = x[0, :, 0]

                            else:
                                pred[0, :] = x[0, 0, :]
                                pred[:, 0] = x[0, :, 0]

                            x_test = x.clone()
                            x_test[:,window_size:,1:] = 0
                            x_test = x_test.to(device)

                            out = model(x_test)
                            
                            pred[window_size:,1:] = out

                            test_loss += loss_fn(pred[window_size:, 1], x[0, window_size:, 1]).detach().cpu().numpy()
                            test_loss_deriv += loss_fn(pred[window_size:, 2], x[0, window_size:, 2]).detach().cpu().numpy()
                            total_loss += loss_fn(pred[window_size:, 1:], x[0, window_size:, 1:]).detach().cpu().numpy()

                            total_firsthalf += loss_fn(pred[window_size:int((pred.size(dim=0)-window_size)/2), 1:], 
                                                        x[0, window_size:int((pred.size(dim=0)-window_size)/2), 1:]).detach().cpu().numpy()  
                            total_secondhalf += loss_fn(pred[int((pred.size(dim=0)-window_size)/2):, 1:],
                                                        x[0, int((pred.size(dim=0)-window_size)/2):, 1:]).detach().cpu().numpy()
                            
                        predictionary[model_type] = pred
                        error_dict[model_type] = total_loss  
        
        if not plot_errs:
            #phase_plot_predictions(predictionary, data)
            plot_predictions(predictionary, data)
        else:
            plot_errors(error_dict)

def plot_predictions(predictionary, x):



    p_max = 3.5*1e5 #Druck in [bar]         ... [1 , 3.5]
    s_max = 0.6*1e-3 #Position [m]          ... [0, 0.0006]
    w_max = 1.7 #Geschwindigkeit in [m/s]   ... [-1.7, 1.7]
    p_min = 1.0
    s_min = 0.0
    w_min = -1.7
    physics_rescaling = [p_max, s_max, w_max, p_min, s_min, w_min]

    colors = {"or_lstm" : "red",
              "or_mlp" : "green",
              "mlp" : "green",
              "or_tcn" : "purple",
              "neural_cde"  : "brown"}

    figure , axs = plt.subplots(3,1, figsize=(16,9))
    figure.tight_layout(pad=5.0)    

    if x.dim() == 3:
        x = x.view(x.size(dim=1), x.size(dim=2))
    
        #scale back:    
    if physics_rescaling != None:

        x[:,0] = x[:,0]*(physics_rescaling[0] - physics_rescaling[3]) + physics_rescaling[3]
        x[:,1] = x[:,1]*(physics_rescaling[1] - physics_rescaling[4]) + physics_rescaling[4]
        x[:,2] = x[:,2]*(physics_rescaling[2] - physics_rescaling[5]) + physics_rescaling[5]

    greek_letterz=[chr(code) for code in range(945,970)]
    mu = greek_letterz[11]

    stepsize = 2e-5
    time = np.linspace(0,x.size(dim=0)* stepsize, x.size(dim=0))

    #data
    axs[0].plot(time, x.detach().cpu().numpy()[:, 1], color="blue", label="true", linestyle="dashed")
    axs[1].plot(time, x.detach().cpu().numpy()[:, 2], color="blue", label="true", linestyle="dashed")
    axs[2].plot(time, x.detach().cpu().numpy()[:,0], label="pressure")

    #predictions
    for key, pred in predictionary.items():

        if physics_rescaling != None:

            # we invert:
            # x = (x - xmin)/(xmax - xmin)
            # x * (xmax - xmin) + xmin

            pred[:,0] = pred[:,0]*(physics_rescaling[0] - physics_rescaling[3]) + physics_rescaling[3]
            pred[:,1] = pred[:,1]*(physics_rescaling[1] - physics_rescaling[4]) + physics_rescaling[4]
            pred[:,2] = pred[:,2]*(physics_rescaling[2] - physics_rescaling[5]) + physics_rescaling[5]

        if pred.dim() == 3:
            pred = pred.view(pred.size(dim=1), pred.size(dim=2))

        axs[0].plot(time, pred.detach().cpu().numpy()[:, 1], color=colors[key], label=f"{key}-prediciton", alpha=0.5)
        axs[1].plot(time, pred.detach().cpu().numpy()[:, 2], color=colors[key], label=f"{key}-prediciton", alpha=0.5)


    axs[0].set_title("position")
    axs[0].set_ylabel("[m]")
    axs[0].set_xlabel(f"time [s]")
    axs[0].grid()
    axs[0].legend()   
    axs[1].set_title("speed")
    axs[1].set_ylabel("[m/s]")
    axs[1].set_xlabel(f"time [s]")
    axs[1].grid()
    axs[1].legend()
    axs[2].set_title("pressure")
    axs[2].set_ylabel("[Pa]")
    axs[2].set_xlabel(f"time [s]")
    axs[2].grid()
    axs[2].legend()

    plt.grid(True)
    plt.legend()
    plt.show()



def phase_plot_predictions(predictionary, x):



    p_max = 3.5*1e5 #Druck in [bar]         ... [1 , 3.5]
    s_max = 0.6*1e-3 #Position [m]          ... [0, 0.0006]
    w_max = 1.7 #Geschwindigkeit in [m/s]   ... [-1.7, 1.7]
    p_min = 1.0
    s_min = 0.0
    w_min = -1.7
    physics_rescaling = [p_max, s_max, w_max, p_min, s_min, w_min]

    colors = {"or_lstm" : "red",
              "or_mlp" : "green",
              "or_tcn" : "purple",
              "neural_cde"  : "brown"}

    figure , axs = plt.subplots(1,1, figsize=(16,16))
    figure.tight_layout(pad=5.0)    

    if x.dim() == 3:
        x = x.view(x.size(dim=1), x.size(dim=2))
    
        #scale back:    
    if physics_rescaling != None:

        x[:,0] = x[:,0]*(physics_rescaling[0] - physics_rescaling[3]) + physics_rescaling[3]
        x[:,1] = x[:,1]*(physics_rescaling[1] - physics_rescaling[4]) + physics_rescaling[4]
        x[:,2] = x[:,2]*(physics_rescaling[2] - physics_rescaling[5]) + physics_rescaling[5]

    greek_letterz=[chr(code) for code in range(945,970)]
    mu = greek_letterz[11]

    stepsize = 2e-5
    time = np.linspace(0,x.size(dim=0)* stepsize, x.size(dim=0))

    #data
    axs.plot(x.detach().cpu().numpy()[:, 1],x.detach().cpu().numpy()[:, 2], color="blue", label="true", linestyle="dashed")

    #predictions
    for key, pred in predictionary.items():

        if physics_rescaling != None:

            # we invert:
            # x = (x - xmin)/(xmax - xmin)
            # x * (xmax - xmin) + xmin

            pred[:,0] = pred[:,0]*(physics_rescaling[0] - physics_rescaling[3]) + physics_rescaling[3]
            pred[:,1] = pred[:,1]*(physics_rescaling[1] - physics_rescaling[4]) + physics_rescaling[4]
            pred[:,2] = pred[:,2]*(physics_rescaling[2] - physics_rescaling[5]) + physics_rescaling[5]

        if pred.dim() == 3:
            pred = pred.view(pred.size(dim=1), pred.size(dim=2))

        axs.plot(pred.detach().cpu().numpy()[:, 1], pred.detach().cpu().numpy()[:, 2], color=colors[key], label=f"{key}-prediciton", alpha=0.5)

    axs.set_title("position")
    axs.set_ylabel("[m]")
    axs.set_xlabel(f"time [s]")
    axs.grid()
    axs.legend()   


    plt.grid(True)
    plt.legend()
    plt.show()


def plot_errors():
     return