In [None]:
import sys
import numpy as np
import torch
from torch import Tensor, ones, stack, load
from torch.autograd import grad
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.pyplot import figure
import pandas as pd
from torch.nn import Module
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from scipy import stats
from pathlib import Path
import wandb
import time
from tesladatano import TeslaDatasetNo

from tqdm import trange
from tqdm.autonotebook import tqdm

# Import PINNFramework etc.
# https://github.com/ComputationalRadiationPhysics/NeuralSolvers.git
# sys.path.append("...")# PINNFramework etc.
# sys.path.append("/home/hoffmnic/Code/NeuralSolvers")
sys.path.append("NeuralSolvers")  
import PINNFramework as pf

## Preprocessing


In [None]:
# Use cuda if it is available, else use the cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# login wandb (optional)
wandb.login()

sweep_config = {
    'method': 'grid'
    }

metric = {
    'name': 'Validation Loss',
    'goal': 'minimize'  
    }

sweep_config['metric'] = metric


parameters_dict = {
    
        'alpha': {
        'values': [1e-3,1e-2,0.1,1,10,100,1000,10**4,10**5]
        },   
        'normalize': {
        'values': [1000]
        },
        'batch_size': {
        'values': [4096]
        },
        'lr': {
        'values': [1e-4]
        },
        'input_size': {
        'values': [6]
        },
        'output_size': {
        'values': [1]
        },
        'hidden_size': {
        'values': [100]
        },
        'num_hidden': {
        'values': [4]
        },
        'epochs': {
        'values': [5]
        },
    }

sweep_config['parameters'] = parameters_dict

print('hey')


sweep_id = wandb.sweep(sweep_config, project="Neural_Operator_project")

In [None]:
def derivative(x, u):

    grads = ones(u.shape, device=u.device) # move to the same device as prediction
    grad_u = grad(u, x, create_graph=True, grad_outputs=grads )[0]

    
    return grad_u

## Function for saving a checkpoint during training

In [None]:
def write_checkpoint(checkpoint_path, epoch, min_mlp_loss, optimizer):
    checkpoint = {}
    checkpoint["epoch"] = epoch
    checkpoint["minimum_pinn_loss"] = min_mlp_loss
    checkpoint["optimizer"] = optimizer.state_dict()
    checkpoint["mlp_model"] = model.state_dict()
    torch.save(checkpoint, checkpoint_path)

## Function for evaluating the performance on the test data


In [None]:
def evaluate(model,idd,rel_time,diff,normalize):
    print('ID = ', idd)
    # import test data
    ds_test = TeslaDatasetNo(device = device, ID = idd, data = "test",rel_time = rel_time, diff = diff)

    # Prediction accuracy of the Neural Operator
    print('Prediction accuracy of the Neural Operator (NO)')
    begin = time.time()
    pred_der = model(ds_test.x.to(device))
    pred_der = pred_der.detach().cpu().numpy()/normalize
    end = time.time()
    true_der = ds_test.y.cpu().numpy()
    print("time:", end - begin)

    # relative time
    t=ds_test.t

    #MAE
    mae_der = np.sum(np.abs(pred_der- true_der).mean(axis=None))
    print('MAE:', mae_der)

    #MSE
    mse_der = ((true_der - pred_der)**2).mean(axis=None)
    print('MSE:', mse_der)

    #Relative error
    rel_error_der = np.linalg.norm(pred_der - true_der) / np.linalg.norm(true_der)*100
    print('Relative error (%):', rel_error_der)

#     # Plot
#     plt.figure(figsize = (12, 8))
#     plt.plot(t, pred_der, '-', label='Prediction')
#     plt.plot(t, true_der, '--', label='Ground-truth')
#     plt.title('Prediction accuracy of Neural Operator vs ground-truth for drive-ID = {}'.format(idd))
#     plt.xlabel('t (seconds)')
#     plt.ylabel('ΔTemp/Δt (°C/s)')
#     plt.grid()
#     plt.legend(loc='lower right')
#     plt.show()

    print('########################################################')

    #3)Forward Euler method with fixed initial env. conditions but with updated 
    #Temperature (and rel time) from the prediction of the model at previous iteration
    #with generated temporally equidistant time steps

    print('Forwad Euler method with fixed initial env conditions')
    rel_t = ds_test.rel_t

    # ground-truth time
    t=ds_test.t
    max_t = t.max()
    t=t.cpu().numpy()

    # Ground-truth temperature
    true_temp = ds_test.x[:,4].cpu().numpy()

    # Predicted temperature using model prediction and forward euler method
    pred_temp = np.zeros((ds_test.x.shape[0]))
    pred_temp = true_temp.copy()

    # Fixed initial conditions for all environmental conditions
    input = ds_test.x[0].detach().clone()

    # temporally equdistant time steps
    tt = np.linspace(0,max_t,ds_test.x.shape[0])
    step_size=tt[2]-tt[1]

    #ODE
    begin = time.time()

    for i in range(0, ds_test.x.shape[0] - 1):
        input[4] = torch.tensor(pred_temp[i]).detach().clone()
        if rel_time == True:
            input[5] = torch.tensor(rel_t[i]).detach().clone()
        pred = model(input.to(device))
        pred = pred.detach().cpu().numpy()/normalize
        pred_temp[i + 1] = pred_temp[i] + pred*step_size
    end = time.time()

    print("time:", end - begin)

    #MAE
    mae = np.sum(np.abs(pred_temp- true_temp).mean(axis=None))
    print('MAE:', mae)

    #MSE
    mse = ((true_temp - pred_temp)**2).mean(axis=None)
    print('MSE:', mse)

    #Relative error
    rel_error = np.linalg.norm(pred_temp - true_temp) / np.linalg.norm(true_temp)*100
    print('Relative error (%):', rel_error)

#     #Plot
#     plt.figure(figsize = (12, 8))
#     plt.plot(tt, pred_temp, '-', label='Prediction')
#     plt.plot(t, true_temp, '--', label='Ground-truth')
#     plt.title('Prediction vs ground-truth for drive-ID = {} (temporally equidistant step size)'.format(idd))
#     plt.xlabel('t (seconds)')
#     plt.ylabel('Temperature (°C)')
#     plt.grid()
#     plt.legend(loc='lower right')
#     plt.show()

    #4)Forward Euler method with updated environmental conditions from the dataset at each iteration
    #But with updated temperature from the prediction of the model at previous iteration
    #with true step sizes
    print('Forwad Euler method with updated env conditions from the dataset at each iteration with true step sizes')

    # time
    t=ds_test.t
    t=t.numpy()

    # max time
    max_t = t.max()

    # Ground-truth temperature
    true_temp = ds_test.x[:,4].cpu().numpy()

    # Predicted temperature using model prediction and forward euler method
    pred_temp = np.zeros((ds_test.x.shape[0]))
    pred_temp[0] = true_temp[0].copy()

    begin = time.time()
    for i in range(0, ds_test.x.shape[0] - 1):
        input = ds_test.x[i].detach().clone()
        input[4] = torch.tensor(pred_temp[i]).detach().clone()
        pred = model(input.to(device))
        pred = pred.detach().cpu().numpy()/normalize
        pred_temp[i + 1] = pred_temp[i] + pred*(t[i+1]-t[i])
    end = time.time()


    print("time:", end - begin)
    #MAE 
    mae_upd = np.sum(np.abs(pred_temp- true_temp).mean(axis=None))
    print('MAE:', mae_upd)

    #MSE
    mse_upd = ((true_temp - pred_temp)**2).mean(axis=None)
    print('MSE:', mse_upd)

    # Relative error
    rel_error_upd = np.linalg.norm(pred_temp - true_temp) / np.linalg.norm(true_temp)*100
    print('Relative error (%):', rel_error_upd)

#     #Plot
#     plt.figure(figsize = (12, 8))
#     plt.plot(t, pred_temp, '-', label='Prediction')
#     plt.plot(t, true_temp, '--', label='Ground-truth')
#     #plt.title('Approximate (with updated env conditions) and True Solution (true step size)'.format(idd))
#     plt.title('Prediction (with updated env. conditions) vs ground-truth for drive-ID = {} (true step size)'.format(idd))
#     plt.xlabel('t (seconds)')
#     plt.ylabel('Temperature (°C)')
#     plt.grid()
#     plt.legend(loc='lower right')
#     plt.show()

#     # variables
#     #rel_t = ds_test.rel_t
#     power = ds_test.x[:,0]
#     speed = ds_test.x[:,1]
#     bat_level = ds_test.x[:,2]
#     out_temp = ds_test.x[:,3]
#     temp = ds_test.x[:,4]
#     rel_t = ds_test.x[:,5]

#     # Log the loss and accuracy values at the end of each epoch
#     wandb.log({
#         "power test {}".format(idd): power,
#         "speed test {}".format(idd): speed,
#         "battery level test {}".format(idd): bat_level,
#         "outside temperature test {}".format(idd): out_temp,
#         "speed test {}".format(idd): temp,
#         "rel time test {}".format(idd): rel_t,
#          })

    # #5 time
    # plt.figure(figsize = (12, 8))
    # plt.plot(tt, '-', label='equidistant step size')
    # plt.plot(t, '--', label='true stepsize')
    # plt.title('Plot of equidistant test time vs true time')
    # plt.legend(loc='lower right')

    # #6
    # plt.figure(figsize = (12, 8))
    # plt.plot(np.diff(tt.reshape(-1)),'-', label='equidistant step size')
    # plt.plot(np.diff(t.reshape(-1)),  '--', label='true step-size')
    # plt.title('Plot of equidistant test time step size vs true time step size')
    # plt.legend(loc='upper left')

    mae_arr = np.array([mae_der, mae, mae_upd])
    mse_arr = np.array([mse_der, mse, mse_upd])
    rel_arr = np.array([rel_error_der, rel_error, rel_error_upd])
    
    return mae_arr, mse_arr, rel_arr


## Training


In [None]:
def train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config): 
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config
        
        
        # Create instance of the dataset
        ds = TeslaDatasetNo(diff = "fwd_diff", device = device, data ='train', normalize = config.normalize, rel_time = True)
        ds_test = TeslaDatasetNo(device = device, ID = -1, data = "test",normalize = config.normalize, rel_time = True)
         
        # bounds
        config.lb = ds.lb
        config.ub = ds.ub
        
        # trainloader
        train_loader = DataLoader(ds, batch_size=config.batch_size,shuffle=True)
        validloader = DataLoader(ds_test, batch_size=config.batch_size,shuffle=True)
        
        #model
        model = pf.models.MLP(input_size=config.input_size,
                      output_size=config.output_size, 
                      hidden_size=config.hidden_size, 
                      num_hidden=config.num_hidden, 
                      lb=config.lb, 
                      ub=config.ub,
                      #activation = torch.relu
                      )
        model.to(device)
        
        #Log the network weight histograms (optional)
        wandb.watch(model)
        
        # optimizer
        optimizer = torch.optim.Adam(model.parameters(),lr=wandb.config.lr)
        criterion = torch.nn.MSELoss()
        
        
        min_mlp_loss = np.inf
        min_valid_loss = np.inf

        x_data_plot=[]
        y_data_all_plot=[]
        y_data_1_plot=[]
        y_data_2_plot=[]

        # Set fixed random number seed
        torch.manual_seed(1234)
        

        begin = time.time()
        for epoch in range(config.epochs):
        # with trange(epochs, unit="epochs") as pbar:
        #     for epoch in pbar:
                # Set current and total loss value
                current_loss = 0.0
                total_loss = 0.0
                total_loss1 = 0.0
                total_loss2 = 0.0

                model.train()   # Optional when not using Model Specific layer
                for i, data in enumerate(train_loader,0):
                    #print(i)
                    x_batch, y_batch = data
                    #print('data', x_batch.shape)
                    if wandb.config.batch_size == 1:
                        x_batch=torch.squeeze(x_batch)
                        y_batch=torch.squeeze(y_batch)
                      #delta_t=torch.squeeze(delta_t)

                    # Ground-truth temperature
                    true_temp = x_batch[:,4].detach().clone()

                    optimizer.zero_grad()
                    x_batch.requires_grad=True #new
                    pred = model(x_batch.to(device))

                    u_deriv = derivative(x_batch,pred) #new
                    loss1 = criterion(pred,y_batch.to(device))

                    #target2 = torch.zeros_like(u_t) #new
                    #loss2 = criterion(target2.to(device),u_t.to(device))*wandb.config.alpha #new
                    loss2 = torch.mean(u_deriv**2) * config.alpha #new new
                    loss = loss1 + loss2 #new

                    loss.backward()
                    optimizer.step()

                    # Print statistics
                    current_loss += loss.item()
                    total_loss += loss.item()
                    total_loss1 += loss1.item()
                    total_loss2 += loss2.item()

                    if i % 50 == 49:
        #                 print('Loss after mini-batch %5d: %.8f' %
        #                      (i + 1, current_loss / 50))
                        current_loss = 0.0

                train_loss = total_loss/(i+1)
                loss1 = total_loss1/(i+1) #new
                loss2 = total_loss2/(i+1) #new
                #print("Epoch ", epoch, "Total Loss ", Loss)
                #pbar.set_postfix(loss=Loss)
                x_data_plot.append(epoch)
                y_data_all_plot.append(train_loss)      
                y_data_1_plot.append(loss1)
                y_data_2_plot.append(loss2)

                    ######## new
                valid_loss = 0.0
                model.eval()     # Optional when not using Model Specific layer
                for j, data in enumerate(validloader,0):
                    x_batch, y_batch = data
                    # Forward Pass
                    target = model(x_batch.to(device))
                    # Find the Loss
                    loss = criterion(target,y_batch.to(device))
                    # Calculate Loss
                    valid_loss += loss.item()

                print(f'Epoch {epoch} \t Training Loss: {\
                train_loss:.5f} \t Loss 1: {loss1:.5f} \t Loss 2: {loss2:.5f} \t Validation Loss: {valid_loss / (j+1):.5f}')    



            #     print(f'Epoch {epoch} \t Training Loss: {\
            #     train_loss} \t Loss 1: {\loss1} \t Loss 2: {\loss2} \t Validation Loss: {\
            #     valid_loss / (j+1)}')


                # uncomment for saving the best model and checkpoints during training
                # save best model
                if min_valid_loss > valid_loss:
                    print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
                    min_valid_loss = valid_loss

                    # Saving State Dict
            #         model_name_path = Path('/content/drive/MyDrive/NeuralSolvers-heat-eqn/examples/Research project/nomodel/best_model_{}_{}.pt'.format(wandb.run.id, wandb.run.name))
            #         torch.save(model.state_dict(), model_name_path)       
                    model_name_path = Path('nomodel/best_model_{}_{}.pt'.format(wandb.run.id, wandb.run.name))
                    torch.save(model.state_dict(), model_name_path)

                # writing checkpoint
                if (epoch + 1) % 200 == 0:
            #         checkpoint_path = Path('/content/drive/MyDrive/NeuralSolvers-heat-eqn/examples/Research project/nomodel/checkpoint_{}_{}_{}.pt'.format(wandb.run.id, wandb.run.name, epoch))
            #         write_checkpoint(checkpoint_path, epoch, min_mlp_loss, optimizer)
                    checkpoint_path = Path('nomodel/checkpoint_{}_{}_{}.pt'.format(wandb.run.id, wandb.run.name, epoch))
                    write_checkpoint(checkpoint_path, epoch, min_valid_loss, optimizer)

                # Log the loss and accuracy values at the end of each epoch
                wandb.log({
                    "Epoch": epoch,
                    "Total Loss": train_loss,
                    "Loss1 (temperature)": loss1,
                    "Loss2 (regulariser)": loss2,
                    "Validation Loss": valid_loss
                    })
        end = time.time()
        print("time:", end - begin)


        # Load the best model
        PATH = 'nomodel/best_model_{}_{}.pt'.format(wandb.run.id, wandb.run.name)

        model.load_state_dict(torch.load(PATH))
        model.eval()


        # Test values = [16,39,47,52,72,81,88]
        rl = True
        diff = "fwd_diff"

        # Test value = 16
        mae16,mse16,rel_err16 = evaluate(model,idd=16,rel_time=rl,diff=diff,normalize = config.normalize)
        print(mae16,mse16,rel_err16)

        # Test value = 39
        mae39,mse39,rel_err39 = evaluate(model,idd=39,rel_time=rl,diff=diff,normalize = config.normalize)

        # Test value = 47
        mae47,mse47,rel_err47 = evaluate(model,idd=47,rel_time=rl,diff=diff,normalize = config.normalize)

        # Test value = 52
        mae52,mse52,rel_err52 = evaluate(model,idd=52,rel_time=rl,diff=diff,normalize = config.normalize)

        # Test value = 72
        mae72,mse72,rel_err72 = evaluate(model,idd=72,rel_time=rl,diff=diff,normalize = config.normalize)

        # Test value = 81
        mae81,mse81,rel_err81 = evaluate(model,idd=81,rel_time=rl,diff=diff,normalize = config.normalize)

        # Test value = 88
        mae88,mse88,rel_err88 = evaluate(model,idd=88,rel_time=rl,diff=diff,normalize = config.normalize)

        # Test values = [16,39,47,52,72,81,88]
        mae_avg = (mae16+mae39+mae47+mae52+mae72+mae81+mae88)/7
        mse_avg = (mse16+mse39+mse47+mse52+mse72+mse81+mse88)/7
        rel_err_avg = (rel_err16+rel_err39+rel_err47+rel_err52+rel_err72+rel_err81+rel_err88)/7


        # Log the loss and accuracy values at the end of each epoch
        wandb.log({
            "mae derivative": mae_avg[0],
            "mae fixed IC": mae_avg[1],
            "mae updated env": mae_avg[2],
            "mse derivative": mse_avg[0],
            "mse fixed IC": mse_avg[1],
            "mse updated env": mse_avg[2],
            "rel err derivative": rel_err_avg[0],
            "rel err fixed IC": rel_err_avg[1],
            "rel err updated env": rel_err_avg[2],
             })

In [None]:
wandb.agent(sweep_id, train)

In [None]:
# Log the loss and accuracy values at the end of each epoch
wandb.log({
    "mae derivative": mae_avg[0],
    "mae fixed IC": mae_avg[1],
    "mae updated env": mae_avg[2],
    "mse derivative": mse_avg[0],
    "mse fixed IC": mse_avg[1],
    "mse updated env": mse_avg[2],
    "rel err derivative": rel_err_avg[0],
    "rel err fixed IC": rel_err_avg[1],
    "rel err updated env": rel_err_avg[2],
     })