In [None]:
import json
import numpy as np
import torch
import torchaudio
import librosa
import soundfile
from vqvae.vqvae import VQVAE
from utils.misc import get_spectrograms_helper
from IPython.display import display, Audio


def extend_duration_like(this: torch.Tensor, reference: torch.Tensor, dim: int = -1) -> torch.Tensor:
    padding_duration = reference.shape[dim] - this.shape[dim]
    if padding_duration < 0:
        return this.narrow(start=0, length=reference.shape[dim], dim=dim)
    padding_shape = list(this.shape)
    padding_shape[dim] = padding_duration
    return torch.cat([this, torch.zeros_like(this.flatten()[0]).expand(padding_shape)], dim=dim)

device = 'cuda'

# downsampling factor 8 then 4
model_large_model_parameters_json_path = '/home/theis/bachbach/code/vq-vae-2-pytorch/checkpoints/20200310-105659-1b55c7/model_parameters.json'
model_large_model_weights_path = '/home/theis/bachbach/code/vq-vae-2-pytorch/checkpoints/20200310-105659-1b55c7/vqvae_nsynth_223.pt'

model_large = VQVAE.from_parameters_and_weights(
    model_large_model_parameters_json_path,
    model_large_model_weights_path
)

# downsampling factor 16 then 2
model_small_model_parameters_json_path = '/home/theis/code/vq-vae-2-pytorch/data/checkpoints/20200117-205357-5b1d22/model_parameters.json'
model_small_model_weights_path = '/home/theis/code/vq-vae-2-pytorch/data/checkpoints/20200117-205357-5b1d22/vqvae_nsynth_471.pt'

model_small = VQVAE.from_parameters_and_weights(
    model_small_model_parameters_json_path,
    model_small_model_weights_path
)
models = {
    'small': {
            'model': model_small,
            'model_parameters_json_path': model_small_model_parameters_json_path,
            'model_weights_path': model_small_model_weights_path
        },
    'large': {
        'model': model_large,
        'model_parameters_json_path': model_large_model_parameters_json_path,
        'model_weights_path': model_large_model_weights_path
        } 
}

for model_name, model_dict in models.items():
    model = models[model_name]['model']
    models[model_name]['model'].adapt_quantized_durations = True
    models[model_name]['model'] = model.to(device)
    print(model_name + ' model uses GANSynth normalization:',
          model.use_gansynth_normalization)

assert models['small']['model'].adapt_quantized_durations

VQVAE_TRAINING_PARAMETERS_PATH = '/home/theis/code/vq-vae-2-pytorch/data/checkpoints/20200117-205357-5b1d22/command_line_parameters.json'
with open(VQVAE_TRAINING_PARAMETERS_PATH, 'r') as f:
        vqvae_training_parameters = json.load(f)
spectrograms_helper = get_spectrograms_helper(
        device=device, **vqvae_training_parameters)
spectrograms_helper.to(device)

In [None]:
@torch.no_grad()
def process(audio: torch.Tensor, sr: int, model: VQVAE,
            device: str = 'cuda') -> torch.Tensor:
    if audio.ndim == 3:
        audio = audio.squeeze(0)
    if sr != 16000:
        resampler_to_model = torchaudio.transforms.Resample(
            orig_freq=sr, new_freq=16000)
        resampler_from_model = torchaudio.transforms.Resample(
            orig_freq=16000, new_freq=sr)
    else:
        resampler_to_model = resampler_from_model = lambda x: x

    audio_original_sr = audio
    audio_original_sr = audio_original_sr.to(device)
    audio_resampled = resampler_to_model(audio_original_sr)
    spec_resampled = spectrograms_helper.to_spectrogram(audio_resampled)
    spec_reconstruction_resampled = model(spec_resampled)[0]
    audio_reconstruction_resampled = spectrograms_helper.to_audio(
        spec_reconstruction_resampled)
    audio_reconstruction_original_sr = resampler_from_model(
        audio_reconstruction_resampled)
    audio_reconstruction_original_sr_same_duration = extend_duration_like(
        audio_reconstruction_original_sr, audio_original_sr)
    return audio_reconstruction_original_sr_same_duration, sr

def process_all(audio: torch.Tensor, sr: int) -> torch.Tensor:
    return {model_name: process(audio, sr, model_dict['model'])
            for model_name, model_dict in models.items()}

In [None]:
def make_player(audio: torch.Tensor, rate: int = 16000):
    try:
        audio = audio.cpu().numpy()
    except AttributeError:
        pass
    display(Audio(data=audio[0], rate=rate,
                  normalize=True))

In [None]:
from torch.utils.data import Dataset, DataLoader
from torch import nn
from glob import glob
import os
from pathlib import Path
from tqdm import tqdm
import shutil
from typing import Optional

class AudioSet(Dataset):
    def __init__(self, directory: Path, extension: str = 'wav',
                 fs_hz: int = 16000, min_duration_s: Optional[float] = None):
        self.directory = directory
        self.extension = extension
        self.filenames = glob(str(self.directory / f'**/*.{self.extension}'), recursive=False)
        self.filenames = [filename for filename in self.filenames
                         if os.path.isfile(filename)]
        self.fs_hz = fs_hz
        self.min_duration_s = min_duration_s
        self.min_duration_n = None
        if self.min_duration_s is not None:
            self.min_duration_n = self.min_duration_s * self.fs_hz

    def __getitem__(self, index):
        filename = self.filenames[index]
        
        audio_original_sr, original_sr = librosa.load(filename, sr=None, mono=False)
        if audio_original_sr.ndim == 1:
            audio_original_sr = audio_original_sr[None, :]

        audio_stereo = librosa.resample(audio_original_sr, original_sr, self.fs_hz)
        audio_tensor = torch.as_tensor(audio_stereo)
        
        duration_n = audio_tensor.shape[-1] 
        if (self.min_duration_n is not None and duration_n  < self.min_duration_n):
            padding_duration = self.min_duration_n - duration_n
            padding_shape = list(audio_tensor.shape)
            padding_shape[-1] = padding_duration
            padding = torch.zeros(padding_shape, dtype=audio_tensor.dtype,
                                  layout=audio_tensor.layout,
                                  device=audio_tensor.device,)
            audio_tensor = torch.cat([audio_tensor,
                                      padding],
                                     dim=-1)
        return audio_tensor, filename
    
    def __len__(self):
        return len(self.filenames)

In [None]:
root_dir = Path('/home/theis/sounds_for_processing')
sample_pack_names = [
    'Gentleman'
]
output_dir_name = 'VQ-VAE-PROCESSED-stereo/'

for model_size, model_dict in models.items():
    parallel_model = nn.DataParallel(model_dict['model'])

    root_output_dir = root_dir / output_dir_name / (model_size + '_model')
    root_output_dir.mkdir(exist_ok=True, parents=True)

    shutil.copy2(model_dict['model_parameters_json_path'], root_output_dir)
    shutil.copy2(model_dict['model_weights_path'], root_output_dir)

    for sample_pack_name in sample_pack_names:
        samples_dir = root_dir / sample_pack_name
        print(samples_dir)
        output_dir = root_output_dir / (sample_pack_name + '/')
        output_dir.mkdir(parents=True, exist_ok=True)
        print(output_dir)
        audio_set = AudioSet(samples_dir, min_duration_s=1)
        audio_loader = DataLoader(audio_set, num_workers=8, batch_size=1, shuffle=False)

        DEVICE = 'cuda'
        parallel_model.to(DEVICE)

        with torch.no_grad():
            for i, (samples, filenames) in enumerate(tqdm(audio_loader)):
                sample = samples[0]
                filename = filenames[0]
                output_file_path = output_dir / Path(filename).relative_to(samples_dir)
                output_file_path.parent.mkdir(exist_ok=True, parents=True)
                processed_channels, processed_sr = process(sample, 16000, parallel_model, DEVICE)

                with soundfile.SoundFile(filename) as sf:
                    soundfile.write(
                        file=str(output_file_path), data=processed_channels.cpu().numpy().T,
                        samplerate=processed_sr, format=sf.format, subtype=sf.subtype)

In [None]:
with SoundFile(sample_path) as sf:
    print(sf)
    
    # get the audio in channels-first format
    audio_stereo = sf.read(dtype='float32').T
    # resample to the model's sample-rate
    audio_stereo_model_sr = librosa.resample(audio_stereo, sf.samplerate, 16000)
    
    audio_tensor_model_sr = torch.as_tensor(audio_stereo_model_sr).sum(0, keepdim=True)
    make_player(audio_model_sr, rate=16000)
    
    for model_name in models:
        print(model_name)
        model = models[model_name]['model']
        processed_tensor, processed_sr = process(audio_tensor_model_sr, 16000,
                                                 model=model)
        make_player(processed_tensor, rate=processed_sr)
        
        # resample to the original sample-rate 
        processed_original_sr = librosa.resample(processed_tensor.cpu().numpy(),
                                                 processed_sr, sf.samplerate)
        
        processed_soundfile = processed_original_sr.T
        write(file=f'./soundfile_output-{model_name}.wav', data=processed_soundfile,
              samplerate=sf.samplerate, format=sf.format, subtype=sf.subtype)
        
        processed_soundfile_normalized = processed_soundfile / processed_soundfile.max()
        write(file=f'./soundfile_output_normalized-{model_name}.wav',
              data=processed_soundfile_normalized,
              samplerate=sf.samplerate, format=sf.format, subtype=sf.subtype)