# Example Usage

## Setup

In [13]:
# Uncomment to install the torch_mdct package
# %pip install torch_mdct

In [14]:
# Comment out the following line if you have installed the torch_mdct package
import sys

sys.path.append("../src")

## Imports

In [15]:
import torch
import torchaudio
from IPython.display import Audio
from matplotlib import pyplot as plt

from torch_mdct import IMDCT, MDCT

## Utils

In [16]:
def plot_waveform(waveform: torch.Tensor, sample_rate: int, title: str) -> None:
    channels, n_frames = waveform.shape

    skip = int(n_frames / (0.01 * n_frames))
    waveform = waveform[..., 0:-1:skip]

    n_frames = waveform.shape[-1]
    time_axis = torch.linspace(0, n_frames / (sample_rate / skip), steps=n_frames)

    fig, axes = plt.subplots(2, max(channels // 2, 1), constrained_layout=True)
    axes = axes.flatten()

    for c in range(channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)

        if channels > 1:
            axes[c].set_ylabel(f"Channel {c}")

    fig.suptitle(title)
    plt.xlabel("Time (s)")
    plt.show(block=False)


def plot_spectrogram(spectrogram: torch.Tensor, title: str) -> None:
    channels = spectrogram.shape[0]

    fig, axes = plt.subplots(2, max(channels // 2, 1), constrained_layout=True)
    axes = axes.flatten()

    for c in range(channels):
        im = axes[c].imshow(
            torch.log(spectrogram[c].abs() + 1e-5), origin="lower", aspect="auto"
        )
        fig.colorbar(im, ax=axes[c])

        if channels > 1:
            axes[c].set_ylabel(f"Channels {c}")

    fig.suptitle(title)
    plt.xlabel("Time")
    plt.show(block=False)


def stats(x: torch.Tensor) -> str:
    return f"Shape: {tuple(x.shape)} Min: {x.min():.2f} Max: {x.max():.2f} Mean: {x.mean():.2f} Std: {x.std():.2f}"

## DataLoading

In [None]:
waveform, sample_rate = torchaudio.load("audio_samples/sample.ogg")
Audio(waveform, rate=sample_rate)

In [None]:
plot_waveform(waveform, sample_rate, f"Waveform: \n({stats(waveform)})")

## Transforms

In [19]:
mdct = MDCT(win_length=2048)
imdct = IMDCT(win_length=2048)

## MDCT Experiments

In [None]:
spectrogram = mdct(waveform)
plot_spectrogram(spectrogram, f"Log Absolute Spectrogram: \n({stats(spectrogram)})")

## IMDCT Experiments

In [None]:
reconst_waveform = imdct(spectrogram, length=waveform.shape[-1])
Audio(reconst_waveform, rate=sample_rate)

In [None]:
plot_waveform(
    reconst_waveform,
    sample_rate,
    f"Reconstructed Waveform: \n({stats(reconst_waveform)})",
)

## Waveform Difference

In [None]:
waveform_diff = waveform - reconst_waveform
Audio(waveform_diff, rate=sample_rate)

In [None]:
plot_waveform(
    waveform_diff,
    sample_rate,
    f"Waveform Difference: \n({stats(waveform_diff)})",
)