In [None]:
from typing import *
import os

import librosa
import numpy as np
from numpy.fft import rfft
from numpy import pi
from matplotlib import pyplot as plt
from IPython.display import Audio
import cmath
import scipy
from scipy.signal import stft
import torch
from torch import nn
import torch.nn.functional as F

from cache_no_hash import cache
from blindDescend import blindDescend
from yin import yin
from harmonicSynth import HarmonicSynth, Harmonic

TWO_PI = np.pi * 2

In [None]:
os.chdir('..')
from shared import *
from lobe import getLobe
from manual_fc import ManualFC
os.chdir('playground')

In [None]:
def sino(freq, length):
    return np.sin(np.arange(length) * freq * TWO_PI / SR)

def playHard(data):
    return Audio(data, rate = SR)
def play(data, soft = .1):
    t = np.concatenate([data, [1]])
    length = round(soft * SR)
    t[:length ] = np.multiply(t[:length ], np.linspace(0, 1, length))
    t[-length:] = np.multiply(t[-length:], np.linspace(1, 0, length))
    return playHard(t)

def findPeaks(energy):
    slope = np.sign(energy[1:] - energy[:-1])
    extrema = slope[1:] - slope[:-1]
    return np.argpartition(
        (extrema == -2) * energy[1:-1], - N_HARMONICS,
    )[- N_HARMONICS:] + 1

def sft(signal, freq_bin):
    # Slow Fourier Transform
    return np.abs(np.sum(signal * np.exp(IMAGINARY_LADDER * freq_bin))) / PAGE_LEN

def refineGuess(guess, signal):
    def loss(x):
        if x < 0:
            return 0
        return - sft(signal, x)
    freq_bin, loss = blindDescend(loss, .01, .4, guess)
    return freq_bin * SR / PAGE_LEN, - loss

def widePlot(h = 3, w = 12):
    plt.gcf().set_size_inches(w, h)
    
def spectrum(signal, do_wide = True, trim = 130):
    energy = np.abs(rfft(signal * HANN))
    X = np.linspace(0, SR / 2, len(energy))
    plt.plot(
        X     [:trim], 
        energy[:trim], 
    )
    plt.xlabel('freq (Hz)')
    if do_wide:
        widePlot()

def spectrogram(signal, **kw):
    f, t, Zxx = stft(signal, fs=SR, **kw)
    plt.pcolormesh(t, f, np.abs(Zxx))

def concatSynth(synth, harmonics, n):
    buffer = []
    for i in range(n):
        synth.eat(harmonics)
        buffer.append(synth.mix())
    return np.concatenate(buffer)

def pitch2freq(pitch):
    return np.exp((pitch + 36.37631656229591) * 0.0577622650466621)

def freq2pitch(f):
    return np.log(f) * 17.312340490667562 - 36.37631656229591

def pagesOf(signal):
    for i in range(0, signal.size - PAGE_LEN + 1, PAGE_LEN):
        yield signal[i : i + PAGE_LEN]

def plotUnstretchedPartials(f0, n_partials = 14, color = 'r', alpha = .3):
    for i in range(1, n_partials + 1):
        freq = f0 * i
        plt.axvline(x = freq, color = color, alpha = alpha)


In [None]:
timbre_points = [
    (700, .2), 
    (950, .05), 
    (2400, .05), 
    (2800, 0), 
    (3000, 0), 
]
timbre = ManualFC(
    torch.tensor([x[0] for x in timbre_points]), 
    torch.tensor([x[1] for x in timbre_points]), 
)
X = torch.linspace(0, NYQUIST, 1000)
plt.plot(X, timbre(X))

In [None]:
f0 = 300
lobe = getLobe()
y = []
for i in range(N_HARMONICS):
    freq = f0 * (i + 1)
    mag = timbre(torch.tensor(freq)).numpy()
    y.append(sino(freq, SR) * mag)
y_long = np.stack(y).sum(axis=0)
y = y_long[:PAGE_LEN]
energy = np.abs(rfft(y * HANN)) / (PAGE_LEN / 2)
energy = torch.tensor(energy).float()
plt.plot(energy)
play(y_long)

In [None]:
freqs = np.linspace(0, NYQUIST, SPECTRUM_SIZE)
freq_bin: float = freqs[1]
one_over_freq_bin = torch.tensor(1 / freq_bin).float()

In [None]:
freqCube = torch.arange(0, SPECTRUM_SIZE).float()
freqCube = freqCube.unsqueeze(0).repeat(N_HARMONICS, 1)
freqCube.shape

In [None]:
LADDER = torch.arange(0, N_HARMONICS).float().contiguous() + 1

In [None]:
def forward(f0):
    freq = f0 * LADDER
    mag = timbre(freq)
    x = freqCube - (freq * one_over_freq_bin).unsqueeze(1)
    x = lobe(x)
    x = x * mag.unsqueeze(1)
    return x.sum(dim=0)

plt.plot(forward(f0), linewidth=4)
plt.plot(energy)

In [None]:
from itertools import count
from time import sleep

In [None]:
latent_f0 = torch.tensor(160.0, requires_grad=True)
optim = torch.optim.Adam([latent_f0], lr=1)
for epoch in count():
    spec_hat = forward(latent_f0)
    loss = F.mse_loss(spec_hat, energy)
    optim.zero_grad()
    loss.backward()
    optim.step()
    print(epoch, loss.item(), latent_f0.item())
    sleep(.5)

In [None]:
plt.plot(forward(latent_f0).detach(), linewidth=4)
plt.plot(energy)

Damn, we need multi-hot octave!