# Torch MDCT


## Imports


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
from IPython.display import Audio
from torchaudio import functional as F

from torch_mdct import MDCT, InverseMDCT


## Helper Functions


In [None]:
def plot_waveform(waveform, sample_rate, title):
    channels, n_frames = waveform.shape

    skip = int(n_frames / (0.01 * n_frames))
    waveform = waveform[..., 0:-1:skip]

    n_frames = waveform.shape[-1]
    time_axis = torch.linspace(0,
                               n_frames / (sample_rate / skip),
                               steps=n_frames)

    fig, axes = plt.subplots(2, max(channels // 2, 1), constrained_layout=True)
    axes = axes.flatten()

    for c in range(channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)

        if channels > 1:
            axes[c].set_ylabel(f"Channel {c}")

    fig.suptitle(title)
    plt.xlabel("Time (s)")
    plt.show(block=False)


def plot_spectrogram(spectrogram, title):
    channels = spectrogram.shape[0]

    fig, axes = plt.subplots(2, max(channels // 2, 1), constrained_layout=True)
    axes = axes.flatten()

    for c in range(channels):
        im = axes[c].imshow(torch.log(spectrogram[c].abs() + 1e-5),
                            origin="lower",
                            aspect="auto")
        fig.colorbar(im, ax=axes[c])

        if channels > 1:
            axes[c].set_ylabel(f"Channels {c}")

    fig.suptitle(title)
    plt.xlabel("Time")
    plt.show(block=False)


def stats(x):
    print(
        f"Shape: {x.shape} Min: {x.min():.4f} Max: {x.max():.4f} Mean: {x.mean():.4f} Std: {x.std():.4f}"
    )

## Data Loading & Transforms


In [None]:
waveform, sample_rate = torchaudio.load("sample_audio.ogg")
stats(waveform)

In [None]:
mdct = MDCT(win_length=2048)
imdct = InverseMDCT(win_length=2048)

## Experiments


In [None]:
# Transform audio into mdct specgram
specgram = mdct(waveform)
stats(specgram)

In [None]:
# Visualize the mdct specgram
plot_spectrogram(specgram, "Log Absolute MDCT Spectrogram")

In [None]:
# Convert mdct spectrogram back to audio
waveform_reconst = imdct(specgram, length=waveform.shape[-1])
stats(waveform_reconst)

In [None]:
# Plot the original audio
plot_waveform(waveform, sample_rate, "Original Audio")

In [None]:
# Listen to the original audio
Audio(waveform, rate=sample_rate)

In [None]:
# Plot the reconstructed audio
plot_waveform(waveform_reconst, sample_rate, "Reconstructed Audio")

In [None]:
# Listen to the reconstructed audio
Audio(waveform_reconst, rate=sample_rate)

In [None]:
# L1 distance between the two audio samples
print(f"L1 Loss: {(waveform - waveform_reconst).abs().mean()}")