# Imports

- select the voice_isolate_env kernel from top right of screen

In [31]:
from data import CleanDataset, DataTransformer, NoiseGenerator
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
from pathlib import Path
from autoencoder import AttnParams, CustomVAE

output_folder = Path('../outputs2')
device = 'mps'

# Download Dataset

In [18]:
dataset = CleanDataset(chunk_size = 50_000, split = 'train-clean-100')

dt = DataTransformer()
ng = NoiseGenerator()
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Test Model Results Audibly

In [32]:
waveform, sample_rate = next(iter(dataloader))
dt.waveform_to_audio(waveform, sample_rate, fname = output_folder / 'original')

amp_clean, phase_clean, _ = dt.waveform_to_spectrogram(waveform)
_, W,H = amp_clean.shape

noisy_waveform = ng.add_gaussian(waveform)
amp_noisy, phase_noisy, minmax_info = dt.waveform_to_spectrogram(noisy_waveform)
dt.waveform_to_audio(noisy_waveform, sample_rate, output_folder / 'noisy')


In [33]:
attn_params = AttnParams(num_heads=4, window_size=None, use_rel_pos_bias=False, dim_head=64)
model = CustomVAE(in_channels=2, spatial_dims=2, use_attn=True, attn_params=attn_params, vae_use_log_var = True)

state = torch.load('/Users/marcusbluestone/Desktop/MIT/Fall (25-26)/Voice_Isolation/src/model.pth', map_location=device)
model.load_state_dict(state)
model.to(device)

CustomVAE(
  (act_mod): ReLU(inplace=True)
  (convInit): Convolution(
    (conv): Conv2d(2, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (down_layers): ModuleList(
    (0): Sequential(
      (0): Identity()
      (1): ResBlock(
        (norm1): GroupNorm(8, 8, eps=1e-05, affine=True)
        (norm2): GroupNorm(8, 8, eps=1e-05, affine=True)
        (act): ReLU(inplace=True)
        (conv1): Convolution(
          (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (conv2): Convolution(
          (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
    (1): Sequential(
      (0): Convolution(
        (conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      )
      (1): ResBlock(
        (norm1): GroupNorm(8, 16, eps=1e-05, affine=True)
        (norm2): GroupNorm(8, 16, eps=1e-05, affine=True)
        (act): ReLU(inplace=True

In [34]:
# Prepare for Model Input
amp_inp, phase_inp = dt.add_padding(amp_noisy), dt.add_padding(phase_noisy)
input = torch.stack((amp_inp, phase_inp), axis = 1)
input = input.to(device)

# Run model 
output, _, _ = model(input)

# Reconstruct Waveform
output = output.cpu()
amp_recon =   output[:, 0, : , :]
phase_recon = output[:, 1, : , :]

waveforms_reconstr = dt.spectrogram_to_waveform(amp_recon, phase_recon, *minmax_info)
dt.waveform_to_audio(waveforms_reconstr, sample_rate = sample_rate, fname = output_folder / 'reconstr')