In [1]:
%load_ext autoreload

In [2]:
%autoreload
import numpy as np
import jax.numpy as jnp
from dctkit.mesh import util, simplex
from dctkit.math.opt import optctrl as oc
import dctkit.dec.cochain as C
import dctkit as dt
import pygmsh
import pyvista as pv
from pyvista import themes

In [3]:
dt.config()
pv.set_jupyter_backend('trame')
pv.global_theme = themes.ParaViewTheme()

In [4]:
lc = 0.2
with pygmsh.geo.Geometry() as geom:
        poly = geom.add_polygon([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]], lc)

        top, volume, lat = geom.extrude(poly, [0, 0, 1.0])
        
        geom.add_physical(poly, label="bottom")
        geom.add_physical(top, label="top")
        geom.add_physical(volume, label="volume")
        geom.add_physical(lat, label="lat")
        mesh = geom.generate_mesh()

#pv.plot(mesh)

In [5]:
bottom_faces_ids = mesh.cell_sets_dict['bottom']['triangle']

In [6]:
node_coords = mesh.points
num_nodes = node_coords.shape[0]
tet_node_tags = mesh.cells[2].data
print("number of nodes = ", num_nodes)
print("number of tets = ", tet_node_tags.shape[0])

number of nodes =  237
number of tets =  741


In [7]:
spx = simplex.SimplicialComplex(tet_node_tags, node_coords, is_well_centered=True)
spx.get_circumcenters()
spx.get_primal_volumes()
spx.get_dual_volumes()
spx.get_hodge_star()
# S.get_dual_edge_vectors()
# S.get_flat_weights()

In [8]:
# boundary conditions
# bottom_nodes = spx.S[2][bottom_faces_ids]
# bottom_nodes = np.unique(bottom_nodes.flatten())
bottom_nodes = np.argwhere(spx.node_coord[:,2]<1e-6).flatten()
top_nodes = np.argwhere(abs(spx.node_coord[:,2]-1.)<1e-6).flatten()
values = np.zeros(len(bottom_nodes)+len(top_nodes), dtype=dt.float_dtype)
boundary_values = (np.hstack((bottom_nodes,top_nodes)), values)
boundary_values

(array([  0,   1,   2,   3,   8,   9,  10,  11,  12,  13,  14,  15,  16,
         17,  18,  19,  20,  21,  22,  23,  56,  57,  58,  59,  60,  61,
         62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,
         75,  76,  77,  78,  79,   4,   5,   6,   7,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39, 178,
        179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
        192, 193, 194, 195, 196, 197, 198, 199, 200, 201]),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.]))

In [9]:
from functools import partial
from dctkit.physics import poisson as p

def disspot(u, u_prev, deltat):
    u_coch = C.CochainP0(spx, u)
    u_prev_coch = C.CochainP0(spx, u_prev)
    u_diff = C.sub(u_coch, u_prev_coch)
    return (1/2)*C.inner_product(u_diff, u_diff)/deltat

energy = partial(p.energy_poisson, S=spx)

def obj(u, u_prev, f, k, boundary_values, gamma, deltat):
    en = energy(x=u, f=f, k=k, boundary_values=boundary_values, gamma=gamma)
    return en + disspot(u, u_prev, deltat)

k = 1.
f_vec = np.ones(num_nodes, dtype=dt.float_dtype)
gamma = 1000.
deltat = 0.1

u_0 = np.zeros(num_nodes, dt.float_dtype)
u_prev = u_0

In [10]:
sols = []
prb = oc.OptimizationProblem(dim=num_nodes, state_dim=num_nodes, objfun=obj)
for i in range(10):
    print("t = ", (i+1)*deltat)
    args = {'u_prev': u_prev, 'f': f_vec, 'k': k, 'boundary_values': boundary_values,
        'gamma': gamma, 'deltat': deltat}
    prb.set_obj_args(args)
    u = prb.run(u_prev, ftol_abs=1e-8, ftol_rel=1e-8)
    u_prev = u.__array__()
    sols.append(u)
prb.last_opt_result

t =  0.1
t =  0.2
t =  0.30000000000000004
t =  0.4
t =  0.5
t =  0.6000000000000001
t =  0.7000000000000001
t =  0.8
t =  0.9
t =  1.0


1

In [13]:
p = pv.Plotter()
p.add_mesh(mesh, scalars=sols[-1])
p.show()

Widget(value="<iframe src='http://localhost:44599/index.html?ui=P_0x7fdafd33d120_2&reconnect=auto' style='widt…

In [54]:
import meshio
filename = "timedata.xdmf"
points = mesh.points
cells = {"tetra": mesh.cells_dict["tetra"]}
with meshio.xdmf.TimeSeriesWriter(filename) as writer:
    writer.write_points_cells(points, cells)
    for i in range(10):
        writer.write_data(i, point_data={"u": sols[i]})