# Imports

- select the voice_isolate_env kernel from top right of screen

In [None]:
from data import CleanDataset, DataTransformer, NoiseGenerator
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import torch.nn.functional as F
import torch
from pathlib import Path
from autoencoder import AttnParams, CustomVAE

output_folder = Path('../playground_outputs')
output_folder.mkdir(exist_ok=True)
model_path = Path('/Users/marcusbluestone/Desktop/MIT/Fall (25-26)/Voice_Isolation/outputs/model.pth')
device = 'mps'

# Download Dataset

In [None]:
dataset = CleanDataset(chunk_size = 50_000, split = 'train-clean-100', count = 1)

dt = DataTransformer()
ng = NoiseGenerator()
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

# Test Model Results Audibly

In [None]:
waveform, sample_rate = next(iter(dataloader))
dt.waveform_to_audio(waveform, sample_rate, fname = output_folder / 'original')

amp_clean, phase_clean, _ = dt.waveform_to_spectrogram(waveform)
_, W,H = amp_clean.shape

noisy_waveform = waveform
amp_noisy, phase_noisy, minmax_info = dt.waveform_to_spectrogram(noisy_waveform)
dt.waveform_to_audio(noisy_waveform, sample_rate, output_folder / 'noisy')


In [None]:
attn_params = AttnParams(num_heads=4, window_size=None, use_rel_pos_bias=False, dim_head=64)
model = CustomVAE(in_channels=2, spatial_dims=2, use_attn=False, vae_latent_channels=16,
                    attn_params=attn_params, vae_use_log_var = True, beta = 0, dropout_prob=0, blocks_down=(1,),
                    blocks_up = [])
beta = 0
state = torch.load(model_path, map_location=device)
model.load_state_dict(state)
model.to(device)

CustomVAE(
  (act_mod): ReLU(inplace=True)
  (convInit): Convolution(
    (conv): Conv2d(2, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (down_layers): ModuleList(
    (0): Sequential(
      (0): Identity()
      (1): ResBlock(
        (norm1): GroupNorm(8, 8, eps=1e-05, affine=True)
        (norm2): GroupNorm(8, 8, eps=1e-05, affine=True)
        (act): ReLU(inplace=True)
        (conv1): Convolution(
          (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (conv2): Convolution(
          (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
  )
  (up_layers): ModuleList()
  (up_samples): ModuleList()
  (conv_final): Sequential(
    (0): GroupNorm(8, 8, eps=1e-05, affine=True)
    (1): ReLU(inplace=True)
    (2): Convolution(
      (conv): Conv2d(8, 2, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (attn_layer): SelfAttentionND(
    (to_qkv): Linear(i

In [None]:
def vae_loss(output, target, z_mean, log_var, beta):
    assert output.shape == target.shape

    # KL Loss
    kl = -0.5 * torch.mean(
        torch.sum(1 + log_var - z_mean.pow(2) - log_var.exp(), dim=1)
    )

    # Reconstruction loss
    recon_loss = F.mse_loss(output, target, reduction="none")
    recon_loss = recon_loss.view(recon_loss.size(0), -1).mean(dim=1) # per-sample mean
    recon_loss = recon_loss.mean()

    return recon_loss, kl, recon_loss + beta * kl


In [None]:
# Prepare for Model Input
amp_inp, phase_inp = dt.add_padding(amp_noisy), dt.add_padding(phase_noisy)
input = torch.stack((amp_inp, phase_inp), axis = 1)
input = input.to(device)

# Run model 
output, z_mean, log_var = model(input)

# Reconstruct Waveform
# output = torch.clamp(output, 0, 1)
output = output.cpu().detach()[:, :, :W, :H]
amp_recon =   output[:, 0, :, :]
phase_recon = output[:, 1, :, :]

# Loss
target = torch.stack((amp_clean, phase_clean), axis = 1).to(device).cpu()
loss = vae_loss(output, target, z_mean, log_var, beta = 0)

print("Loss", loss[0].item())

print("Amp")
print(amp_inp.min(), amp_inp.max())
print(amp_recon.min(), amp_recon.max())

print("Phase")
print(phase_inp.min(), phase_inp.max())
print(phase_recon.min(), phase_recon.max())

# amp_recon, _, _ = dt._normalize_amplitude(amp_recon)
# phase_recon, _, _ = dt._normalize_amplitude(phase_recon)

# print("Amp")
# print(amp_inp.min(), amp_inp.max())
# print(amp_recon.min(), amp_recon.max())

# print("Phase")
# print(phase_inp.min(), phase_inp.max())
# print(phase_recon.min(), phase_recon.max())



dt.save_spectrogram(amp_clean, phase_clean, output_folder / 'input')
dt.save_spectrogram(amp_recon, phase_recon, output_folder / 'recon')

waveforms_reconstr = dt.spectrogram_to_waveform(amp_recon, phase_recon, *minmax_info)
dt.waveform_to_audio(waveforms_reconstr, sample_rate = sample_rate, fname = output_folder / 'reconstr')

Loss 0.24413937330245972
Amp
tensor(-1.) tensor(1.)
tensor(-0.8511) tensor(0.4655)
Phase
tensor(-1.0000) tensor(1.0000)
tensor(-0.4098) tensor(0.1488)
