In [1]:
!pip install -U pytorch-lightning
!pip install comet-ml

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.7.7-py3-none-any.whl (708 kB)
[K     |████████████████████████████████| 708 kB 22.1 MB/s eta 0:00:01
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.10.0-py3-none-any.whl (529 kB)
[K     |████████████████████████████████| 529 kB 57.0 MB/s eta 0:00:01
Collecting pyDeprecate>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Collecting tensorboard>=2.9.1
  Downloading tensorboard-2.10.1-py3-none-any.whl (5.9 MB)
[K     |████████████████████████████████| 5.9 MB 62.2 MB/s eta 0:00:01
Collecting aiohttp
  Downloading aiohttp-3.8.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[K     |████████████████████████████████| 1.0 MB 90.4 MB/s eta 0:00:01
Collecting frozenlist>=1.1.1
  Downloading frozenlist-1.3.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (161 kB)


In [2]:
import os
os.chdir("/notebooks")
from pytorch_lightning import LightningModule
import torch
import torch.nn as nn
import math
from base_lightning import Dataset
from torch.autograd import grad
import numpy as np
#from neptune.new.types import File
import matplotlib.pyplot as plt
from datetime import datetime
from matplotlib.animation import FuncAnimation, PillowWriter
from sklearn.linear_model import LassoCV,Lasso, LinearRegression

import matplotlib.animation as animation
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor,TQDMProgressBar,EarlyStopping
from pytorch_lightning.loggers import CometLogger
from PIL import Image
import io


def gpu_prints():
    print("The total number of GPUs is:",torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        print("GPU number",i,"is",torch.cuda.get_device_name(i))
        
gpu_prints()
now = datetime.now()
dt_string = now.strftime("%d/%m/%Y %H:%M:%S")


The total number of GPUs is: 1
GPU number 0 is Quadro P5000


In [3]:
class SinusoidalActivation(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.pi = torch.tensor([math.pi],dtype=torch.float32,device="cuda")
    
    def forward(self,input): 
        
        sinusoid = torch.sin(2*self.pi*input)
        return sinusoid

class NN(nn.Module):
    
    def __init__(self,init=False):
        
        super().__init__()
        self.init = init
        self.network = self.build_network(sigma=1,hidden=50)
        
    def forward(self, input_: torch.Tensor):
        
        input_ = input_.requires_grad_(True)
        return self.network(input_),input_

    def build_network(self,sigma,hidden):

        network= []
        first = nn.Linear(2,hidden)
        if self.init:
            print("Initing NN weights and baises")
            nn.init.normal_(first.weight,0,sigma**2)
            nn.init.zeros_(first.bias)
        network.append(first)
        network.append(SinusoidalActivation())
        
        second = nn.Linear(hidden,hidden)
        if self.init:
            torch.nn.init.xavier_uniform_(second.weight)
            torch.nn.init.zeros_(second.bias)
        network.append(second)
        network.append(nn.Tanh())
        
        third = nn.Linear(hidden,hidden)
        if self.init:
            torch.nn.init.xavier_uniform_(third.weight)
            torch.nn.init.zeros_(third.bias)
        network.append(third)
        network.append(nn.Tanh())
        
        fourth = nn.Linear(hidden,hidden)
        if self.init:
            torch.nn.init.xavier_uniform_(fourth.weight)
            torch.nn.init.zeros_(fourth.bias)
        network.append(fourth)
        network.append(nn.Tanh())
        
        fith = nn.Linear(hidden,hidden)
        if self.init:
            torch.nn.init.xavier_uniform_(fith.weight)
            torch.nn.init.zeros_(fith.bias)
        network.append(fith)
        network.append(nn.Tanh())
        
        sixth = nn.Linear(hidden,1)
        if self.init:
            torch.nn.init.xavier_uniform_(sixth.weight)
            torch.nn.init.zeros_(sixth.bias)
        network.append(sixth)
        network.append(nn.Tanh())
        
        network.pop()
        network = nn.Sequential(*network)
        
        return network 

In [4]:
class RegularNN(LightningModule):

    def __init__(self,filename,config,init_NN):

        super().__init__()
        
        self.data            = np.load(filename)
        a = filename.split('/')
        self.dir = a[0]
        self.filename = a[-1].split('.npz')[0]
        
        self.total_X         = torch.tensor(self.data["total_X"],dtype=torch.float32)
        self.total_X         = self.total_X[self.total_X[:,0]<=4]
        self.total_Y         = torch.tensor(self.data["total_Y"],dtype=torch.float32)
        self.total_Y         = self.total_Y[:self.total_X.shape[0]]


        self.X_train         = self.total_X[self.total_X[:,0]<=2]
        self.X_validation    = self.total_X[self.total_X[:,0]>=2]

        self.Y_train         = self.total_Y[self.total_X[:,0]<=2]
        self.Y_validation    = self.total_Y[self.total_X[:,0]>=2]

        self.c               = torch.tensor(self.data["c"],dtype=torch.float32)
        self.v               = torch.tensor(self.data["v"],dtype=torch.float32)
        self.total_x         = torch.tensor(self.data["X"],dtype=torch.float32)
        self.T               = torch.tensor(self.data["t"],dtype=torch.float32)
        self.T               = self.T[self.T<=4]

        self.U               = np.array(self.data["wave"])
        self.U               = self.U[:self.T.shape[0]]
        self.coefs           = self.data["coefs"]


        self.coiso              = None
        self.xi                 = None
        self.analytical_du2_dx2 = None
        self.lr                 = config["lr"]
        self.k_pde              = config["k_pde"]
        self.network            = custom_NN(n_in=2,n_hidden=6*[60],n_out=1,init=init_NN)
        
        self.first_coeff  = -(self.c**2-self.v**2)
        self.second_coeff = +2*self.v
        
        print("1st coef:",self.first_coeff)
        print("2nd coef:",self.second_coeff)
    
    def fig2img(self,fig):
        """Convert a Matplotlib figure to a PIL Image and return it"""
        buf = io.BytesIO()
        fig.savefig(buf)
        buf.seek(0)
        img = Image.open(buf)
        return img


    def forward(self,input_):

        return self.network(input_)

    def on_train_start(self):

        print("Device is:",self.device)
        fig,ax1 = plt.subplots(figsize=(20,7))
        extent = np.min(self.total_x.detach().cpu().numpy()), np.max(self.total_x.detach().cpu().numpy()), np.min(self.T.detach().cpu().numpy()), np.max(self.T.detach().cpu().numpy())
        im1 = plt.imshow(self.U,origin="lower",cmap=plt.cm.magma,extent=extent,aspect="auto",interpolation=None)
        plt.colorbar()
        plt.xlabel("X")
        plt.ylabel("Time")
        plt.title(f"Real wave")
        image = self.fig2img(fig)
        self.logger.experiment.log_image(image,name="Real Wave")
        plt.close()

    def training_step(self, batch,batch_idx):

        #forward pass
        x,target = batch
        prediction,coordinates = self.forward(x)
        second_time_deriv,theta,term1,term2 = self.compute_derivatives(prediction,coordinates) # derivatives

        #self.xi = self.least_squares_QR(theta,second_time_deriv) #sparse vector computed using least squares
        self.xi = self.least_squares_SK(theta,second_time_deriv)
        
        #losses
        mse_loss                            = torch.mean((prediction-target)**2) # scalar
        pde_loss                            = torch.mean((second_time_deriv+term1+term2)**2)#should be scalar as well
        total_loss                          = mse_loss + self.k_pde*pde_loss

        #for i,j in enumerate(xi):
        #   self.log(f"Coefficient nr{i}",j,logger=True,on_epoch=True,on_step=False)
        
        self.log(f"Coefficient nr1",-self.xi[0],logger=True,on_epoch=True,on_step=False)
        self.log(f"Coefficient nr2",-self.xi[1],logger=True,on_epoch=True,on_step=False)
        self.log(f"Error in coefficient nr1",torch.abs((-self.xi[0]-self.first_coeff)/self.first_coeff)*100,logger=True,on_epoch=True,on_step=False)
        self.log(f"Error in coefficient nr2",torch.abs((-self.xi[1]-self.second_coeff)/self.second_coeff)*100,logger=True,on_epoch=True,on_step=False)
    
        self.log("MSE Loss",mse_loss,on_step=False,on_epoch=True,logger=True,prog_bar=True)
        self.log("PDE Loss",pde_loss,on_step=False,on_epoch=True,logger=True,prog_bar=True)
        self.log("Total Loss",total_loss,on_step=False,on_epoch=True,logger=True,prog_bar=True)

        return total_loss

    def validation_step(self,batch,batch_idx):

        x,target = batch
        val_prediction,val_coordinates = self.forward(x)
        val_loss = torch.mean((val_prediction-target)**2)
        self.log("Validation Loss",val_loss,on_step=False,on_epoch=True,logger=True,prog_bar=True)

    def compute_derivatives(self,prediction,coords,derivs=False):

        du = grad(outputs=prediction, inputs=coords, grad_outputs=torch.ones_like(prediction), create_graph=True)[0]
        first_time_deriv = du[:,0:1]
        du_dx            = du[:,1:2]
        
        du2 = grad(outputs=first_time_deriv,inputs=coords,grad_outputs=torch.ones_like(prediction),create_graph=True)[0]
        second_time_deriv = du2[:,0:1]
        du2_dtdx          = du2[:,1:2]
        
        
        du2_dx2 = grad(outputs=du_dx,inputs=coords,grad_outputs=torch.ones_like(prediction),create_graph=True)[0][:,1:2]

        #term1 = -(self.c**2-self.v**2)*du2_dx2
        #term2 = +2*self.v*du2_dtdx

        term1 = du2_dx2
        term2 = du2_dtdx
        

        theta = torch.reshape(torch.cat((term1,term2),dim=1),(prediction.shape[0],-1))

        if not derivs:

            return second_time_deriv,theta,term1,term2

        else:

            return second_time_deriv,du2_dtdx,du2_dx2

    def least_squares_QR(self,theta,second_deriv):

        Q,R = torch.linalg.qr(theta)
        xi  = torch.inverse(R) @ Q.T @ second_deriv
        return xi
    
    def least_squares_SK(self,theta,second_time_deriv):
        
        x,y = theta.detach().cpu().numpy(), second_time_deriv.detach().cpu().detach()
        coefs = LinearRegression().fit(x,y).coef_
        return coefs[0]
        
    
    def configure_optimizers(self):

        optimizer = torch.optim.AdamW(self.parameters(),lr=self.lr,amsgrad=True,weight_decay=1e-8)
        #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(mode="min",factor=0.5,patience=500,threshold_mode="rel",threshold=1e-5)
        return optimizer

    def train_dataloader(self):

        dataset    = Dataset(data=self.X_train, labels=self.Y_train)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=int(self.X_train.shape[0]/6),drop_last=True, num_workers=0, shuffle=True)

        return dataloader

    def val_dataloader(self):

        val_dataset = Dataset(data=self.X_validation, labels=self.Y_validation)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=int(self.X_validation.shape[0]/6),
                                                     drop_last=True, num_workers=0, shuffle=False)
        
        return val_dataloader

    def analytical_derivs(self,x,time):

        if self.coiso is None:

            self.omega = np.zeros(self.coefs.shape[0])
            self.k_rev = np.zeros(self.coefs.shape[0])
            self.k_fwd = np.zeros(self.coefs.shape[0])
            self.phi   = np.zeros(self.coefs.shape[0])
            self.coiso = 1

            for i in range(self.coefs.shape[0]):

                n = i+1
                self.omega[i] = n*np.pi*(self.c**2-self.v**2)/self.c
                self.k_rev[i] = n*np.pi*(self.c+self.v)/self.c
                self.k_fwd[i] = n*np.pi*(self.c-self.v)/self.c
                self.phi[i]   = -n*(np.pi*(self.c+self.v)/(2*self.c)-np.pi/2)



        du2_dx2  = np.zeros((len(time),len(x)))
        du2_dxdt = np.zeros((len(time),len(x)))
        du2_dt2  = np.zeros((len(time),len(x)))

        for b,t in enumerate(time):

            sum1  = 0
            sum2  = 0
            sum3  = 0

            for n in range(self.coefs.shape[0]):

                C1     = self.coefs[n] * np.sin(self.k_fwd[n]*x - self.omega[n]*t - self.phi[n])
                C2     = self.coefs[n] * np.sin(self.k_rev[n]*x + self.omega[n]*t + self.phi[n])

                sum1 += -self.k_fwd[n]**2*C1 - self.k_rev[n]**2*C2
                sum2 += +self.k_fwd[n]*self.omega[n]*C1  - self.k_rev[n]*self.omega[n]*C2
                sum3 += -self.omega[n]**2*C1 - self.omega[n]**2*C2

            du2_dx2[b,:]  = sum1
            du2_dxdt[b,:] = sum2
            du2_dt2[b,:]  = sum3

        #this analytical derivatives should be [T,X]
        return du2_dx2,du2_dxdt,du2_dt2

    def plot(self):


        total_output = self.forward(self.total_X.to(self.device))[0].detach().cpu().numpy().reshape(self.T.shape[0],self.total_x.shape[0])
        fig,ax1 = plt.subplots(figsize=(20,7))
        extent = np.min(self.total_x.detach().cpu().numpy()), np.max(self.total_x.detach().cpu().numpy()), np.min(self.T.detach().cpu().numpy()), np.max(self.T.detach().cpu().numpy())
        im1 = plt.imshow(total_output,origin="lower",cmap=plt.cm.magma,extent=extent,aspect="auto",interpolation=None)
        plt.colorbar()
        plt.xlabel("X")
        plt.ylabel("Time")
        plt.title(f"Full predicted wave at epoch {self.current_epoch}")
        image = self.fig2img(fig)
        self.logger.experiment.log_image(image,name=f"Full predicted wave")
        plt.close()


        fig,ax1 = plt.subplots(figsize=(20,7))
        extent = np.min(self.total_x.detach().cpu().numpy()), np.max(self.total_x.detach().cpu().numpy()), np.min(self.T.detach().cpu().numpy()), np.max(self.T.detach().cpu().numpy())
        im1 = plt.imshow(np.abs(total_output-self.U),origin="lower",cmap=plt.cm.magma,extent=extent,aspect="auto",interpolation=None)
        plt.colorbar()
        plt.xlabel("X")
        plt.ylabel("Time")
        plt.title(f"Difference between true wave and predicted wave at epoch {self.current_epoch}")
        image = self.fig2img(fig)
        self.logger.experiment.log_image(image,name=f"Difference between true wave and predicted wave")
        plt.close()

    def anim(self):

        def init():
            for line in lines:
                line.set_data([],[])
            return lines

        fig = plt.figure(figsize=(15,5))
        ax = plt.axes(xlim=(0,1),ylim=(-3,3))
        line, = ax.plot([], [], lw=3)
        plt.xlim(0,1)
        plt.ylim(-3,3)
        plt.xlabel("X")
        plt.ylabel("Displacement")
        plt.title(f"Prediction for epoch {self.current_epoch}",fontsize=20)
        lines = []
        lobj1 = ax.plot([],[],lw=3,label="Predicted")[0]
        lobj2 = ax.plot([],[],lw=3,label="Real")[0]
        lines.append(lobj1)
        lines.append(lobj2)
        plt.legend()

        anim = FuncAnimation(fig,
                    self.update_plot,
                    init_func=init,
                    frames=int(len(self.T)),
                    fargs=(lines),
                     blit=True,
                    interval=100,
                    repeat=True)

        plt.legend()
        anim.save(f"{self.dir}/predictions/{self.filename}_epoch{self.current_epoch}_{now.minute}{now.second}" + ".gif",writer=PillowWriter(fps=24))
        #image = self.fig2img(anim)
        #self.logger.experiment.log_image(image,name=f"Prediction fig at {self.current_epoch}")
        plt.close()
        self.logger.experiment.log_image(f"{self.dir}/predictions/{self.filename}_epoch{self.current_epoch}_{now.minute}{now.second}" + ".gif",name=f"Prediction gif at {self.current_epoch}")

    def on_validation_epoch_end(self):
        
        self.plot()
        self.anim()
      
        """with torch.enable_grad():
            predictions,coordinates = self.forward(self.X_train.to(self.device))

            #analytical derivatives--> stay constant only need to compute 1 time
            if self.analytical_du2_dx2 == None:
                self.analy_du2_dx2,self.analy_du2_dxdt,self.analy_du2_dt2 = self.analytical_derivs(self.total_x,self.T[self.T<=2])

            #auto diff derivatives
            NN_du2_dt2,NN_du2_dxdt,NN_du2_dx2          = self.compute_derivatives(predictions,coordinates,derivs=True)

        NN_du2_dt2 = NN_du2_dt2.detach().cpu().numpy().reshape(-1,self.total_x.shape[0])
        NN_du2_dxdt= NN_du2_dxdt.detach().cpu().numpy().reshape(-1,self.total_x.shape[0])
        NN_du2_dx2 = NN_du2_dx2.detach().cpu().numpy().reshape(-1,self.total_x.shape[0])

        dif1 = 100 * np.median(np.abs((NN_du2_dt2[self.analy_du2_dt2>1e-5]  - self.analy_du2_dt2[self.analy_du2_dt2>1e-5])/self.analy_du2_dt2[self.analy_du2_dt2>1e-5]))
        dif2 = 100 * np.median(np.abs((NN_du2_dx2[self.analy_du2_dx2>1e-5]  - self.analy_du2_dx2[self.analy_du2_dx2>1e-5])/self.analy_du2_dx2[self.analy_du2_dx2>1e-5]))
        dif3 = 100 * np.median(np.abs((NN_du2_dxdt[self.analy_du2_dxdt>1e-5] - self.analy_du2_dxdt[self.analy_du2_dxdt>1e-5])/self.analy_du2_dxdt[self.analy_du2_dxdt>1e-5]))

        abs_dif1 = np.mean(np.abs(NN_du2_dt2 - self.analy_du2_dt2))
        abs_dif2 = np.mean(np.abs(NN_du2_dx2 - self.analy_du2_dx2))
        abs_dif3 = np.mean(np.abs(NN_du2_dxdt- self.analy_du2_dxdt))

        self.log("Absolute Diff1",abs_dif1,on_step=False,on_epoch=True,prog_bar=False,logger=True)
        self.log("Relative Diff1",dif1,on_step=False,on_epoch=True,prog_bar=False,logger=True)
        self.log("NN_du2_dt2",np.mean(np.abs(NN_du2_dt2)),on_step=False,on_epoch=True,logger=True,prog_bar=False)
        self.log("Analy_du2_dt2",np.mean(np.abs(self.analy_du2_dt2)),on_step=False,on_epoch=True,logger=True,prog_bar=False)

        self.log("Absolute Diff2",abs_dif2,on_step=False,on_epoch=True,prog_bar=False,logger=True)
        self.log("Relative Diff2",dif2,on_step=False,on_epoch=True,prog_bar=False,logger=True)
        self.log("NN_du2_dx2",np.mean(np.abs(NN_du2_dx2)),on_step=False,on_epoch=True,logger=True,prog_bar=False)
        self.log("Analy_du2_dx2",np.mean(np.abs(self.analy_du2_dx2)),on_step=False,on_epoch=True,logger=True,prog_bar=False)

        self.log("Absolute Diff3",abs_dif3,on_step=False,on_epoch=True,prog_bar=False,logger=True)
        self.log("Relative Diff3",dif3,on_step=False,on_epoch=True,prog_bar=False,logger=True)
        self.log("NN_du2_dxdt",np.mean(np.abs(NN_du2_dxdt)),on_step=False,on_epoch=True,logger=True,prog_bar=False)
        self.log("Analy_du2_dxdt",np.mean(np.abs(self.analy_du2_dxdt)),on_step=False,on_epoch=True,logger=True,prog_bar=False)"""

 

    def update_plot(self,i,line1,line2):

        a = self.forward(self.total_X.to(self.device))[0].detach().cpu().numpy()
        b   = a.reshape(self.T.shape[0],self.total_x.shape[0])
        line1.set_data(self.total_x,b[i,:])
        line2.set_data(self.total_x,self.U[i,:])
        lines = (line1,line2)

In [24]:
class PINN(LightningModule):

    def __init__(self,filename,config):

        super().__init__()
        
        self.data            = np.load(filename)
        a = filename.split('/')
        self.dir = a[0]
        self.filename = a[-1].split('.npz')[0]
        
        
        
        
        
        self.total_X         = torch.tensor(self.data["total_X"],dtype=torch.float32)
        self.total_X         = self.total_X[self.total_X[:,0]<=4]
        self.total_Y         = torch.tensor(self.data["total_Y"],dtype=torch.float32)
        self.total_Y         = self.total_Y[:self.total_X.shape[0]]
        
        self.X1              = self.total_X[self.total_X[:,0]<=2]
        self.Y1              = self.total_Y[self.total_X[:,0]<=2]
        
        self.X_validation    = self.total_X[self.total_X[:,0]>=2]
        self.Y_validation    = self.total_Y[self.X_validation.shape[0]:]
        
        self.X_fixed_points  = self.X_validation[(self.X_validation[:,1]== 0)  |
                                                 #(self.X_validation[:,1]== 0.1)|
                                                 #(self.X_validation[:,1]== 0.2)|
                                                 #(self.X_validation[:,1]== 0.3)|
                                                 #(self.X_validation[:,1]== 0.4)|
                                                 #(self.X_validation[:,1]== 0.5)|
                                                 #(self.X_validation[:,1]== 0.6)|
                                                 #(self.X_validation[:,1]== 0.7)|
                                                 #(self.X_validation[:,1]== 0.8)|
                                                 #(self.X_validation[:,1]== 0.9)|
                                                 (self.X_validation[:,1]== 1)
                                                ]
        
        self.Y_fixed_points  = self.Y_validation[(self.X_validation[:,1]== 0) |
                                                 #(self.X_validation[:,1]== 0.1)|
                                                 #(self.X_validation[:,1]== 0.2)|
                                                 #(self.X_validation[:,1]== 0.3)|
                                                 #(self.X_validation[:,1]== 0.4)|
                                                 #(self.X_validation[:,1]== 0.5)|
                                                 #(self.X_validation[:,1]== 0.6)|
                                                 #(self.X_validation[:,1]== 0.7)|
                                                 #(self.X_validation[:,1]== 0.8)|
                                                 #(self.X_validation[:,1]== 0.9)|
                                                 (self.X_validation[:,1]== 1)
                                                 ]
        
        self.X_MSE           = self.X_fixed_points
        self.Y_MSE           = self.Y_fixed_points
        
        self.c               = torch.tensor(self.data["c"],dtype=torch.float32)
        self.v               = torch.tensor(self.data["v"],dtype=torch.float32)
        self.total_x         = torch.tensor(self.data["X"],dtype=torch.float32)
        self.T               = torch.tensor(self.data["t"],dtype=torch.float32)
        self.T               = self.T[self.T<=4]
        
        self.U               = np.array(self.data["wave"])
        self.U               = self.U[:self.T.shape[0]]
        self.coefs           = self.data["coefs"]
        
        self.first_coeff  = -(self.c**2-self.v**2)
        self.second_coeff = +2*self.v
        

        
        self.coiso              = None
        self.xi                 = None
        self.analytical_du2_dx2 = None
        self.lr                 = config["lr"]
        self.k_pde              = config["k_pde"]
        self.k_mse              = config["k_mse"]
        self.network            = NN(init=False)
        

        
    def fig2img(self,fig):
        """Convert a Matplotlib figure to a PIL Image and return it"""
        buf = io.BytesIO()
        fig.savefig(buf)
        buf.seek(0)
        img = Image.open(buf)
        return img
        
    def forward(self,input_):
                
        return self.network(input_)
    
    def on_train_start(self):
        
        fig,ax1 = plt.subplots(figsize=(20,7))
        extent = np.min(self.total_x.detach().cpu().numpy()), np.max(self.total_x.detach().cpu().numpy()), np.min(self.T.detach().cpu().numpy()), np.max(self.T.detach().cpu().numpy())
        im1 = plt.imshow(self.U,origin="lower",cmap=plt.cm.magma,extent=extent,aspect="auto",interpolation=None)
        plt.colorbar()
        plt.xlabel("X")
        plt.ylabel("Time")
        plt.title(f"Real wave")
        image=self.fig2img(fig)
        self.logger.experiment.log_image(image,name="Real Wave")
        plt.close()
        self.anim()

    def training_step(self, batch,batch_idx):
        
        
        #forward pass
        batch1 = batch["mse"]
        batch2 = batch["pde"]

        x1,target1 = batch1
        x2,target2 = batch2

        prediction1,coordinates1 = self.forward(x1)
        prediction2,coordinates2 = self.forward(x2)

        second_time_deriv,theta,term1,term2 = self.compute_derivatives(prediction2,coordinates2)

        self.xi = self.least_squares_QR(theta,second_time_deriv)

        #losses
        mse_loss                            = torch.mean((prediction1-target1)**2) # scalar
        pde_loss                            = torch.mean((second_time_deriv+term1+term2)**2)#should be scalar as well
        total_loss                          = self.k_mse*mse_loss  + self.k_pde*pde_loss


        self.log(f"Coefficient nr1",self.xi[0],logger=True,on_epoch=True,on_step=False)
        self.log(f"Coefficient nr2",self.xi[1],logger=True,on_epoch=True,on_step=False)
        self.log(f"Error in coefficient nr1",torch.abs((-self.xi[0]-self.first_coeff)/self.first_coeff)*100,logger=True,on_epoch=True,on_step=False)
        self.log(f"Error in coefficient nr2",torch.abs((-self.xi[1]-self.second_coeff)/self.second_coeff)*100,logger=True,on_epoch=True,on_step=False)
    
        self.log("MSE Loss",mse_loss,on_step=False,on_epoch=True,logger=True,prog_bar=True)
        self.log("PDE Loss",pde_loss,on_step=False,on_epoch=True,logger=True,prog_bar=True)
        self.log("Total Loss",total_loss,on_step=False,on_epoch=True,logger=True,prog_bar=True)      
        
        return total_loss

    def validation_step(self,batch,batch_idx):

        x,target = batch
        val_prediction,val_coordinates = self.forward(x)
        val_loss = torch.mean((val_prediction-target)**2)
        self.log("Validation Loss",val_loss,on_step=False,on_epoch=True,logger=True,prog_bar=True)


    def compute_derivatives(self,prediction,coords,derivs=False):
        
        du = grad(outputs=prediction, inputs=coords, grad_outputs=torch.ones_like(prediction), create_graph=True)[0]
        first_time_deriv = du[:,0:1]
        du_dx            = du[:,1:2]

        du2 = grad(outputs=first_time_deriv,inputs=coords,grad_outputs=torch.ones_like(prediction),create_graph=True)[0]
        second_time_deriv = du2[:,0:1]
        du2_dtdx          = du2[:,1:2]

        du2_dx2 = grad(outputs=du_dx,inputs=coords,grad_outputs=torch.ones_like(prediction),create_graph=True)[0][:,1:2]

        term1 = -(self.c**2-self.v**2)*du2_dx2
        term2 = +2*self.v*du2_dtdx

        theta = torch.reshape(torch.cat((term1,term2),dim=1),(prediction.shape[0],-1))

        if not derivs:

            return second_time_deriv,theta,term1,term2

        else:

            return second_time_deriv,du2_dtdx,du2_dx2


        
    def least_squares_QR(self,theta,second_deriv):

        Q,R = torch.linalg.qr(theta)
        xi  = torch.inverse(R) @ Q.T @ second_deriv
        return xi
    
    def configure_optimizers(self):

        optimizer = torch.optim.AdamW(self.parameters(),lr=self.lr,amsgrad=True,weight_decay=1e-8)

        return optimizer

    def train_dataloader(self):

        dataset1    = Dataset(data=self.X_fixed_points, labels=self.Y_fixed_points)
        dataset2   = Dataset(data=self.X_validation,labels = self.Y_validation)

        dataloader1 = torch.utils.data.DataLoader(dataset1, batch_size=int(self.X_fixed_points.shape[0]/4),drop_last=True, shuffle=True,num_workers=0)
        dataloader2 = torch.utils.data.DataLoader(dataset2, batch_size=int(self.X_validation.shape[0]/4),drop_last=True,shuffle=True,num_workers=0)



        return {"mse":dataloader1,"pde":dataloader2}

    def val_dataloader(self):

        val_dataset = Dataset(data=self.X_validation, labels=self.Y_validation)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=int(self.X_validation.shape[0]/4),
                                                     drop_last=True, num_workers=0, shuffle=False)

        return val_dataloader

    def plot(self):

        
        total_output = self.forward(self.total_X.to(self.device))[0].detach().cpu().numpy().reshape(self.T.shape[0],self.total_x.shape[0])
        fig,ax1 = plt.subplots(figsize=(20,7))
        extent = np.min(self.total_x.detach().cpu().numpy()), np.max(self.total_x.detach().cpu().numpy()), np.min(self.T.detach().cpu().numpy()), np.max(self.T.detach().cpu().numpy())
        im1 = plt.imshow(total_output,origin="lower",cmap=plt.cm.magma,extent=extent,aspect="auto",interpolation=None)
        plt.colorbar()
        plt.xlabel("X")
        plt.ylabel("Time")
        image = self.fig2img(fig)
        self.logger.experiment.log_image(image,name=f"Full predicted wave")
        plt.close()
        
        
        fig,ax1 = plt.subplots(figsize=(20,7))
        extent = np.min(self.total_x.detach().cpu().numpy()), np.max(self.total_x.detach().cpu().numpy()), np.min(self.T.detach().cpu().numpy()), np.max(self.T.detach().cpu().numpy())
        im1 = plt.imshow(np.abs(total_output-self.U),origin="lower",cmap=plt.cm.magma,extent=extent,aspect="auto",interpolation=None)
        plt.colorbar()
        plt.xlabel("X")
        plt.ylabel("Time")
        plt.title(f"Difference between true wave and predicted wave at epoch {self.current_epoch}")
        image = self.fig2img(fig)
        self.logger.experiment.log_image(image,name=f"Difference between true wave and predicted wave")
        plt.close()
    
    def anim(self):
        
        def init():
            for line in lines:
                line.set_data([],[])
            return lines

        fig = plt.figure(figsize=(15,5))
        ax = plt.axes(xlim=(0,1),ylim=(-3,3))
        line, = ax.plot([], [], lw=3)
        plt.xlim(0,1)
        plt.ylim(-3,3)
        plt.xlabel("X")
        plt.ylabel("Displacement")
        plt.title(f"Prediction for epoch {self.current_epoch}",fontsize=20)
        lines = []
        lobj1 = ax.plot([],[],lw=3,label="Predicted")[0]
        lobj2 = ax.plot([],[],lw=3,label="Real")[0]
        lines.append(lobj1)
        lines.append(lobj2)
        plt.legend()
        
        
        anim = FuncAnimation(fig,
                    self.update_plot,
                    init_func=init,
                    frames=int(len(self.T)),
                    fargs=(lines),
                     blit=True,
                    interval=100,
                    repeat=True)
        
        plt.legend()
        writergif = animation.PillowWriter(fps=30)

        anim.save(f"predictions/{self.filename}_epoch{self.current_epoch}_{now.minute}{now.second}" + ".gif",writer="ffmpeg")
        plt.close()
        self.logger.experiment.log_image(f"predictions/{self.filename}_epoch{self.current_epoch}_{now.minute}{now.second}" + ".gif",name=f"Prediction gif at {self.current_epoch}")
    
    def update_plot(self,i,line1,line2):
        
        a = self.forward(self.total_X.to(self.device))[0].detach().cpu().numpy()
        b   = a.reshape(self.T.shape[0],self.total_x.shape[0])
        line1.set_data(self.total_x,b[i,:])
        line2.set_data(self.total_x,self.U[i,:])
        lines = (line1,line2)
        
        return lines    
    
    def on_validation_epoch_end(self):
        
        self.plot()
        self.anim()
    
        

In [25]:
from pytorch_lightning.callbacks import ModelCheckpoint


def training(filename,PATH):
    
    print(filename)
    save_name = filename.split("/")[-1]
    save_name = save_name.split(".npz")[0]

    config = {
        "lr": 1e-3,
        "filename":filename,
        "k_pde":0.01,
        "k_mse":5,
        "PATH":PATH
    } 
    
    model = PINN(filename=filename,config=config)
    model.load_from_checkpoint(filename=filename,config=config,checkpoint_path=config["PATH"])
    model.to("cuda:0")
    print(model)
    

    comet_logger = CometLogger(
        api_key="kZhkiprqabfgQqOTbHNHpOJvf",
        workspace="jose-bastos",
        project_name="resolution2"  # Optional
    )

    comet_logger.log_hyperparams(config)
    lr_monitor = LearningRateMonitor(logging_interval="epoch")

    kwargs = {"max_epochs": 25*10**3+10,
              "accelerator": "gpu",
              "devices":1,
              "num_sanity_val_steps": 0,
              "logger": comet_logger,
              "check_val_every_n_epoch":1000,
              "enable_checkpointing":False,
              "enable_progress_bar":False
             }

    trainer = Trainer(**kwargs,resume_from_checkpoint=config["PATH"])
    trainer.fit(model)
    torch.save(model,f"resolution/models/PINNs/{filename}.ckpt")
    
    comet_logger.experiment.end()

In [None]:
training(filename="resolution/data/spatial/c=1_v=0.5c_noise=0.0_modes=0.08_0.49_0.85_0.35_0.15_dt=0.010_dx=0.010.npz",PATH="resolution/models/NN/spatial/c=1_v=0.5c_noise=0.0_modes=0.08_0.49_0.85_0.35_0.15_dt=0.010_dx=0.100_15kepochs.ckpt")

CometLogger will be initialized in online mode
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.com/jose-bastos/resolution2/03003880ea0042eba4c5a83a43cf7f35
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     Coefficient nr1 [2998]          : (-1.0002986192703247, -0.4868721663951874)
COMET INFO:     Coefficient nr2 [2998]          : (-1.0003142356872559, -0.04561415687203407)
COMET INFO:     Error in coefficient nr1 [2998] : (164.91629028320312, 233.37315368652344)
COMET INFO:     Error in coefficient nr2 [2998] : (0.000667572021484375, 95.43858337402344)
COMET INFO:     MSE Loss [2998]                 : (0.00012791056360583752, 1.559272289276123)
COMET INFO:     PDE Loss [2998]                 : (0.07254322618246078, 4531.80029296875)
COMET INFO:     Total Loss [2998]               : (0.001

resolution/data/spatial/c=1_v=0.5c_noise=0.0_modes=0.08_0.49_0.85_0.35_0.15_dt=0.010_dx=0.010.npz
PINN(
  (network): NN(
    (network): Sequential(
      (0): Linear(in_features=2, out_features=50, bias=True)
      (1): SinusoidalActivation()
      (2): Linear(in_features=50, out_features=50, bias=True)
      (3): Tanh()
      (4): Linear(in_features=50, out_features=50, bias=True)
      (5): Tanh()
      (6): Linear(in_features=50, out_features=50, bias=True)
      (7): Tanh()
      (8): Linear(in_features=50, out_features=50, bias=True)
      (9): Tanh()
      (10): Linear(in_features=50, out_features=1, bias=True)
    )
  )
)


COMET INFO: Experiment is live on comet.com https://www.comet.com/jose-bastos/resolution2/0070d4384c1145d993736f54a9b47666

  rank_zero_deprecation(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  ckpt_path = ckpt_path or self.resume_from_checkpoint
Restoring states from the checkpoint path at resolution/models/NN/spatial/c=1_v=0.5c_noise=0.0_modes=0.08_0.49_0.85_0.35_0.15_dt=0.010_dx=0.100_15kepochs.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type | Params
---------------------------------
0 | network | NN   | 10.4 K
---------------------------------
10.4 K    Trainable params
0         Non-trainable params
10.4 K    Total params
0.042     Total estimated model params size (MB)
Restored all states from the checkpoint file at resolution/models/NN/spatial/c=1_v=0.5c_noise=0.0_modes=0.08_0.49_0.85_0.35_0.15_dt=0.010_dx=0.100_15kepochs.ckpt
  rank_zero_warn(
  rank_z