In [None]:
try:
    !wget "https://fem-on-colab.github.io/releases/firedrake-install-release-real.sh" -O "/tmp/firedrake-install.sh"
    !bash "/tmp/firedrake-install.sh"
    from firedrake import *  # noqa: F401
except:
    from firedrake import *  # noqa: F401

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def navier_stokes(h=2**-10, degree=2, nu=0.01, timestep=2**-5, end_time=1.0):
    """
    Simulate 2D Navier-Stokes flow past a cylinder.
    """
    mesh = RectangleMesh(50, 50, 2.0, 1.0)
    cylinder = Circle(Point(0.5, 0.5), 0.1)
    mesh = MeshGenerator(mesh, [cylinder])
    n = FacetNormal(mesh)
    x, y = SpatialCoordinate(mesh)

    V = VectorFunctionSpace(mesh, 'CG', degree)
    Q = FunctionSpace(mesh, 'CG', degree - 1)
    W = V * Q

    u = Function(V)
    p = Function(Q)
    u_n = Function(V)

    bc_inlet = DirichletBC(W.sub(0), Constant((1.0, 0.0)), 'near(x[0], 0)')
    bc_walls = DirichletBC(W.sub(0), Constant((0.0, 0.0)), 'on_boundary')
    bcs = [bc_inlet, bc_walls]

    F = (inner((u - u_n) / timestep, v) * dx +
         inner(grad(u) * u, v) * dx +
         inner(grad(v), p) * dx -
         inner(div(u), q) * dx)

    sp = {
        'snes_max_it': 100,
    }

    fig, ax = plt.subplots(figsize=(10, 6))
    energy = []
    state = {'t': 0.0}

    def update(frame):
        if frame > 0:
            state['t'] += timestep
            print(f'Solving for time t = {state["t"]:.4f}:')
            solve(F == 0, u, bcs=bcs, solver_parameters=sp)
            u_n.assign(u)

            # Record energy
            energy.append(0.5 * assemble(inner(u, u) * dx))

        ax.clear()
        plot(u, axes=ax, linewidth=3)
        ax.set_title(f'Navier-Stokes (t = {state["t"]:.2f})')
        ax.set_xlabel('x')
        ax.set_ylabel('u')
        ax.set_ylim(-1, 1.5)
        ax.grid(True)

    num_frames = int(end_time / timestep) + 1
    anim = FuncAnimation(fig, update, frames=num_frames, interval=100)
    plt.close()

    # Plot energy
    plt.figure()
    plt.plot(np.arange(len(energy)) * timestep, energy)
    plt.title('Energy over time')
    plt.xlabel('Time')
    plt.ylabel('Energy')
    plt.grid(True)
    plt.show()

    return HTML(anim.to_jshtml())
