In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import jax
import jax.numpy as jnp
import jax.nn as jnn
from jax import jit,grad,vmap,value_and_grad,jacfwd,jacrev, random
from jax.example_libraries.optimizers import adam
from functools import partial
import numpy as np
import numpy.random as npr
import math
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import sklearn.metrics

initializer = jnn.initializers.glorot_normal()

def init_glorot_params(layer_sizes, key = random.PRNGKey(10)):
    return [(initializer(key, (m, n), jnp.float32), jnp.zeros(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]

# A MLP for first domain
def f(params,x,y,adap):
    inputs = jnp.array([x,y])
    for w,b in params:
        outputs = jnp.dot(inputs,w) + b
        inputs = jnn.swish(outputs*nn*adap)
    return outputs 

# Hyperparameters
layer_size = [2, 128, 128, 128, 128, 128, 1]
step_size = 1e-3
train_iters = 28000
nn = 10
adap = [0.1,0.1,0.1]

# Initial guess of parameters
params_fwd = init_glorot_params(layer_size)
params = (params_fwd,adap)


# Domain parameters for training
xmin = 0
xmax = 3
ymin = 0
ymax = 3
nx = 100
ny = 100
x = jnp.linspace(xmin,xmax,nx)
y = jnp.linspace(ymin,ymax,ny)
xp,yp = jnp.meshgrid(x,y)
X = jnp.stack([xp.flatten(),yp.flatten()],1)

# Getting co-ordinates for points in different domains and interface separately (a(upper half), b(lower half), c(at interfce)) for training
a = np.ones([nx*ny,2])
b = np.ones([nx*ny,2])
c = np.ones([nx*ny,2])
d = np.ones([nx*ny,2])
e = np.ones([nx*ny,2])
z = 0
q = 0
r = 0
s = 0
t = 0
for i in X:
    if i[1] >= 2:
        a[z,0] = i[0]
        a[z,1] = i[1]
        z = z+1
    if i[1] <= 2 and i[1] >= 1:
        b[q,0] = i[0]
        b[q,1] = i[1]
        q = q+1
    if i[1] <= 1:
        c[r,0] = i[0]
        c[r,1] = i[1]
        r = r+1
    if i[1] == 2:
        d[s,0] = i[0]
        d[s,1] = i[1]
        s = s+1
    if i[1] == 1:
        e[t,0] = i[0]
        e[t,1] = i[1]
        t = t+1
inputs1 = a[0:z,:]
inputs2 = b[0:q,:]
inputs3 = c[0:r,:]
inputs_int1 = d[0:s,:]
inputs_int2 = e[0:t,:]

# Getting coordinates for boundaries of domain separately
xl = jnp.stack([jnp.ones(ny)*xmin,y],1)
xr = jnp.stack([jnp.ones(ny)*xmax,y],1)
xb = jnp.stack([x,jnp.ones(nx)*ymin],1)
xt = jnp.stack([x,jnp.ones(nx)*ymax],1)

a1 = np.ones([nx*ny,2]); z1 = 0
b1 = np.ones([nx*ny,2]); q1 = 0
c1 = np.ones([nx*ny,2]); r1 = 0
d1 = np.ones([nx*ny,2]); s1 = 0
e1 = np.ones([nx*ny,2]); t1 = 0
f1 = np.ones([nx*ny,2]); u1 = 0
a2 = np.ones([nx*ny,2]); z2 = 0
b2 = np.ones([nx*ny,2]); q2 = 0
d2 = np.ones([nx*ny,2]); s2 = 0
e2 = np.ones([nx*ny,2]); t2 = 0

for i in xl:
    if i[1] >= 2:
        a1[z1,0] = i[0]
        a1[z1,1] = i[1]
        z1 = z1+1
    if i[1] >= 1 and i[1] <= 2:
        b1[q1,0] = i[0]
        b1[q1,1] = i[1]
        q1 = q1+1
    if i[1] <= 1:
        c1[r1,0] = i[0]
        c1[r1,1] = i[1]
        r1 = r1+1
    if i[1] == 2:
        a2[z2,0] = i[0]
        a2[z2,1] = i[1]
        z2 = z2+1
    if i[1] == 1:
        b2[q2,0] = i[0]
        b2[q2,1] = i[1]
        q2 = q2+1
        
xl1 = a1[0:z1,:]
xl2 = b1[0:q1,:]
xl3 = c1[0:r1,:]
xl_int1 = a2[0:z2,:]
xl_int2 = b2[0:q2,:]

for i in xr:
    if i[1] >= 2:
        d1[s1,0] = i[0]
        d1[s1,1] = i[1]
        s1 = s1+1
    if i[1] >= 1 and i[1] <= 2:
        e1[t1,0] = i[0]
        e1[t1,1] = i[1]
        t1 = t1+1
    if i[1] <= 1:
        f1[u1,0] = i[0]
        f1[u1,1] = i[1]
        u1 = u1+1
    if i[1] == 2:
        d2[s2,0] = i[0]
        d2[s2,1] = i[1]
        s2 = s2+1
    if i[1] == 1:
        e2[t2,0] = i[0]
        e2[t2,1] = i[1]
        t2 = t2+1
        
xr1 = d1[0:s1,:]
xr2 = e1[0:t1,:]
xr3 = f1[0:u1,:]
xr_int1 = d2[0:s2,:]
xr_int2 = e2[0:t2,:]

# Ploting input points
plt.figure(figsize=(5,5))
plt.scatter(inputs1[:,0],inputs1[:,1],color='green',marker='.',label='')
plt.scatter(inputs2[:,0],inputs2[:,1],color='black',marker='.',label='')
plt.scatter(inputs3[:,0],inputs3[:,1],color='blue',marker='.',label='')
plt.scatter(inputs_int1[:,0],inputs_int1[:,1],color='red',marker='*',label='')
plt.scatter(inputs_int2[:,0],inputs_int2[:,1],color='red',marker='*',label='')
#plotting left boundry
plt.scatter(xl1[:,0],xl1[:,1],color='pink',marker='o',label='')
plt.scatter(xl2[:,0],xl2[:,1],color='pink',marker='o',label='')
plt.scatter(xl3[:,0],xl3[:,1],color='pink',marker='o',label='')
plt.scatter(xl_int1[:,0],xl_int1[:,1],color='pink',marker='o',label='')
plt.scatter(xl_int2[:,0],xl_int2[:,1],color='pink',marker='o',label='')
#plotting right boundry
plt.scatter(xr1[:,0],xr1[:,1],color='pink',marker='o',label='')
plt.scatter(xr2[:,0],xr2[:,1],color='pink',marker='o',label='')
plt.scatter(xr3[:,0],xr3[:,1],color='pink',marker='o',label='')
plt.scatter(xr_int1[:,0],xr_int1[:,1],color='pink',marker='o',label='')
plt.scatter(xr_int2[:,0],xr_int2[:,1],color='pink',marker='o',label='')
#plotting top boundry
plt.scatter(xt[:,0],xt[:,1],color='pink',marker='o',label='')
#plotting bottom boundry
plt.scatter(xb[:,0],xb[:,1],color='pink',marker='o',label='')
plt.title('Collocation points')

# Material properties
k1 = 1
k2 = 5
k3 = 1

BC_l_n = 0
BC_r_n = 0
BC_t = 10
BC_b = 5
# BC_left_d = P_import
# BC_right_d = P_import

def objective(params,inputs1,inputs2,inputs3,inputs_int1,inputs_int2):
    
    params_fwd,adap = params
    
    x1 = inputs1[:,0].reshape(-1)
    y1 = inputs1[:,1].reshape(-1)
    x2 = inputs2[:,0].reshape(-1)
    y2 = inputs2[:,1].reshape(-1)
    x3 = inputs3[:,0].reshape(-1)
    y3 = inputs3[:,1].reshape(-1)
    xint1 = inputs_int1[:,0].reshape(-1)
    yint1 = inputs_int1[:,1].reshape(-1)
    xint2 = inputs_int2[:,0].reshape(-1)
    yint2 = inputs_int2[:,1].reshape(-1)
    
    p1 = lambda x,y:  f(params_fwd, x, y, adap[0])
    p2 = lambda x,y:  f(params_fwd, x, y, adap[1])
    p3 = lambda x,y:  f(params_fwd, x, y, adap[2])

    p1_x = vmap(jacfwd(p1,argnums=0))
    p1_x_bc = jacfwd(p1,argnums=0)
    p1_y = vmap(jacfwd(p1,argnums=1))
    p1_y_bc = jacfwd(p1,argnums=1)

    p2_x = vmap(jacfwd(p2,argnums=0))
    p2_x_bc = jacfwd(p2,argnums=0)
    p2_y = vmap(jacfwd(p2,argnums=1))
    p2_y_bc = jacfwd(p2,argnums=1)
    
    p3_x = vmap(jacfwd(p3,argnums=0))
    p3_x_bc = jacfwd(p3,argnums=0)
    p3_y = vmap(jacfwd(p3,argnums=1))
    p3_y_bc = jacfwd(p3,argnums=1)

    p1_xx = vmap(jacfwd(jacrev(p1,argnums=0),argnums=0))
    p1_yy = vmap(jacfwd(jacrev(p1,argnums=1),argnums=1))

    p2_xx = vmap(jacfwd(jacrev(p2,argnums=0),argnums=0))
    p2_yy = vmap(jacfwd(jacrev(p2,argnums=1),argnums=1))
    
    p3_xx = vmap(jacfwd(jacrev(p3,argnums=0),argnums=0))
    p3_yy = vmap(jacfwd(jacrev(p3,argnums=1),argnums=1))
    
    #LOSS FUNCTION FOR DOMAIN 1
    bc_top = (vmap(p1)(xt[:,0],xt[:,1]).reshape(-1) - BC_t)**2
    bc_left1 = (-k1*vmap(p1_x_bc)(xl1[:,0],xl1[:,1]).reshape(-1) - BC_l_n)**2
    bc_right1 = (k1*vmap(p1_x_bc)(xr1[:,0],xr1[:,1]).reshape(-1) - BC_r_n)**2
    loss_bc1 = jnp.mean(bc_top) + jnp.mean(bc_left1) + jnp.mean(bc_right1)
    
    eq1 = (k1*p1_xx(x1,y1).reshape(-1) + k1*p1_yy(x1,y1).reshape(-1))**2
    loss_eq1 = jnp.mean(eq1)
    
    int12_p = (vmap(p2)(xint1,yint1).reshape(-1) - vmap(p1)(xint1,yint1).reshape(-1))**2
    int12_flux_y = (k2*p2_y(xint1,yint1).reshape(-1) + k1*p1_y(xint1,yint1).reshape(-1))**2
    loss_int1 = jnp.mean(int12_p) + jnp.mean(int12_flux_y) 
    
    LOSS1 = loss_bc1 + loss_eq1 + 30*loss_int1
    
    
    #LOSS FUNCTION FOR DOMAIN 2
    bc_left2 = (-k2*vmap(p2_x_bc)(xl2[:,0],xl2[:,1]).reshape(-1) - BC_l_n)**2
    bc_right2 = (k2*vmap(p1_x_bc)(xr2[:,0],xr2[:,1]).reshape(-1) - BC_r_n)**2
    loss_bc2 = jnp.mean(bc_left2) + jnp.mean(bc_right2)
    
    eq2 = (k2*p2_xx(x2,y2).reshape(-1) + k2*p2_yy(x2,y2).reshape(-1))**2
    loss_eq2 = jnp.mean(eq2)
    
    #int21_u = (vmap(p1)(xint1,yint1).reshape(-1) - vmap(p2)(xint1,yint1).reshape(-1))**2
    int23_p = (vmap(p3)(xint2,yint2).reshape(-1) - vmap(p2)(xint2,yint2).reshape(-1))**2
    #int21_flux_y = (-k1*p1_y(xint1,yint1).reshape(-1) - k2*p2_y(xint1,yint1).reshape(-1))**2
    int23_flux_y = (k3*p3_y(xint2,yint2).reshape(-1) + k2*p2_y(xint2,yint2).reshape(-1))**2
    loss_int2 =  jnp.mean(int23_p) + jnp.mean(int23_flux_y) 
    
    LOSS2 =  loss_bc2 + loss_eq2 + 30*loss_int2   
    
    #LOSS FUNCTION FOR DOMAIN 3
    bc_bottom = (vmap(p3)(xb[:,0],xb[:,1]).reshape(-1) - BC_b)**2
    bc_left3 = (-k3*vmap(p3_x_bc)(xr3[:,0],xr3[:,1]).reshape(-1) - BC_r_n)**2
    bc_right3 = (k3*vmap(p3_x_bc)(xr3[:,0],xr3[:,1]).reshape(-1) - BC_r_n)**2
    loss_bc3 = jnp.mean(bc_bottom) + jnp.mean(bc_left3) + jnp.mean(bc_right3)

    eq3 = (k3*p3_xx(x3,y3).reshape(-1) + k3*p3_yy(x3,y3).reshape(-1))**2
    loss_eq3 = jnp.mean(eq3)
    
    #int32_u = (vmap(p2)(xint2,yint2).reshape(-1) - vmap(p3)(xint2,yint2).reshape(-1))**2
    #int32_flux_y = (-k2*p2_y(xint2,yint2).reshape(-1) - k3*p3_y(xint2,yint2).reshape(-1))**2
    #loss_int3 = jnp.sum(int32_u) + jnp.sum(int32_flux_y) 
    
    LOSS3 = loss_bc3 + loss_eq3 
    
#     #Dirichlet data at left boundry
#     P1_left = vmap(p1)(xl1[:,0],xl1[:,1])
#     P_left_int1 = (vmap(p1)(xl_int1[:,0],xl_int1[:,1])+vmap(p2)(xl_int1[:,0],xl_int1[:,1]))/2
#     P2_left = vmap(p2)(xl2[:,0],xl2[:,1])
#     P_left_int2 = (vmap(p2)(xl_int2[:,0],xl_int2[:,1])+vmap(p3)(xl_int2[:,0],xl_int2[:,1]))/2
#     P3_left = vmap(p3)(xl3[:,0],xl3[:,1])
#     #Dirichlet data at left boundry
#     P1_right = vmap(p1)(xr1[:,0],xr1[:,1])
#     P_right_int1 = (vmap(p1)(xr_int1[:,0],xr_int1[:,1])+vmap(p2)(xr_int1[:,0],xr_int1[:,1]))/2
#     P2_right = vmap(p2)(xr2[:,0],xr2[:,1])
#     P_right_int2 = (vmap(p2)(xr_int2[:,0],xr_int2[:,1])+vmap(p3)(xr_int2[:,0],xr_int2[:,1]))/2
#     P3_right = vmap(p3)(xr3[:,0],xr3[:,1])
    
#     Left_dirichlet = (jnp.concatenate((P3_left,P_left_int2,P2_left,P_left_int1,P1_left),0) - BC_left_d)**2
#     Right_dirichlet = (jnp.concatenate((P3_right,P_right_int2,P2_right,P_right_int1,P1_right),0) - BC_right_d)**2
#     Loss_dirichlet = jnp.sum(Left_dirichlet) + jnp.sum(Right_dirichlet)

#      #Interface boundary condition
#     Loss_int1 = ((vmap(p1)(xint1,yint1)+vmap(p2)(xint1,yint1))/2 - 7.7272)**2
#     Loss_int2 = ((vmap(p2)(xint2,yint2)+vmap(p2)(xint2,yint2))/2 - 7.2727)**2
#     Loss_int = jnp.sum(Loss_int1) + jnp.sum(Loss_int2)
    
    return LOSS1 + LOSS2 + LOSS3 #+ Loss_int #+ Loss_dirichlet 

# Adam optimizer
@jit
def resnet_update(params, inputs1,inputs2,inputs3,inputs_int1,inputs_int2,opt_state):
    """ Compute the gradient for a batch and update the parameters """
    value, grads = value_and_grad(objective)(params,inputs1,inputs2,inputs3,inputs_int1,inputs_int2)
    opt_state = opt_update(0, grads, opt_state)
    return get_params(opt_state), opt_state, value

opt_init, opt_update, get_params = adam(step_size, b1=0.9, b2=0.999, eps=1e-08)
opt_state = opt_init(params)


for i in range(train_iters):
    params, opt_state, value = resnet_update(params, inputs1,inputs2,inputs3,inputs_int1,inputs_int2,opt_state)
    if i % 1000 == 0:
        print("Iteration {0:3d} objective {1}".format(i,objective(params,inputs1,inputs2,inputs3,inputs_int1,inputs_int2)))

params_fwd,adap = params

# Testing grid generation
nxR = 40
nyR = 40
xR = jnp.linspace(xmin,xmax,nxR)
yR = jnp.linspace(ymin,ymax,nyR)
xpR,ypR = jnp.meshgrid(xR,yR)
XR = jnp.stack([xpR.flatten(),ypR.flatten()],1)

# Getting co-ordinates for points in different domains and interface separately (a(upper third), b(middle third), c(lower third), d(interface 1), e(interface2) for plotting
a3 = np.ones([nxR*nyR,2])
b3 = np.ones([nxR*nyR,2])
c3 = np.ones([nxR*nyR,2])
d3 = np.ones([nxR*nyR,2])
e3 = np.ones([nxR*nyR,2])
z3 = 0
q3 = 0
r3 = 0
s3 = 0
t3 = 0
for j in XR:
    if j[1] >= 2:
        a3[z3,0] = j[0]
        a3[z3,1] = j[1]
        z3 = z3+1
    if j[1] <= 2 and j[1] >= 1:
        b3[q3,0] = j[0]
        b3[q3,1] = j[1]
        q3 = q3+1
    if j[1] <= 1:
        c3[r3,0] = j[0]
        c3[r3,1] = j[1]
        r3 = r3+1
    if j[1] == 2:
        d3[s3,0] = j[0]
        d3[s3,1] = j[1]
        s3 = s3+1
    if j[1] == 1:
        e3[t3,0] = j[0]
        e3[t3,1] = j[1]
        t3 = t3+1
inputs1R = a3[0:z3,:]
inputs2R = b3[0:q3,:]
inputs3R = c3[0:r3,:]
inputs_int1R = d3[0:s3,:]
inputs_int2R = e3[0:t3,:]

x1R = inputs1R[:,0].reshape(-1)
y1R = inputs1R[:,1].reshape(-1)
x2R = inputs2R[:,0].reshape(-1)
y2R = inputs2R[:,1].reshape(-1)
x3R = inputs3R[:,0].reshape(-1)
y3R = inputs3R[:,1].reshape(-1)
xint1R = inputs_int1R[:,0].reshape(-1)
yint1R = inputs_int1R[:,1].reshape(-1)
xint2R = inputs_int2R[:,0].reshape(-1)
yint2R = inputs_int2R[:,1].reshape(-1)

# Ploting test points
plt.figure(figsize=(5,5))
plt.scatter(inputs1R[:,0],inputs1R[:,1],color='green',marker='.',label='')
plt.scatter(inputs2R[:,0],inputs2R[:,1],color='black',marker='.',label='')
plt.scatter(inputs3R[:,0],inputs3R[:,1],color='blue',marker='.',label='')
plt.scatter(inputs_int1R[:,0],inputs_int1R[:,1],color='red',marker='*',label='')
plt.scatter(inputs_int2R[:,0],inputs_int2R[:,1],color='red',marker='*',label='')
plt.title('Testing points')

# Getting I-PINNs based solution
p1_approx = lambda x,y: f(params_fwd,x,y,adap[0])
p2_approx = lambda x,y: f(params_fwd,x,y,adap[1])
p3_approx = lambda x,y: f(params_fwd,x,y,adap[2])

p1_x_approx = vmap(jacfwd(p1_approx,argnums=0))
p1_y_approx = vmap(jacfwd(p1_approx,argnums=1))
p2_x_approx = vmap(jacfwd(p2_approx,argnums=0))
p2_y_approx = vmap(jacfwd(p2_approx,argnums=1))
p3_x_approx = vmap(jacfwd(p3_approx,argnums=0))
p3_y_approx = vmap(jacfwd(p3_approx,argnums=1))

P1_approx = vmap(p1_approx)(x1R,y1R).reshape(-1)
P2_approx = vmap(p2_approx)(x2R,y2R).reshape(-1)
P3_approx = vmap(p3_approx)(x3R,y3R).reshape(-1)
P1_int1_approx = vmap(p1_approx)(xint1R,yint1R).reshape(-1)
P2_int1_approx = vmap(p2_approx)(xint1R,yint1R).reshape(-1)
P_int1 = (P1_int1_approx + P2_int1_approx)/2
P2_int1_approx = vmap(p2_approx)(xint2R,yint2R).reshape(-1)
P3_int1_approx = vmap(p3_approx)(xint2R,yint2R).reshape(-1)
P_int2 = (P2_int1_approx + P3_int1_approx)/2
                                 

# Combining solutions and inputs in two differen domains
P_approx = np.concatenate((P1_approx,P2_approx,P3_approx),0).reshape(-1)
INPUTS_x = np.concatenate((x1R,x2R,x3R),0).reshape(-1)
INPUTS_y = np.concatenate((y1R,y2R,y3R),0).reshape(-1)

COMSOL_data = np.loadtxt('/kaggle/input/2-d-plot/2 D plot data COMSOL Theisis 2D problem.txt')
L = COMSOL_data.shape[0]

# Getting co-ordinates for points in different domains and interface separately (a(upper third), b(middle third), c(lower third), d(interface 1), e(interface2) for plotting
a4 = np.ones([L,3])
b4 = np.ones([L,3])
c4 = np.ones([L,3])
d4 = np.ones([L,3])
e4 = np.ones([L,3])
z4 = 0
q4 = 0
r4 = 0
s4 = 0
t4 = 0
for j in COMSOL_data:
    if j[1] > 2:
        a4[z4,0] = j[0]
        a4[z4,1] = j[1]
        a4[z4,2] = j[2]
        z4 = z4+1
    if j[1] < 2 and j[1] > 1:
        b4[q4,0] = j[0]
        b4[q4,1] = j[1]
        b4[q4,2] = j[2]
        q4 = q4+1
    if j[1] < 1:
        c4[r4,0] = j[0]
        c4[r4,1] = j[1]
        c4[r4,2] = j[2]
        r4= r4+1
    if j[1] == 2:
        d4[s4,0] = j[0]
        d4[s4,1] = j[1]
        d4[s4,2] = j[2]
        s4 = s4+1
    if j[1] == 1:
        e4[t4,0] = j[0]
        e4[t4,1] = j[1]
        e4[t4,2] = j[2]
        t4 = t4+1
inputs1c = a4[0:z4,0:2]
inputs2c = b4[0:q4,0:2]
inputs3c = c4[0:r4,0:2]
inputs_int1c = d4[0:s4,0:2]
inputs_int2c = e4[0:t4,0:2]

COMSOL_P1c = a4[0:z4,2]
COMSOL_P2c = b4[0:q4,2]
COMSOL_P3c = c4[0:r4,2]
COMSOL_P_int1c = d4[0:s4,2]
COMSOL_P_int2c = e4[0:t4,2]

x1c = inputs1c[:,0].reshape(-1)
y1c = inputs1c[:,1].reshape(-1)
x2c = inputs2c[:,0].reshape(-1)
y2c = inputs2c[:,1].reshape(-1)
x3c = inputs3c[:,0].reshape(-1)
y3c = inputs3c[:,1].reshape(-1)
xint1c = inputs_int1c[:,0].reshape(-1)
yint1c = inputs_int1c[:,1].reshape(-1)
xint2c = inputs_int2c[:,0].reshape(-1)
yint2c = inputs_int2c[:,1].reshape(-1)

# Getting I-PINNs based solution at points where COMSOL solution is available
pc1 = lambda x,y: f(params_fwd,x,y,adap[0])
pc2 = lambda x,y: f(params_fwd,x,y,adap[1])
pc3 = lambda x,y: f(params_fwd,x,y,adap[2])

Pc1_approx = vmap(pc1)(x1c, y1c).reshape(-1)
Pc2_approx = vmap(pc2)(x2c, y2c).reshape(-1)
Pc3_approx = vmap(pc3)(x3c, y3c).reshape(-1)
Pc1_int1_approx = vmap(pc1)(xint1c, yint1c).reshape(-1)
Pc2_int1_approx = vmap(pc2)(xint1c, yint1c).reshape(-1)
Pc_int1 = (Pc1_int1_approx + Pc2_int1_approx) / 2
Pc2_int1_approx = vmap(pc2)(xint2c, yint2c).reshape(-1)
Pc3_int1_approx = vmap(pc3)(xint2c, yint2c).reshape(-1)
Pc_int2 = (Pc2_int1_approx + Pc3_int1_approx) / 2

                                
# Combining solutions and inputs in two differen domains
Pc_approx = np.concatenate((Pc1_approx,Pc_int1,Pc2_approx,Pc_int2,Pc3_approx),0).reshape(-1)
COMSOL_P = np.concatenate((COMSOL_P1c,COMSOL_P_int1c,COMSOL_P2c,COMSOL_P_int2c,COMSOL_P3c),0).reshape(-1)
INPUTS_xc = np.concatenate((x1c,xint1c,x2c,xint2c,x3c), 0).reshape(-1)
INPUTS_yc = np.concatenate((y1c,yint1c,y2c,yint2c,y3c), 0).reshape(-1)

def RMSE(actual, predicted):
  MSE = jnp.square(jnp.subtract(actual,predicted)).mean()
  return math.sqrt(MSE)
print("Root Mean Square Error1 = ", str(RMSE(COMSOL_P,Pc_approx)))

# Calculate the L2 norm of the difference (error) between the two vectors
error_l2_norm = np.linalg.norm(COMSOL_P - Pc_approx, ord=2)
# Calculate the L2 norm of the reference vector (COMSOL_P)
reference_l2_norm = np.linalg.norm(COMSOL_P, ord=2)
# Calculate the relative L2 error
relative_l2_error = error_l2_norm / reference_l2_norm
print("Relative L2 Error:", relative_l2_error)

cmap = plt.get_cmap('rainbow')
# Plotting AdaI-PINNs solution of pressure
fontsize = 15
plt.figure()
ax = plt.subplot()
ax.set_aspect('equal', adjustable='box')
im = plt.tricontourf(INPUTS_x, INPUTS_y, P_approx,100 ,cmap=cmap)
# Set the fontsize for x and y axis labels
ax.tick_params(axis='both', labelsize=fontsize)
plt.title("AdaI-PINNs solution ", fontsize=fontsize)
plt.xlabel("X", fontsize=fontsize)
plt.ylabel("Pressure", fontsize=fontsize)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax)
cbar.ax.tick_params(labelsize=fontsize)  # Set font size for colorbar labels
plt.savefig('pinns_u_inv', bbox_inches='tight', dpi=1200)
plt.show()

# Plotting COMSOL solution of pressure
fontsize = 15
plt.figure()
ax = plt.subplot()
ax.set_aspect('equal', adjustable='box')
im = plt.tricontourf(INPUTS_xc, INPUTS_yc, COMSOL_P,100 ,cmap=cmap)
# Set the fontsize for x and y axis labels
ax.tick_params(axis='both', labelsize=fontsize)
plt.title("COMSOL solution ", fontsize=fontsize)
plt.xlabel("X", fontsize=fontsize)
plt.ylabel("Pressure", fontsize=fontsize)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax)
cbar.ax.tick_params(labelsize=fontsize)  # Set font size for colorbar labels
plt.savefig('pinns_u_inv', bbox_inches='tight', dpi=1200)
plt.show()

# Plotting AdaI-PINNs solution at COMSOL data points of pressure
fontsize = 15
plt.figure()
ax = plt.subplot()
ax.set_aspect('equal', adjustable='box')
im = plt.tricontourf(INPUTS_xc, INPUTS_yc, Pc_approx,100 ,cmap=cmap)
# Set the fontsize for x and y axis labels
ax.tick_params(axis='both', labelsize=fontsize)
plt.title("AdaI-PINNs solution at COMSOL data points ", fontsize=fontsize)
plt.xlabel("X", fontsize=fontsize)
plt.ylabel("Pressure", fontsize=fontsize)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax)
cbar.ax.tick_params(labelsize=fontsize)  # Set font size for colorbar labels
plt.savefig('pinns_u_inv', bbox_inches='tight', dpi=1200)
plt.show()