In [None]:
# Colab setup

# try:
#     from dlroms import *
# except:
#     !pip install git+https://github.com/NicolaRFranco/dlroms.git
#     from dlroms import *

# TODO: install dlroms_bayesian

In [1]:
import numpy as np
import torch
import os
import time
import random
import matplotlib.pyplot as plt
import gmsh

from dlroms import *
from dlroms_bayesian.bayesian import Bayesian
from dlroms_bayesian.svgd import SVGD
from dlroms_bayesian.utils import *

from IPython.display import clear_output as clc

In [None]:
# Setup

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

set_seeds(0)

In [3]:
# Domain definition

loop = lambda v: np.concatenate((v, v[[0]]))
brain = np.load(os.path.join('brain_meshes', 'brainshape.npz'))
domain = fe.polygon(loop(brain['main'][::9])) - fe.polygon(loop(brain['hole1'][::9])) - fe.polygon(loop(brain['hole2'][::8]))

# Mesh and function space definition

mesh_H = fe.loadmesh(os.path.join('brain_meshes', 'brain-mesh40.xml'))
Vh_H = fe.space(mesh_H, 'CG', 1)
h_H, nh_H = mesh_H.hmax(), Vh_H.dim()

mesh_C = fe.loadmesh(os.path.join('brain_meshes', 'brain-mesh15.xml'))
Vh_C = fe.space(mesh_C, 'CG', 1)
h_C, nh_C = mesh_C.hmax(), Vh_C.dim()

clc()

In [4]:
# Load snapshots

path_train = os.path.join('snapshots', 'snapshots_train.npz')
if not os.path.exists(path_train):
	print(f"Training snapshots not found at {path_train}.")

data_train = np.load(path_train)
N_train = data_train['mu'].shape[0]
mu_train, u_train = data_train['mu'].astype(np.float32), data_train['u'].astype(np.float32)
mu_train, u_train = torch.tensor(mu_train).to(device), torch.tensor(u_train).to(device)

path_test = os.path.join('snapshots', 'snapshots_test.npz')
if not os.path.exists(path_test):
	print(f"Test snapshots not found at {path_test}.")

data_test = np.load(path_test)
N_test = data_test['mu'].shape[0]
mu_test, u_test = data_test['mu'].astype(np.float32), data_test['u'].astype(np.float32)
mu_test, u_test = torch.tensor(mu_test).to(device), torch.tensor(u_test).to(device)

In [5]:
# Bayesian network definition

layer_1 = Geodesic(domain, Vh_H, Vh_C, support=0.05)
layer_2 = Geodesic(domain, Vh_C, Vh_C, support=0.1)
layer_3 = Geodesic(domain, Vh_C, Vh_H, support=0.05, activation=None)

l2 = L2(Vh_H)
clc()

# ROM model
model = DFNN(layer_1, layer_2, layer_3)

# Bayesian model
model_bayes = Bayesian(model)

if torch.cuda.is_available():
	model_bayes.cuda()
	l2.cuda()

In [6]:
# SVGD trainer definition

N_particles = 30

trainer = SVGD(model_bayes, n_samples=N_particles)
trainer.He()

model_bayes.set_trainer(trainer) # assign trainer to Bayesian model

In [None]:
# Bayesian network training

model_bayes.train(mu_train, u_train, ntrain=N_train, loss=mse(l2), lr=0.02, epochs=3000)

In [None]:
# Bayesian network evaluation

with torch.no_grad():
	u_pred_bayes_train_mean, u_pred_bayes_train_var = model_bayes.sample(mu_train, n_samples=N_particles)
	u_pred_bayes_mean, u_pred_bayes_var = model_bayes.sample(mu_test, n_samples=N_particles)

error_train_mean = mre(l2)(u_train, u_pred_bayes_train_mean)
error_test_mean = mre(l2)(u_test, u_pred_bayes_mean)
print(f"Relative train error: {100 * torch.mean(error_train_mean):.2f}%")
print(f"Relative test error: {100 * torch.mean(error_test_mean):.2f}%")

In [None]:
# Plot a random snapshot

idx = 30

plt.figure(figsize=(16, 3))
plt.subplot(1, 4, 1)
plt.title("Brain damage")
fe.plot(1 + 0 * mu_test[idx], Vh_H, cmap='jet', vmin=0, vmax=1)
fe.plot(mu_test[idx], Vh_H, cmap='jet', colorbar=True)
plt.subplot(1, 4, 2)
plt.title("True time to recovery")
fe.plot(u_test[idx], Vh_H, vmin=0, vmax=1, cmap='jet', colorbar=True)
plt.subplot(1, 4, 3)
plt.title("Predicted time to recovery (mean)")
fe.plot(u_pred_bayes_mean[idx], Vh_H, vmin=0, vmax=1, cmap='jet', colorbar=True)
plt.subplot(1, 4, 4)
plt.title("Predicted time to recovery (variance)")
vmin, vmax = torch.min(u_pred_bayes_var[idx]), torch.max(u_pred_bayes_var[idx])
fe.plot(u_pred_bayes_var[idx], Vh_H, vmin=vmin, vmax=vmax, cmap='magma', colorbar=True)
plt.tight_layout()

In [None]:
# Save trainer state

checkpoint_dir = 'checkpoints'
if not os.path.exists(checkpoint_dir):
	os.makedirs(checkpoint_dir)

trainer.save_particles(os.path.join(checkpoint_dir, 'particles_' + str(N_particles) + '.pth'))