In [None]:
import sys, os
if os.path.abspath('../') not in sys.path:
    sys.path.append(os.path.abspath('../'))

In [None]:
import IPython.display

def play_audio(audio, sample_rate=22100):
    """
    Play raw audio in Jupyter notebook.

    audio:
        torch.Tensor or numpy.ndarray, shape == (1, t)
        Raw audio, e.g. from `Vocoder`.

    return:
    widget:
        IPython.display.Audio
        Jupyter notebook widget.
    """
    return IPython.display.Audio(audio, rate=sample_rate, autoplay=True)

In [None]:
import torchaudio
import wandb
from speech_distances.models import load_model
from ss_models.synthesis_utils import make_preprocessor_trainable
from torch import nn
import torch

In [None]:
wav, sample_rate = torchaudio.load('../../audios_val_14780_1.wav')
num_iters = 1000

In [None]:
def calc_l2(signal_a, signal_b):
    return torch.sqrt(torch.mean((signal_a - signal_b) ** 2))

class LpipsL2(nn.Module):
    def __init__(self, model, decay=None):
        super().__init__()
        self.features = model.encoder.encoder
        self.decay = decay
    
    def forward(self, signal_left, signal_left_len, signal_right, signal_right_len):
        dists = []
        signal_left = [signal_left]
        signal_right = [signal_right]
        for i, layer in enumerate(self.features):
            signal_left, signal_left_len = layer((signal_left, signal_left_len))
            signal_right, signal_right_len = layer((signal_right, signal_right_len))
            dist = calc_l2(signal_left[-1], signal_right[-1])[None]
            if decay is not None:
                dist *= 1 / (i + 1) ** (1.0 / decay)
            dists.append(dist)
        return torch.mean(torch.cat(dists))

In [None]:
def train_loop(num_iters, stt, audio_left, audio_right, optimizer, loss_fn=lambda signal_left, seq_len_left, signal_right, seq_len_right: calc_l2(signal_left, signal_right)):
    for i in range(num_iters):
        optimizer.zero_grad()
        signal_left, seq_len = stt.preprocessor.get_features(audio_left, torch.tensor(audio_left.shape[-1])[None].cuda())
        signal_right, seq_len = stt.preprocessor.get_features(audio_right + 0., torch.tensor(audio_right.shape[-1])[None].cuda())
        loss = loss_fn(signal_left, seq_len, signal_right, seq_len)
        loss.backward()
        optimizer.step()
        torch.nn.utils.clip_grad_value_([audio_right], 5)
        if i % 10 == 0:
            examples = [
                wandb.Image(signal_right.detach().cpu().numpy(), caption='predicted_mel'),
                wandb.Image(signal_left.detach().cpu().numpy(), caption='target_mel'),
            ]
            wandb.log({
                "mels": examples
            }, step=i)

            examples = []
            examples.append(wandb.Audio(audio_right.detach().cpu()[0], caption='reconstructed_wav', sample_rate=sample_rate))
            examples.append(wandb.Audio(audio_left.detach().cpu()[0], caption='target_wav', sample_rate=sample_rate))
            wandb.log({
                "audios": examples
            }, step=i)

In [None]:
from scipy.ndimage import gaussian_filter1d

def make_train_data():
    audio_left = wav.clone().cuda()
    blurred_wav = gaussian_filter1d(wav.cpu().numpy(), 6)
    blurred_wav = torch.from_numpy(blurred_wav).cuda()
    audio_right = blurred_wav.clone()
    audio_right.requires_grad = True

    optimizer = torch.optim.Adam([audio_right], lr=1e-3)
    return audio_left, audio_right, optimizer

In [None]:
wandb.init(project='lpips_audio')

stt = load_model('quartznet').cuda()
stt = make_preprocessor_trainable(stt).cuda()
audio_left, audio_right, optimizer = make_train_data()
train_loop(num_iters, stt, audio_left, audio_right, optimizer)

In [None]:
wandb.init(project='lpips_audio')

stt = load_model('quartznet').cuda()
stt = make_preprocessor_trainable(stt).cuda()
loss_l2_pips = LpipsL2(stt)
audio_left, audio_right, optimizer = make_train_data()
train_loop(num_iters, stt, audio_left, audio_right, optimizer, loss_l2_pips)

In [None]:
wandb.init(project='lpips_audio')

stt = load_model('quartznet').cuda()
stt = make_preprocessor_trainable(stt).cuda()
loss_l2_pips = LpipsL2(stt, decay=2)
audio_left, audio_right, optimizer = make_train_data()
train_loop(num_iters, stt, audio_left, audio_right, optimizer, loss_l2_pips)

In [None]:
wandb.init(project='lpips_audio')

stt = load_model('speakerverification_speakernet').cuda()
stt = make_preprocessor_trainable(stt).cuda()
loss_l2_pips = LpipsL2(stt)
audio_left, audio_right, optimizer = make_train_data()
train_loop(num_iters, stt, audio_left, audio_right, optimizer, loss_l2_pips)

In [None]:
wandb.init(project='lpips_audio')

stt = load_model('speakerrecognition_speakernet').cuda()
stt = make_preprocessor_trainable(stt).cuda()
loss_l2_pips = LpipsL2(stt)
audio_left, audio_right, optimizer = make_train_data()
train_loop(num_iters, stt, audio_left, audio_right, optimizer, loss_l2_pips)

In [None]:
wandb.init(project='lpips_audio')

stt = load_model('jasper').cuda()
stt = make_preprocessor_trainable(stt).cuda()
loss_l2_pips = LpipsL2(stt)
audio_left, audio_right, optimizer = make_train_data()
train_loop(num_iters, stt, audio_left, audio_right, optimizer, loss_l2_pips)

In [None]:
wandb.init(project='lpips_audio')

stt = load_model('quartznet_de').cuda()
stt = make_preprocessor_trainable(stt).cuda()
loss_l2_pips = LpipsL2(stt)
audio_left, audio_right, optimizer = make_train_data()
train_loop(num_iters, stt, audio_left, audio_right, optimizer, loss_l2_pips)

In [None]:
## Todo: wav2vec

In [None]:
wandb.init(project='lpips_audio')

stt = load_model('wav2vec2').cuda()
stt = make_preprocessor_trainable(stt).cuda()
loss_l2_pips = LpipsL2(stt)
audio_left, audio_right, optimizer = make_train_data()
train_loop(num_iters, stt, audio_left, audio_right, optimizer, loss_l2_pips)

In [None]:
wandb.init(project='lpips_audio')

stt = load_model('wav2vec2_conv').cuda()
stt = make_preprocessor_trainable(stt).cuda()
loss_l2_pips = LpipsL2(stt)
audio_left, audio_right, optimizer = make_train_data()
train_loop(num_iters, stt, audio_left, audio_right, optimizer, loss_l2_pips)