In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
import pathlib
import imageio

%matplotlib inline

torch.manual_seed(0)
np.random.seed(0)

device = 'cuda:3'
noise = 1.5
exp_name = 'noise_15_heat_eq_gp'
pathlib.Path(f"./{exp_name}/").mkdir(parents=True, exist_ok=True)

In [2]:
class NumericalGP(torch.nn.Module):
    def __init__(self, xb, xu, ub, u, dt=0.01, alpha=0.2, device=device):
        super().__init__()
        self.device = device
        self.reset_params()
        
        self.xb = xb.to(device)
        self.xu = xu.to(device)
        self.ub = ub.to(device)
        self.u = u.to(device)

        self.S0 = torch.zeros(self.xu.shape[0], self.xu.shape[0]).to(device)

        self.dt = dt
        self.alpha = alpha

        self.to(device)

    def reset_params(self):
        self.sigma_n = torch.nn.parameter.Parameter(torch.tensor(0.))
        self.scale = torch.nn.parameter.Parameter(torch.tensor(0.))
        self.length_scale = torch.nn.parameter.Parameter(torch.tensor(0.))

    def kernel_helper(self, x, x_prime):
        l = torch.exp(self.length_scale)
        x_norm = torch.square(x).sum(dim=1) 
        x_prime_norm = torch.square(x_prime).sum(dim=1)
        sq_dists = x_norm.unsqueeze(1) + (x_prime_norm - 2 * x @ x_prime.T)
        return sq_dists, torch.exp(self.scale) * torch.exp(-0.5 * sq_dists / l)
    
    def kernel_nn(self, x, x_prime):
        _, k = self.kernel_helper(x, x_prime)
        return k

    def kernel_nn1(self, x, x_prime):
        l = torch.exp(self.length_scale)
        sq_dists, k = self.kernel_helper(x, x_prime)
        d2k_dxp2 = k * (-l + sq_dists) / l ** 2
        return k - self.alpha * self.dt * d2k_dxp2
    
    def kernel_n1n1(self, x, x_prime):
        l = torch.exp(self.length_scale)
        sq_dists, k = self.kernel_helper(x, x_prime)
        d2k_dx2 = k * (-l + sq_dists) / l ** 2
        d2d2kd2xd2xp = k * (3 * l ** 2 - 6 * l * sq_dists + sq_dists ** 2) / l ** 4
        return k - 2 * self.alpha * self.dt * d2k_dx2 + self.alpha ** 2 * self.dt ** 2 * d2d2kd2xd2xp
    
    def neg_log_likelihood(self):
        knn = self.kernel_nn(self.xb, self.xb)
        knn1 = self.kernel_nn1(self.xb, self.xu)
        kn1n1 = self.kernel_n1n1(self.xu, self.xu) + torch.exp(self.sigma_n) * torch.eye(self.xu.shape[0]).to(self.device)

        K = torch.cat([torch.cat([knn, knn1], dim=1), torch.cat([knn1.T, kn1n1], dim=1)], dim=0)
        K += torch.eye(K.shape[0]).to(self.device) * 1e-6

        L = torch.linalg.cholesky_ex(K).L
        self.L = L

        dim = K.shape[0]
        y = torch.cat([self.ub, self.u], dim=0)

        return dim / 2. * np.log(2 * torch.pi) + torch.sum(torch.log(torch.diag(L))) + 0.5 * y.T @ torch.cholesky_solve(y, L)
    
    def posterior(self, x_star):
        y = torch.cat([self.ub, self.u], dim=0)

        n_u = self.u.shape[0]
        n_b = self.ub.shape[0]

        S = torch.cat([torch.cat([torch.zeros(n_b, n_b).to(self.device), torch.zeros(n_b, n_u).to(self.device)], dim=1), torch.cat([torch.zeros(n_u, n_b).to(self.device), self.S0], dim=1)], dim=0)

        K1 = self.kernel_nn(x_star, self.xb)
        K2 = self.kernel_nn1(x_star, self.xu)

        psi = torch.cat([K1, K2], dim=1)

        f = psi @ torch.cholesky_solve(y, self.L)

        alpha = torch.cholesky_solve(psi.T, self.L)
        var = self.kernel_nn(x_star, x_star) - psi @ alpha + alpha.T @ S @ alpha

        return f, var
    
    def train_step(self, t, epochs=500):

        self.reset_params()
        optim = torch.optim.SGD(self.parameters(), lr=0.01)
        scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=50, gamma=0.1)

        for epoch in tqdm(range(epochs)):
            optim.zero_grad()
            loss = self.neg_log_likelihood()
            loss.backward()
            optim.step()
            scheduler.step()
        
        pred, var = self.posterior(self.xu)

        self.u = pred.detach().clone()
        self.S0 = var.detach().clone()

        var = torch.abs(torch.diag(var))
        
        exact = np.exp(-self.alpha * torch.pi * torch.pi * t) * torch.sin(torch.pi * self.xu)
        error = torch.mean((pred - exact) ** 2)

        return pred, var, error

    def plot_step(self, t, res=101):
        x_star = torch.linspace(0, 1, res).unsqueeze(1)
        pred, var = self.posterior(x_star.to(self.device))
        pred = pred.detach().cpu().numpy().squeeze()
        var = np.abs(np.diag(var.detach().cpu().numpy()))

        exact = np.exp(-self.alpha * np.pi * np.pi * t) * np.sin(np.pi * x_star)
        if isinstance(exact, torch.Tensor):
            exact = exact.detach().cpu().numpy().squeeze()

        plt.figure(figsize=(12, 8))
        plt.plot(x_star, exact, label="Exact", linewidth=4)
        plt.plot(x_star, pred, label="GP", linestyle="--", linewidth=4, color="tab:red")
        plt.fill_between(x_star.squeeze(), pred - 2 * np.sqrt(var), pred + 2 * np.sqrt(var), alpha=0.2, color="tab:orange", linewidth=2)
        plt.title(f"GP Approximation ($\sigma={noise}$)\n$t = {t:.3f}$\nError (MSE) = ${np.mean((pred - exact) ** 2):.4f}$")
        plt.xlabel('$x$')
        plt.ylabel('$u(t, x)$')
        plt.xlim([0, 1])
        plt.ylim([0, 1])
        plt.legend()
    
    def train(self, nsteps, plot=True):
        preds = []
        vars = []
        errors = []
        figs = []

        for i in range(1, int(nsteps)+1):
            pred, var, error = self.train_step(i * self.dt)
            preds.append(pred.detach().cpu().numpy().squeeze())
            vars.append(var.detach().cpu().numpy().squeeze())
            errors.append(error.detach().cpu().numpy().squeeze())

            if plot:
                if i % 10 == 0:
                    self.plot_step(i * self.dt)
                    plt.savefig(f"./{exp_name}/{i}.png")
                    plt.close()
                    figs.append(f"./{exp_name}/{i}.png")

        return preds, vars, errors, figs

In [3]:
t_domain = [0, 1]
x_domain = [0, 1]

npoints = 51

T = 1
dt = 1e-3
nsteps = T/dt

x_u = torch.linspace(x_domain[0], x_domain[1], npoints)
u = torch.sin(np.pi*x_u) + noise*torch.randn_like(x_u)

x_b = torch.tensor([0., 1.])
u_b = torch.tensor([0., 0.])

S0 = torch.zeros(npoints)

xstar = torch.linspace(x_domain[0], x_domain[1], npoints)

In [4]:
plt.figure(figsize=(12, 8))
plt.scatter(x_u, u, facecolors='none', edgecolors='r', s=100)
plt.plot(x_u, torch.sin(np.pi*x_u), linestyle="--", linewidth=4, color="tab:blue")
plt.title(f"Initial Data ($\sigma={noise}$)\n$t = {0:.3f}$")
plt.xlabel('$x$')
plt.ylabel('$u(t, x)$')
plt.xlim([0, 1])
plt.ylim([0, 1.2])
plt.savefig(f"./{exp_name}/init.png")
plt.close()

In [5]:
model = NumericalGP(x_b.unsqueeze(1), x_u.unsqueeze(1), u_b.unsqueeze(1), u.unsqueeze(1), dt=dt)

In [None]:
preds, vars, errors, figs = model.train(nsteps, plot=True)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:19<00:00, 25.98it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:09<00:00, 50.02it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:10<00:00, 48.61it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:09<00:00, 50.80it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:10<00:00, 49.84it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:10<00:0

In [None]:
np.save(f"./{exp_name}/preds.npy", np.array(preds))
np.save(f"./{exp_name}/vars.npy", np.array(vars))
np.save(f"./{exp_name}/errors.npy", np.array(errors))

In [None]:
with imageio.get_writer(f"./{exp_name}/gp_progression.gif", mode="I", duration=0.05, loop=0) as w:
    for fname in tqdm(figs):
        w.append_data(imageio.imread(fname))

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

outs = np.array(preds)
tspace = np.linspace(0, 1, 1001)[1:]
xspace = np.linspace(0, 1, 51)
T, X = np.meshgrid(tspace, xspace)

fig = plt.figure(figsize=(16,24))
ax1 = fig.add_subplot(1, 1, 1, projection='3d')
ax1.plot_surface(T.T, X.T, outs, cmap='viridis', alpha=0.9)
ax1.view_init(35,35)


eps = 0.025
ax1.axes.set_xlim3d(left=0.-eps, right=1+eps)
ax1.axes.set_ylim3d(bottom=0.-eps, top=1+eps) 
ax1.axes.set_zlim3d(bottom=0.-eps, top=1.2+eps) 

tspace = np.linspace(0, 0, 1)
xspace = np.linspace(0, 1, 51)
T, X = np.meshgrid(tspace, xspace)

ax1.scatter(T.flatten(), X.flatten(), u, facecolors='none', edgecolors='r', s=100)
ax1.plot(T.flatten(), X.flatten(), torch.sin(np.pi*x_u), linestyle='--', linewidth=4, color='r')

ax1.set_xlabel('$t$')
ax1.set_ylabel('$x$')
ax1.set_zlabel('$u_(t,x)$')
ax1.set_title(f'GP Approximation ($\sigma={noise}$)\nFinal MSE: {np.mean(errors):.6e}\nGP Variance: {np.mean(vars):.6e}')
plt.savefig(f"./{exp_name}/final.png", bbox_inches='tight')
plt.show()

In [None]:
d3_figs = []
for i in tqdm(range(1000)):
    if i % 10 != 9:
        continue
    outs = np.array(preds)
    tspace = np.linspace(0, 1, 1001)[1:i+1]
    xspace = np.linspace(0, 1, 51)
    T, X = np.meshgrid(tspace, xspace)

    fig = plt.figure(figsize=(16,24))
    ax1 = fig.add_subplot(1, 1, 1, projection='3d')
    ax1.plot_surface(T.T, X.T, outs[:i, :], cmap='viridis', alpha=0.9)
    ax1.view_init(35,35)
    
    eps = 0.025
    ax1.axes.set_xlim3d(left=0.-eps, right=1+eps)
    ax1.axes.set_ylim3d(bottom=0.-eps, top=1+eps) 
    ax1.axes.set_zlim3d(bottom=0.-eps, top=1.2+eps) 

    tspace = np.linspace(0, 0, 1)
    xspace = np.linspace(0, 1, 51)
    T, X = np.meshgrid(tspace, xspace)

    ax1.scatter(T.flatten(), X.flatten(), u, facecolors='none', edgecolors='r', s=100)
    ax1.plot(T.flatten(), X.flatten(), torch.sin(np.pi*x_u), linestyle='--', linewidth=4, color='r')

    ax1.set_xlabel('$t$')
    ax1.set_ylabel('$x$')
    ax1.set_zlabel('$u_(t,x)$')
    ax1.set_title(f'GP Approximation ($\sigma={noise}$)\n$t={(i+1)*dt:.3f}$')
    plt.savefig(f"./{exp_name}/d3_{i}.png", bbox_inches='tight')
    d3_figs.append(f"./{exp_name}/d3_{i}.png")
    plt.close()

In [16]:
with imageio.get_writer(f"./{exp_name}/d3_gp_progression.gif", mode="I", duration=0.05, loop=0) as w:
    for fname in tqdm(d3_figs):
        w.append_data(imageio.imread(fname))

  w.append_data(imageio.imread(fname))
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:41<00:00,  1.01s/it]
