In [1]:
import numpy as np

from model import HiFiGAN, MelSpectrogram

In [2]:
model = HiFiGAN(training=False)
mel_spectrogram = MelSpectrogram()


In [3]:
import torch
cp = torch.load('../weights/generator.pth', map_location='cpu')
model.generator.load_state_dict(cp)

<All keys matched successfully>

In [4]:
def generate_audio(mel):
    with torch.no_grad():
        audio = model(mel)
    return audio.squeeze().numpy()

def preprocess_audio(audio, sr=None):
    if isinstance(audio, np.ndarray):
        audio = torch.from_numpy(audio)
    if audio.dim() == 1:
        audio = audio.unsqueeze(0)
    mel = mel_spectrogram(audio.float(), sr=sr)
    return mel

In [5]:
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
sample = ds[0]["audio"]
sample_rate, waveform = sample['sampling_rate'], sample['array']

In [6]:
from torch import nn


class Denoiser(nn.Module):
    def __init__(self, hifigan, filter_length=1024, hop_size=256, win_length=1024):
        super(Denoiser, self).__init__()
        self.filter_length = filter_length
        self.hop_size = hop_size
        self.win_length = win_length
        self.window = torch.hann_window(self.win_length)
        self.register_buffer('bias_spec', self.compute_bias_spec(hifigan), persistent=False)
        
    def compute_bias_spec(self, hifigan):
        with torch.no_grad():
            mel_input = torch.zeros(1, 80, 100)
            bias_audio = hifigan(mel_input).float().squeeze()
            bias_spec = torch.stft(bias_audio, self.filter_length, self.hop_size, self.win_length,
                                   window=self.window, return_complex=True)[:, 0][:, None]
        return bias_spec

    def forward(self, audio, strength=0.05):
        if isinstance(audio, np.ndarray):
            audio = torch.from_numpy(audio)
        if audio.dim() == 1:
            audio = audio.unsqueeze(0)
        
        audio_stft = torch.stft(audio, self.filter_length, self.hop_size, self.win_length,
                                window=self.window, return_complex=True)
        
        audio_spec_denoised = audio_stft - self.bias_spec * strength
        audio_denoised = torch.istft(audio_spec_denoised, self.filter_length, self.hop_size,
                                     self.win_length, window=self.window, return_complex=False)
        return audio_denoised

d = Denoiser(model)

In [7]:
import IPython.display as ipd
ipd.Audio(waveform, rate=sample_rate)

In [8]:
torch.from_numpy(waveform).unsqueeze(0).shape

torch.Size([1, 93680])

In [9]:
full_audio = []
for i in range(0, len(waveform), 16000):
    audio = generate_audio(preprocess_audio(waveform[i:i+16000]))
    full_audio.extend(audio)
full_audio = np.array(full_audio)
de_noised_audio = d(full_audio)

In [10]:
ipd.Audio(full_audio, rate=sample_rate)

In [11]:
ipd.Audio(de_noised_audio, rate=sample_rate)

In [12]:
import torch
import torchaudio
waveform, sample_rate = torchaudio.load("../data/1170331873109626940.ogg")
waveform = waveform[0]
ipd.Audio(waveform, rate=sample_rate)

In [13]:
waveform = torch.nn.functional.pad(waveform, (0, 16000 - waveform.shape[0] % 16000))
waveform = waveform.reshape(-1, 16000)
preprocessed = preprocess_audio(waveform, sr=sample_rate)
generated = generate_audio(preprocessed).reshape(-1)
ipd.Audio(generated, rate=sample_rate)

min value is  tensor(-1.0775)
max value is  tensor(1.1709)


In [14]:
denoised = d(generated)
ipd.Audio(denoised, rate=sample_rate)