In [None]:
from a import *

In [None]:
from IPython.display import Audio

In [None]:
def playHard(data):
    return Audio(data, rate = SR)
def play(data, soft = .1):
    t = np.concatenate([data, [1]])
    length = round(soft * SR)
    t[:length ] = np.multiply(t[:length ], np.linspace(0, 1, length))
    t[-length:] = np.multiply(t[-length:], np.linspace(1, 0, length))
    return playHard(t)


In [None]:
raw = []

y, sr = librosa.load('dan.wav', sr=SR)
assert sr == SR
raw.append(y)

# y, sr = librosa.load('yanhe.wav', sr=SR)
# assert sr == SR
# raw.append(y)


In [None]:
print('stft...')
freqs, times, Zxx = stft(
    y, fs=SR, nperseg=PAGE_LEN, 
)
spectrogram = np.abs(Zxx)


In [None]:
f0s = []
amps = []
timbres: List[List[Harmonic]] = []

for page_i, (t, page) in tqdm(
    [*enumerate(zip(times, pagesOf(y)))], 
    desc='extract timbre', 
):
    # spectrum = spectrogram[:, page_i]
    spectrum = np.abs(rfft(page * HANN)) / PAGE_LEN
    f0 = yin(
        page, SR, PAGE_LEN, 
        fmin=pitch2freq(36), 
        fmax=pitch2freq(84), 
    )
    harmonics_f = np.arange(f0, NYQUIST, f0)
    assert harmonics_f.size < N_HARMONICS
    harmonics_a_2 = np.zeros((harmonics_f.size, ))
    spectrum_2 = np.square(spectrum)
    bins_taken = 0
    for partial_i, freq in enumerate(harmonics_f):
        mid_f_bin = round(freq * PAGE_LEN / SR)
        for offset in range(-2, 3):
            try:
                harmonics_a_2[partial_i] += spectrum_2[
                    mid_f_bin + offset
                ]
            except IndexError:
                pass
            else:
                bins_taken += 1
    mean_bin_noise = (spectrum_2.sum() - harmonics_a_2.sum()) / (
        len(spectrum_2) - bins_taken
    )
    harmonics_a_2[harmonics_a_2 < 2 * mean_bin_noise] = 0
    harmonics_a = np.sqrt(harmonics_a_2)

    harmonics = [
        Harmonic(f, a) for (f, a) in zip(
            harmonics_f, 
            harmonics_a, 
        )
    ]
    freq = harmonics_f[-1]
    for _ in range(len(harmonics), N_HARMONICS):
        freq += f0
        harmonics.append(Harmonic(freq, 0))
    f0s.append(f0)
    timbres.append(harmonics)
    amps.append(np.sqrt(spectrum_2.sum()))


In [None]:
n_vowel_dims = 2


In [None]:
class MyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()

        I = []
        X = []
        Y = []
        for page_i, (f0, harmonics, amp) in tqdm([*enumerate(
            zip(f0s, timbres, amps), 
        )], desc='prep data'):
            page_X = []
            for harmonic in harmonics:
                page_X.append(torch.tensor((
                    harmonic.freq, f0, amp, 
                )))
                Y.append(harmonic.mag)
                I.append(page_i)
            page_X = torch.stack(page_X)
            # X.append(torch.concat((
            #     page_X, vowel_emb.unsqueeze(0).repeat(len(harmonics), 1), 
            # ), dim=1))
            X.append(page_X)
        X = torch.concat(X, dim=0).float()
        Y = torch.tensor(Y).float()
        I = torch.tensor(I, dtype=torch.long)

        self.X_mean = X.mean(dim=0)
        X = X - self.X_mean
        self.X_std = X.std(dim=0)
        X = X / self.X_std

        self.Y_mean = Y.mean(dim=0)
        Y = Y - self.Y_mean
        self.Y_std = Y.std(dim=0)
        Y = Y / self.Y_std

        self.X = X
        self.Y = Y
        self.I = I
    
    def transformX(self, x):
        return (x - self.X_mean) / self.X_std

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, index):
        return (
            self.X[index, :], 
            self.Y[index], 
            self.I[index], 
        )

dataset = MyDataset()

In [None]:
def Train(nitf: NITF, batch_size):
    dataLoader = DataLoader(dataset, batch_size, shuffle=True)

    vowel_embs = torch.zeros(
        (len(f0s), n_vowel_dims), 
        requires_grad=True, 
    )
    
    optim = torch.optim.Adam([
        *nitf.parameters(), 
        vowel_embs, 
    ], LR)

    while True:
        nitf.train()
        losses = []
        _iter = dataLoader
        # _iter = tqdm([*_iter], desc='batches')
        for x, y, page_i in _iter:
            x_vowel = torch.concat((
                x, vowel_embs[page_i], 
            ), dim=1)
            # print('forward...')
            y_hat = nitf.forward(x_vowel)
            # print('mse...')
            loss = F.mse_loss(y_hat[:, 0], y)
            # print('zero_grad...')
            optim.zero_grad()
            # print('backward...')
            loss.backward()
            # print('step...')
            optim.step()
            # print('loss...')
            losses.append(loss.detach())
        yield nitf, vowel_embs, torch.tensor(losses).mean()


In [None]:
trainers = [
    (Train(NITF(128, 6, n_vowel_dims), batch_size=2 ** 12), []), 
]

try:
    for epoch in count():
        print(f'{epoch = }', end=', ')
        for trainer, losses in trainers:
            nitf, vowel_embs, loss = next(trainer)
            losses.append(loss)
            print(loss.item(), end=', ')
        print()
except KeyboardInterrupt:
    pass

In [None]:
for trainer, losses in trainers:
    plt.plot(losses)
plt.show()

In [None]:
v = vowel_embs.detach()
vm = v.mean(dim=0)
vs = v.std(dim=0)

In [None]:
plt.hist(amps)
plt.show()

In [None]:
plt.hist(f0s)
plt.show()

In [None]:
import random

In [None]:
with torch.no_grad():
    nitf.eval()

    for _ in range(8):
        x, y, page_i = dataset[random.randint(0, len(dataset))]
        x_vowel = torch.concat((
            x, vowel_embs[page_i, :], 
        ), dim=0)
        mag = nitf.forward(x_vowel)
        # print(mag.item())
        # print(  y.item())
        # print()
        print(x_vowel)

In [None]:
with torch.no_grad():
    nitf.eval()

    hS = HarmonicSynth(
        N_HARMONICS, SR, PAGE_LEN, DTYPE, True, True, 
    )
    buffer = []

    n_pages = 2 * SR // PAGE_LEN
    for v, f0, amp in zip(
        np.linspace(vm - 2 * vs, vm + 2 * vs, n_pages), 
        np.linspace(220, 880, n_pages), 
        np.linspace(.01, .04, n_pages), 
    ):
        harmonics = []
        for partial_i in range(N_HARMONICS):
            freq = f0 * (partial_i + 1)
            x = dataset.transformX(torch.tensor([
                freq, f0, amp, 
            ]).unsqueeze(0))
            x_vowel = torch.concat((
                x[0, :], torch.tensor(v), 
            ))
            # print(x_vowel)
            mag = nitf.forward(x_vowel.float()).item()
            mag = max(0, mag)
            harmonics.append(Harmonic(freq, mag))
        hS.eat(harmonics)
        buffer.append(hS.mix())

    audio = np.concatenate(buffer)

In [None]:
play(audio)