# Lecture 3
Some interactive examples shown. Play around with them to get a better feel for the how we approximate functions in FEM.

## Example curve fittning
This shows curve fitting using piecewise linear/quadratic interpolation. The main take-away is that we can approxaimate complex functions using simple functions if we compensate with more elements. This is something we use when constructing our FE approximations.

In [1]:
import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display

# =============== User function ===============
def f(x):
    return np.sin(3*x) * np.sin(5*x - 2)  # example function

# =============== Domain & sampling ===============
x_min, x_max = 0.0, np.pi
x_dense = np.linspace(x_min, x_max, 1000)
y_dense = f(x_dense)

# =============== Helpers ===============
def unique_preserve_order(arr, tol=0.0):
    """Return array with duplicates removed, preserving order."""
    out = []
    for v in arr:
        if not any(abs(v - u) <= tol for u in out):
            out.append(v)
    return np.array(out)

def build_linear(n_elements):
    """Return (x_nodes, y_nodes) for piecewise linear polyline."""
    x_nodes = np.linspace(x_min, x_max, n_elements + 1)
    y_nodes = f(x_nodes)
    return x_nodes, y_nodes

def build_quadratic(n_elements, samples_per_element=25):
    """Return (x_curve, y_curve, x_qnodes, y_qnodes) for per-element quadratic."""
    x_curve, y_curve = [], []
    x_nodes_all = []
    for e in range(n_elements):
        x0 = x_min + e * (x_max - x_min) / n_elements
        x1 = x_min + (e + 1) * (x_max - x_min) / n_elements
        xm = 0.5 * (x0 + x1)

        Xi = np.array([x0, xm, x1])
        Yi = f(Xi)

        # Normalize for stability
        Xi_local = Xi - x0
        a, b, c = np.polyfit(Xi_local, Yi, deg=2)

        xs = np.linspace(x0, x1, samples_per_element)
        xs_local = xs - x0
        ys = a*xs_local**2 + b*xs_local + c

        x_curve.extend(xs.tolist())
        y_curve.extend(ys.tolist())
        x_nodes_all.extend([x0, xm, x1])

    x_qnodes = unique_preserve_order(x_nodes_all, tol=1e-12)
    y_qnodes = f(x_qnodes)
    return np.array(x_curve), np.array(y_curve), x_qnodes, y_qnodes

# =============== Figure skeleton ===============
fig = go.FigureWidget(layout=dict(
    title="Piecewise Linear / Quadratic Approximation",
    xaxis_title="x",
    yaxis_title="f(x)",
    template="plotly_white",
    height=520,
    uirevision='keep'  # preserves zoom/pan when updating
))

# Original function
fig.add_scatter(x=x_dense, y=y_dense, mode="lines",
                name="Original Function", line=dict(color="black", width=2))

# Linear line (blue)
x_lin0, y_lin0 = build_linear(n_elements=4)
fig.add_scatter(x=x_lin0, y=y_lin0, mode="lines",
                name="Linear (P1)", line=dict(color="#1f77b4", width=3), visible=True)

# Linear nodes (blue markers)
fig.add_scatter(x=x_lin0, y=y_lin0, mode="markers",
                name="Linear nodes", marker=dict(color="#1f77b4", size=8, symbol="circle-open"),
                showlegend=False, visible=True)

# Quadratic curve (red)
xq0, yq0, xqnodes0, yqnodes0 = build_quadratic(n_elements=4, samples_per_element=25)
fig.add_scatter(x=xq0, y=yq0, mode="lines",
                name="Quadratic (P2, per-element)", line=dict(color="#d62728", width=3),
                visible=False)

# Quadratic nodes (red markers)
fig.add_scatter(x=xqnodes0, y=yqnodes0, mode="markers",
                name="Quadratic nodes", marker=dict(color="#d62728", size=8, symbol="circle-open"),
                showlegend=False, visible=False)

# =============== Controls ===============
elem_slider = widgets.IntSlider(value=4, min=1, max=50, step=1,
                                description="Elements", continuous_update=True)
mode_toggle = widgets.ToggleButtons(
    options=["Linear", "Quadratic", "Both"], value="Linear",
    description="Mode", button_style=""
)

def set_visibility(mode):
    if mode == "Linear":
        return [True, True, True, False, False]
    elif mode == "Quadratic":
        return [True, False, False, True, True]
    else:  # Both
        return [True, True, True, True, True]

def update(_=None):
    n = elem_slider.value
    mode = mode_toggle.value

    xl, yl = build_linear(n)
    xq, yq, xqn, yqn = build_quadratic(n)

    with fig.batch_update():
        for trace in fig.data:
            if trace.name == "Linear (P1)":
                trace.x, trace.y = xl, yl
            elif trace.name == "Linear nodes":
                trace.x, trace.y = xl, yl
            elif trace.name == "Quadratic (P2, per-element)":
                trace.x, trace.y = xq, yq
            elif trace.name == "Quadratic nodes":
                trace.x, trace.y = xqn, yqn

        vis = set_visibility(mode)
        for tr, v in zip(fig.data, vis):
            tr.visible = v

elem_slider.observe(update, names="value")
mode_toggle.observe(update, names="value")

# Initial render
update()
display(fig, widgets.HBox([elem_slider, mode_toggle]))

FigureWidget({
    'data': [{'line': {'color': 'black', 'width': 2},
              'mode': 'lines',
              'name': 'Original Function',
              'type': 'scatter',
              'uid': 'b7c412ea-1372-417d-979d-1f55b8feaa21',
              'visible': True,
              'x': {'bdata': ('AAAAAAAAAAAo7vUH/sJpPyju9Qf+wn' ... 'DVGRUJQJwvwpSKGwlAGC1EVPshCUA='),
                    'dtype': 'f8'},
              'y': {'bdata': ('AAAAAAAAAIA3p8IEWbGBvxSJ/BHVz5' ... 'WqZk6RP/xNGgifcIE/5JCVHoMSuDw='),
                    'dtype': 'f8'}},
             {'line': {'color': '#1f77b4', 'width': 3},
              'mode': 'lines',
              'name': 'Linear (P1)',
              'type': 'scatter',
              'uid': '6b7ec144-7f17-4135-bf0f-77019d57f4fd',
              'visible': True,
              'x': {'bdata': 'AAAAAAAAAAAYLURU+yHpPxgtRFT7Ifk/0iEzf3zZAkAYLURU+yEJQA==', 'dtype': 'f8'},
              'y': {'bdata': 'AAAAAAAAAICltj0LBTXlPwpyU1cmoto/ZfZPfsePz7/kkJUegxK4PA==', 'dtype': 'f8

HBox(children=(IntSlider(value=4, description='Elements', max=50, min=1), ToggleButtons(description='Mode', op…

## Example showing how a piecewise linear approximation is constructed
$$
    u(x) \approx u_h(x) = \sum_{i=1}^5 N_i(x)a_i = N_1(x) a_1 + N_2(x) a_2 + N_3(x) a_3 + N_4(x) a_4 + N_5(x) a_5
$$
Use the sliders to see how each DOF $a_i$ influence the approximation. Note that the value of the DOF, for example $a_2$, corresponds to the actual value at that node.

(The code below for the actual plotting is not relevant, only the plot.)

In [2]:
# --- Piecewise-linear FEM approximation with target function (Jupyter) ---

# Mesh: 4 elements on [0, 1] -> 5 nodes
nodes = np.linspace(0.0, 1.0, 5)  # [0, 0.25, 0.5, 0.75, 1.0]
nn = len(nodes)

# Dense grid for plotting
x = np.linspace(0.0, 1.0, 1000)

# Target function: f(x) = sin(pi*x) * x
def f_target(x):
    return np.sin(np.pi * x) * x

# Define P1 hat basis
def hat_basis(i, x, nodes):
    N = np.zeros_like(x)
    xi = nodes[i]
    if i > 0:
        left = (x >= nodes[i-1]) & (x <= xi)
        N[left] = (x[left] - nodes[i-1]) / (xi - nodes[i-1])
    if i < len(nodes)-1:
        right = (x >= xi) & (x <= nodes[i+1])
        N[right] = (nodes[i+1] - x[right]) / (nodes[i+1] - xi)
    return N

# Precompute basis on the plotting grid
N = np.vstack([hat_basis(i, x, nodes) for i in range(nn)])  # shape (nn, len(x))

def assemble_u(a):
    return a @ N

# Initial coefficients
a0 = np.array([0.0, 0.4, -0.2, 0.6, 0.2])

# --- Build figure ---
fig = go.FigureWidget(layout=dict(
    title=r"Piecewise linear approximation uₕ(x) = ∑ᵢ Nᵢaᵢ",
    xaxis_title="x",
    yaxis_title="u(x)",
    template="plotly_white",
    height=520
))

# Target function trace
fig.add_scatter(x=x, y=f_target(x), name="example function f(x)=sin(πx)·x",
                mode="lines", line=dict(width=2, color="#444", dash="dash"))

# FEM approximation trace
u0 = assemble_u(a0)
fig.add_scatter(x=x, y=u0, name="u_h(x)", mode="lines",
                line=dict(width=3, color="#1f77b4"))

# Nodal polyline
fig.add_scatter(x=nodes, y=np.interp(nodes, x, u0),
                mode="markers+lines", name="nodal interp",
                line=dict(color="#1f77b4", width=1, dash="dot"),
                marker=dict(size=8, color="#1f77b4", symbol="circle-open"))

# Optional basis functions
show_basis = False
if show_basis:
    palette = ["#d62728", "#2ca02c", "#9467bd", "#8c564b", "#e377c2"]
    for i in range(nn):
        fig.add_scatter(x=x, y=N[i], mode="lines", name=f"N_{i}(x)",
                        line=dict(width=1.5, color=palette[i % len(palette)], dash="dash"))

# Node markers
# fig.add_scatter(x=nodes, y=[0]*nn, mode="markers", name="nodes",
#                 marker=dict(size=8, color="black"))

# Lock y-limits to ±2
fig.update_yaxes(range=[-1.1, 1.1])

# --- Sliders ---
sliders = []
for i in range(nn):
    sliders.append(widgets.FloatSlider(
        value=a0[i], min=-1.0, max=1.0, step=0.1,
        description=f"a{i}", continuous_update=True, readout_format=".2f",
        layout=widgets.Layout(width='360px')
    ))
grid = widgets.VBox(sliders)

# Callback
def update_plot(*args):
    a = np.array([s.value for s in sliders])
    uh = assemble_u(a)
    with fig.batch_update():
        fig.data[1].y = uh
        fig.data[2].y = np.interp(nodes, x, uh)

for s in sliders:
    s.observe(update_plot, names="value")

display(fig, grid)

FigureWidget({
    'data': [{'line': {'color': '#444', 'dash': 'dash', 'width': 2},
              'mode': 'lines',
              'name': 'example function f(x)=sin(πx)·x',
              'type': 'scatter',
              'uid': 'a3da9732-b8c5-4050-8796-dc9c7e546567',
              'x': {'bdata': ('AAAAAAAAAABoBgGkgGZQP2gGAaSAZm' ... 't/me/vP33/rb/M9+8/AAAAAAAA8D8='),
                    'dtype': 'f8'},
              'y': {'bdata': ('AAAAAAAAAACKQTRhBWjKPjRiltL8Z+' ... 'jovrV5P9NsPD5hvGk/B1wUMyamoTw='),
                    'dtype': 'f8'}},
             {'line': {'color': '#1f77b4', 'width': 3},
              'mode': 'lines',
              'name': 'u_h(x)',
              'type': 'scatter',
              'uid': 'ca90bc0c-c9b1-4663-9d39-b07df4e765db',
              'x': {'bdata': ('AAAAAAAAAABoBgGkgGZQP2gGAaSAZm' ... 't/me/vP33/rb/M9+8/AAAAAAAA8D8='),
                    'dtype': 'f8'},
              'y': {'bdata': ('AAAAAAAAAADaowGgmj1aP9qjAaCaPW' ... 'kEkALKP+Gc2c4Uzsk/mpmZmZmZyT8='),
    

VBox(children=(FloatSlider(value=0.0, description='a0', layout=Layout(width='360px'), max=1.0, min=-1.0), Floa…