# Packages and Functions

In [7]:
import sys; sys.path.append('..')
import dgl
import torch
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from fenics import *
from dolfin import *
from mshr import *
from mpl_toolkits.axes_grid1 import make_axes_locatable
from dgl.data.utils import load_graphs
from src.utils.to_dgl import fenics_to_graph as to_dgl
from src.utils.gif import gif_generator

In [2]:
def plot_graph(graph, vmin=0, vmax=1):
    plt.figure(figsize=(11, 5))
    plt.subplot(1, 2, 1)
    x = graph.ndata['x'].view(-1).numpy()
    y = graph.ndata['y'].view(-1).numpy()
    value = graph.ndata['value'].view(-1).numpy()
    # Plot Nodes
    plt.scatter(x, y, value)
    # Plot Edges
    src, dst = graph.edges()
    for i in range(len(dst)):
        nodes_x = [x[src[i]], x[dst[i]]]
        nodes_y = [y[src[i]], y[dst[i]]]
        plt.plot(nodes_x, 
                 nodes_y, 
                 color='black', 
                 alpha=0.9, 
                 linewidth=0.6)
    # Apply norm 
    norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
    # Mesh on plot or separated
    cax = None
    ax = plt.subplot(1, 2, 2)
    fig = plt.tricontourf(x, y, value, levels=30, norm=norm)  
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    # Plot with interpolation
    plt.colorbar(fig, cax=cax)

In [3]:
def mark_boundary(graph):
    error = 1e-2
    graph.ndata['is_bdd'] = torch.zeros(graph.number_of_nodes(), 1)
    for i in range(graph.number_of_nodes()):
        if abs(graph.ndata['x'][i].item() - 2) < error or abs(graph.ndata['x'][i].item() + 2) < error or abs(graph.ndata['y'][i].item() - 2) < error or abs(graph.ndata['y'][i].item() + 2) < error:
             graph.ndata['is_bdd'][i] = 1

In [28]:
def adaptive_graph(graph1, graph2, update_rate=0.5):
    update_graph = graph1.clone()
    update_graph.ndata['value'] = abs(graph2.ndata['value'] - graph1.ndata['value'])

    mark_boundary(update_graph)
    dist_threshold = 1e-1

    vector_list = []
    for i in range(update_graph.number_of_nodes()):
        if update_graph.ndata['is_bdd'][i] > 0.5: 
            vector = torch.tensor([0, 0], dtype=torch.float)
            vector_list.append(vector)
            continue # Boundary Nodes

        vector = torch.tensor([0, 0], dtype=torch.float)
        _, neighbors = update_graph.out_edges([i])

        for j in neighbors:
            if update_graph.ndata['value'][j].item() < update_graph.ndata['value'][i].item():
                continue # Update Smaller Than Src Node
            dist = torch.sqrt((update_graph.ndata['x'][j] - update_graph.ndata['x'][i]) ** 2 + (update_graph.ndata['x'][i] - update_graph.ndata['x'][i]) ** 2)
            if dist < dist_threshold:
                continue # Two Nodes Too Close
            diff = update_graph.ndata['value'][j].item() - update_graph.ndata['value'][i].item()
            update_x = update_rate * (update_graph.ndata['x'][j] - update_graph.ndata['x'][i]) * diff / dist
            update_y = update_rate * (update_graph.ndata['y'][j] - update_graph.ndata['y'][i]) * diff / dist
            vector[0] += update_x
            vector[1] += update_y

        vector_list.append(vector)
    
    return vector_list

In [4]:
def solve_function(mesh): 
    x0=-2 
    xn=2
    y0=-2
    yn=2
    stop=1
    steps=60
    f='0'
    # ud_top='sin(t)',
    ud_top='1'
    ud_bottom='0'
    ud_left='0'
    ud_right='0'
    u0='0'
    cell_size=3
    tol=1e-2
    dy=False
    
    function_space, bc = rectangle(x0, 
                                 xn, 
                                 y0, 
                                 yn, 
                                 ud_top, 
                                 ud_bottom, 
                                 ud_left, 
                                 ud_right,
                                 cell_size,
                                 tol,
                                 mesh
                                 )

    dt = stop / steps  
    u0 = Expression(u0, degree=2)
    f = Expression(f, degree=2)
    un = interpolate(u0, function_space)

    u = TrialFunction(function_space)
    v = TestFunction(function_space)
    F = u * v * dx + dt * dot(grad(u), grad(v)) * dx - (un + dt * f) * v * dx
    a, L = lhs(F), rhs(F)

    u = Function(function_space)
    t = 0
    graphs = []
    for _ in range(steps):
        t += dt
        if (dy):
            _, _, bc = rectangle(x0,
                              xn, 
                              y0, 
                              yn, 
                              ud_top, 
                              ud_bottom, 
                              ud_left,
                              ud_right, 
                              cell_size, 
                              tol,
                              t=t, 
                              ms=mesh,
                              fs=function_space
                              )
        solve(a == L, u, bc)
        un.assign(u)
        graphs.append(to_dgl(function=u, mesh=mesh))
    
    return graphs

In [10]:
def rectangle(x0, 
              xn, 
              y0, 
              yn, 
              ud_top, 
              ud_bottom, 
              ud_left, 
              ud_right,
              cell_size,
              tol,
              mesh,
              t=-1, 
              ms=None,
              fs=None):
    function_space = FunctionSpace(mesh, 'P', 1)
    
    top = YBoundary(yn, tol)
    bottom = YBoundary(y0, tol)
    left = XBoundary(x0, tol)
    right = XBoundary(xn, tol)

    boundaries = MeshFunction('size_t', mesh, mesh.topology().dim()-1)
    boundaries.set_all(0)
    top.mark(boundaries, 1)
    bottom.mark(boundaries, 2)
    left.mark(boundaries, 3)
    right.mark(boundaries, 4)
    
    if t >= 0:
        ud_top = Expression(ud_top, degree=2, t=t)
        ud_bottom = Expression(ud_bottom, degree=2, t=t)
        ud_left = Expression(ud_left, degree=2, t=t)
        ud_right = Expression(ud_right, degree=2, t=t)
    else: 
        ud_top = Expression(ud_top, degree=2)
        ud_bottom = Expression(ud_bottom, degree=2)
        ud_left = Expression(ud_left, degree=2)
        ud_right = Expression(ud_right, degree=2)
    bc = []
    bc.append(DirichletBC(function_space, ud_top, boundaries, 1))
    bc.append(DirichletBC(function_space, ud_bottom, boundaries, 2))
    bc.append(DirichletBC(function_space, ud_left, boundaries, 3))
    bc.append(DirichletBC(function_space, ud_right, boundaries, 4))

    return function_space, bc


def boundary(x, on_boundary):
    return on_boundary


class CircleBoundary(SubDomain):
    def __init__(self, x, y, r, tol):
        SubDomain.__init__(self)
        self.x = x
        self.y = y
        self.r = r
        self.tol = tol
    def inside(self, x, on_boundary):
        flag = np.linalg.norm(x - [self.x, self.y])
        return near(flag, self.r, self.tol)


class XBoundary(SubDomain):
    def __init__(self, value, tol):
        SubDomain.__init__(self)
        self.value = value
        self.tol = tol
    def inside(self, x, on_boundary):
        return near(x[0], self.value, self.tol)


class YBoundary(SubDomain):
    def __init__(self, value, tol):
        SubDomain.__init__(self)
        self.value = value
        self.tol = tol
    def inside(self, x, on_boundary):
        return near(x[1], self.value, self.tol) 

# Main Process

In [8]:
# Load Data
graph_list = load_graphs('../data/gsi2.bin')[0]
print(len(graph_list))

50


In [None]:
# Epoch Test
mesh = generate_mesh(Rectangle(Point(-2, -2), Point(2, 2)), 4)
mesh_list = []
plot_graph_list = []
adaptive_rate = 0.1
steps = 60
for epoch in range(steps):
    graph_list = solve_function(mesh)
    vector_list = adaptive_graph(graph_list[epoch], graph_list[epoch+1], adaptive_rate)
    # mesh.coordinates() = mesh.coordinates() + vector_list
    for i in range(mesh.coordinates().shape[0]):
        mesh.coordinates()[i][0] += vector_list[i][0]
        mesh.coordinates()[i][1] += vector_list[i][1]
    plot_graph_list.append(graph_list[epoch])
    mesh_list.append(mesh)
    adaptive_rate += 0.1

In [None]:
# Plot Graphs
len(plot_graph_list)
for i in range(steps):
    plot_graph(plot_graph_list[i], vmax=0.9)
    plt.savefig(f'../fig/gaussian_side/{i}.png')

In [37]:
# Generate GIF
gif_generator('../fig/gaussian_side/', '../fig/gaussian_side.gif')

100%|███████████████████████████████████████████| 30/30 [00:00<00:00, 36.44it/s]


In [10]:
str(format(1, '.2f'))

'1.00'

# Trash Bin

In [None]:
mesh = generate_mesh(Rectangle(Point(-2, -2), Point(2, 2)), 4)
graph_list = solve_function(mesh)

In [None]:
for i in range(len(graph_list)):
    plot_graph(graph_list[i], vmax=0.9)
    plt.savefig(f'../fig/gaussian_side_noadmesh/{i}.png')

In [14]:
gif_generator('../fig/gaussian_side_noadmesh/', '../fig/gaussian_side_noadmesh.gif')

100%|███████████████████████████████████████████| 60/60 [00:01<00:00, 35.90it/s]


In [23]:
x0=-2 
xn=2
y0=-2
yn=2
stop=1
steps=60
f='0'
# ud_top='sin(t)',
ud_top='1'
ud_bottom='0'
ud_left='0'
ud_right='0'
u0='0'
cell_size=5
tol=1e-2
dy=False


if (dy):
    mesh, function_space, bc = rectangle(x0,
                                         xn, 
                                         y0, 
                                         yn, 
                                         ud_top, 
                                         ud_bottom, 
                                         ud_left,
                                         ud_right, 
                                         cell_size, 
                                         tol,
                                         mesh,
                                         t=0 
                                         )
else:
    mesh, function_space, bc = rectangle(x0, 
                                         xn, 
                                         y0, 
                                         yn, 
                                         ud_top, 
                                         ud_bottom, 
                                         ud_left, 
                                         ud_right,
                                         cell_size,
                                         tol,
                                         mesh
                                         )
dt = stop / steps  
u0 = Expression(u0, degree=2)
f = Expression(f, degree=2)
un = interpolate(u0, function_space)

u = TrialFunction(function_space)
v = TestFunction(function_space)
F = u * v * dx + dt * dot(grad(u), grad(v)) * dx - (un + dt * f) * v * dx
a, L = lhs(F), rhs(F)

u = Function(function_space)
t = 0
graphs = []
for _ in range(steps):
    t += dt
    if (dy):
        _, _, bc = rectangle(x0,
                          xn, 
                          y0, 
                          yn, 
                          ud_top, 
                          ud_bottom, 
                          ud_left,
                          ud_right, 
                          cell_size, 
                          tol,
                          t=t, 
                          ms=mesh,
                          fs=function_space
                          )
    solve(a == L, u, bc)
    un.assign(u)
    graphs.append(to_dgl(function=u, mesh=mesh))

Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational p



Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
