# Imports

- select the voice_isolate_env kernel from top right of screen

In [1]:
from data import CleanDataset, DataTransformer, NoiseGenerator
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import torch.nn.functional as F
import torch
from pathlib import Path
from autoencoder import AttnParams, CustomVAE

output_folder = Path('../playground_outputs')
output_folder.mkdir(exist_ok=True)
model_path = Path('/Users/marcusbluestone/Desktop/MIT/Fall (25-26)/Voice_Isolation/outputs/model.pth')
device = 'mps'

  from .autonotebook import tqdm as notebook_tqdm


# Download Dataset

In [2]:
dataset = CleanDataset(chunk_size = 50_000, split = 'dev-clean', count = 10)
beta = 0
sigma_noise = 0.01

dt = DataTransformer()
ng = NoiseGenerator()
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

# Test Model Results Audibly

In [3]:
waveform, sample_rate = next(iter(dataloader))
dt.waveform_to_audio(waveform, sample_rate, fname = output_folder / 'original')

amp_clean, phase_clean, _ = dt.waveform_to_spectrogram(waveform)
_, W,H = amp_clean.shape

noisy_waveform = ng.add_gaussian(waveform, sigma = sigma_noise)

amp_noisy, phase_noisy, minmax_info = dt.waveform_to_spectrogram(noisy_waveform)
dt.waveform_to_audio(noisy_waveform, sample_rate, output_folder / 'noisy')


In [4]:
attn_params = AttnParams(num_heads=4, window_size=None, use_rel_pos_bias=False, dim_head=64)
model = CustomVAE(in_channels=1, spatial_dims=2, use_attn=False, vae_latent_channels=16,
                    attn_params=attn_params, vae_use_log_var = True, beta = beta, dropout_prob=0, blocks_down=(1,),
                    blocks_up = [])
state = torch.load(model_path, map_location=device)
model.load_state_dict(state)
model.to(device)

CustomVAE(
  (act_mod): ReLU(inplace=True)
  (convInit): Convolution(
    (conv): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (down_layers): ModuleList(
    (0): Sequential(
      (0): Identity()
      (1): ResBlock(
        (norm1): GroupNorm(8, 8, eps=1e-05, affine=True)
        (norm2): GroupNorm(8, 8, eps=1e-05, affine=True)
        (act): ReLU(inplace=True)
        (conv1): Convolution(
          (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (conv2): Convolution(
          (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
  )
  (up_layers): ModuleList()
  (up_samples): ModuleList()
  (conv_final): Sequential(
    (0): GroupNorm(8, 8, eps=1e-05, affine=True)
    (1): ReLU(inplace=True)
    (2): Convolution(
      (conv): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (attn_layer): SelfAttentionND(
    (to_qkv): Linear(i

In [5]:
def vae_loss(output, target, z_mean, log_var, beta):
    assert output.shape == target.shape

    # KL Loss
    kl = -0.5 * torch.mean(
        torch.sum(1 + log_var - z_mean.pow(2) - log_var.exp(), dim=1)
    )

    # Reconstruction loss
    recon_loss = F.mse_loss(output, target, reduction="none")
    recon_loss = recon_loss.view(recon_loss.size(0), -1).mean(dim=1) # per-sample mean
    recon_loss = recon_loss.mean()

    return recon_loss, kl, recon_loss + beta * kl


In [6]:
input = dt.add_padding(amp_noisy).unsqueeze(1).to(device)

output, z_mean, log_var = model(input)
output = torch.tanh(output)
amp_recon = output[:, :, :W, :H]

# Loss
target = amp_clean.to(device).unsqueeze(1)
loss = vae_loss(amp_recon, target, z_mean, log_var, beta = 0)
amp_recon = amp_recon[:, 0, :, :]

print("Loss", loss[0].item())

dt.save_spectrograms(
    amps =   [amp_clean.detach(), amp_noisy.detach(), amp_recon.detach(), ],
    phases = [phase_clean.detach(), phase_noisy.detach(), phase_noisy.detach()],
    names = ['original', 'noisy', 'recon'], 
    out_dir = output_folder / 'spectrograms'
)
# dt.save_spectrogram(amp_clean.detach(), phase_clean.detach(), output_folder / 'input')
# dt.save_spectrogram(amp_recon.detach(), phase_noisy.detach(), output_folder / 'recon')
# dt.save_spectrogram(amp_noisy.detach(), phase_noisy.detach(), output_folder / 'noisy')

waveforms_reconstr = dt.spectrogram_to_waveform(amp_recon.cpu(), phase_noisy.cpu(), *minmax_info)
dt.waveform_to_audio(waveforms_reconstr, sample_rate = sample_rate, fname = output_folder / 'reconstr')

Loss 0.10235114395618439
