In [29]:
import os

import numpy as np
import torch
from frechet_audio_distance import FrechetAudioDistance
from scipy.io import wavfile
from torch.utils.data import DataLoader

from src.NSynthDataset import NSynthDataset
from src.WaveGAN import WaveGANGenerator

In [30]:
z_size = 1000
sr = 8191
duration = 2
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

In [31]:
os.chdir('../../..')
path = 'mnt/data/public/NSynth/'

In [32]:
train_set = NSynthDataset(data_path=path,  mel=False, pitched_z=True, 
                          sampling_rate=sr, duration=duration,
                          min_class_count=10000, max_class_count=3000,
                          z_size=z_size)

In [33]:
test_set = NSynthDataset(data_path=path, mel=False,
                          stage='test', pitched_z=True, 
                          sampling_rate=sr, duration=duration, 
                          cond_classes=train_set.cond_classes,
                          z_size=z_size)

In [34]:
testloader = DataLoader(test_set, batch_size=test_set.__len__())


In [35]:
gen = WaveGANGenerator(z_size, train_set.label_size, train_set.y_size, sr=sr, duration=duration).to(device)
gen.load_state_dict(torch.load(f'users/adcy353/GANs-Conditional-Audio-Synthesis/models/wavegan/G_0.0001-1-826.pt'))
gen.eval()

WaveGANGenerator(
  (fc): Linear(in_features=1006, out_features=16384, bias=True)
  (deconv): Sequential(
    (0): ConvTranspose1d(1024, 512, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (1): ReLU()
    (2): ConvTranspose1d(512, 256, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (3): ReLU()
    (4): ConvTranspose1d(256, 128, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (5): ReLU()
    (6): ConvTranspose1d(128, 64, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (7): ReLU()
    (8): ConvTranspose1d(64, 1, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
    (9): Tanh()
  )
)

In [36]:
ev_path = 'users/adcy353/GANs-Conditional-Audio-Synthesis/frechet/pitched/'
for i, (w, l, z) in enumerate(test_set):
    s = gen.forward(z.unsqueeze(0).to(device), l.unsqueeze(0).to(device))
    s.to(torch.device('cpu'))
    s = s.detach().cpu()
    
    
    output_wav_file = f"{ev_path}output_{i}.wav"
    # Normalize the audio data to the appropriate range (-32768 to 32767 for 16-bit PCM)
    normalized_audio = np.int16(s / max(np.abs(s)) * 32767)
    # Write the WAV file
    wavfile.write(output_wav_file, sr, normalized_audio)

In [37]:
# to use `vggish`
frechet = FrechetAudioDistance(
    model_name="vggish",
    use_pca=False, 
    use_activation=False,
    verbose=False
)

Using cache found in /users/adcy353/.cache/torch/hub/harritaylor_torchvggish_master


In [40]:
# Specify the paths to your saved embeddings
background_embds_path = "users/adcy353/GANs-Conditional-Audio-Synthesis/frechet/background/embeddings.npy"
eval_embds_path = "users/adcy353/GANs-Conditional-Audio-Synthesis/frechet/pitched/embeddings.npy"
test_path = 'mnt/data/public/NSynth/nsynth-test/audio'

# Compute FAD score while reusing the saved embeddings (or saving new ones if paths are provided and embeddings don't exist yet)
fad_score = frechet.score(
    test_path,
    ev_path,
    background_embds_path=background_embds_path,
    eval_embds_path=eval_embds_path,
    dtype="float32"
)

In [41]:
fad_score

58.32707437765022