# How to model audio waveforms

* Which type of distribution is a good choice for modelling audio waveform frames?

* Are we correctly normalizing the discretized logistic distributions?

* Can we benefit from using a discretized Laplace distribution?

* What is the conditional distribution of the next audio frame given the previous frame?

* What is the conditional distribution of the next audio frame given the two previous frames?

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from glob import glob

import torch
import matplotlib.pyplot as plt
import numpy as np

from vseq.data.datapaths import TIMIT_TEST
from vseq.settings import DATA_DIRECTORY
from vseq.utils.log_likelihoods import gaussian_ll, discretized_logistic_ll, discretized_logistic_mixture_ll, discretized_laplace_ll, discretized_laplace_mixture_ll

## Waveform frame values

In [None]:
import math

import torchaudio
import numpy as np

from vseq.data.transforms import MuLawEncode, MuLawDecode, Quantize, Scale

In [None]:
# f1 = [os.path.join("/data/research/data/timit/test/DR1/FAKS0", f) for f in os.listdir("/data/research/data/timit/test/DR1/FAKS0") if f.endswith("WAV")]
# f2 = [os.path.join("/data/research/data/timit/test/DR1/FDAC1", f) for f in os.listdir("/data/research/data/timit/test/DR1/FDAC1") if f.endswith("WAV")]
# f3 = [os.path.join("/data/research/data/timit/test/DR1/FELC0", f) for f in os.listdir("/data/research/data/timit/test/DR1/FELC0") if f.endswith("WAV")]
# fs = [*f1, *f2, *f3]
# fs

In [None]:
DATA_DIRECTORY

In [None]:
DATA_DIRECTORY + "/timit/test/**/.*wav"

In [None]:
fs = glob(DATA_DIRECTORY + "/timit/test/**/*.wav", recursive=True)
fs = fs[:100]

In [None]:
audios = []
for f in fs:
    a, sr = torchaudio.load(f, normalize=True)
    audios.append(a)

a = torch.cat(audios, dim=1)#.squeeze(0)
#a, fs = torchaudio.load("/data/research/data/timit/TEST/DR1/FAKS0/SA2.WAV", normalize=True)
a.shape

In [None]:
a, a.shape, a.shape[1] / 16000, a.min(), a.max(), a.abs().min(), a.unique().numel()

## Waveform values with different transforms

In [None]:
fig, axes = plt.subplots(1, 4, sharey=False, figsize=(20, 5))

# axes[0].hist(a, bins=256, range=(-1, 1))
axes[0].hist(a, bins=256, range=(-1, 1));

me = MuLawEncode(bits=8)
# axes[2].hist(me(a), bins=256, range=(-1, 1));
axes[1].hist(me(a), bins="auto", range=(-1, 1));
# axes[1].hist(me(a), bins=me(torch.linspace(-1, 1, 256)), range=(-1, 1))

me = MuLawEncode(bits=10)
# axes[2].hist(me(a), bins=256, range=(-1, 1));
axes[2].hist(me(a), bins="auto", range=(-1, 1));
# axes[2].hist(me(a), bins=me(torch.linspace(-1, 1, 1024)), range=(-1, 1));

me = MuLawEncode(bits=16)
# axes[3].hist(me(a), bins=256, range=(-1, 1));
axes[3].hist(me(a), bins="auto", range=(-1, 1));
# hist = axes[3].hist(me(a), bins=me(torch.linspace(-1, 1, 2000)), range=(-1, 1));

In [None]:
fig, axes = plt.subplots(1, 4, sharey=False, figsize=(20, 5))

a_dequantized = a + torch.rand_like(a) * 1/(2**16) - 0.5/(2**16)

# axes[0].hist(a, bins=256, range=(-1, 1))
axes[0].hist(a_dequantized, bins=256, range=(-1, 1));

me = MuLawEncode(bits=8)
# axes[2].hist(me(a), bins=256, range=(-1, 1));
axes[1].hist(me(a_dequantized), bins="auto", range=(-1, 1));
# axes[1].hist(me(a), bins=me(torch.linspace(-1, 1, 256)), range=(-1, 1))

me = MuLawEncode(bits=10)
# axes[2].hist(me(a), bins=256, range=(-1, 1));
axes[2].hist(me(a_dequantized), bins="auto", range=(-1, 1));
# axes[2].hist(me(a), bins=me(torch.linspace(-1, 1, 1024)), range=(-1, 1));

me = MuLawEncode(bits=16)
# axes[3].hist(me(a), bins=256, range=(-1, 1));
axes[3].hist(me(a_dequantized), bins="auto", range=(-1, 1));
# hist = axes[3].hist(me(a), bins=me(torch.linspace(-1, 1, 2000)), range=(-1, 1));

In [None]:
fig, axes = plt.subplots(1, 4, sharey=False, figsize=(20, 5))

axes[0].hist(a, bins=256, range=(-1, 1))

me = MuLawEncode(bits=8)
q = Quantize(bits=8, rescale=True)
axes[1].hist(q(me(a)), bins=256, range=(-1, 1))

me = MuLawEncode(bits=10)
q = Quantize(bits=8, rescale=True)
axes[2].hist(q(me(a)), bins=256, range=(-1, 1))

me = MuLawEncode(bits=16)
q = Quantize(bits=8, rescale=True)
axes[3].hist(q(me(a)), bins=256, range=(-1, 1));

In [None]:
fig, axes = plt.subplots(1, 4, sharey=False, figsize=(20, 5))

s = Scale(low=-1, high=1, min_val=a.min(), max_val=-a.min())
axes[0].hist(s(a), bins=256, range=(-1, 1))

s = Scale(low=-1, high=1, min_val=a.min(), max_val=-a.min())
me = MuLawEncode(bits=8)
axes[1].hist(me(s(a)), bins=256, range=(-1, 1))

s = Scale(low=-1, high=1, min_val=a.min(), max_val=-a.min())
me = MuLawEncode(bits=10)
axes[2].hist(me(s(a)), bins=256, range=(-1, 1))

s = Scale(low=-1, high=1, min_val=a.min(), max_val=-a.min())
me = MuLawEncode(bits=16)
axes[3].hist(me(s(a)), bins=256, range=(-1, 1));


In [None]:
a_8bit = q(me(a))
a_8bit = (a_8bit + 1) / 2 * 255
a_8bit = a_8bit.to(int)

In [None]:
a_8bit.unique()

## Conditional waveform values

In [None]:
plt.hist(a_8bit, bins=256, range=(0, 255));

In [None]:
# p(x_t+1 | x_t): Given x_t what is x_t+1
c = 90
idx = a_8bit == c
idx = torch.where(idx.squeeze(0))[0]
plt.hist(a_8bit[:,idx+1], bins=256, range=(0,255));

In [None]:
# p(x_t+2 | x_t, x_t+1): Given x_t and a range for x_t+1, what is x_t+2
c1 = 90
c2 = (50, 70)  #(160, 170)
idx = (a_8bit[:, :-1] == c1) * (torch.logical_and(c2[0] <= a_8bit[:, 1:], a_8bit[:, 1:] <= c2[1]))
idx = torch.where(idx.squeeze(0))[0]
plt.hist(a_8bit[:,idx+2], bins=256, range=(0,255));

In [None]:
a * 2 ** 16

In [None]:
sorted(a.abs().unique())[:10], sorted(a.abs().unique())[-10:]

In [None]:
sorted(me(a).abs().unique())[:10], sorted(me(a).abs().unique())[-10:]

# Distributions

## Gaussian

In [None]:
x = torch.linspace(-5, 5, 200)
gauss_log_pdf = gaussian_ll(x, torch.zeros_like(x), torch.ones_like(x))
plt.plot(x, gauss_log_pdf.exp())
print(np.trapz(gauss_log_pdf.exp(), dx=10/200))

## Discretized Logistic

In [None]:
NUM_BINS = 64

In [None]:
edges = torch.arange(0, NUM_BINS + 1)
print(edges)


In [None]:
centers = edges[1:] - 0.5
centers

In [None]:
edges = torch.arange(0, NUM_BINS + 1) / NUM_BINS
edges

In [None]:
centers = edges[1:] - 0.5 * (1 / NUM_BINS)
centers

In [None]:
edges = torch.arange(0, NUM_BINS + 1) / NUM_BINS * 2 - 1
edges

In [None]:
centers = edges[1:] - 0.5 * (2 / NUM_BINS)
print(centers)

In [None]:
2 / NUM_BINS - 1, 1 - 2 / NUM_BINS

In [None]:
0.5 / NUM_BINS - 1, 1 - 0.5 / NUM_BINS

In [None]:
x

In [None]:
edges = torch.arange(0, NUM_BINS + 1) / NUM_BINS * 2 - 1
centers = edges[1:] - 0.5 * (2 / NUM_BINS)
x = centers
print(x)

mean = 0.8 * torch.ones_like(x)
log_scale = torch.log(0.1 * torch.ones_like(x))
discretized_logistic_log_pdf = discretized_logistic_ll(x, mean, log_scale, num_bins=NUM_BINS)
plt.plot(x, discretized_logistic_log_pdf.exp())
print(discretized_logistic_log_pdf.exp().sum())

In [None]:
%timeit discretized_logistic_ll(x, mean, log_scale, num_bins=NUM_BINS)

## Discretized Logistic Mixture

In [None]:
NUM_BINS = 64

In [None]:
x

In [None]:
edges = torch.arange(0, NUM_BINS + 1) / NUM_BINS * 2 - 1
centers = edges[1:] - 0.5 * (2 / NUM_BINS)
x = centers

logit_probs = torch.tensor([0.3, 0.7])
mean = torch.stack([0.6 * torch.ones_like(x), 0.6 * torch.ones_like(x)], dim=-1)
log_scale = torch.stack([torch.log(0.1 * torch.ones_like(x)), torch.log(0.1 * torch.ones_like(x))], dim=-1)
discretized_logistic_mixture_log_pdf = discretized_logistic_mixture_ll(x, logit_probs, mean, log_scale, num_mix=2, num_bins=NUM_BINS)
plt.plot(x, discretized_logistic_mixture_log_pdf.exp())
print(discretized_logistic_log_pdf.exp().sum())

In [None]:
%timeit discretized_logistic_mixture_ll(x, logit_probs, mean, log_scale, num_mix=2, num_bins=NUM_BINS)

## Discretized Laplace

In [None]:
NUM_BINS = 64

In [None]:
edges = torch.arange(0, NUM_BINS + 1) / NUM_BINS * 2 - 1
centers = edges[1:] - 0.5 * (2 / NUM_BINS)
x = centers

mean = 0.8 * torch.ones_like(x)
log_scale = torch.log(0.1 * torch.ones_like(x))
discretized_logistic_log_pdf = discretized_laplace_ll(x, mean, log_scale, num_bins=NUM_BINS)
plt.plot(x, discretized_logistic_log_pdf.exp())
print(discretized_logistic_log_pdf.exp().sum())

In [None]:
%timeit discretized_laplace_ll(x, mean, log_scale, num_bins=NUM_BINS)

## Discretized Mixture of Laplacians

In [None]:
NUM_BINS = 64

In [None]:
edges = torch.arange(0, NUM_BINS + 1) / NUM_BINS * 2 - 1
centers = edges[1:] - 0.5 * (2 / NUM_BINS)
x = centers

logit_probs = torch.tensor([0.3, 0.7])
mean = torch.stack([0.2 * torch.ones_like(x), 0.6 * torch.ones_like(x)], dim=-1)
log_scale = torch.stack([torch.log(0.1 * torch.ones_like(x)), torch.log(0.1 * torch.ones_like(x))], dim=-1)
discretized_logistic_mixture_log_pdf = discretized_laplace_mixture_ll(x, logit_probs, mean, log_scale, num_mix=2, num_bins=NUM_BINS)
plt.plot(x, discretized_logistic_mixture_log_pdf.exp())
print(discretized_logistic_log_pdf.exp().sum())

In [None]:
x.shape, mean.shape, log_scale.shape, logit_probs.shape

In [None]:
%timeit discretized_laplace_mixture_ll(x, logit_probs, mean, log_scale, num_bins=NUM_BINS, num_mix=2)