In [1]:
import numpy as np

import matplotlib.pyplot as plt
import gif
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

from pytorch_lightning import LightningModule
import torch
import torch.nn as nn
import math
from deepymod.data.base_lightning import Dataset
from torch.autograd import grad
import numpy as np
from neptune.new.types import File
import matplotlib.pyplot as plt
from datetime import datetime
from matplotlib.animation import FuncAnimation
from sklearn.linear_model import LassoCV,Lasso

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor,TQDMProgressBar,EarlyStopping
from neptune.new.integrations.pytorch_lightning import NeptuneLogger


def gpu_prints():
    print("The total number of GPUs is:",torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        print("GPU number",i,"is",torch.cuda.get_device_name(i))
        
gpu_prints()
now = datetime.now()
dt_string = now.strftime("%d/%m/%Y %H:%M:%S")

  from IPython.core.display import display, HTML


The total number of GPUs is: 2
GPU number 0 is NVIDIA GeForce GTX 1080 Ti
GPU number 1 is NVIDIA GeForce GTX 1080 Ti


In [2]:
class SinusoidalActivation(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.pi = torch.tensor([math.pi],dtype=torch.float32,device="cpu")
    
    def forward(self,input):
        
        sinusoid = torch.sin(2*self.pi*input)
        return sinusoid
    

class NN(nn.Module):
    def __init__(self, n_in: int, n_hidden, n_out: int,init: bool) -> None:
        
        super().__init__()
        self.init = init
        self.network = self.build_network(n_in, n_hidden, n_out)
        
    def forward(self, input_: torch.Tensor):
        
        input_ = input_.requires_grad_(True)
        return self.network(input_),input_

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            torch.nn.init.zeros_(m.bias)
            

    def build_network(self, n_in, n_hidden, n_out):

        network = []
        architecture = [n_in] + n_hidden + [n_out]
        count = 0
        for layer_i, layer_j in zip(architecture, architecture[1:]):
            network.append(nn.Linear(layer_i, layer_j))
            if count ==0:
                network.append(SinusoidalActivation())
            else:
                network.append(nn.Tanh())
            count+=1
        network.pop()  # get rid of last activation function
        network = nn.Sequential(*network)
        if self.init == True:
            network.apply(self.init_weights)
        return network
    
class PINN(LightningModule):

    def __init__(self,filename,config):

        super().__init__()
        self.filename        = filename
        self.data            = np.load(f"string_data/" + self.filename + ".npz")
        
        self.total_X         = torch.tensor(self.data["total_X"],dtype=torch.float32)
        self.total_X         = self.total_X[self.total_X[:,0]<=4]
        self.total_Y         = torch.tensor(self.data["total_Y"],dtype=torch.float32)
        self.total_Y         = self.total_Y[:self.total_X.shape[0]]
        
        self.X1              = self.total_X[self.total_X[:,0]<=2]
        self.Y1              = self.total_Y[self.total_X[:,0]<=2]
        
        self.X_validation    = self.total_X[self.total_X[:,0]>=2]
        self.Y_validation    = self.total_Y[self.X_validation.shape[0]:]
        
        self.X_fixed_points  = self.X_validation[(self.X_validation[:,1]== 0)  |
                                                 #(self.X_validation[:,1]== 0.1)|
                                                 #(self.X_validation[:,1]== 0.2)|
                                                 #(self.X_validation[:,1]== 0.3)|
                                                 #(self.X_validation[:,1]== 0.4)|
                                                 #(self.X_validation[:,1]== 0.5)|
                                                 #(self.X_validation[:,1]== 0.6)|
                                                 #(self.X_validation[:,1]== 0.7)|
                                                 #(self.X_validation[:,1]== 0.8)|
                                                 #(self.X_validation[:,1]== 0.9)|
                                                 (self.X_validation[:,1]== 1)
                                                ]
        
        self.Y_fixed_points  = self.Y_validation[(self.X_validation[:,1]== 0) |
                                                 #(self.X_validation[:,1]== 0.1)|
                                                 #(self.X_validation[:,1]== 0.2)|
                                                 #(self.X_validation[:,1]== 0.3)|
                                                 #(self.X_validation[:,1]== 0.4)|
                                                 #(self.X_validation[:,1]== 0.5)|
                                                 #(self.X_validation[:,1]== 0.6)|
                                                 #(self.X_validation[:,1]== 0.7)|
                                                 #(self.X_validation[:,1]== 0.8)|
                                                 #(self.X_validation[:,1]== 0.9)|
                                                 (self.X_validation[:,1]== 1)
                                                 ]
        
        self.X_MSE           = self.X_fixed_points
        self.Y_MSE           = self.Y_fixed_points
        #self.X_MSE           = torch.cat((self.X1,self.X_fixed_points),dim=0)
        #self.Y_MSE           = torch.cat((self.Y1,self.Y_fixed_points),dim=0)
        
        self.c               = torch.tensor(self.data["c"],dtype=torch.float32)
        self.v               = torch.tensor(self.data["v"],dtype=torch.float32)
        self.total_x         = torch.tensor(self.data["X"],dtype=torch.float32)
        self.T               = torch.tensor(self.data["t"],dtype=torch.float32)
        self.T               = self.T[self.T<=4]
        
        self.U               = np.array(self.data["wave"])
        self.U               = self.U[:self.T.shape[0]]
        self.coefs           = self.data["coefs"]

        
        self.coiso              = None
        self.xi                 = None
        self.analytical_du2_dx2 = None
        self.lr                 = config["lr"]
        self.k_pde              = config["k_pde"]
        self.k_mse              = config["k_mse"]
        self.network            = NN(n_in=2,n_hidden=4*[50],n_out=1,init=False)
        

        
    def forward(self,input_):
                
        return self.network(input_)
    
    def on_train_start(self):
        
        fig,ax1 = plt.subplots(figsize=(20,7))
        extent = np.min(self.total_x.detach().cpu().numpy()), np.max(self.total_x.detach().cpu().numpy()), np.min(self.T.detach().cpu().numpy()), np.max(self.T.detach().cpu().numpy())
        im1 = plt.imshow(self.U,origin="lower",cmap=plt.cm.magma,extent=extent,aspect="auto",interpolation=None)
        plt.colorbar()
        plt.xlabel("X")
        plt.ylabel("Time")
        plt.title(f"Real wave")
        self.logger.experiment["images/"].log(File.as_image(fig))
        plt.close()
        
       


    def training_step(self, batch,batch_idx):

        #forward pass
        batch1 = batch["mse"]
        batch2 = batch["pde"]
        
        x1,target1 = batch1
        x2,target2 = batch2

        prediction1,coordinates1 = self.forward(x1)
        prediction2,coordinates2 = self.forward(x2)
        
        second_time_deriv,theta,term1,term2 = self.compute_derivatives(prediction2,coordinates2)
        
        xi = self.least_squares_QR(theta,second_time_deriv)
        
        #losses
        mse_loss                            = torch.mean((prediction1-target1)**2) # scalar
        pde_loss                            = torch.mean((second_time_deriv+term1+term2)**2)#should be scalar as well 
        total_loss                          = self.k_mse*mse_loss  + self.k_pde*pde_loss

        for i,j in enumerate(xi):
            self.log(f"Coefficient nr{i}",j,logger=True,on_epoch=True,on_step=False)
        
        self.log("MSE Loss",mse_loss,on_step=False,on_epoch=True,logger=True,prog_bar=True)
        self.log("PDE Loss",pde_loss,on_step=False,on_epoch=True,logger=True,prog_bar=True)
        self.log("Total Loss",total_loss,on_step=False,on_epoch=True,logger=True,prog_bar=True)      
        
        return total_loss

    def validation_step(self,batch,batch_idx):

        x,target = batch
        val_prediction,val_coordinates = self.forward(x)
        val_loss = torch.mean((val_prediction-target)**2)
        self.log("Validation Loss",val_loss,on_step=False,on_epoch=True,logger=True,prog_bar=True)


    def compute_derivatives(self,prediction,coords,derivs=False):
        
        du = grad(outputs=prediction, inputs=coords, grad_outputs=torch.ones_like(prediction), create_graph=True)[0]
        first_time_deriv = du[:,0:1]
        du_dx            = du[:,1:2]

        du2 = grad(outputs=first_time_deriv,inputs=coords,grad_outputs=torch.ones_like(prediction),create_graph=True)[0]
        second_time_deriv = du2[:,0:1]
        du2_dtdx          = du2[:,1:2]

        du2_dx2 = grad(outputs=du_dx,inputs=coords,grad_outputs=torch.ones_like(prediction),create_graph=True)[0][:,1:2]

        term1 = -(self.c**2-self.v**2)*du2_dx2
        term2 = +2*self.v*du2_dtdx
    
        theta = torch.reshape(torch.cat((term1,term2),dim=1),(prediction.shape[0],-1))

        if not derivs:
        
            return second_time_deriv,theta,term1,term2
        
        else:
            
            return second_time_deriv,du2_dtdx,du2_dx2

        
    def least_squares_QR(self,theta,second_deriv):
        
        Q,R = torch.linalg.qr(theta)
        xi  = torch.inverse(R) @ Q.T @ second_deriv
        return xi
    
    def configure_optimizers(self):

        optimizer = torch.optim.AdamW(self.parameters(),lr=self.lr,amsgrad=True,weight_decay=1e-6)

        return optimizer

    def train_dataloader(self):

        dataset    = Dataset(data=self.X_MSE, labels=self.Y_MSE)
        dataset2   = Dataset(data=self.X_validation,labels = self.Y_validation)
        
        dataloader1 = torch.utils.data.DataLoader(dataset, batch_size=self.X_MSE.shape[0],drop_last=True, shuffle=True,num_workers=0)
        dataloader2 = torch.utils.data.DataLoader(dataset2, batch_size=self.X_validation.shape[0],drop_last=True,shuffle=True,num_workers=0)
      
        
        
        return {"mse":dataloader1,"pde":dataloader2}

    def val_dataloader(self):

        val_dataset = Dataset(data=self.X_validation, labels=self.Y_validation)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=self.X_validation.shape[0],
                                                     drop_last=True, num_workers=0, shuffle=False)

        return val_dataloader
    
    def analytical_derivs(self,x,time):

        if self.coiso is None:

            self.omega = np.zeros(self.coefs.shape[0])
            self.k_rev = np.zeros(self.coefs.shape[0])
            self.k_fwd = np.zeros(self.coefs.shape[0])
            self.phi   = np.zeros(self.coefs.shape[0])
            self.coiso = 1

            for i in range(self.coefs.shape[0]):

                n = i+1
                self.omega[i] = n*np.pi*(self.c**2-self.v**2)/self.c                
                self.k_rev[i] = n*np.pi*(self.c+self.v)/self.c
                self.k_fwd[i] = n*np.pi*(self.c-self.v)/self.c
                self.phi[i]   = -n*(np.pi*(self.c+self.v)/(2*self.c)-np.pi/2)



        du2_dx2  = np.zeros((len(time),len(x)))
        du2_dxdt = np.zeros((len(time),len(x)))
        du2_dt2  = np.zeros((len(time),len(x)))

        for b,t in enumerate(time):

            sum1  = 0
            sum2  = 0
            sum3  = 0

            for n in range(self.coefs.shape[0]):

                C1     = self.coefs[n] * np.sin(self.k_fwd[n]*x - self.omega[n]*t - self.phi[n])
                C2     = self.coefs[n] * np.sin(self.k_rev[n]*x + self.omega[n]*t + self.phi[n])

                sum1 += -self.k_fwd[n]**2*C1 - self.k_rev[n]**2*C2
                sum2 += +self.k_fwd[n]*self.omega[n]*C1  - self.k_rev[n]*self.omega[n]*C2
                sum3 += -self.omega[n]**2*C1 - self.omega[n]**2*C2

            du2_dx2[b,:]  = sum1
            du2_dxdt[b,:] = sum2
            du2_dt2[b,:]  = sum3

        #this analytical derivatives should be [T,X]
        return du2_dx2,du2_dxdt,du2_dt2

    def plot(self):

        
        total_output = self.forward(self.total_X.to(self.device))[0].detach().cpu().numpy().reshape(self.T.shape[0],self.total_x.shape[0])
        fig,ax1 = plt.subplots(figsize=(20,7))
        extent = np.min(self.total_x.detach().cpu().numpy()), np.max(self.total_x.detach().cpu().numpy()), np.min(self.T.detach().cpu().numpy()), np.max(self.T.detach().cpu().numpy())
        im1 = plt.imshow(total_output,origin="lower",cmap=plt.cm.magma,extent=extent,aspect="auto",interpolation=None)
        plt.colorbar()
        plt.xlabel("X")
        plt.ylabel("Time")
        plt.title(f"Full predicted wave at epoch {self.current_epoch}")
        self.logger.experiment["images/"].log(File.as_image(fig))
        plt.close()
        
        
        fig,ax1 = plt.subplots(figsize=(20,7))
        extent = np.min(self.total_x.detach().cpu().numpy()), np.max(self.total_x.detach().cpu().numpy()), np.min(self.T.detach().cpu().numpy()), np.max(self.T.detach().cpu().numpy())
        im1 = plt.imshow(np.abs(total_output-self.U),origin="lower",cmap=plt.cm.magma,extent=extent,aspect="auto",interpolation=None)
        plt.colorbar()
        plt.xlabel("X")
        plt.ylabel("Time")
        plt.title(f"Difference between true wave and predicted wave at epoch {self.current_epoch}")
        self.logger.experiment["images/"].log(File.as_image(fig))
        plt.close()
        
        
    def anim(self):
        
        def init():
            for line in lines:
                line.set_data([],[])
            return lines

        fig = plt.figure(figsize=(15,5))
        ax = plt.axes(xlim=(0,1),ylim=(-3,3))
        line, = ax.plot([], [], lw=3)
        plt.xlim(0,1)
        plt.ylim(-3,3)
        plt.xlabel("X")
        plt.ylabel("Displacement")
        plt.title(f"Prediction for epoch {self.current_epoch}",fontsize=20)
        lines = []
        lobj1 = ax.plot([],[],lw=3,label="Predicted")[0]
        lobj2 = ax.plot([],[],lw=3,label="Real")[0]
        lines.append(lobj1)
        lines.append(lobj2)
        plt.legend()
                
        anim = FuncAnimation(fig,
                    self.update_plot,
                    init_func=init,
                    frames=int(len(self.T)),
                    fargs=(lines),
                     blit=True,
                    interval=100,
                    repeat=True)
        
        plt.legend()
        anim.save(f"predictions/{self.filename}_epoch{self.current_epoch}_{now.minute}{now.second}" + ".gif",fps=30,writer="ffmpeg")
        plt.close()
        self.logger.experiment["images/"].log(File(f"predictions/{self.filename}_epoch{self.current_epoch}_{now.minute}{now.second}" + ".gif"))
        
    def on_validation_epoch_end(self):
        
#         if self.current_epoch%50 == 0:
#             with torch.enable_grad():

#                 predictions,coordinates = self.forward(self.X_train.to(self.device))

#                 #analytical derivatives--> stay constant only need to compute 1 time
#                 if self.analytical_du2_dx2 == None:

#                     self.analy_du2_dx2,self.analy_du2_dxdt,self.analy_du2_dt2 = self.analytical_derivs(self.total_x,self.T[self.T<=2])
                    
#                 #auto diff derivatives
#                 NN_du2_dt2,NN_du2_dxdt,NN_du2_dx2          = self.compute_derivatives(predictions,coordinates,derivs=True)
#                 second_time_deriv,theta,term1,term2        = self.compute_derivatives(predictions,coordinates)


#             self.xi = torch.linalg.lstsq(theta,second_time_deriv).solution        
#             for i,j in enumerate(self.xi):
#                 self.log(f"Coefficient nr{i}",j,logger=True,on_epoch=True,on_step=False)

#             NN_du2_dt2 = NN_du2_dt2.detach().cpu().numpy().reshape(int(self.T.shape[0]/2),self.total_x.shape[0])
#             NN_du2_dxdt= NN_du2_dxdt.detach().cpu().numpy().reshape(int(self.T.shape[0]/2),self.total_x.shape[0])
#             NN_du2_dx2 = NN_du2_dx2.detach().cpu().numpy().reshape(int(self.T.shape[0]/2),self.total_x.shape[0])


#             dif1 = 100 * np.median(np.abs((NN_du2_dt2[self.analy_du2_dt2>1e-5]  - self.analy_du2_dt2[self.analy_du2_dt2>1e-5])/self.analy_du2_dt2[self.analy_du2_dt2>1e-5]))        
#             dif2 = 100 * np.median(np.abs((NN_du2_dx2[self.analy_du2_dx2>1e-5]  - self.analy_du2_dx2[self.analy_du2_dx2>1e-5])/self.analy_du2_dx2[self.analy_du2_dx2>1e-5]))
#             dif3 = 100 * np.median(np.abs((NN_du2_dxdt[self.analy_du2_dxdt>1e-5] - self.analy_du2_dxdt[self.analy_du2_dxdt>1e-5])/self.analy_du2_dxdt[self.analy_du2_dxdt>1e-5]))

#             abs_dif1 = np.mean(np.abs(NN_du2_dt2 - self.analy_du2_dt2))
#             abs_dif2 = np.mean(np.abs(NN_du2_dx2 - self.analy_du2_dx2))
#             abs_dif3 = np.mean(np.abs(NN_du2_dxdt- self.analy_du2_dxdt))


#             self.log("Absolute Diff1",abs_dif1,on_step=False,on_epoch=True,prog_bar=False,logger=True)
#             self.log("Relative Diff1",dif1,on_step=False,on_epoch=True,prog_bar=False,logger=True)
#             self.log("NN_du2_dt2",np.mean(np.abs(NN_du2_dt2)),on_step=False,on_epoch=True,logger=True,prog_bar=False)    
#             self.log("Analy_du2_dt2",np.mean(np.abs(self.analy_du2_dt2)),on_step=False,on_epoch=True,logger=True,prog_bar=False)    


#             self.log("Absolute Diff2",abs_dif2,on_step=False,on_epoch=True,prog_bar=False,logger=True)
#             self.log("Relative Diff2",dif2,on_step=False,on_epoch=True,prog_bar=False,logger=True)
#             self.log("NN_du2_dx2",np.mean(np.abs(NN_du2_dx2)),on_step=False,on_epoch=True,logger=True,prog_bar=False)
#             self.log("Analy_du2_dx2",np.mean(np.abs(self.analy_du2_dx2)),on_step=False,on_epoch=True,logger=True,prog_bar=False)



#             self.log("Absolute Diff3",abs_dif3,on_step=False,on_epoch=True,prog_bar=False,logger=True)
#             self.log("Relative Diff3",dif3,on_step=False,on_epoch=True,prog_bar=False,logger=True)
#             self.log("NN_du2_dxdt",np.mean(np.abs(NN_du2_dxdt)),on_step=False,on_epoch=True,logger=True,prog_bar=False)
#             self.log("Analy_du2_dxdt",np.mean(np.abs(self.analy_du2_dxdt)),on_step=False,on_epoch=True,logger=True,prog_bar=False)


        
        if (self.current_epoch%500 == 0):
            self.plot()
            self.anim()
        else:
            pass
        
            
    def update_plot(self,i,line1,line2):
        
        a = self.forward(self.total_X.to(self.device))[0].detach().cpu().numpy()
        b   = a.reshape(self.T.shape[0],self.total_x.shape[0])
        line1.set_data(self.total_x,b[i,:])
        line2.set_data(self.total_x,self.U[i,:])
        lines = (line1,line2)
        
        return lines

In [3]:
filename = "c=1_v=0.3c_noise=0.0_modes=0.39_0.26_t_final=10_nr=1000"

config = {
    "lr" :1e-3,
    "filename":filename,
    "k_pde":0.01,
    "k_mse":5
}


PATH = "models/PINN/Boundary_only_PIPE220.ckpt"
model = PINN.load_from_checkpoint(PATH,filename=filename,config=config)

In [4]:
data = np.load(f"string_data/" + filename + ".npz")
t = data["t"]
t = t[t<=4]
x = data["X"]
x = x[:t.shape[0]]


train = data["X_train"]
validation = data["X_validation"]
total_input = torch.tensor(data["total_X"],dtype=torch.float32)
total_input = total_input[total_input[:,0]<=4]

data_array = model.forward(total_input.to("cpu"))[0].detach().cpu().numpy()
data_array = data_array.reshape(t.shape[0],x.shape[0])
print(data_array[:10])


real_wave              = np.array(data["wave"])
real_wave              = real_wave[:t.shape[0]]



[[ 0.01588874  0.03471116  0.05386645 ... -0.00357217 -0.00489361
  -0.00639827]
 [ 0.01331375  0.03309936  0.05322064 ... -0.00153685 -0.00343335
  -0.00552232]
 [ 0.01084834  0.03157524  0.05263724 ...  0.00052738 -0.00196231
  -0.00465251]
 ...
 [ 0.00110022  0.02602449  0.05123494 ...  0.01076978  0.00510537
  -0.0007906 ]
 [-0.00020228  0.02542342  0.05131581 ...  0.01271337  0.00638544
  -0.00017697]
 [-0.00127098  0.0249985   0.05151139 ...  0.01459729  0.00760132
   0.00036823]]


In [42]:
from matplotlib.offsetbox import AnchoredText

@gif.frame
def helper_plot(data_array,real_wave,x,i):
    sep  = int(data_array.shape[0]/2)
    data = data_array.copy()
    fig,ax = plt.subplots(figsize=(15,10))
    
    if i < sep:
        color1 = "orange"
        color2 = "red"
        label1 = "Model prediction"
        label2 = "Ground truth"
        marker = "x"
        alpha2 = 0.5
        at2    = AnchoredText(f"Training t<2s", prop=dict(size=15), frameon=True, loc='upper center')
        order  = [0,1]

    
    if i >= sep:
        
        color1 = "orange"
        label1 = "Model prediction"
        at2    = AnchoredText(f"Extrapolating t>2s ", prop=dict(size=15), frameon=True, loc='upper center')
        alpha2 = 1
        color2 = "blue"
        label2 = "Ground truth test data"
        marker = None
        
        
        
        plt.plot([0],[0],marker="x",color="red",markersize=12,label="Ground truth training data")
        plt.plot([1],[0],marker="x",markersize=12,color="red")
        order  = [0,2,1]
        
        
        
    #TIME elapsing annotation
    at = AnchoredText(f"t={i/100}s", prop=dict(size=15), frameon=True, loc='upper left')
    at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
    ax.add_artist(at)
    
    at2.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
    ax.add_artist(at2)
    
    #real plots
    
    plt.plot(x,data[i,:],color=color1,linestyle="solid",lw=4,label = label1)
    plt.plot(x,real_wave[i,:],color=color2,linestyle="dashed",lw=3,label=label2,marker=marker,markersize=12,alpha=alpha2)
    
    
   
    
    #plt.vlines(x=0,color="r",ymin=-3,ymax=3,lw=4,label="Fixed points")
    #plt.vlines(x=1,color="red",linestyles="dashed",ymin=-3,ymax=3,lw=2)
    
    #plot settings
    plt.plot(0,0)
    plt.xlim(-0.1,1.1)
    plt.ylim(-1.5,1.5)
    plt.xlabel("X")
    plt.ylabel("Displacement")
    
    handles, labels = plt.gca().get_legend_handles_labels()
    plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order],fontsize=20)

    
    
    plt.grid()
    plt.title(f"Vibrating String with horizontal speed")
    

In [43]:
from functools import wraps
from io import BytesIO as Buffer
from matplotlib import pyplot as plt
from PIL import Image


frames = []
for i in range(data_array.shape[0]):
    frames.append(helper_plot(data_array,real_wave,x,i))
    
frames[0].save("test.gif",save_all=True,
        append_images=frames[1:],
        duration=10,
        loop=0,
    )   

##### f = Image.open('test.gif')
f.info['duration'] = 50
f.save('out.gif', save_all=True)