# Start

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import panel as pn

pn.extension("ipywidgets", "katex", "mathjax")
import sys
import timeit
from datetime import datetime
from inspect import signature
from random import shuffle, uniform

import ipywidgets as ipw
from matplotlib.animation import FuncAnimation
from matplotlib.figure import Figure
from matplotlib.ticker import MultipleLocator

print("Packages successfully loaded ")

# Functions from cookbook

In [None]:
try:
    %run Initialize/init_cookbook.ipynb # valid when running the cookbook in the main file
except:
    %run init_cookbook.ipynb # valid when running the cookbook from this file.

# Commonly used functions

In [None]:
def dispersion(k, h):  # calculate omega
    return (9.81 * k * np.tanh(k * h)) ** 0.5

In [None]:
def wave_length(T, h):
    d = h

    # based on waveNumber_Fenton(T,d) from Jaime in computerlab
    g = 9.81
    omega = 2 * np.pi / T
    k0 = omega * omega / g
    alpha = k0 * d
    beta = alpha * (np.tanh(alpha)) ** -0.5
    k = (
        (alpha + beta**2 * np.cosh(beta) ** -2)
        / (np.tanh(beta) + beta * np.cosh(beta) ** -2)
        / d
    )

    L = 2 * np.pi / k

    return L


# wave_length(T= 5*60, h=4000)

In [None]:
def group_stats(k1, k2, w1, w2):
    Delta_k = np.abs(k2 - k1)
    Delta_w = np.abs(w2 - w1)
    L = 2 * np.pi / Delta_k
    T = 2 * np.pi / Delta_w
    cg = Delta_w / Delta_k
    return L, T, cg

# Questions

### Q6

In [None]:
from matplotlib.figure import Figure

In [None]:
def check_answers_W2_Q6(
    id,
    answers,
    unit,
    FB_G,
    FB_W,
    num_widgets,
    feedback_widgets,
    attempt,
    final_score_widget,
    plots,
    panes,
    figures,
    par,
):
    FV = par
    omega_serie, T_serie, L_serie, k_serie, c_serie, n_serie, cg_serie = (
        FV["omega_serie"],
        FV["T_serie"],
        FV["L_serie"],
        FV["k_serie"],
        FV["c_serie"],
        FV["n_serie"],
        FV["cg_serie"],
    )

    def button_callback(b):
        attempt.value += 1

        # store responses in a list for plotting
        responses = []
        score = 0
        for i in range(len(answers)):
            num_widget = num_widgets[i]
            feedback_widget = feedback_widgets[i]
            answer = answers[i]
            response = num_widget.value
            responses.append(response)

            # if answer is correct
            if np.abs(answer - num_widget.value) < limit_answer(answer):
                score += 1
                if len(FB_G) != 0:
                    feedback_widget.value = FB_G
                else:
                    feedback_widget.value = "Well done, this is correct!"

            # the answer is NOT within boundaries, provide feedback based on the number of attempts
            if np.abs(answer - num_widget.value) >= limit_answer(answer):

                if attempt.value < 3 and len(FB_W) > 0:
                    feedback_widget.value = FB_W

                if attempt.value < 3 and len(FB_W) == 0:
                    feedback_widget.value = "Oops, there seems to be a mistake"

                if attempt.value >= 3:
                    feedback_widget.value = (
                        "The correct answer is " + str(answer) + str(unit) + "."
                    )

        final_score_widget.value = str(score) + "/" + str(len(answers))

        # Update the graphs
        colors = ["red", "orange", "blue", "green"]

        if id == 1:
            ax = plots[0]
            ax.clear()
            line = ax.plot(omega_serie, L_serie, label="L")
            # ax.scatter(par['omega_points'], np.array(answers) * 1000, c="green", s=30, label = 'correct answer')
            # ax.scatter(par['omega_points'], np.array([responses]) * 1000, c = colors, s=25, label = 'answer')
            for i, (x, y, color) in enumerate(
                zip(par["omega_points"], responses, colors), start=1
            ):
                ax.scatter(x, y * 1000, s=25, c=color, label=f"$L_{i}$")
            ax.set_title(par["titles"][0])
            ax.legend(bbox_to_anchor=(1.01, 0.5), loc="center left")
            ax.set_xlabel("$\omega $ [rad/s]")
            ax.set_ylabel("L [m]")
            panes[0].object = figures[0]

        # draw second graph (next to question 3, which is fig_id 2)
        if id == 3 or id == 4:
            fig_id = 2
            # The number of times questions 2 and 3 are answered
            attempt2 = par["attempts"][2].value
            attempt3 = par["attempts"][3].value

            # the given answers
            responses2 = return_answers_widgets(par["all_answers"][2])
            responses3 = return_answers_widgets(par["all_answers"][3])

            ax = plots[fig_id]
            ax.clear()
            ax.set_ylabel("celerity [m/s]")
            ax.set_xlabel("$\omega $ [rad/s]")
            ax.set_title(par["titles"][fig_id])

            if attempt2 > 0:
                line = ax.plot(omega_serie, c_serie, label="c", color="#006AB5")
                # ax.scatter(par['omega_points'], answers, color ="#006AB5", s=30, label = 'correct answer')
                # ax.scatter(par['omega_points'], responses2, color=colors, s=25, label = 'answer c')
                for i, (x, y, color) in enumerate(
                    zip(par["omega_points"], responses2, colors), start=1
                ):
                    ax.scatter(x, y, s=25, c=color, label=f"c$_{i}$")

            if attempt3 > 0:
                line = ax.plot(omega_serie, cg_serie, label="$c_g$", color="#003f6c")
                for i, (x, y, color) in enumerate(
                    zip(par["omega_points"], responses3, colors), start=1
                ):
                    ax.scatter(x, y, s=25, c=color, marker="s")
                ax.scatter(
                    [], [], c="#003f6c", s=25, label="$c_{g.n}$", marker="s"
                )  # dummy label

            ax.legend(bbox_to_anchor=(1.01, 0.5), loc="center left")
            panes[fig_id].object = figures[fig_id]

        if id == 4:
            fig_id = 3
            ax = plots[fig_id]
            ax.clear()
            line = ax.plot(omega_serie, n_serie, label="n [-]")
            ax.set_title(par["titles"][fig_id])
            ax.legend(bbox_to_anchor=(1.01, 0.5), loc="center left")
            ax.set_xlabel("$\omega $ [rad/s]")
            panes[fig_id].object = figures[fig_id]

    return (
        button_callback  # otherwise gives TypeError: 'NoneType' object is not callable
    )

In [None]:
def return_answers_widgets(widgets):
    answers = []
    for widget in widgets:
        answers.append(widget.value)
    return answers

In [None]:
def W2_Q6():

    h = np.arange(3500, 5500 + 500, 500)
    h = np.random.choice(h)
    h = 4000

    T1 = 5 * 60  # 5 min
    T4 = 60 * 60  # 60 min

    text_general = (
        "We are going to analyze the characteristics of tsunami waves at a depth of "
        + str(h)
        + " m. It will be analyzed for waves with a period of "
        + str(T1 / 60)
        + " minutes (point 1), at the deep/intermediate water boundary (point 2), the shallow/intermediate water boundary (point 3), and for waves with a period of "
        + str(T4 / 60)
        + "  minutes (point 4). Use the dispersion relationship when assessing these 4 points."
    )

    Q0 = [
        "Between which periods are tsunami waves at intermediate depth, according to the dispersion relationship?"
    ]
    L_lower = h / 0.5
    k_lower = 2 * np.pi / L_lower
    w_lower = (9.81 * k_lower * np.tanh(k_lower * h)) ** 0.5  # dispersion relationship
    T_lower = 2 * np.pi / w_lower
    L_up = h / 0.05
    k_up = 2 * np.pi / L_up
    w_up = (9.81 * k_up * np.tanh(k_up * h)) ** 0.5
    T_up = 2 * np.pi / w_up
    answer0 = [[round(T_up, 1), round(T_lower, 1)]]
    subquestions_0 = [["Longest period", "Shortest period"]]
    Q0_unit = "s"

    Q0_widget = nummerical_subquestions(Q0, answer0, subquestions_0, Q0_unit)

    Q1 = "What is the wavelength of the tsunami wave at those 4 points?"
    Q1_labels = ["L1", "L2", "L3", "L4"]
    Q1_unit = " km"

    L1 = wave_length(T=T1, h=h)
    L2 = h / 0.5
    L3 = h / 0.05
    L4 = wave_length(T=T4, h=h)
    Q1_answers = [round(L1 / 1000, 1), L2 / 1000, L3 / 1000, round(L4 / 1000, 1)]

    # The 4 points of interest in this question
    w1 = 2 * np.pi / T1
    w2 = dispersion(k=2 * np.pi / L2, h=h)
    w3 = dispersion(k=2 * np.pi / L3, h=h)
    w4 = 2 * np.pi / T4
    omega_points = np.array([w1, w2, w3, w4])

    Q2 = "What is the wave period of the tsunami wave at those 4 points?"
    Q2_labels = ["T1", "T2", "T3", "T4"]
    Q2_unit = "s"

    T2_ans = round((2 * np.pi) / w2, 1)
    T3_ans = round((2 * np.pi) / w3, 1)
    Q2_answers = [T1, T2_ans, T3_ans, T4]

    Q3 = "What is the wave celerity (c) of the tsunami wave at those 4 points?"
    Q3_labels = ["c1", "c2", "c3", "c4"]
    Q3_unit = " m/s"
    T2 = 2 * np.pi / w2
    T3 = 2 * np.pi / w3

    # def calc_c(T,h):
    #    L = wave_length(T, h)
    #    return 9.81*T/(2*np.pi)*np.tanh(2*np.pi*h/L)

    c1 = L1 / T1
    c2 = L2 / T2
    c3 = L3 / T3
    c4 = L4 / T4

    # deze fout eruit halen!
    # c1 = calc_c(T1,h)
    # c2 = calc_c(T2,h)
    # c3 = calc_c(T3,h)
    # c4 = calc_c(T4,h)

    # c = c0 * tanh(kh) = g/omega k h = gh/c (ondiep)
    # L = L0 tanh kh
    # h/L0 = H/L tanh kh
    Q3_answers = [round(c1, 1), round(c2, 1), round(c3, 1), round(c4, 1)]

    Q4 = "What is the wave group velocity (cg) of the tsunami at those 4 points?"
    Q4_labels = ["c_{g1}$", "c_{g2}", "c_{g3}", "c_{g4}"]
    Q4_unit = " m/s"

    k1 = 2 * np.pi / L1
    k2 = 2 * np.pi / L2
    k3 = 2 * np.pi / L3
    k4 = 2 * np.pi / L4

    n1 = 0.5 + k1 * h / np.sinh(2 * k1 * h)
    n2 = 0.5 + k2 * h / np.sinh(2 * k2 * h)
    n3 = 0.5 + k3 * h / np.sinh(2 * k3 * h)
    n4 = 0.5 + k4 * h / np.sinh(2 * k4 * h)

    cg1 = c1 * n1
    cg2 = c2 * n2
    cg3 = c3 * n3
    cg4 = c4 * n4

    Q4_answers = [round(cg1, 1), round(cg2, 1), round(cg3, 1), round(cg4, 1)]

    # store the questions in a list
    Questions = [Q1, Q2, Q3, Q4]
    Unit_question = [Q1_unit, Q2_unit, Q3_unit, Q4_unit]
    answer_question = [Q1_answers, Q2_answers, Q3_answers, Q4_answers]
    label_question = [Q1_labels, Q2_labels, Q3_labels, Q4_labels]

    # no feedback provided
    FB_G, FB_W = "", ""

    def calc_c(T, h):
        L = wave_length(T, h)
        return 9.81 * T / (2 * np.pi) * np.tanh(2 * np.pi * h / L)

    # the answers for the graph
    omega_serie = np.linspace(w4, np.max([w1, w2, w3, w4]) * 1.3, 100)
    T_serie = [2 * np.pi / w for w in omega_serie]
    L_serie = [wave_length(T, h) for T in T_serie]
    k_serie = [2 * np.pi / L for L in L_serie]
    c_serie = [L / T for L, T in zip(L_serie, T_serie)]
    # c_serie = [calc_c(T,h) for T in T_serie] # omgedraaid
    n_serie = [0.5 + k * h / np.sinh(2 * k * h) for k in k_serie]
    cg_serie = [c * n for c, n in zip(c_serie, n_serie)]

    # set plot settings and make plots
    titles = ["Wavelength", "None", "Celerity", "n"]
    figures = []
    plots = []
    panes = []
    for i, (title) in enumerate(titles):
        fig = Figure((5, 2.5))
        ax = fig.subplots()
        ax.set_axis_off()
        pane = pn.pane.Matplotlib(fig, dpi=96)
        fig.subplots_adjust(hspace=0)
        fig.subplots_adjust(
            bottom=0.25
        )  # Add some extra space for the axis at the bottom
        fig.subplots_adjust(left=0.2, right=0.8)  # Add some space for the labels
        figures.append(fig)
        plots.append(ax)
        panes.append(pane)

    # fill an empty list with widgets
    all_question_widgets = []
    all_answers = []
    attempts = (
        []
    )  # widgets only used for counting how many times the submit button is pressed
    id = 0
    for i, (question, unit, answers, label) in enumerate(
        zip(Questions, Unit_question, answer_question, label_question)
    ):
        id += 1
        question_widget = pn.widgets.StaticText(value=question, width=750)

        Rows_answer = []
        num_widgets = []
        feedback_widgets = []
        for number, answer in enumerate(answers):
            number_widget = pn.widgets.StaticText(
                value=str(number + 1) + str(")"), width=10
            )
            unit_widget = pn.widgets.StaticText(value=unit, width=50)
            num_widget = pn.widgets.FloatInput(value=0, step=0.01, width=100)
            feedback_widget = pn.widgets.StaticText(value="", name="", width=250)

            num_widgets.append(num_widget)
            feedback_widgets.append(feedback_widget)

            Hbox = pn.Row(number_widget, num_widget, unit_widget, feedback_widget)
            Rows_answer.append(Hbox)
        all_answers.append(num_widgets)

        # Add a submit button with widget that returns the final score
        submit_button = pn.widgets.Button(name="Check")
        submit_text_widget = pn.widgets.StaticText(value="Final score:", width=70)
        final_score_widget = pn.widgets.TextInput(value="", name="", width=130)
        row_submit = pn.Row(submit_button, submit_text_widget, final_score_widget)

        # The values for the submit button are determined at the moment these are created.
        attempt = pn.widgets.FloatInput(value=0)
        attempts.append(attempt)

        # store the local parameters to be used inside the function
        FV = locals()
        submit_button.on_click(
            check_answers_W2_Q6(
                id,
                answers,
                unit,
                FB_G,
                FB_W,
                num_widgets,
                feedback_widgets,
                attempt,
                final_score_widget,
                plots,
                panes,
                figures,
                FV,
            )
        )

        # Structure the widgets
        Vbox_answer = pn.Column(*Rows_answer, row_submit)
        if i < len(panes):
            Hbox_plot = pn.Row(Vbox_answer, panes[i])
        question_widget = pn.Column(question_widget, Hbox_plot)
        all_question_widgets.append(question_widget)

    text_widget = pn.widgets.StaticText(value=text_general)
    quiz_widget = pn.Column(text_widget, Q0_widget, *all_question_widgets)

    return quiz_widget


# W2_Q6()

### Q7

In [None]:
from ipywidgets import HBox, VBox, interact

In [None]:
def dispersion(k, h):
    return (9.81 * k * np.tanh(k * h)) ** 0.5


def W2_Q7_graph(T1, T2, T3, slope_in, d0):
    %matplotlib inline
    # bed profile

    slope = 1.0 / slope_in  # bed slope [-]
    d0  # offshore water depth [m]
    x_max = round((d0 + 2) / slope)
    x = np.arange(0, x_max + 1, 1)  # cross-shore coordinate [m]
    # x = np.linspace(0, x_max + x_max/100, 100)# <-- should be used to reduce computational demands, influences xticks
    zbed = -(d0 - slope * x)  # bed elevation [m]
    h = -zbed  # still water depth [m]
    h[h < 0] = 0  # no negative depths

    # w is zero when h is 0, causing a divide by zero.
    # shorten the lists if a water depth of 0 is reached.
    x0_id = np.argwhere(h == 0)[0][0]  # first location where water depth = 0
    h_water = h[0:x0_id]
    x_water = x[0:x0_id]

    # wavelength through profile
    L1 = [wave_length(T1, h) for h in h_water]
    L2 = [wave_length(T2, h) for h in h_water]
    L3 = [wave_length(T3, h) for h in h_water]

    # velocity profile
    def calc_c(T, h):
        L = wave_length(T, h)
        return 9.81 * T / (2 * np.pi) * np.tanh(2 * np.pi * h / L)

    c1 = [calc_c(T1, h) for h in h_water]
    c2 = [calc_c(T2, h) for h in h_water]
    c3 = [calc_c(T3, h) for h in h_water]

    fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(9, 6), sharex=True, sharey=False)
    fig.subplots_adjust(hspace=0)
    fig.subplots_adjust(wspace=0.1)

    # bathymetry
    y_max = np.max(zbed) * 1.1
    y_min = np.min(zbed) * 1.1
    axs[0].plot(x, zbed, label="Bed (1:" + str(round(slope_in, 2)) + ")", color="k")
    axs[0].plot([0, x[x0_id]], [0, 0], color="gray", label="Still water surface")
    axs[0].plot(
        [x[x0_id], x[x0_id]], [y_min, y_max], color="grey", linestyle="--", label="x=0"
    )
    axs[0].set_ylim(y_min, y_max)
    axs[0].set_ylabel("y [m]")
    axs[0].legend(loc="lower right")

    # wavelength
    y_max = np.max(([L1], [L2], [L3])) * 1.1
    axs[1].plot(x_water, L1, label="wave 1")
    axs[1].plot(x_water, L2, label="wave 2")
    axs[1].plot(x_water, L3, label="wave 3")
    axs[1].plot(
        [x[x0_id], x[x0_id]], [0, y_max], color="grey", linestyle="--", label="x=0"
    )
    axs[1].set_ylim(0, y_max)
    axs[1].set_ylabel("Wavelength (L) [m]")
    axs[1].legend(loc="upper right")

    # wave celerity
    y_max = np.max(([c1], [c2], [c3])) * 1.1
    axs[2].plot(x_water, c1, label="wave 1")
    axs[2].plot(x_water, c2, label="wave 2")
    axs[2].plot(x_water, c3, label="wave 3")
    axs[2].plot(
        [x[x0_id], x[x0_id]], [0, y_max], color="grey", linestyle="--", label="x=0"
    )
    axs[2].set_xlim(0, np.max(x))
    axs[2].set_ylim(0, y_max)
    axs[2].set_ylabel("Celerity (c) [m/s]")
    axs[2].legend(loc="upper right")
    axs[2].set_xlabel("cross-shore location (x) [m]")
    # axs[2].set_xticks('cross-shore location (x) [m]')

    # remove the lines related to the x-ticks
    axs[0].xaxis.set_visible(False)
    axs[1].xaxis.set_visible(False)

    # set title
    axs[0].set_title("Wave characteristics in cross-shore direction")

    # get values of xticks, change them to get x=0 at water boundary and positive direction offshore
    xticks = axs[2].get_xticks()
    new_ticks = np.ones(len(xticks)) * x_water[-1] - xticks + 1
    axs[2].set_xticks(xticks, new_ticks)

    # Plot dots at the transition with intermediate water depth
    # for L,c in zip((L1,L2,L3), (c1,c2,c3)):
    #    h_L = h_water/L
    #    if h_L[0] > 0.5:
    #       id_deep = np.argwhere(h_L <= 0.5)[0][0]  # first location where water depth = 0
    #       axs[1].plot(x_water[id_deep], L[id_deep], 'ro', label="wave 3")
    #
    #    if h_L[0] > 0.05:
    #        id_shl = np.argwhere(h_L < 0.05)[0][0]  # first location where water depth = 0
    #        axs[1].plot(x_water[id_shl], L[id_shl], 'ro', label="wave 3")


def W2_Q7_set_graph(FV, FV2):

    def button_callback(b):
        if FV2["id"] == 1:
            FV["T1"].value = 10
            FV["T2"].value = 15
            FV["T3"].value = 20
            FV["slope"].value = 30
            FV["d0"].value = 250

        if FV2["id"] == 2:
            FV["T1"].value = 4
            FV["T2"].value = 6
            FV["T3"].value = 8
            FV["slope"].value = 30
            FV["d0"].value = 50

        if FV2["id"] == 3:
            FV["T1"].value = 4
            FV["T2"].value = 6
            FV["T3"].value = 8
            FV["slope"].value = 30
            FV["d0"].value = 20

    return button_callback


def W2_Q7_check_answers(FV, FV2):

    def button_callback(b):
        all_answers = FV2["all_answers"]
        response = FV2["checkbutton_group"].value

        if response == FV2["answer"]:
            FV2["feedback_widget"].value = FV2["FG"]
        else:
            FV2["feedback_widget"].value = FV2["FW"]

    return button_callback


def W2_Q7_questions(FV):
    all_answers = ["wave 1", "wave 2", "wave 3"]

    Q1 = "Which waves are in deep water at x = 6000?"
    Ans1 = [all_answers[0], all_answers[1]]
    FG_1 = "Indeed, these two waves are in deep water. Wave 3, with a period of 20 seconds, is influenced by bottom friction."
    FW_1 = (
        "Try another reasoning, what are the spatial changes for waves in deep water?"
    )

    Q2 = "Which of the waves are in intermediate water at x = 500?"
    Ans2 = [all_answers[1], all_answers[2]]
    FG_2 = "Good Job, this wave is the only wave that in intermediate water depth"
    FW_2 = "Almost, which waves are affected by the changing water depth?"

    Q3 = "Which of the waves are in shallow water at x = 200?"
    Ans3 = []
    FG_3 = "Indeed, all of the waves are in intermediate water, the waves are in shallow water around x = 50 m, where the celerity is only related to the water depth."
    FW_3 = "Which parameters are relevant for waves in shallow water?"

    number_of_questions = 3

    all_question_widgets = []
    for id in np.arange(1, number_of_questions + 1, 1):
        question = eval("Q" + str(id))
        answer = eval("Ans" + str(id))
        FG = eval("FG_" + str(id))
        FW = eval("FW_" + str(id))

        question_widget = pn.widgets.StaticText(value=question)
        set_button = pn.widgets.Button(name="Reset graph")
        checkbutton_group = pn.widgets.CheckButtonGroup(
            name="Check Button Group", value=[], options=all_answers
        )
        submit_button = pn.widgets.Button(name="Check")
        feedback_widget = pn.widgets.StaticText(value="", name="")

        FV2 = {key: value for key, value in locals().items()}
        # FV2 = {key: value for key, value in {**globals(), **locals()}.items()}
        set_button.on_click(W2_Q7_set_graph(FV, FV2))
        submit_button.on_click(W2_Q7_check_answers(FV, FV2))

        question_row = pn.Row(set_button, checkbutton_group, submit_button)
        question_widget = pn.Column(question_widget, question_row, feedback_widget)
        all_question_widgets.append(question_widget)

    return all_question_widgets


def W2_Q7():
    # Create interactive widgets, which require IPY Widgets, widgets from panel do not work
    # L1 = pn.widgets.FloatSlider(name='Float Slider', start=0, end=3.141, step=0.01, value=1.57)
    T1 = ipw.FloatSlider(value=4, min=1, max=20, step=0.01, description="T1 [s]")
    T2 = ipw.FloatSlider(value=7, min=1, max=20, step=0.01, description="T2 [s]")
    T3 = ipw.FloatSlider(value=25, min=1, max=20, step=0.01, description="T3 [s]")

    slope = ipw.FloatSlider(
        value=75, min=50, max=200, step=0.1, description="slope 1:..."
    )
    d0 = ipw.FloatSlider(value=50, min=0.1, max=500, step=0.1, description="depth [m]")

    # Setup widget layout (User Interface) for the graph input
    vbox1 = ipw.VBox(
        [
            ipw.Label("Wave components", layout=ipw.Layout(align_self="center")),
            T1,
            T2,
            T3,
        ]
    )
    vbox2 = ipw.VBox(
        [ipw.Label("General", layout=ipw.Layout(align_self="center")), slope, d0]
    )
    UI = ipw.HBox([vbox1, vbox2])

    # Use the interactive function to update the plot
    graph = ipw.interactive_output(
        W2_Q7_graph, {"T1": T1, "T2": T2, "T3": T3, "slope_in": slope, "d0": d0}
    )

    FV = {key: value for key, value in locals().items()}
    questions = W2_Q7_questions(FV)

    intro = 'Can you find the waves that are in deep, intermediate, or shallow water? You should first set the graph by pressing the "Set graph" button below each question. After that, when the kernel is idle, you can select the asked waves and then press the "check" button.  Good luck!'
    intro_widget = pn.widgets.StaticText(value=intro)

    display(UI, graph, intro_widget, *questions)


# W2_Q7()

### Q8

Not used any more, the code is move to harmonische_componenten_reserve.ipynb

## Wave groups

### Q9 Coding

In [None]:
def W2_Q9():

    question_start = "Can you use formula 3 to complete the functions below to calculate the water elevation, the varying amplitude, and the wave envelope? "
    # Unrandomized to match presentation/slides
    T1 = 7
    T2 = 6.2
    a1 = a2 = 1.5
    h = 20

    # The calculation
    L1 = wave_length(T1, h)
    L2 = wave_length(T2, h)

    k1 = 2 * np.pi / L1
    k2 = 2 * np.pi / L2
    w1 = 2 * np.pi / T1
    w2 = 2 * np.pi / T2

    L_group, T_group, c_group = group_stats(k1, k2, w1, w2)

    # spatial point of interest
    xp = 0

    # time point of interest
    tp = 0

    # The maximum X-coordinate, 3 times the wave group length, rounded up to 25 meters
    x_max = (3 * L_group) // 25 * 25 + 25

    # The maximum time, 3 times the wave group period, rounded up to 10 seconds
    t_max = (3 * T_group) // 10 * 10 + 10

    # The question that is asked
    question_data = (
        "The wave periods T1 and T2 are "
        + str(T1)
        + " and "
        + str(T2)
        + " seconds, respectively, with both an amplitudes of "
        + str(a1)
        + " meter. The water depth is "
        + str(h)
        + " meter. "
    )
    question_end = " The answer you coded will be plotted when you run the cell and the function is valid."
    question_boundaries = (
        "One function is used to plot time series at x = "
        + str(xp)
        + " m from t = 0 to t = "
        + str(t_max)
        + " seconds, and one to plot the wave conditions from x=0 to x="
        + str(x_max)
        + " m at t="
        + str(tp)
        + " seconds. Define the function such that it will calculate the water elevation at a specific time or space."
    )
    question = question_start + question_data + question_boundaries + question_end

    # make the attempt counter, one for each subquestion, should not be changed
    attempt = pn.widgets.FloatInput(value=0)

    # Required widgets for functionality, should not be changed
    question_widget = pn.widgets.StaticText(name="", value=question)

    # define a new global variable, with a unique name (This is related to week 2, question 9)
    global W2_Q9_param
    # store the question-related parameters and the widget 'attempt'
    W2_Q9_param = a1, a2, T1, T2, L1, L2, x_max, xp, t_max, tp, L_group, T_group

    return question_widget


# W2_Q9()

In [None]:
def Show_Q9A1():
    # define the name of the function that the students will make
    function_name = "eta_t"

    # define the name of the parameter plotted on the horizontal axis
    parameter_x_axis = "t"

    # set the horizontal axis of the graph
    a1, a2, T1, T2, L1, L2, x_max, xp, t_max, tp, L_group, T_group = W2_Q9_param
    Delta_T = np.min([T1, T2]) / 30
    horizontal_axis = np.arange(0, t_max + Delta_T, Delta_T)

    # define the correct function and its values along the x-axis.
    def correct_function(a1, a2, T1, T2, L1, L2, t, xp):
        eta1 = a1 * np.sin(2 * np.pi / T1 * t - 2 * np.pi / L1 * xp)
        eta2 = a2 * np.sin(2 * np.pi / T2 * t - 2 * np.pi / L2 * xp)
        eta = eta1 + eta2
        return eta

    # set the acceptable computational error (ratio)
    f_margin = 0.001  # 0.001 = 0.01%

    fig = Figure((5, 2.5))
    check_code_function(
        fig,
        horizontal_axis,
        function_name,
        correct_function,
        parameter_x_axis,
        f_margin,
        xlabel="time [s]",
        ylabel="y coordinate [m]",
    )

In [None]:
def Show_Q9B1():
    # define the name of the function that the students will make
    function_name = "eta_x"

    # define the name of the parameter plotted on the horizontal axis
    parameter_x_axis = "x"

    # set the horizontal axis of the graph
    a1, a2, T1, T2, L1, L2, x_max, xp, t_max, tp, L_group, T_group = W2_Q9_param
    horizontal_axis = np.arange(0, x_max + L_group / 30, L_group / 60)

    # define the correct function and its values along the x-axis.
    def correct_function(a1, a2, T1, T2, L1, L2, tp, x):
        eta1 = a1 * np.sin(2 * np.pi / T1 * tp - 2 * np.pi / L1 * x)
        eta2 = a2 * np.sin(2 * np.pi / T2 * tp - 2 * np.pi / L2 * x)
        eta = eta1 + eta2
        return eta

    # set the acceptable computational error (ratio)
    f_margin = 0.001  # 0.001 = 0.01%

    fig = Figure((5, 2.5))
    check_code_function(
        fig,
        horizontal_axis,
        function_name,
        correct_function,
        parameter_x_axis,
        f_margin,
        xlabel="x coordinate [m]",
        ylabel="y coordinate [m]",
    )

In [None]:
def Show_Q9A2():

    # set the acceptable computational error (ratio)
    f_margin = 0.005  # 0.005 = 0.5%

    # define the name of the parameter plotted on the horizontal axis
    parameter_x_axis = "t"

    # set the horizontal axis of the graph
    a1, a2, T1, T2, L1, L2, x_max, xp, t_max, tp, L_group, T_group = W2_Q9_param
    Delta_T = np.min([T1, T2]) / 30
    horizontal_axis = np.arange(0.01, t_max + Delta_T, Delta_T)

    fig = Figure((5, 2.5))
    ax = fig.subplots()
    fig.subplots_adjust(top=0.7)
    pane = pn.pane.Matplotlib(fig, dpi=100)

    # define the name of the function that the students will make
    function_name = "varying_amplitude_t"

    def correct_function(a1, a2, T1, T2, L1, L2, t, xp):
        w1 = 2 * np.pi / T1
        w2 = 2 * np.pi / T2
        k1 = 2 * np.pi / L1
        k2 = 2 * np.pi / L2
        Delta_w = np.abs(w1 - w2)
        Delta_k = k2 - k1
        var_amp = (a1 + a2) * np.cos(0.5 * Delta_w * t - 0.5 * Delta_k * xp)
        return var_amp

    pane, ax = check_code_function(
        fig,
        horizontal_axis,
        function_name,
        correct_function,
        parameter_x_axis,
        f_margin,
        new_graph=False,
        ax=ax,
        pane=pane,
    )

    # should be defined below the check_code_function, the function will be improved to make it possible to have the check_code_function grouped at the bottom
    function_name2 = "envelope_t"

    def correct_function(a1, a2, T1, T2, t):
        w1 = 2 * np.pi / T1
        w2 = 2 * np.pi / T2
        Delta_w = np.abs(w1 - w2)
        car_wave = (a1 + a2) * np.cos(0.5 * Delta_w * t)
        y = np.abs(car_wave)
        return y

    pane, ax = check_code_function(
        fig,
        horizontal_axis,
        function_name2,
        correct_function,
        parameter_x_axis,
        f_margin,
        new_graph=False,
        ax=ax,
        pane=pane,
    )
    # should be defined below the check_code_function, the function will be improved to make it possible to have the check_code_function grouped at the bottom

    function_name3 = "eta_t"

    def correct_function(a1, a2, T1, T2, L1, L2, t, xp):
        eta1 = a1 * np.sin(2 * np.pi / T1 * t - 2 * np.pi / L1 * xp)
        eta2 = a2 * np.sin(2 * np.pi / T2 * t - 2 * np.pi / L2 * xp)
        eta = eta1 + eta2
        return eta

    pane, ax = check_code_function(
        fig,
        horizontal_axis,
        function_name3,
        correct_function,
        parameter_x_axis,
        f_margin,
        new_graph=False,
        ax=ax,
        pane=pane,
        xlabel="time [s]",
        ylabel="y coordinate [m]",
    )
    display(pane)

In [None]:
def Show_Q9B2():

    # set the acceptable computational error (ratio)
    f_margin = 0.005  # 0.005 = 0.5%

    # define the name of the parameter plotted on the horizontal axis
    parameter_x_axis = "x"

    # set the horizontal axis of the graph
    a1, a2, T1, T2, L1, L2, x_max, xp, t_max, tp, L_group, T_group = W2_Q9_param
    horizontal_axis = np.arange(0, x_max + L_group / 30, L_group / 60)

    fig = Figure((5, 2.5))
    ax = fig.subplots()
    fig.subplots_adjust(top=0.7)
    pane = pn.pane.Matplotlib(fig, dpi=100)

    # define the name of the function that the students will make
    function_name = "varying_amplitude_x"

    # define the correct function and its values along the x-axis.
    def correct_function(a1, a2, T1, T2, L1, L2, tp, x):
        w1 = 2 * np.pi / T1
        w2 = 2 * np.pi / T2
        k1 = 2 * np.pi / L1
        k2 = 2 * np.pi / L2
        Delta_w = np.abs(w1 - w2)
        Delta_k = k2 - k1
        var_amp = (a1 + a2) * np.cos(0.5 * Delta_w * tp - 0.5 * Delta_k * x)
        return var_amp

    pane, ax = check_code_function(
        fig,
        horizontal_axis,
        function_name,
        correct_function,
        parameter_x_axis,
        f_margin,
        new_graph=False,
        ax=ax,
        pane=pane,
    )

    # should be defined below the check_code_function, the function will be improved to make it possible to have the check_code_function grouped at the bottom
    function_name2 = "envelope_x"

    def correct_function(a1, a2, T1, T2, L1, L2, tp, x):
        w1 = 2 * np.pi / T1
        w2 = 2 * np.pi / T2
        k1 = 2 * np.pi / L1
        k2 = 2 * np.pi / L2
        Delta_w = np.abs(w1 - w2)
        Delta_k = k2 - k1
        var_amp = (a1 + a2) * np.cos(0.5 * Delta_w * tp - 0.5 * Delta_k * x)
        y = np.abs(var_amp)
        return y

    pane, ax = check_code_function(
        fig,
        horizontal_axis,
        function_name2,
        correct_function,
        parameter_x_axis,
        f_margin,
        new_graph=False,
        ax=ax,
        pane=pane,
    )
    # should be defined below the check_code_function, the function will be improved to make it possible to have the check_code_function grouped at the bottom

    function_name3 = "eta_x"

    def correct_function(a1, a2, T1, T2, L1, L2, tp, x):
        eta1 = a1 * np.sin(2 * np.pi / T1 * tp - 2 * np.pi / L1 * x)
        eta2 = a2 * np.sin(2 * np.pi / T2 * tp - 2 * np.pi / L2 * x)
        eta = eta1 + eta2
        return eta

    pane, ax = check_code_function(
        fig,
        horizontal_axis,
        function_name3,
        correct_function,
        parameter_x_axis,
        f_margin,
        new_graph=False,
        ax=ax,
        pane=pane,
        xlabel="x coordinate [m]",
        ylabel="y coordinate [m]",
    )
    display(pane)

In [None]:
def Show_Q9A():
    Show_Q9A1()
    Show_Q9A2()


# Show_Q9A()

In [None]:
def Show_Q9B():
    Show_Q9B1()
    Show_Q9B2()


# Show_Q9B()

### Q10 Multiple Choice

In [None]:
def Q10():
    # The information of the question
    question_1 = "How can the speed of the carrier wave be expressed"
    choices_1 = [
        "omega_{average}/k_{average}",
        "The group velocity (cg)",
        "The average celerity of the wave components (c1+c2)/2",
        "The speed corresponding to omega = 2 * pi/(T_{average})",
    ]
    answer_1 = choices_1[0]  # 0-based index
    hint_1 = "Unfortunately not, try another reasoning."
    comment_1 = "That is correct!"

    single_multiple_choice(question_1, choices_1, answer_1, hint_1, comment_1)


# Q10()

### Widget

In [None]:
def W2_wave_groups():
    from scipy.signal import hilbert

    # define widgets
    a1 = ipw.FloatText(value=1, min=0, max=20, step=0.01, description="a [m]")
    a2 = ipw.FloatText(value=1.5, min=0, max=20, step=0.01, description="a [m]")

    T1 = ipw.FloatText(value=7, min=0.01, max=250000, step=0.01, description="T [s]")
    T2 = ipw.FloatText(value=6.2, min=0.01, max=250000, step=0.01, description="T [s]")

    phi_1 = ipw.FloatText(
        value=0, min=-1, max=1, step=0.01, description="phi [2 pi rad]"
    )  #'$\phi$ [2 $\pi$ rad]')
    phi_2 = ipw.FloatText(
        value=0, min=-1, max=1, step=0.01, description="phi [2 pi rad]"
    )  #'$\phi$ [2 $\pi$ rad]')

    n_waves = ipw.FloatText(value=3, min=0.1, max=10, step=0.1, description="n_{waves}")

    depth = ipw.FloatText(value=20, min=1, max=250, step=0.01, description="h [m]")
    xp = ipw.FloatText(value=0, min=0, step=0.1, description="x [m]")
    tp = ipw.FloatText(value=0, min=1, step=0.1, description="t [s]")

    L1 = ipw.FloatText(
        value=wave_length(T1.value, depth.value), description="L [m]", disabled=True
    )
    L2 = ipw.FloatText(
        value=wave_length(T2.value, depth.value), description="L [m]", disabled=True
    )
    Lgroup = ipw.FloatText(
        value=group_stats(
            k1=2 * np.pi / L1.value,
            k2=2 * np.pi / L2.value,
            w1=2 * np.pi / T1.value,
            w2=2 * np.pi / T2.value,
        )[0],
        description="L_{group} [m]",
        disabled=True,
    )

    c1 = ipw.FloatText(value=L1.value / T1.value, description="c [m/s]", disabled=True)
    c2 = ipw.FloatText(value=L2.value / T2.value, description="c [m/s]", disabled=True)

    h_L1 = ipw.FloatText(
        value=depth.value / L1.value, description="h/L [m]", disabled=True
    )
    h_L2 = ipw.FloatText(
        value=depth.value / L2.value, description="h/L [m]", disabled=True
    )

    # Setup widget layout (User Interface)
    vbox1 = ipw.VBox(
        [ipw.Label("Wave 1", layout=ipw.Layout(align_self="center")), a1, T1, c1, h_L1]
    )
    vbox2 = ipw.VBox(
        [ipw.Label("Wave 2", layout=ipw.Layout(align_self="center")), a2, T2, c2, h_L2]
    )
    vbox3 = ipw.VBox(
        [
            ipw.Label("Wave group", layout=ipw.Layout(align_self="center")),
            n_waves,
            Lgroup,
        ]
    )
    vbox4 = ipw.VBox(
        [ipw.Label("General", layout=ipw.Layout(align_self="center")), depth, xp, tp]
    )

    ui = ipw.HBox([vbox1, vbox2, vbox4])

    def calc_eta(a1, T1, phi_1, a2, T2, phi_2, L1, L2, n_waves, xp, tp, depth):
        L1 = wave_length(T1, depth)
        L2 = wave_length(T2, depth)

        L_group, T_group, c_g = group_stats(
            k1=2 * np.pi / L1,
            k2=2 * np.pi / L2,
            w1=2 * np.pi / T1,
            w2=2 * np.pi / T2,
        )

        T = np.min([T1, T2])
        L = np.max([L1, L2])
        # requires additional x and t values to get a correct Hilbert transformation at the graph boundaries
        # t = np.arange(0,n_waves*T_group+T/30,T/30)
        # x = np.arange(0,n_waves*L_group+L/30,L/30)
        t = np.arange(-0.5 * n_waves * T_group, (n_waves + 0.5) * T_group, T / 30)
        x = np.arange(-0.5 * n_waves * L_group, (n_waves + 0.5) * L_group, L / 30)

        fig, axs = plt.subplots(
            nrows=3, ncols=2, figsize=(9, 5), sharex=False, sharey=False
        )
        fig.subplots_adjust(hspace=0)
        fig.subplots_adjust(wspace=0.06)

        # time based
        ax1 = axs[0, 0]
        ax2 = axs[1, 0]
        ax3 = axs[2, 0]
        # space based
        ax4 = axs[0, 1]
        ax5 = axs[1, 1]
        ax6 = axs[2, 1]

        # calculate surface, including phase change
        eta1_T = a1 * np.sin(
            2 * np.pi / T1 * t - 2 * np.pi / L1 * xp - phi_1 * (2 * np.pi)
        )
        eta2_T = a2 * np.sin(
            2 * np.pi / T2 * t - 2 * np.pi / L2 * xp - phi_2 * (2 * np.pi)
        )
        eta_T = eta1_T + eta2_T

        eta1_x = a1 * np.sin(
            2 * np.pi / T1 * tp - 2 * np.pi / L1 * x - phi_1 * (2 * np.pi)
        )
        eta2_x = a2 * np.sin(
            2 * np.pi / T2 * tp - 2 * np.pi / L2 * x - phi_2 * (2 * np.pi)
        )
        eta_x = eta1_x + eta2_x

        # calculate surface, without phase change
        eta1_T_basic = a1 * np.sin(2 * np.pi / T1 * t - 2 * np.pi / L1 * xp)
        eta2_T_basic = a2 * np.sin(2 * np.pi / T2 * t - 2 * np.pi / L2 * xp)
        eta_T_basic = eta1_T_basic + eta2_T_basic

        eta1_x_basic = a1 * np.sin(2 * np.pi / T1 * tp - 2 * np.pi / L1 * x)
        eta2_x_basic = a2 * np.sin(2 * np.pi / T2 * tp - 2 * np.pi / L2 * x)
        eta_x_basic = eta1_x_basic + eta2_x_basic

        # calculate hilbert
        eta_T_envelope = np.abs(hilbert(eta_T_basic))
        eta_x_envelope = np.abs(hilbert(eta_x_basic))

        # carier wave
        k_bar = 2 * np.pi / L1 + 2 * np.pi / L2
        w_bar = (2 * np.pi / T1 + 2 * np.pi / T2) / 2
        car_wave_t = (a1 + a2) * np.sin(w_bar * t - k_bar * xp)
        car_wave_x = (a1 + a2) * np.sin(w_bar * tp - k_bar * x)

        # variable amplitude
        Delta_k = 2 * np.pi / L1 - 2 * np.pi / L2  # Delta k
        Delta_w = 2 * np.pi / T1 - 2 * np.pi / T2
        var_amp_t = (a1 + a2) * np.cos(Delta_w / 2 * t - Delta_k / 2 * xp)
        var_amp_x = (a1 + a2) * np.cos(Delta_w / 2 * tp - Delta_k / 2 * x)

        # plot surface excluding phase change
        ax3.plot(t, eta_T_basic, color="grey", linestyle="--", label="$\eta_1$")

        # plot surface including phase change
        ax1.plot(t, eta1_T, label="$\eta_1$")
        ax2.plot(t, eta2_T, label="$\eta_2$")
        ax3.plot(t, eta_T, label="$\eta_{1+2}$")
        ax3.plot(t, eta_T_envelope, label="Envelope")
        # ax3.plot(t,np.abs(var_amp_t), label = 'abs: car wave')

        ax4.plot(x, eta1_x, label="$\eta_1$")
        ax5.plot(x, eta2_x, label="$\eta_2$")
        ax6.plot(x, eta_x, label="$\eta_{1+2}$")
        ax6.plot(x, eta_x_envelope, label="Envelope")
        # ax6.plot(x,np.abs(var_amp_x), label = 'abs: car wave')

        # set vertical axis the same
        amp = (a1 + a2) * 1.1
        ax1.set_ylim(-amp, amp)
        ax2.set_ylim(-amp, amp)
        ax3.set_ylim(-amp, amp)
        ax4.set_ylim(-amp, amp)
        ax5.set_ylim(-amp, amp)
        ax6.set_ylim(-amp, amp)

        # set horizontal axis
        ax1.set_xlim(0, n_waves * T_group)
        ax2.set_xlim(0, n_waves * T_group)
        ax3.set_xlim(0, n_waves * T_group)
        ax4.set_xlim(0, n_waves * L_group)
        ax5.set_xlim(0, n_waves * L_group)
        ax6.set_xlim(0, n_waves * L_group)

        # set labels
        ax1.set_ylabel("$\eta_1$ [m]")
        ax2.set_ylabel("$\eta_2$ [m]")
        ax3.set_ylabel("$\eta_{1+2}$ [m]")

        ax3.set_xlabel("t/T_{group}")
        ax6.set_xlabel("x/L_{group}")

        # remove the lines related to the x-ticks and y-ticks
        ax1.xaxis.set_visible(False)
        ax2.xaxis.set_visible(False)
        ax4.yaxis.set_visible(False)
        ax5.yaxis.set_visible(False)
        ax6.yaxis.set_visible(False)

        # set scaled ticks
        if n_waves >= 1:
            ax3.set_xticks(np.arange(0, n_waves // 1 + 1, 1) * T_group)
            ax3.set_xticklabels(np.arange(0, n_waves // 1 + 1, 1))
            ax6.set_xticks(np.arange(0, n_waves // 1 + 1, 1) * L_group)
            ax6.set_xticklabels(np.arange(0, n_waves // 1 + 1, 1))

        else:  # 3 times when the scale is smaller than 1
            ax3.set_xticks([0, 0.5 * n_waves * T_group, n_waves * T_group])
            ax3.set_xticklabels([0, 0.5 * n_waves, n_waves])
            ax6.set_xticks([0, 0.5 * n_waves * L_group, n_waves * L_group])
            ax6.set_xticklabels([0, 0.5 * n_waves, n_waves])

        # remove x and y ticks
        ax4.set_xticklabels([], fontsize=0)
        ax5.set_xticklabels([], fontsize=0)

        ax4.set_yticklabels([], fontsize=0)
        ax5.set_yticklabels([], fontsize=0)
        ax6.set_yticklabels([], fontsize=0)

        # set title
        ax1.set_title("Time-based (x =" + str(xp) + " m)")
        ax4.set_title("Space-based (t =" + str(tp) + " s)")

        # plot legends
        legend1 = ax4.legend(loc="center left", bbox_to_anchor=(1.001, 0.5))
        legend2 = ax5.legend(loc="center left", bbox_to_anchor=(1.001, 0.5))
        legend3 = ax6.legend(loc="center left", bbox_to_anchor=(1.001, 0.5))

    # update graph
    out = ipw.interactive_output(
        calc_eta,
        {
            "a1": a1,
            "T1": T1,
            "phi_1": phi_1,
            "a2": a2,
            "T2": T2,
            "phi_2": phi_2,
            "L1": L1,
            "L2": L2,
            "n_waves": n_waves,
            "xp": xp,
            "tp": tp,
            "depth": depth,
        },
    )

    display(ui, out)


# W2_wave_groups()

### Q11) Theory questions - Not equal amplitudes

In [None]:
try:
    %run Initialize/init_cookbook.ipynb # valid when running the cookbook in the main file
except:
    %run init_cookbook.ipynb # valid when running the cookbook from this file.

def Q11():
    question = 'If now a1 = 1 m instead of 1.5 m and a2 remains 1.5 m, what happens?'
    correct_statements = ["The group period is unaffected", "The minimum of the envelope is nonzero"]
    false_statements = ["The group length becomes longer ", "The carrier wave has a smaller phase velocity"]

    question2 = 'Suppose that one of the parameters is changed by 10%. Which change will increase the group period?'
    correct_statements2 = ["Increasing T2"]
    false_statements2 = ["Increasing a1", "Reducing the water depth", "Increasing T1"]

    question3 = ['How large is the maximum and minimum of the envelope if a1 = 1 m and a2 = 1.5 m?']
    answer3 = [[2.5, 0.5]]
    subquestions3 = [['Maximum', 'Minimum']]
    Q3_unit = 'm'
    
    Q1 = multiple_selection(question, correct_statements, false_statements)
    Q2 = multiple_selection(question2, correct_statements2, false_statements2)
    Q3 = nummerical_subquestions(question3, answer3, subquestions3, Q3_unit)
   
    return pn.Column(Q1,Q2, *Q3)

#Q11()

### Q12) Coding question

In [None]:
def W2_Q12():
    # the question-related parameters
    # T1 = round(uniform(5,7), 1)
    # T2 = round(uniform(7,10), 1)

    # Unrandomized to match presentation
    T1 = 7
    T2 = 6.2
    h = 20

    # The question that is asked
    question = (
        "Can you calculate the difference in wave number (Delta_k) and radials frequency (Delta_w) when the periods of wave 1 and 2 are "
        + str(T1)
        + " and "
        + str(T2)
        + " seconds and the water depth is "
        + str(h)
        + ' m by completing the code below? You can (re)check your computation by running the code, then press the "load" button (or fill in the answer manually), and finally press the button "check loaded values". \n You can calculate the wavelength based on the ratio h/L defined in the graph above or use your previously defined function.'
    )

    # The calculation
    L1 = wave_length(T1, h)
    L2 = wave_length(T2, h)

    k1 = 2 * np.pi / L1
    k2 = 2 * np.pi / L2
    w1 = 2 * np.pi / T1
    w2 = 2 * np.pi / T2

    Delta_k = np.abs(k2 - k1)
    Delta_w = np.abs(w2 - w1)
    k_average = np.average([k1, k2])  # not used any more
    w_average = np.average([w1, w2])  # not used any more

    c1 = L1 / T1
    c2 = L2 / T2

    c_average = np.average([c1, c2])
    L_group, T_group, c_group = group_stats(k1, k2, w1, w2)

    # Required widgets for functionality, do not have to be changed
    attempt_A = pn.widgets.FloatInput(value=0)
    attempt_B = pn.widgets.FloatInput(value=0)

    question_widget = pn.widgets.StaticText(name="", value=question)
    display(question_widget)

    # define a new global variable, so return (and related print) is prevented,
    # + a required parameter to count the attempt, dont change the name.
    global W2_Q12_param
    # store the question-related parameters and
    W2_Q12_param = (
        T1,
        T2,
        h,
        Delta_k,
        Delta_w,
        k_average,
        w_average,
        c_average,
        L_group,
        T_group,
        c_group,
        attempt_A,
        attempt_B,
    )


def Check_W2_Q12A():
    (
        T1,
        T2,
        h,
        Delta_k,
        Delta_w,
        k_average,
        w_average,
        c_average,
        L_group,
        T_group,
        c_group,
        attempt_A,
        attempt_B,
    ) = W2_Q12_param
    attempt = attempt_A

    # define the parameter names that have to be checked
    check_parameters = ["Delta_k", "Delta_w"]

    # define the names of the parameters as they are displayed
    name_parameters = ["Delta k", "Delta w"]

    # additional settings
    n_decimals = 3  # decimals of the answer shown when 3 wrong answers
    f_margin = 0.01  # (1%) the allowed error margin

    # build  the question
    FV = classify_variables(locals())
    check_code_values(FV)


def Check_W2_Q12B():
    (
        T1,
        T2,
        h,
        Delta_k,
        Delta_w,
        k_average,
        w_average,
        c_average,
        L_group,
        T_group,
        c_group,
        attempt_A,
        attempt_B,
    ) = W2_Q12_param
    attempt = attempt_B

    # define the parameter names that have to be checked
    check_parameters = ["c_group", "L_group", "T_group"]

    # define the names of the parameters as they are displayed
    name_parameters = ["cg [m/s]", "L group [m]", "T group [s]"]

    # additional settings
    n_decimals = 2  # decimals of the answer shown when 3 wrong answers
    f_margin = 0.01  # 1.0% the allowed error margin

    # build the question
    FV = classify_variables(locals())
    check_code_values(FV)

### W2_Wave_animation

In the description add:
- n_waves = number of wave groups shown

A factor to change the speed (time) should be included.

In [None]:
def W2_Wave_animation():
    from scipy.signal import hilbert

    # adjusting the graph while it is on pause can be achieved by widget.observe(). See old Week_3_initialize.ipynb
    %matplotlib widget
    a1 = pn.widgets.FloatInput(
        name="a1 [m]", start=0, end=3, step=0.01, value=1, width=75
    )
    a2 = pn.widgets.FloatInput(
        name="a2 [m]", start=0, end=3, step=0.01, value=1, width=75
    )

    T1 = pn.widgets.FloatInput(
        name="T1 [s]", start=0, step=0.01, value=5, width=75
    )  # 3
    T2 = pn.widgets.FloatInput(
        name="T2 [s]", start=0, step=0.01, value=3, width=75
    )  # 5

    depth = pn.widgets.FloatInput(
        name="h [m]", start=0.01, end=250, step=0.01, value=2, width=75
    )
    n_waves = pn.widgets.FloatInput(
        name="n_waves group", start=0.02, end=20, step=0.01, value=10, width=75
    )

    time = pn.widgets.FloatInput(name="time [s]", start=0, value=0, width=75)
    f_time = pn.widgets.FloatInput(
        name="play speed", start=0.1, value=1, width=75, step=0.1
    )

    L1 = pn.widgets.FloatInput(
        name="L1 [m]",
        start=0,
        step=0.01,
        value=wave_length(T1.value, depth.value),
        width=75,
        disabled=True,
    )
    L2 = pn.widgets.FloatInput(
        name="L2 [m]",
        start=0,
        step=0.01,
        value=wave_length(T2.value, depth.value),
        width=75,
        disabled=True,
    )

    L_group, T_group, c_g = group_stats(
        k1=2 * np.pi / L1.value,
        k2=2 * np.pi / L2.value,
        w1=2 * np.pi / T1.value,
        w2=2 * np.pi / T2.value,
    )

    # Setup widget layout (User Interface) and display
    vbox1 = pn.Column("Wave 1", a1, T1, L1)
    vbox2 = pn.Column("Wave 2", a2, T2, L2)
    vbox3 = pn.Column(
        "General",
        depth,
        n_waves,
    )
    vbox4 = pn.Column("Play settings", f_time, time)

    # Define and display User Interface (UI)
    ui = pn.Row(vbox1, vbox2, vbox3, vbox4)
    display(ui)

    # Setup linear mesh (x) and duration before time (t) is reset
    x_max = L_group * n_waves.value
    x = np.linspace(0, x_max, 500)

    # Create figure, set structure and initial layout
    fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(9, 5), sharex=True)
    plt.tight_layout()
    fig.subplots_adjust(hspace=0)
    fig.subplots_adjust(bottom=0.2)  # Add some extra space for the axis at the bottom
    fig.subplots_adjust(left=0.1)  # Add some space for the labels

    # set grid distance
    grid_x = MultipleLocator(base=25)
    grid_y = MultipleLocator(base=5)

    # Set grid for multiple axes
    for ax in axs:
        ax.xaxis.set_minor_locator(grid_x)
        ax.yaxis.set_minor_locator(grid_y)
        ax.grid(which="both", linestyle="-", linewidth="0.5", color="gray", alpha=0.6)

    # Compute initial displacement
    eta = a1.value * np.sin(2 * np.pi / L1.value * x) + a2.value * np.sin(
        2 * np.pi / L2.value * x
    )
    eta1 = a1.value * np.sin(2 * np.pi / L1.value * x)
    eta2 = a2.value * np.sin(2 * np.pi / L2.value * x)

    k_bar = (2 * np.pi / L1.value + 2 * np.pi / L2.value) / 2
    car_wave = (a1.value + a2.value) * np.sin(k_bar * x)  # carrier wave

    k_dif = 2 * np.pi / L1.value - 2 * np.pi / L2.value  # Delta k
    var_amp = np.cos(-k_dif / 2 * x)  # variable amplitude
    var_amp_scaled = (a1.value + a2.value) * np.cos(
        -k_dif / 2 * x
    )  # variable amplitude

    # calculate hilbert
    eta_hilbert = np.abs(hilbert(eta))

    # Plot initial wave
    (line1,) = axs[0].plot(
        x, eta1, label="$\u03B7_1$ (wave 1)", linewidth=0.75, color="#0b5394"
    )
    (line2,) = axs[1].plot(
        x, eta2, label="$\u03B7_2$ (wave 2)", linewidth=0.75, color="#0b5394"
    )  # 03396c
    (line,) = axs[2].plot(
        x, eta, label="$\u03B7$ (wave 1+2)", linewidth=0.75, color="k"
    )
    (line_var,) = axs[2].plot(
        x, var_amp_scaled, label="Variable amplitude", linewidth=0.75, color="gray"
    )

    # set initial layout and make legends
    amp = a1.value + a2.value
    for ax in axs:
        ax.set_xlim(0, x_max)
        ax.set_ylim(-amp * 1.15, amp * 1.15)
        legend = ax.legend(loc="lower right")

        for text in legend.get_texts():
            text.set_fontsize(7)  # Set individual legend item text size

    start_time = timeit.default_timer()

    # adjust the graph when the animation is running
    def update_line(change):
        t = change.new
        t = (timeit.default_timer() - start_time) * f_time.value

        # adjust the widget
        time.value = t

        L1.value = wave_length(T1.value, h=depth.value)
        L2.value = wave_length(T2.value, h=depth.value)
        L_group, T_group, c_g = group_stats(
            k1=2 * np.pi / L1.value,
            k2=2 * np.pi / L2.value,
            w1=2 * np.pi / T1.value,
            w2=2 * np.pi / T2.value,
        )
        x_max = L_group * n_waves.value
        for ax in axs:
            ax.set_xlim(0, x_max)

        x = np.linspace(0, x_max, 750)

        # calculate sea surface elevation
        eta = a1.value * np.sin(
            2 * np.pi / T1.value * t - 2 * np.pi / L1.value * x
        ) + a2.value * np.sin(2 * np.pi / T2.value * t - 2 * np.pi / L2.value * x)
        eta1 = a1.value * np.sin(2 * np.pi / T1.value * t - 2 * np.pi / L1.value * x)
        eta2 = a2.value * np.sin(2 * np.pi / T2.value * t - 2 * np.pi / L2.value * x)

        omega_bar = (2 * np.pi / T1.value + 2 * np.pi / T2.value) / 2
        k_bar = (2 * np.pi / L1.value + 2 * np.pi / L2.value) / 2

        # calculate hilbert
        eta_hilbert = np.abs(hilbert(eta))

        # adjust sea surface elevation (line) in plot
        line.set_ydata(eta)
        line1.set_ydata(eta1)
        line2.set_ydata(eta2)
        line_var.set_ydata(eta_hilbert)

        line.set_xdata(x)
        line1.set_xdata(x)
        line2.set_xdata(x)
        line_var.set_xdata(x)

        # adjust the line thickness, to 0 (not visible) if the amplitude is zero
        # if a1.value > 0:
        #    line1.set_linewidth(0.75)
        # else:
        #    line1.set_linewidth(0)
        #
        # if a2.value > 0:
        #    line2.set_linewidth(0.75)
        # else:
        #    line2.set_linewidth(0)

        amp = a1.value + a2.value
        for ax in axs:
            ax.set_ylim(-amp * 1.15, amp * 1.15)

        fig.canvas.draw()

    delta_t = 0.0005  # s
    discrete_player = pn.widgets.DiscretePlayer(
        name="Discrete Player",
        options=np.arange(0, 500, delta_t * f_time.value).tolist(),
        value=0,
        loop_policy="loop",
        interval=int(delta_t * 1000),
    )

    discrete_player.param.watch(update_line, "value")
    display(discrete_player)
    # pn.panel(discrete_player)


# W2_Wave_animation()