In [1]:
%matplotlib widget
from dataclasses import InitVar, dataclass, field
from itertools import repeat
from typing import List, NamedTuple

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from IPython.display import Audio

sns.set()
make_wav = Audio._make_wav


@dataclass
class FourierTerm:
    frequency: int
    amplitude: float = 1
    phase: float = 0

    def __post_init__(self):
        self.frequency = widgets.IntSlider(
            value=self.frequency,
            min=0,
            max=1000,
            step=1,
            description="Frequency:",
            readout=False,
        )
        frequencybox = widgets.BoundedIntText(
            min=0, max=1000, layout=widgets.Layout(width="80px")
        )
        widgets.jslink((self.frequency, "value"), (frequencybox, "value"))
        self.amplitude = widgets.FloatSlider(
            value=self.amplitude,
            min=0,
            max=2,
            step=0.01,
            description="Amplitude:",
            readout=False,
        )
        self.phase = widgets.FloatSlider(
            value=self.phase,
            min=0,
            max=1,
            step=0.01,
            description="Phase:",
            readout=False,
        )
        values = {"f": self.frequency, "a": self.amplitude, "p": self.phase}
        self.text = widgets.interactive_output(
            lambda f, a, p: print(f"{a:.2f} sin( 2π({f:4d} t + {p:.2f}))"), values
        )
        self.out = widgets.HBox(
            [
                self.text,
                widgets.HBox(
                    [
                        widgets.HBox([self.frequency, frequencybox]),
                        self.amplitude,
                        self.phase,
                    ]
                ),
            ]
        )

    def _ipython_display_(self):
        display(self.out)

    def __call__(self, t):
        f, a, p = self.frequency.value, self.amplitude.value, self.phase.value
        return a * np.sin(2 * np.pi * (f * t + p))

    def observe(self, *args, **kwargs):
        for widget in [self.frequency, self.amplitude, self.phase]:
            widget.observe(*args, **kwargs)


def list_join(separator, to_join):
    return [x for pair in zip(repeat(separator), to_join) for x in pair][1:]


@dataclass
class AudioDemo:
    duration_seconds: int = 3
    framerate: int = 44100
    autoplay: bool = True
    terms: List[FourierTerm] = field(
        default_factory=lambda: [FourierTerm(220), FourierTerm(228, 1, 0.5)]
    )

    def __post_init__(self):
        self.time = np.linspace(
            0, self.duration_seconds, self.framerate * self.duration_seconds
        )

        self.terms_out = widgets.Output()
        self.new_term_button = widgets.Button(
            description="Add a term", icon="wave-square"
        )
        self.new_term_button.on_click(lambda button: self._make_term())
        self._reload_terms()

        # do audio before plotting so that the plotting routines
        # can update the audio along the way
        self.audio = widgets.Audio(rate=self.framerate, autoplay=self.autoplay)

        self.fig, self.ax = plt.subplots()
        self.ax.set_xlim(0, 20 / min(t.frequency.value for t in self.terms))
        self.ax.set_xlabel("time(seconds)")
        self.ax.set_ylabel("amplitude(unitless)")
        self._replot()

        self.out = widgets.Output(layout={"border": "1px solid grey"})
        with self.out:
            display(self.terms_out)
            display(self.audio)
            plt.show()

    def _reload_terms(self):
        self.terms_out.clear_output(wait=True)
        with self.terms_out:
            display(
                widgets.VBox(
                    [term.out for term in self.terms]
                    + [
                        self.new_term_button,
                        widgets.HBox(
                            [widgets.HTMLMath("f(t) =")]
                            + list_join(
                                widgets.HTMLMath("+"), [t.text for t in self.terms]
                            )
                        ),
                    ]
                )
            )

    def _replot(self):
        self.ax.lines.clear()
        self.data = [term(self.time) for term in self.terms]
        for d in self.data:
            plt.plot(self.time, d)
        total = sum(self.data)
        self.ax.plot(self.time, total, color="black")[0]
        self.audio.value = make_wav(total, rate=self.framerate, normalize=True)

        for d, term, line in zip(self.data, self.terms, self.ax.get_lines()):

            def update_plot(
                change, d=d, term=term, line=line, generation=len(self.data)
            ):
                if generation != len(self.data):
                    # a hackish way to disable old callbacks when we make a new term
                    return
                d[:] = term(self.time)
                total = sum(self.data)
                line.set_ydata(d)
                self.ax.get_lines()[-1].set_ydata(total)
                self.audio.value = make_wav(total, rate=self.framerate, normalize=True)
                self.fig.canvas.draw_idle()

            term.observe(update_plot)

    def _make_term(self, frequency=440, **kwargs):
        term = FourierTerm(frequency=frequency, **kwargs)
        self.terms.append(term)
        self._reload_terms()
        self._replot()

    def _ipython_display_(self):
        display(self.out)

In [2]:
AudioDemo()

Output(layout=Layout(border='1px solid grey'))

In [3]:
g_major = AudioDemo(terms=[FourierTerm(f) for f in [196, 247, 294]])
display(g_major)

Output(layout=Layout(border='1px solid grey'))