# Tuning of the learning rate for the Adam optimizer using a PiNN on a chaotic Lorenz system

# Combined Bachelor Thesis (NS-320B), June 2022
## Mathematics & Physics and Astronomy
*'Applying Physics-informed Neural Networks to Chaotic Systems of Ordinary Differential Equations'*

**Author:** Martijn Sebastiaan Brouwer (6859488)

**Mathematics supervisors:** prof. dr. ir. C.W. Oosterlee

**Physics supervisor:** dr. J. de Graaf

**PhD supervisor:** B. Negyesi

In [None]:
# Expected runtime: 16:34:25.883144
!mkdir plots

from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import warnings
from datetime import datetime
time = datetime.now()

warnings.simplefilter(action='ignore', category=FutureWarning)
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)

#=============================================================================
#============================  Definitions  ==================================
#=============================================================================
# Fully Connected Network ----------------------------------------------------
class FCN(nn.Module): 
    def __init__(self, N_INPUT, N_OUTPUT, N_HIDDEN, N_LAYERS):
        super().__init__()
        activation = nn.Tanh #Specify the used activation function
        self.fc1 = nn.Sequential(*[nn.Linear(N_INPUT, N_HIDDEN), activation()]) #Input to first hidden layer
        self.fc2 = nn.Sequential(*[nn.Sequential(*[nn.Linear(N_HIDDEN, N_HIDDEN), activation()]) for _ in range(N_LAYERS-1)]) #Going through the remaining hidden layers
        self.fc3 = nn.Linear(N_HIDDEN, N_OUTPUT) #Last hidden layer to output layer

    def forward(self, *args):
        if len(args) == 1: #When multiple initial conditions are specified, this will provide the correct shape. 
            x = args[0]
        elif len(np.shape(args[0])) <= 1:
            x = torch.FloatTensor([*args]).T
        else:
            x = torch.FloatTensor(torch.cat([*args], 1))

        x = self.fc1(x) #Going to the layers
        x = self.fc2(x)
        x = self.fc3(x)
        return x

# Runge Kutta fourth order method ---------------------------------------------
def RungeKutta(dxdt,dydt,dzdt, x0,y0,z0, ti,tf,n): # Specify derivatives, initial conditions and time
    h = tf/n #Stepsize
    xl,yl,zl = n*[0],n*[0],n*[0] #Create lists for output
    xl[0],yl[0],zl[0] = x0,y0,z0 #First element in outputlist is initial condition
    for i in range(1,n): #Loop over steps while skipping the first one due to the initial condition
        x,y,z = xl[i-1],yl[i-1],zl[i-1]
        #Going through the four RK4 equations:
        k1x,k1y,k1z = (h*f(x,y,z)    for f in (dxdt,dydt,dzdt))
        xs,ys,zs    = (r + 0.5*kr    for r,kr in zip((x,y,z),(k1x,k1y,k1z,h)))
        k2x,k2y,k2z = (h*f(xs,ys,zs) for f in (dxdt,dydt,dzdt))
        xs,ys,zs    = (r + 0.5*kr    for r,kr in zip((x,y,z),(k2x,k2y,k2z,h)))
        k3x,k3y,k3z = (h*f(xs,ys,zs) for f in (dxdt,dydt,dzdt))
        xs,ys,zs    = (r + kr        for r,kr in zip((x,y,z),(k3x,k3y,k3z,h)))
        k4x,k4y,k4z = (h*f(xs,ys,zs) for f in (dxdt,dydt,dzdt))
        #Update last next value in output list:
        xl[i],yl[i],zl[i] = (r + (k1r + 2*k2r + 2*k3r + k4r)/6 for r,k1r,k2r,k3r,k4r in 
                zip((x,y,z),(k1x,k1y,k1z),(k2x,k2y,k2z),(k3x,k3y,k3z),(k4x,k4y,k4z)))
    return xl,yl,zl

# Used to create a iteratively gif-animation of the convergence ---------------
def save_gif(outfile, files, fps=5, loop=0):
    imgs = [Image.open(file) for file in files]
    imgs[0].save(fp=outfile, format='GIF', append_images=imgs[1:], save_all=True, duration=int(1000/fps), loop=loop)


#=============================================================================
#============================  Parameters  ===================================
#=============================================================================
# Lorenz system parameters ---------------------------------------------------
x0,y0,z0 = 10,10,10 #Initial conditions
sigma = 10       #Parameters of system:
rho = 28
beta = 8/3
ti = 0           #Initial time
tf = 0.3         #Starting final time (without update)
n = 1000         #Steps taken between ti and tf
h = tf/n         #Stepsize of each step of n

# PiNN -----------------------------------------------------------------------
lr = 1e-3           #Learning rate
N_NODES_HID = 50    #Number of neurons in hidden layers
N_LAYERS = 4        #Number of hidden layers
alpha = 0.01        #Threshold value for updating tf with update_t
update_t = 0.3      #Update value for tf after reaching threshold
iterations = 100000 #Total amount of iterations
intermediate = 1000 #Safe intermediate results after every {intermediate} iterations
network_amount = 6
seed_amount = 5

# Lorenz differential equation definitions -----------------------------------
def dxdt_def(x,y,z): return -sigma*x + sigma*y
def dydt_def(x,y,z): return -x*z + rho*x - y
def dzdt_def(x,y,z): return x*y - beta*z


#=============================================================================
#================================  PiNN  =====================================
#=============================================================================

t_physics = torch.linspace(ti,tf,n).view(-1,1).requires_grad_(True) #Time-list used in dynamic loss2 training

files = [] #Create empty files for the intermediate results for the animation
color_list = ['b','g','r','c','m','y','k'] #Specify some colours for the plots of different models
tf_list = [[[tf] for _ in range(network_amount)] for _ in range(seed_amount)] #Empty file for updating of the final time 
loss_list = [[[] for _ in range(network_amount)] for _ in range(seed_amount)]

learning_rates = [1e-2,5e-3,1e-3,5e-4,1e-4,5e-5]

#Create multiple networks of different shapes for different seeds due to seed dependence on initialization
networks_seeds = []
optimizers_seeds = []
for s in range(seed_amount):
    seed = s*100
    torch.manual_seed(seed)
    np.random.seed(seed)

    networks = []
    for net in range(network_amount):
        model = FCN(1,3,N_NODES_HID,N_LAYERS) #Create a network with the specified hidden layers and neurons
        networks.append(model) 

    optimizers = [[torch.optim.Adam(networks[net].parameters(),lr=learning_rates[net])] for net in range(network_amount)]
    networks_seeds.append(networks)
    optimizers_seeds.append(optimizers)

for i in range(iterations):
    for s in range(seed_amount):
        for net in range(network_amount): #Loop over the iterations
            model = networks_seeds[s][net]
            model.eval()
            optimizer = optimizers_seeds[s][net][0]
            optimizer.zero_grad() #Set gradients of all optimized tensors to zero
            
            if len(tf_list[s][net]) > 0: #For updating the final time in the system
                tf = tf_list[s][net][-1]
                h = tf/n
            t_physics = torch.linspace(ti,tf,n).view(-1,1).requires_grad_(True)

            #Calculation of loss1 that depends on the initial condition at t=0:
            m = model(torch.FloatTensor([0]))
            xh,yh,zh = m[0],m[1],m[2]
            loss1 = torch.mean((xh-x0)**2 + (yh-y0)**2 + (zh-z0)**2)

            #Calculation of loss2 that depends on the dynamics of the system:
            p = model(t_physics)
            px,py,pz = p[:,0],p[:,1],p[:,2]
            px,py,pz = px.view(-1,1),py.view(-1,1),pz.view(-1,1) #Correct shape
            dxdt = torch.autograd.grad(px, t_physics, torch.ones_like(px), create_graph=True)[0] #Calculate derivatives of output of PiNN w.r.t. t_physics
            dydt = torch.autograd.grad(py, t_physics, torch.ones_like(py), create_graph=True)[0]
            dzdt = torch.autograd.grad(pz, t_physics, torch.ones_like(pz), create_graph=True)[0]
            physics_x = -sigma*px + sigma*py - dxdt #Calculate f-residuals for the x,y,z differential equations
            physics_y = -px*pz + rho*px - py - dydt
            physics_z = px*py - beta*pz - dzdt
            loss2 = torch.mean(physics_x**2 + physics_y**2 + physics_z**2) #Total loss2 is the MSE of above residuals

            loss = 10*loss1 + loss2 #Total loss with scaling factor of 10 for loss1 due to higher importance of initial condition w.r.t. the dynamic loss
            loss.backward() #parameter.grad += dloss/d(parameter), for every parameter (the weight/bias matrices)
            def closure(): return loss
            optimizer.step(closure) #parameter += -lr * parameter.grad
            loss = loss.detach() #restricts RAM usage, otherwise it crashes
            loss_list[s][net].append(loss)
            if loss2 < alpha: #When dynamic loss2 is below threshold of alpha, add update_t to current tf
                tf_list[s][net].append(tf + update_t)

        if (i+1)%intermediate == 0: #Plot and save intermediate results
            tf_largest = max([max(subsublist) for subsublist in [max(sublist) for sublist in tf_list]])
            x,y,z = torch.FloatTensor(RungeKutta(dxdt_def,dydt_def,dzdt_def, x0,y0,z0, ti,float('%.2g'%(tf_largest+0.3)),n))
            print('Iteration:', i+1, 'tf_list:', tf_list, 'time:', datetime.now() - time)
            fig = plt.figure(figsize=(13,4))

            ax1 = fig.add_subplot(1, 2, 1, projection='3d')
            ax1.scatter(x0,y0,z0, s=40, color="tab:orange", alpha=0.4)
            ax1.plot(x,y,z, color="black", linewidth=1, alpha=0.6)
            ax1.set_xlabel('$x$', fontsize="x-large")
            ax1.set_ylabel('$y$', fontsize="x-large")
            ax1.set_zlabel('$z$', fontsize="x-large")

            t = np.linspace(ti,float('%.2g'%(tf_largest+0.3)),n)
            ax2 = fig.add_subplot(1, 2, 2)
            ax2.scatter(t[0],x0, s=40, color="tab:orange", alpha=0.4)
            ax2.plot(t,x, color="black", linewidth=1, alpha=0.6)
            ax2.set_xlabel('$t$', fontsize="x-large")
            ax2.set_ylabel('$x$', fontsize="x-large", labelpad=0)

            for net in range(network_amount):
                xh_list,yh_list,zh_list = [],[],[]  
                for s in range(seed_amount):
                    model = networks_seeds[s][net]
                    tf = tf_list[s][net][-1]
                    h = tf/n
                    t_physics = torch.linspace(ti,tf_list[s][net][-1],n).view(-1,1).requires_grad_(True)
                    p = model(t_physics)
                    m1 = torch.squeeze(p.detach())
                    xh,yh,zh = m1[:,0],m1[:,1],m1[:,2]
                    xh_list.append(xh)
                    yh_list.append(yh)
                    zh_list.append(zh)
                
                # calculate average loss value over all seeds for the same setting 
                loss_values = [np.mean(loss_list[s]) for s in range(seed_amount)]
                    
                # stack the lists corresponding to each seed and calculate its corresponding average approximations
                xh,yh,zh = torch.mean(torch.stack(xh_list), dim=0),torch.mean(torch.stack(yh_list), dim=0),torch.mean(torch.stack(zh_list), dim=0)

                ax1.plot(xh,yh,zh, color=color_list[net], linewidth=1.25, alpha=0.75)
                t = t_physics.detach()
                ax2.plot(t,xh, color=color_list[net], linewidth=1.25, alpha=0.75, label="PiNN with lr = {:.1e}".format(learning_rates[net]))
            
            plt.annotate("Training step: %i"%(i+1),xy=(1.05, 0.85),xycoords='axes fraction',fontsize="x-large",color="k")
            l = plt.legend(loc=(1.05,0.20), frameon=True, fontsize="large")
            file = "plots/pinn_%.8i.png"%(i+1)
            plt.savefig(file, bbox_inches='tight', pad_inches=0.1, dpi=100, facecolor="white")
            files.append(file)
            plt.show() #Show plots

#Creating gif-animations:
save_gif("pinn.gif", files, fps=20, loop=0)