In [None]:
import matplotlib
matplotlib.use('TkAgg')
from matplotlib import pyplot as plt
from pyDOE import lhs
from matplotlib.cm import get_cmap
from matplotlib.patches import Polygon
import matplotlib.patches as patches
from pyDOE import lhs

%matplotlib inline
color_map = get_cmap('Blues')  # Choose a colormap for shading


def main():
    plt.rcParams.update({'font.size': 14})
    

    # Setup
    Ntr_u = 24
    Ntr_v = 26
    Ntr_u_artificial = 49
    Ntr_v_artificial = 51
    dim = 1
    lb = np.zeros(dim)
    ub = 1.0 * np.ones(dim)
    jitter = 1e-6
    #can change these to add noise if you want to
    noise_u = 0.0
    noise_v = 0.0

    plt_flag = True

    T = 1.5
    dt = 1e-2

    nsteps = int(T / dt)
    n_star_u = 400
    n_star_v = 400
   
    x_star_u = np.linspace(lb[0], ub[0], n_star_u).reshape(-1, 1)
    x_star_v = np.linspace(lb[0], ub[0], n_star_v).reshape(-1, 1)
    #x_star_u = np.linspace(lb[0], ub[0], n_star_u)[:, None]
    #x_star_v = np.linspace(lb[0], ub[0], n_star_v)[:, None]
    #print(x_star_u)
    num_plots = 4

    # Optimize
    polygon = patches.Polygon([[0, 0], [1, 0], [1, 1], [0, 1]], closed=True, edgecolor='r', facecolor='none')

    ModelInfo1 = {}
    ModelInfo1['dt']=dt
    ModelInfo1['x_b'] = np.array([[0], [1]])
    ModelInfo1['u_b'] = np.array([[0], [0]])
    ModelInfo1['jitter']=jitter
    
    lhs_su = lhs(dim, samples=Ntr_u)
    lhs_sv = lhs(dim, samples=Ntr_v)

    ModelInfo1['x_u'] = lb + (ub - lb) *np.random.rand(Ntr_u, dim)
    ModelInfo1['u'] = Exact_solution(0, ModelInfo1['x_u']) + noise_u *lhs_su
    
    #ModelInfo['x_u'] = lb + lhs_samples_u * (ub - lb)
    #ModelInfo['x_v'] = lb + lhs_samples_v * (ub - lb)
   
    ModelInfo1['x_v'] = lb + (ub - lb)*np.random.rand(Ntr_v,dim)
    ModelInfo1['v'] = Exact_solution_derivative(0, ModelInfo1['x_v']) + noise_v *lhs_sv
    ModelInfo1['S0'] = np.zeros((Ntr_u + Ntr_v, Ntr_u + Ntr_v))
    ModelInfo1['hyp'] = np.log([1, 1, 1, 1, np.exp(-6), np.exp(-6.5)])
    
    import matplotlib
    import seaborn as sns
    if plt_flag:
        
        matplotlib.rcParams['text.usetex'] = False

        fig = plt.figure(figsize=(16, 8))
        sns.set()
        ax = fig.add_subplot(2, num_plots, 1)
        ax.plot(x_star_u, Exact_solution(0, x_star_u), 'b', linewidth=3)
        ax.plot(ModelInfo1['x_u'], ModelInfo1['u'], 'r*', markersize=7, linewidth=3)
        ax.set_xlabel('0 <= x <= 1')
        ax.set_ylabel('u(t,x)')
        ax.set_aspect('auto')
        ax.set_ylim([-0.3, 1])
        ax.set_xticks(np.sort(ModelInfo1['x_u'].flatten()))
        ax.set_xticklabels([])
        ax.tick_params(length=5, width=2)
        ax.set_title(f"Time: 0.00\n{Ntr_u} initial data")

        ax = fig.add_subplot(2, num_plots, num_plots + 1)
        ax.plot(x_star_v, Exact_solution_derivative(0, x_star_v), 'b', linewidth=3)
        ax.plot(ModelInfo1['x_v'], ModelInfo1['v'], 'r*', markersize=7, linewidth=3)
        ax.set_xlabel('0 <= x <= 1')
        ax.set_ylabel('v(t,x)')
        ax.set_aspect('auto')
        ax.set_ylim([-5, 5])
        ax.set_xlim([0,1])
        ax.set_xticks(np.sort(ModelInfo1['x_v'].flatten()))
        ax.set_xticklabels([])
        ax.tick_params(length=5, width=2)
        ax.set_title(f"{Ntr_v} initial data")

        plt.tight_layout()
        plt.show()
        
    for i in range(1, nsteps + 1):
        print(ModelInfo1['hyp'])
        ModelInfo1['hyp'], _,_ = minimize3(ModelInfo1['hyp'],likelihood1, nsteps,ModelInfo1)
        NLML, _ = likelihood1(ModelInfo1['hyp'],ModelInfo1)
        Kpred, Kvar = predictor(x_star_u, x_star_v,ModelInfo1)
        Kvar = (np.diag(Kvar))
        u_star_mean = Kpred[:n_star_u]
        v_star_mean= Kpred[n_star_u:]
        u_star_var = Kvar[:n_star_u]
        v_star_var= Kvar[n_star_u:]

        Exact = Exact_solution(i * dt, x_star_u)
        Exact_derivative = Exact_solution_derivative(i * dt, x_star_v)

        error_u = np.linalg.norm(u_star_mean - Exact) / np.linalg.norm(Exact)
        error_v = np.linalg.norm(v_star_mean - Exact_derivative) / np.linalg.norm(Exact_derivative)
        print(f"Step: {i}, Time = {i * dt:.2f}, NLML = {NLML.item():.6e}, error_u = {float(error_u.item()):.6e}, error_v = {float(error_v.item()):.6e}")
        
        
        lhs_u = lhs(dim, samples=Ntr_u_artificial)
        lhs_v = lhs(dim, samples=Ntr_v_artificial)
        lhs_u = lb + lhs_u*(ub - lb)
        lhs_v = lb + lhs_v*(ub - lb)
        x_u = np.random.rand(Ntr_u_artificial, 1)
        x_v = np.random.rand(Ntr_v_artificial, 1)
     
        # Call the predictor function passing x_u and x_v
        data, ModelInfo1['S0'] = predictor(x_u, x_v,ModelInfo1)
        ModelInfo1['u'] = data[:Ntr_u_artificial]
        ModelInfo1['v'] = data[Ntr_u_artificial:]
        ModelInfo1['x_u'] = x_u
        ModelInfo1['x_v'] = x_v
        #print(ModelInfo1['hyp']) to check how your hyper params are getting updated

        if plt_flag and i % (nsteps // num_plots) == 0 and i < nsteps:
            plt.figure()
            sns.set()
            plt.plot(x_star_u, Exact, 'b', linewidth=3)
            upper_bound = u_star_mean + 2.0*np.sqrt(u_star_var.reshape(-1,1))
            lower_bound = u_star_mean - 2.0*np.sqrt(u_star_var.reshape(-1,1))
            plt.plot(x_star_u,u_star_mean,'r--',linewidth=3)
            plt.fill_between(x_star_u.flatten(), lower_bound.flatten(), upper_bound.flatten(), alpha=0.2, color='orange', label='Variance')
            plt.xlabel('0 <= x <= 1')
            plt.tick_params(axis='both', which='major', labelsize=14)
            plt.axis('square')
            plt.ylim([-1, 1])
            plt.xlim([0,1])
            plt.gca().set_aspect('auto')  # or 'auto', or a numerical value
            plt.title('Time: %.2f\n%d artificial data' % (i * dt, Ntr_u_artificial))
        
        #Subplot for v_star
            
            plt.figure()
            plt.plot(x_star_v, Exact_derivative, 'b', linewidth=3)
            plt.plot(x_star_v, v_star_mean, 'r--', linewidth=3)
            upper_boundv = v_star_mean + 2.0*np.sqrt(v_star_var.reshape(-1,1))
            lower_boundv = v_star_mean - 2.0*np.sqrt(v_star_var.reshape(-1,1))
            plt.fill_between(x_star_v.flatten(), lower_boundv.flatten(), upper_boundv.flatten(), alpha=0.2, color='orange', label='Variance')
            plt.xlabel('0 <= x <= 1')
            plt.axis('square')
            plt.ylim([-5, 5])
            plt.xlim([0,1])
            plt.gca().set_aspect('auto')  # or 'auto', or a numerical value
            plt.title('Time: %.2f\n%d artificial data' % (i * dt, Ntr_v_artificial))

            
    plt.tight_layout()  # Adjust subplot spacing
    plt.show()
    





