In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import librosa
import numpy as np
from torch.utils.data import DataLoader, Dataset

In [None]:
'''seed: 42
dataset:
  shift: 4000
  sample_len: 16384
  sample_rate: 16000
  train:
    ann_path: 'dataset/train.txt'
    dataloader:
      batch_size: 32
      num_workers: 10
  val:
    ann_path: 'dataset/val.txt'
    dataloader:
      batch_size: 32
      num_workers: 10
  test:
    ann_path: 'dataset/test.txt'
    dataloader:
      batch_size: 1
      num_workers: 10
'''

In [None]:
import matplotlib.pyplot as plt
import librosa.display

def make_plot_spectrogram(stftaudio_magnitude_db,sample_rate, hop_length_fft) :
    """This function plots a spectrogram"""
    plt.figure(figsize=(12, 6))
    librosa.display.specshow(stftaudio_magnitude_db, x_axis='time', y_axis='linear',
                             sr=sample_rate, hop_length=hop_length_fft)
    plt.colorbar()
    title = 'hop_length={},  time_steps={},  fft_bins={}  (2D resulting shape: {})'
    plt.title(title.format(hop_length_fft,
                           stftaudio_magnitude_db.shape[1],
                           stftaudio_magnitude_db.shape[0],
                           stftaudio_magnitude_db.shape));
    return

def make_plot_phase(stft_phase,sample_rate,hop_length_fft) :
    """This function plots the phase in radian"""
    plt.figure(figsize=(12, 6))
    librosa.display.specshow(np.angle(stft_phase), x_axis='time', y_axis='linear',
                             sr=sample_rate, hop_length=hop_length_fft)
    plt.colorbar()
    title = 'hop_length={},  time_steps={},  fft_bins={}  (2D resulting shape: {})'
    plt.title(title.format(hop_length_fft,
                           stft_phase.shape[1],
                           stft_phase.shape[0],
                           stft_phase.shape));
    return

def make_plot_time_serie(audio,sample_rate):
    """This function plots the audio as a time serie"""
    plt.figure(figsize=(12, 6))
    #plt.ylim(-0.05, 0.05)
    plt.title('Audio')
    plt.ylabel('Amplitude')
    plt.xlabel('Time(s)')
    librosa.display.waveplot(audio, sr=sample_rate)
    return


def make_3plots_spec_voice_noise(stftvoicenoise_mag_db,stftnoise_mag_db,stftvoice_mag_db,sample_rate, hop_length_fft):
    """This function plots the spectrograms of noisy voice, noise and voice as a single plot """
    plt.figure(figsize=(8, 12))
    plt.subplot(3, 1, 1)
    plt.title('Spectrogram voice + noise')
    librosa.display.specshow(stftvoicenoise_mag_db, x_axis='time', y_axis='linear',sr=sample_rate, hop_length=hop_length_fft)
    plt.colorbar()
    plt.subplot(3, 1, 2)
    plt.title('Spectrogram predicted voice')
    librosa.display.specshow(stftnoise_mag_db, x_axis='time', y_axis='linear',sr=sample_rate, hop_length=hop_length_fft)
    plt.colorbar()
    plt.subplot(3, 1, 3)
    plt.title('Spectrogram true voice')
    librosa.display.specshow(stftvoice_mag_db, x_axis='time', y_axis='linear',sr=sample_rate, hop_length=hop_length_fft)
    plt.colorbar()
    plt.tight_layout()

    return


def make_3plots_phase_voice_noise(stftvoicenoise_phase,stftnoise_phase,stftvoice_phase,sample_rate, hop_length_fft):
    """This function plots the phase in radians of noisy voice, noise and voice as a single plot """
    plt.figure(figsize=(8, 12))
    plt.subplot(3, 1, 1)
    plt.title('Phase (radian) voice + noise')
    librosa.display.specshow(np.angle(stftvoicenoise_phase), x_axis='time', y_axis='linear',sr=sample_rate, hop_length=hop_length_fft)
    plt.colorbar()
    plt.subplot(3, 1, 2)
    plt.title('Phase (radian) predicted voice')
    librosa.display.specshow(np.angle(stftnoise_phase), x_axis='time', y_axis='linear',sr=sample_rate, hop_length=hop_length_fft)
    plt.colorbar()
    plt.subplot(3, 1, 3)
    plt.title('Phase (radian) true voice')
    librosa.display.specshow(np.angle(stftvoice_phase), x_axis='time', y_axis='linear',sr=sample_rate, hop_length=hop_length_fft)
    plt.colorbar()
    plt.tight_layout()

    return


def make_3plots_timeseries_voice_noise(clipvoicenoise,clipnoise,clipvoice, sample_rate) :
    """This function plots the time series of audio of noisy voice, noise and voice as a single plot """
    #y_ax_min = min(clipvoicenoise) - 0.15
    #y_ax_max = max(clipvoicenoise) + 0.15

    plt.figure(figsize=(18, 12))
    plt.subplots_adjust(hspace=0.35)
    plt.subplot(3, 1, 1)
    plt.title('Audio voice + noise')
    plt.ylabel('Amplitude')
    plt.xlabel('Time(s)')
    librosa.display.waveplot(clipvoicenoise, sr=sample_rate)
    plt.ylim(-0.05, 0.05)
    plt.subplot(3, 1, 2)
    plt.title('Audio predicted voice')
    plt.ylabel('Amplitude')
    plt.xlabel('Time(s)')
    librosa.display.waveplot(clipnoise, sr=sample_rate)
    plt.ylim(-0.05, 0.05)
    plt.subplot(3, 1, 3)
    plt.title('Audio true voice')
    plt.ylabel('Amplitude')
    plt.xlabel('Time(s)')
    librosa.display.waveplot(clipvoice, sr=sample_rate)
    plt.ylim(-0.05, 0.05)

    return

In [None]:
from pathlib import Path
paths = sorted(Path("../input/ms-snsd/train/mix").glob('*'))
train_noisy_paths = []
str(paths[2])[10:]
for i in paths[:1000]:
    train_noisy_paths.append(str(i))

paths1 = sorted(Path("../input/ms-snsd/dev/mix").glob('*'))
tast_noisy_paths = []
str(paths1[2])[10:]
for i in paths1:
    tast_noisy_paths.append(str(i))

In [None]:
len(train_noisy_paths)

In [None]:
tast_noisy_paths[7]

In [None]:
string =  '../input/ms-snsd/train/mix/noisy1180_SNRdb_5.0_clnsp1180.wav'
string.split('_')[-1]

In [None]:
clean_paths = []
for i in train_noisy_paths:
    clean_paths.append(i[:23]+'clean/'+ i.split('_')[-1])
clean_paths[9]
 
def get_clean_file(path):
    return path[:23]+'clean/'+ path.split('_')[-1]

In [None]:
from glob import glob
import os, pickle
import numpy as np

In [None]:
import librosa

In [None]:
import numpy as np
import pickle
import librosa
from pydub import AudioSegment
import IPython
import scipy


def inverse_stft_transform(stft_features, window_length, overlap):
    return librosa.istft(stft_features, win_length=window_length, hop_length=overlap)


def play(audio, sample_rate):
    # ipd.display(ipd.Audio(data=audio, rate=sample_rate))  # load a local WAV file
    IPython.display.Audio(data=audio, rate=sample_rate)
#     sd.play(audio, sample_rate, blocking=True)


def read_audio(filepath, sample_rate, normalize=True):
    audio, sr = librosa.load(filepath, sr=sample_rate)
    if normalize is True:
        div_fac = 1 / np.max(np.abs(audio)) / 3.0
        audio = audio * div_fac
        # audio = librosa.util.normalize(audio)
    return audio, sr

class FeatureExtractor:
    def __init__(self, audio, *, windowLength, overlap, sample_rate):
        self.audio = audio
        self.ffT_length = windowLength
        self.window_length = windowLength
        self.overlap = overlap
        self.sample_rate = sample_rate
        self.window = scipy.signal.hamming(self.window_length, sym=False)

    def get_stft_spectrogram(self):
        return librosa.stft(self.audio, n_fft=self.ffT_length, win_length=self.window_length, hop_length=self.overlap,
                            window=self.window, center=True)

    def get_audio_from_stft_spectrogram(self, stft_features):
        return librosa.istft(stft_features, win_length=self.window_length, hop_length=self.overlap,
                             window=self.window, center=True)

    def get_mel_spectrogram(self):
        return librosa.feature.melspectrogram(self.audio, sr=self.sample_rate, power=2.0, pad_mode='reflect',
                                              n_fft=self.ffT_length, hop_length=self.overlap, center=True)

    def get_audio_from_mel_spectrogram(self, M):
        return librosa.feature.inverse.mel_to_audio(M, sr=self.sample_rate, n_fft=self.ffT_length,
                                                    hop_length=self.overlap,
                                                    win_length=self.window_length, window=self.window,
                                                    center=True, pad_mode='reflect', power=2.0, n_iter=32, length=None)

In [None]:
audio, _ = read_audio(clean_paths[9], 16000)
audio

In [None]:
def plot(file, fname):

    wav = librosa.load(file, sr=16000)[0]
    stft = librosa.stft(y=wav, n_fft=hparams['n_fft_den'], hop_length=hparams['hop_size_den'], win_length=hparams['win_size_den'])
    print("STFT: ", stft.shape)

    # Display magnitude spectrogram
    D = np.abs(stft)
    librosa.display.specshow(librosa.amplitude_to_db(D, ref=np.max),y_axis='log', x_axis='time')
    plt.title('Power spectrogram')
    plt.colorbar(format='%+2.0f dB')
    plt.tight_layout()
    plt.show()
    plt.savefig(fname+".jpg")
    plt.clf()

In [None]:
plot(clean_paths[9], "Аудио № 9")

In [None]:
IPython.display.Audio(data=audio, rate=16000)

In [None]:
superAudio = FeatureExtractor(audio, windowLength=256, overlap=round(256/4), sample_rate=16000)

In [None]:
superAudio.get_mel_spectrogram

In [None]:
hparams = dict(
    num_mels=80,  # Number of mel-spectrogram channels and local conditioning dimensionality
    #  network
    rescale=True,  # Whether to rescale audio prior to preprocessing
    rescaling_max=0.9,  # Rescaling value

    # For cases of OOM (Not really recommended, only use if facing unsolvable OOM errors, 
    # also consider clipping your samples to smaller chunks)
    max_mel_frames=900,
    # Only relevant when clip_mels_length = True, please only use after trying output_per_steps=3
    #  and still getting OOM errors.
    
    # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
    # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
    # Does not work if n_ffit is not multiple of hop_size!!
    use_lws=False,
    
    n_fft=800,  # Extra window size is filled with 0 paddings to match this parameter
    hop_size=200,  # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
    win_size=800,  # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
    sample_rate=16000,  # 16000Hz (corresponding to librispeech) (sox --i <filename>)

    n_fft_den=512,
    hop_size_den=160,
    win_size_den=400,
    
    frame_shift_ms=None,  # Can replace hop_size parameter. (Recommended: 12.5)
    
    # Mel and Linear spectrograms normalization/scaling and clipping
    signal_normalization=False,
    # Whether to normalize mel spectrograms to some predefined range (following below parameters)
    allow_clipping_in_normalization=False,  # Only relevant if mel_normalization = True
    symmetric_mels=False,
    # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 
    # faster and cleaner convergence)
    max_abs_value=4.,
    # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 
    # be too big to avoid gradient explosion, 
    # not too small for fast convergence)
    normalize_for_wavenet=False,
    # whether to rescale to [0, 1] for wavenet. (better audio quality)
    clip_for_wavenet=False,
    # whether to clip [-max, max] before training/synthesizing with wavenet (better audio quality)
    
    # Contribution by @begeekmyfriend
    # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 
    # levels. Also allows for better G&L phase reconstruction)
    preemphasize=False,  # whether to apply filter
    preemphasis=0.97,  # filter coefficient.
    
    # Limits
    min_level_db=-100,
    ref_level_db=20,
    fmin=55,
    # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 
    # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
    fmax=7600,  # To be increased/reduced depending on data.
    
    # Griffin Lim
    power=1.5,
    # Only used in G&L inversion, usually values between 1.2 and 1.5 are a good choice.
    griffin_lim_iters=60,
    # Number of G&L iterations, typically 30 is enough but we use 60 to ensure convergence.
    ###########################################################################################################################################
    
    N=25,
    img_size=96,
    fps=25,
        
    n_gpu=1,
    batch_size=16,
    num_workers=32,
    initial_learning_rate=1e-3,
    reduced_learning_rate=None,
    nepochs=200,
    ckpt_freq=1,
    validation_interval=3,

    wav_step_size=16000,
    mel_step_size=16,
    spec_step_size=100,
    wav_step_overlap=3200
)


In [None]:
import librosa
import librosa.filters
import numpy as np
from scipy import signal

def load_wav(path, sr):
    return librosa.core.load(path, sr=sr)[0]

def preemphasis(wav, k, preemphasize=True):
    if preemphasize:
        return signal.lfilter([1, -k], [1], wav)
    return wav

def inv_preemphasis(wav, k, inv_preemphasize=True):
    if inv_preemphasize:
        return signal.lfilter([1], [1, -k], wav)
    return wav

def get_hop_size(hparams):
    hop_size = hparams['hop_size']
    if hop_size is None:
        assert hparams['frame_shift_ms'] is not None
        hop_size = int(hparams['frame_shift_ms'] / 1000 * hparams['sample_rate'])
    return hop_size

def linearspectrogram(wav, hparams):
    D = _stft(preemphasis(wav, hparams['preemphasis'], hparams['preemphasize']), hparams)
    S = _amp_to_db(np.abs(D), hparams) - hparams['ref_level_db']
    
    if hparams['signal_normalization']:
        return _normalize(S, hparams)
    return S

def melspectrogram(wav, hparams):
    D = _stft(preemphasis(wav, hparams['preemphasis'], hparams['preemphasize']), hparams)
    S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams['ref_level_db']
    
    if hparams['signal_normalization']:
        return _normalize(S, hparams)
    return S

def inv_linear_spectrogram(linear_spectrogram, hparams):
    """Converts linear spectrogram to waveform using librosa"""
    if hparams.signal_normalization:
        D = _denormalize(linear_spectrogram, hparams)
    else:
        D = linear_spectrogram
    
    S = _db_to_amp(D + hparams['ref_level_db']) #Convert back to linear
    
    if hparams.use_lws:
        processor = _lws_processor(hparams)
        D = processor.run_lws(S.astype(np.float64).T ** hparams['power'])
        y = processor.istft(D).astype(np.float32)
        return inv_preemphasis(y, hparams['preemphasis'], hparams['preemphasize'])
    else:
        return inv_preemphasis(_griffin_lim(S ** hparams['power'], hparams), hparams['preemphasis'], hparams['preemphasize'])

def inv_mel_spectrogram(mel_spectrogram, hparams):
    """Converts mel spectrogram to waveform using librosa"""
    if hparams.signal_normalization:
        D = _denormalize(mel_spectrogram, hparams)
    else:
        D = mel_spectrogram
    
    S = _mel_to_linear(_db_to_amp(D + hparams['ref_level_db']), hparams)  # Convert back to linear
    
    if hparams.use_lws:
        processor = _lws_processor(hparams)
        D = processor.run_lws(S.astype(np.float64).T ** hparams['power'])
        y = processor.istft(D).astype(np.float32)
        return inv_preemphasis(y, hparams['preemphasis'], hparams['preemphasize'])
    else:
        return inv_preemphasis(_griffin_lim(S ** hparams['power'], hparams), hparams['preemphasis'], hparams['preemphasize'])

def _lws_processor(hparams):
    import lws
    return lws.lws(hparams['n_fft'], get_hop_size(hparams), fftsize=hparams['win_size'], mode="speech")

def _griffin_lim(S, hparams):
    """librosa implementation of Griffin-Lim
    Based on https://github.com/librosa/librosa/issues/434
    """
    angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
    S_complex = np.abs(S).astype(np.complex)
    y = _istft(S_complex * angles, hparams)
    for i in range(hparams['griffin_lim_iters']):
        angles = np.exp(1j * np.angle(_stft(y, hparams)))
        y = _istft(S_complex * angles, hparams)
    return y

def _stft(y, hparams):
    if hparams['use_lws']:
        return _lws_processor(hparams).stft(y).T
    else:
        return librosa.stft(y=y, n_fft=hparams['n_fft'], hop_length=get_hop_size(hparams), win_length=hparams['win_size'])

def _istft(y, hparams):
    return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams['win_size'])

# Conversions
_mel_basis = None
_inv_mel_basis = None

def _linear_to_mel(spectogram, hparams):
    global _mel_basis
    if _mel_basis is None:
        _mel_basis = _build_mel_basis(hparams)
    return np.dot(_mel_basis, spectogram)

def _mel_to_linear(mel_spectrogram, hparams):
    global _inv_mel_basis
    if _inv_mel_basis is None:
        _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
    return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))

def _build_mel_basis(hparams):
    assert hparams['fmax'] <= hparams['sample_rate'] // 2
    return librosa.filters.mel(sr=hparams['sample_rate'], n_fft=hparams['n_fft'], n_mels=hparams['num_mels'],
                               fmin=hparams['fmin'], fmax=hparams['fmax'])

def _amp_to_db(x, hparams):
    min_level = np.exp(hparams['min_level_db'] / 20 * np.log(10))
    return 20 * np.log10(np.maximum(min_level, x))

def _db_to_amp(x):
    return np.power(10.0, (x) * 0.05)

def _normalize(S, hparams):
    if hparams['allow_clipping_in_normalization']:
        if hparams['symmetric_mels']:
            return np.clip((2 * hparams['max_abs_value']) * ((S - hparams['min_level_db']) / (-hparams['min_level_db'])) - hparams['max_abs_value'],
                           -hparams['max_abs_value'], hparams['max_abs_value'])
        else:
            return np.clip(hparams['max_abs_value'] * ((S - hparams['min_level_db']) / (-hparams['min_level_db'])), 0, hparams['max_abs_value'])
    
    assert S.max() <= 0 and S.min() - hparams['min_level_db'] >= 0
    if hparams['symmetric_mels']:
        return (2 * hparams['max_abs_value']) * ((S - hparams['min_level_db']) / (-hparams['min_level_db'])) - hparams['max_abs_value']
    else:
        return hparams['max_abs_value'] * ((S - hparams['min_level_db']) / (-hparams['min_level_db']))

def _denormalize(D, hparams):
    if hparams['allow_clipping_in_normalization']:
        if hparams['symmetric_mels']:
            return (((np.clip(D, -hparams['max_abs_value'],
                              hparams['max_abs_value']) + hparams['max_abs_value']) * -hparams['min_level_db'] / (2 * hparams['max_abs_value']))
                    + hparams.min_level_db)
        else:
            return ((np.clip(D, 0, hparams['max_abs_value']) * -hparams['min_level_db'] / hparams['max_abs_value']) + hparams['min_level_db'])
    
    if hparams['symmetric_mels']:
        return (((D + hparams['max_abs_value']) * -hparams['min_level_db'] / (2 * hparams['max_abs_value'])) + hparams.min_level_db)
    else:
        return ((D * -hparams['min_level_db'] / hparams['max_abs_value']) + hparams['min_level_db'])


def db_from_amp(x):
    return 20. * np.log10(np.maximum(1e-5, x))

def amp_from_db(x):
    return np.power(10., x / 20.)

def angle(z):
    return np.arctan2(np.imag(z), np.real(z))

def cast_complex(x): 
    complex_x = x.astype(np.complex64)
    return complex_x

def make_complex(mag, phase):
    mag = cast_complex(mag)
    phase = cast_complex(phase)
    compex_arr = mag * (np.cos(phase) + 1j*np.sin(phase))
    return compex_arr

def normalize_mag(x, min_val=-100, max_val=80):
    return (x - min_val)/float(max_val - min_val)

def normalize_phase(x, min_val=-np.pi, max_val=np.pi):
    return (x - min_val)/float(max_val - min_val)

def unnormalize_mag(y, min_val=-100, max_val=80):
    return float(max_val - min_val) * y + min_val

def unnormalize_phase(y, min_val=-np.pi, max_val=np.pi):
    return float(max_val - min_val) * y + min_val

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
# import audio.audio_utils as audio
# import audio.hparams as hparams
import random
import os
import librosa
     
class DenoisingDataset(Dataset):

    def __init__(self, train_path, sampling_rate):

        self.files = train_path 
#         self.clean_files = get_clean_list(train_path)
        self.sampling_rate = sampling_rate
        

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):

        while(1):
            index = random.randint(0, len(self.files) - 1)
            fname = self.files[index]

            mel, stft, y = self.process_audio(fname)

            if mel is None or stft is None or y is None:
                continue

            inp_mel = torch.FloatTensor(np.array(mel)).unsqueeze(1)
            inp_stft = torch.FloatTensor(np.array(stft))
            gt_stft = torch.FloatTensor(np.array(y))

            return inp_mel, inp_stft, gt_stft 


    def process_audio(self, file):

        # Load the gt wav file
        try:
            mix_wav = load_wav(file, self.sampling_rate)                   # m
        except:
            return None, None, None

        # Mix the noisy wav file with the clean GT file
        try:
            clean_wav = load_wav(get_clean_file(file), self.sampling_rate)    
        except:
            return None, None, None
        noisy_wav = mix_wav - clean_wav

        # Extract the corresponding audio segments of 1 second
        start_idx, gt_seg_wav, noisy_seg_wav = self.crop_audio_window(clean_wav, mix_wav, noisy_wav)
        
        if start_idx is None or gt_seg_wav is None or noisy_seg_wav is None:
            return None, None, None


        # -----------------------------------STFTs--------------------------------------------- #
        # Get the STFT, normalize and concatenate the mag and phase of GT and noisy wavs
        gt_spec = self.get_spec(gt_seg_wav)                                     # Tx514

        noisy_spec = self.get_spec(noisy_seg_wav)                               # Tx514 
        # ------------------------------------------------------------------------------------- #


        # -----------------------------------Melspecs------------------------------------------ #                          
        noisy_mels = self.get_segmented_mels(start_idx, noisy_wav)              # Tx80x16
        if noisy_mels is None:
            return None, None, None
        # ------------------------------------------------------------------------------------- #
        
        # Input to the lipsync student model: Noisy melspectrogram
        inp_mel = np.array(noisy_mels)                                          # Tx80x16

        # Input to the denoising model: Noisy linear spectrogram
        inp_stft = np.array(noisy_spec)                                         # Tx514

        # GT to the denoising model: Clean linear spectrogram
        gt_stft = np.array(gt_spec)                                             # Tx514

        
        return inp_mel, inp_stft, gt_stft


    def crop_audio_window(self, gt_wav, noisy_wav, random_wav):

        if gt_wav.shape[0] - hparams['wav_step_size'] <= 1280: 
            return None, None, None

        # Get 1 second random segment from the wav
        start_idx = np.random.randint(low=1280, high=gt_wav.shape[0] - hparams['wav_step_size'])
        end_idx = start_idx + hparams['wav_step_size']
        gt_seg_wav = gt_wav[start_idx : end_idx]
        
        if len(gt_seg_wav) != hparams['wav_step_size']: 
            return None, None, None

        noisy_seg_wav = noisy_wav[start_idx : end_idx]
        if len(noisy_seg_wav) != hparams['wav_step_size']: 
            return None, None, None

        # Data augmentation
        aug_steps = np.random.randint(low=0, high=3200)
        aug_start_idx = np.random.randint(low=0, high=hparams['wav_step_size']- aug_steps)
        aug_end_idx = aug_start_idx+aug_steps

        aug_types = ['zero_speech', 'reduce_speech', 'increase_noise']
        aug = random.choice(aug_types)

        if aug == 'zero_speech':    
            noisy_seg_wav[aug_start_idx:aug_end_idx] = 0.0
            
        elif aug == 'reduce_speech':
            noisy_seg_wav[aug_start_idx:aug_end_idx] = 0.1*gt_seg_wav[aug_start_idx:aug_end_idx]

        elif aug == 'increase_noise':
            random_seg_wav = random_wav[start_idx : end_idx]
            noisy_seg_wav[aug_start_idx:aug_end_idx] = gt_seg_wav[aug_start_idx:aug_end_idx] + (2*random_seg_wav[aug_start_idx:aug_end_idx])

        return start_idx, gt_seg_wav, noisy_seg_wav


    def crop_mels(self, start_idx, noisy_wav):
        
        end_idx = start_idx + 3200

        # Get the segmented wav (0.2 second)
        noisy_seg_wav = noisy_wav[start_idx : end_idx]
        if len(noisy_seg_wav) != 3200: 
            return None
        
        # Compute the melspectrogram using librosa
        spec = melspectrogram(noisy_seg_wav, hparams).T              # 16x80
        spec = spec[:-1] 

        return spec


    def get_segmented_mels(self, start_idx, noisy_wav):

        mels = []
        if start_idx - 1280 < 0: 
            return None

        # Get the overlapping continuous segments of noisy mels
        for i in range(start_idx, start_idx + hparams['wav_step_size'],40): 
            m = self.crop_mels(i - 1280, noisy_wav)                             # Hard-coded to get 0.2sec segments (5 frames)
            if m is None or m.shape[0] != hparams['mel_step_size']:
                return None
            mels.append(m.T)

        mels = np.asarray(mels)                                             # Tx80x16

        return mels


    def get_spec(self, wav):

        # Compute STFT using librosa
        stft = librosa.stft(y=wav, n_fft=hparams['n_fft_den'], \
               hop_length=hparams['hop_size_den'], win_length=hparams['win_size_den']).T
        stft = stft[:-1]                                                        # Tx257

        # Decompose into magnitude and phase representations
        mag = np.abs(stft)
        mag = db_from_amp(mag)
        phase = angle(stft)

        # Normalize the magnitude and phase representations
        norm_mag = normalize_mag(mag)
        norm_phase = normalize_phase(phase)
            
        # Concatenate the magnitude and phase representations
        spec = np.concatenate((norm_mag, norm_phase), axis=1)               # Tx514
        
        return spec


def load_data( train_path, num_workers, batch_size=4, sampling_rate=16000, shuffle=False):
    
    dataset = DenoisingDataset( train_path, sampling_rate)

    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)

    return data_loader

In [None]:
# dataloader=load_data(train_noisy_paths, 4, batch_size=4, sampling_rate=16000, shuffle=False)

In [None]:
# for idx, batch in enumerate(dataloader):
#     print(f'{idx}: {batch[1].size()}')

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import math
import numpy as np

class MyModel2D(nn.Module):
    def __init__(self):
        super(MyModel2D, self).__init__()
        self.pad = nn.ZeroPad2d((0, 0, 3, 4))
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(9, 16), stride=(1, 1), padding=(0, 0), bias=False)
        self.relu = nn.PReLU()
        self.batchnorm1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0), bias=False)
        self.batchnorm2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0), bias=False)
        self.batchnorm3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0), bias=False)
        self.batchnorm4 = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0), bias=False)
        self.batchnorm5 = nn.BatchNorm2d(256)
        
        self.conv5_1 = nn.Conv2d(256, 512, kernel_size=(3, 1), stride=(8, 1), padding=(1, 0), bias=False)
        self.batchnorm5_1 = nn.BatchNorm2d(512)
        self.conv6_1 = nn.ConvTranspose2d(512, 256, kernel_size=(3, 1), stride=(8, 1), padding=(1, 0), bias=False)
        self.batchnorm6_1 = nn.BatchNorm2d(256)
        
        self.conv6 = nn.ConvTranspose2d(256, 128, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0), bias=False)
        self.batchnorm6 = nn.BatchNorm2d(128)
        self.conv7 = nn.ConvTranspose2d(128, 64, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0), bias=False)
        self.batchnorm7 = nn.BatchNorm2d(64)
        self.conv8 = nn.ConvTranspose2d(64, 32, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0), bias=False)
        self.batchnorm8 = nn.BatchNorm2d(32)
        self.conv9 = nn.ConvTranspose2d(32, 16, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0), bias=False)
        self.batchnorm9 = nn.BatchNorm2d(16)
        self.spatialdropout = nn.Dropout2d(0.2)
        self.conv10 = nn.Conv2d(16, 1, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))

        self.pad2 = nn.ZeroPad2d((0, 0, 2, 1))
        self.pad3 = nn.ZeroPad2d((0, 0, 1, 2))
        self.pad4 = nn.ZeroPad2d((0, 0, 0, 1))
    def forward(self, x):
        #print(x.size())
        x = self.pad(x)
        #print(x.size())
        #x = self.conv1(x)
        skip9 = self.conv1(x)
        x = self.relu(skip9)
        #print(x.size())
        x = self.batchnorm1(x)
        skip8 = self.conv2(x)
        x = self.relu(skip8)
        #print(x.size())
        x = self.batchnorm2(x)
        skip7 = self.conv3(x)
        x = self.relu(skip7)
        #print(x.size())
        x = self.batchnorm3(x)
        skip6 = self.conv4(x)
        x = self.relu(skip6)
        #print(x.size())
        x = self.batchnorm4(x)
        skip6_1 = self.conv5(x)
        #x = self.pad3(x)
        #print(x.size())
        x = self.relu(skip6_1)
        #print(x.size())
        #x = self.batchnorm5(x)
        
        x = self.conv5_1(x)
        #x = self.pad3(x)
        #print(x.size())
        x = self.relu(x)
        #print(x.size())
        x = self.batchnorm5_1(x)
        #x = self.upsample2(x)
        x = self.conv6_1(x)
        #print(x.size())
        x = x + skip6_1
        
        
        
        #x = self.upsample(x)
        x = self.conv6(x)
        #print(x.size())
        #x = x + skip6
        x = self.relu(x)
        #print(x.size())
        x = self.batchnorm6(x)
        #x = self.upsample(x)
        #x = self.pad4(x)
        x = self.conv7(x)
        x = self.pad2(x)
        x = x + skip7
        x = self.relu(x)
        #print(x.size())
        x = self.batchnorm7(x)
        #x = self.upsample(x)
        x = self.conv8(x)
        #x = self.pad3(x)
        #x = x + skip8
        x = self.relu(x)
        #print(x.size())
        x = self.batchnorm8(x)
        #x = self.upsample(x)
        #x = self.pad4(x)
        x = self.conv9(x)
        x = self.pad3(x)
        x = x + skip9
        x = self.relu(x)
        #print(x.size())
        x = self.batchnorm9(x)
        x = self.spatialdropout(x)
        x = self.conv10(x)
        
        x = self.pad4(x)
        #print(x.size())
        return x


In [None]:
class encoder(nn.Module):
    def __init__(self):
        super(encoder, self).__init__()
        self.conv1_1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.norm1 = nn.BatchNorm2d(16)
        self.act = nn.LeakyReLU()
        self.conv1_2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
        self.conv2_1 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.BatchNorm2d(16)
        self.conv2_2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
        self.conv3_1 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.norm3 = nn.BatchNorm2d(32)
        self.conv3_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.conv4_1 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.norm4 = nn.BatchNorm2d(32)
        self.conv4_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.conv5_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.norm5 = nn.BatchNorm2d(64)
        self.conv5_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv6_1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.norm6 = nn.BatchNorm2d(64)
        self.conv6_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv7_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.norm7 = nn.BatchNorm2d(128)
        self.conv7_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.conv8_1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.norm8 = nn.BatchNorm2d(128)
        self.conv8_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.conv9_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.norm9 = nn.BatchNorm2d(256)
        self.conv9_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        
        self.conv0_1 = nn.Conv2d(256, 256, kernel_size=1, stride=2, padding=0)
        self.norm9 = nn.BatchNorm2d(256)
        self.conv0_2 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)
        self.up = nn.Upsample(scale_factor=2, mode='bicubic')
        
        self.deconv1_0 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1,padding=1)
        self.deconv1_1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.deconv1_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.deconv2_0 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1,padding=1)
        self.deconv2_1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.deconv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.deconv3_0 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1,padding=1)
        self.deconv3_1 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.deconv3_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.deconv4_0 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=1,padding=1)
        self.deconv4_1 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
        self.deconv4_2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
        self.deconv5_0 = nn.ConvTranspose2d(16, 4, kernel_size=3, stride=1,padding=1)
        self.deconv5_1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
        self.deconv5_2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
        self.deconv6 = nn.Conv2d(4, 1, kernel_size=1)
        
    def forward(self, audio_sequences):
        print('Start!!!') 
        B = audio_sequences.size()
#         print(B)
        audio_sequences = torch.unsqueeze(audio_sequences, 1)
        input_dim_size = len(audio_sequences.size())
#         if input_dim_size < 3:
#             audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
        print(audio_sequences.size())
        x = self.conv1_1(audio_sequences) 
        x = self.act(x)
        print(x.size())
        x = self.conv1_2(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv2_1(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv2_2(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv3_1(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv3_2(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv4_1(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv4_2(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv5_1(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv5_2(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv6_1(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv6_2(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv7_1(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv7_2(x)
        x = self.act(x)
        print(x.size())
        x = self.conv8_1(x) 
        x = self.act(x)
        print(x.size())
        x = self.conv8_2(x)
        x = self.act(x)
        print(x.size())
        x = self.conv9_1(x) 
        print(x.size())
        x = self.act(x)
        x = self.conv9_2(x) 
        print(x.size())
        x = self.act(x)
        print('Encoder finish')
        
        x = self.up(x) 
        print(x.size())
        x = self.conv0_1(x)
        print(x.size())
        x = self.act(x)
        x = self.conv0_2(x)
        print(x.size())
        x = self.act(x)
        x = self.up(x) 
        print(x.size())
        x = self.conv0_1(x)
        print(x.size())
        x = self.act(x)
        x = self.conv0_2(x)
        print(x.size())
        x = self.act(x)
        x = self.up(x) 
        print(x.size())
        x = self.conv0_1(x)
        print(x.size())
        x = self.act(x)
        x = self.conv0_2(x)
        print(x.size())
        x = self.act(x)
        x = self.up(x) 
        print(x.size())
        x = self.conv0_1(x)
        print(x.size())
        x = self.act(x)
        x = self.conv0_2(x)
        print(x.size())
        x = self.act(x)
        print('Up finish')

        
        x = self.deconv1_0(x) 
        print(x.size())
        x = self.deconv1_1(x) 
        x = self.act(x)
        print(x.size())
        x = self.deconv1_2(x) 
        x = self.act(x)
        print(x.size())
        x = self.deconv2_0(x) 
        print(x.size())
        x = self.deconv2_1(x) 
        x = self.act(x)
        print(x.size())
        x = self.deconv2_2(x) 
        x = self.act(x)
        print(x.size())
        x = self.deconv3_0(x) 
        print(x.size())
        x = self.deconv3_1(x) 
        x = self.act(x)
        print(x.size())
        x = self.deconv3_2(x) 
        x = self.act(x)
        print(x.size())
        x = self.deconv4_0(x) 
        print(x.size())
        x = self.deconv4_1(x) 
        x = self.act(x)
        print(x.size())
        x = self.deconv4_2(x) 
        x = self.act(x)
        print(x.size())
        x = self.deconv5_0(x) 
        print(x.size())
        x = self.deconv5_1(x) 
        x = self.act(x)
        print(x.size())
        x = self.deconv5_2(x) 
        x = self.act(x)
        print(x.size())
        x = self.deconv6(x) 
        print(x.size())
        print('Decoder finish')

            
        return self.act(x)
        

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
import numpy as np
import os
import argparse
# from models import *
# import audio.hparams as hparams 
# from scripts.data_loader import *
from tqdm import tqdm
import librosa
import torch
import torch.optim as optim
import cv2
import subprocess

# Initialize the global variables
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))
device = torch.device("cuda" if use_cuda else "cpu")


# Function to reconstruct the wav from the magnitude and phase representations
def reconstruct_wav(stft):

    mag = stft[:257, :]
    phase = stft[257:, :]
    
    denorm_mag = unnormalize_mag(mag)
    denorm_phase = unnormalize_phase(phase)
    recon_mag = amp_from_db(denorm_mag)
    complex_arr = make_complex(recon_mag, denorm_phase)
    print(complex_arr.shape)
    wav = librosa.istft(complex_arr, hop_length=hparams['hop_size_den'], win_length=hparams['win_size_den'])
    
    return wav
 
# Function to generate and save the sample audio/video files
def save_samples(gt_stft, inp_stft, output_stft, faces, epoch, checkpoint_dir):

    gt_stft = gt_stft.detach().cpu().numpy()
    inp_stft = inp_stft.detach().cpu().numpy()
    output_stft = output_stft.detach().cpu().numpy()
    faces = faces.permute(0,2,3,4,1)
    faces = faces.detach().cpu().numpy()

    folder = join(checkpoint_dir, "samples_step{:04d}".format(epoch))
    if not os.path.exists(folder): 
        os.mkdir(folder)

    for step in range((gt_stft.shape[0])): 

        # Save GT audio
        gt = gt_stft[step]
        gt_wav = reconstruct_wav(gt.T)
        gt_aud_fname = os.path.join(folder, str(step)+'_gt.wav')
        librosa.output.write_wav(gt_aud_fname, gt_wav, 16000) 

        # Save input audio
        inp = inp_stft[step]
        inp_wav = reconstruct_wav(inp.T)
        inp_aud_fname = os.path.join(folder, str(step)+'_inp.wav')
        librosa.output.write_wav(inp_aud_fname, inp_wav, 16000)            
        
        # Save generated audio
        generated = output_stft[step]
        generated_wav = reconstruct_wav(generated.T)
        generated_aud_fname = os.path.join(folder, str(step)+'_pred.wav')
        librosa.output.write_wav(generated_aud_fname, generated_wav, 16000)            

        # Save generated video
        generated_vid_fname = os.path.join(folder, str(step)+'_pred')
        generate_video(faces[step], generated_aud_fname, generated_vid_fname)     
    
    print("Saved samples:", folder)


def train(device, model, train_loader, test_loader, optimizer, epoch_resume, total_epochs, checkpoint_dir, args):

    l1_loss = nn.MSELoss()

    for epoch in range(epoch_resume+1, total_epochs+1):

        print("Epoch %d" %epoch)

        total_loss = 0.0
#         for idx, batch in enumerate(train_loader):
#             print(f'{idx}: {batch[0]}')
            
        progress_bar = tqdm(train_loader)
        step = 0
        
        

        for (inp_mel, inp_stft, gt_stft) in progress_bar:
            print('Loder запущен')
            model.train()
            optimizer.zero_grad()
            

            # Transform data to CUDA device
            inp_mel = inp_mel.to(device)                                        # BxTx1x80x16
            inp_stft = inp_stft.to(device)                                      # BxTx514
            gt_stft = gt_stft.to(device)                                        # BxTx514
            
            
            # Generate the clean stft
            output_stft = model(inp_stft)                               # BxTx514

            # Compute the L1 reconstruction loss
            loss = l1_loss(output_stft, gt_stft)
            total_loss += loss.item()
            
            # Backpropagate
            loss.backward()
            optimizer.step()

            # Display the training progress
            progress_bar.set_description('Loss: {}'.format(total_loss / (step + 1))) 
            inp_stft = np.transpose(inp_stft.cpu().detach().numpy()[0])
            gt_stft = np.transpose(gt_stft.cpu().detach().numpy()[0])
            output_stft = np.transpose(output_stft.cpu().detach().numpy()[0])
            progress_bar.refresh()
            step+=1
        if epoch > 1:
            return reconstruct_wav(inp_stft) , reconstruct_wav(output_stft), reconstruct_wav(gt_stft) 

        train_loss = total_loss / total_batch
        

        # Save the checkpoint
        '''if epoch % args.ckpt_freq == 0:

            # Save the model
            save_checkpoint(model, optimizer, train_loss, checkpoint_dir, epoch)'''

        # Validation loop
        '''if epoch % args.validation_interval == 0:
            with torch.no_grad():
                validate(device, model, test_loader, epoch, checkpoint_dir)'''

    
def validate(device, model, test_loader, epoch, checkpoint_dir):

    print('\nEvaluating for {} steps'.format(len(test_loader)))

    l1_loss = nn.L1Loss()

    losses = []

    for step, (inp_mel, inp_stft, gt_stft) in enumerate(test_loader):

        model.eval()

        # Transform data to CUDA device
        inp_mel = inp_mel.to(device)
        inp_stft = inp_stft.to(device)
        gt_stft = gt_stft.to(device)
        

        # Generate the clean stft
        output_stft = model(inp_stft)

        # Compute the L1 reconstruction loss
        loss = l1_loss(output_stft, gt_stft)
        losses.append(loss.item())

    # Compute the average of the validation loss
    averaged_loss = sum(losses) / len(losses)
    print("Validation loss: ", averaged_loss)

    # Save the GT and the denoised files
    save_samples(gt_stft, inp_stft, output_stft, faces, epoch, checkpoint_dir)

    return

def save_checkpoint(model, optimizer, train_loss, checkpoint_dir, epoch):

    checkpoint_path = join(checkpoint_dir, "checkpoint_step{:04d}.pt".format(epoch))

    torch.save({
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "loss": train_loss,
        "epoch": epoch,
    }, checkpoint_path)
    
    print("Saved checkpoint:", checkpoint_path)

def _load(checkpoint_path):
    
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)

    return checkpoint


def load_checkpoint(path, model, optimizer, reset_optimizer=False):

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    s = checkpoint["state_dict"]
    new_s = {}

    for k, v in s.items():
        if hparams['n_gpu'] > 1:
            if not k.startswith('module.'):
                new_s['module.'+k] = v
            else:
                new_s[k] = v
        else:
            new_s[k.replace('module.', '')] = v

    model.load_state_dict(new_s)

    epoch_resume = 0
    if not reset_optimizer:
        optimizer_state = checkpoint["optimizer"]
        if optimizer_state is not None:
            print("Load optimizer state from {}".format(path))
            optimizer.load_state_dict(checkpoint["optimizer"])

        epoch_resume = checkpoint['epoch']
        loss = checkpoint['loss']

        print("Model resumed for training...")
        print("Epoch: ", epoch_resume)
        print("Loss: ", loss)
    
    return model, epoch_resume

# Call the data generator to get the data
train_loader = load_data(train_path=train_noisy_paths,
                         num_workers= 4,                 # hparams['num_workers'], 
                         batch_size= 14,                   # hparams['batch_size'],
                         shuffle=True)

total_batch = len(train_loader)
print("Total train batch: ", total_batch)

test_loader = load_data(train_path= tast_noisy_paths, num_workers=hparams['num_workers'], batch_size=hparams['batch_size'], shuffle=False)

# Initialize the Denoising model 
model = MyModel2D()
# model = encoder()
if hparams['n_gpu'] > 1:
    print("Using", hparams['n_gpu'], "GPUs for the denoising model!")
    model = nn.DataParallel(model)
else:
    print("Using single GPU for the denoising model!")
model.to(device)
    
print('Total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

    # Set the learning rate
if hparams['reduced_learning_rate'] is not None:
    lr = hparams['reduced_learning_rate']
else:
    lr = hparams['initial_learning_rate']

# Set the optimizer
optimizer = optim.Adam(list(model.parameters()), lr=lr) # [p for p in model.parameters() if p.requires_grad]

# Resume the denoising model for training if the path is provided
args = dict(checkpoint_path = None,
            continue_epoch = True,
           checkpoint_dir = '/kaggle/working' )
epoch_resume=0
if args['checkpoint_path'] is not None:
    model, epoch_resume = load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)

if args['continue_epoch']==True:
    epoch = epoch_resume
else:
    epoch = 0

# Create the folder to save checkpoints
checkpoint_dir = args['checkpoint_dir']
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Train!
tr = train(device, model, train_loader, test_loader, optimizer, epoch, hparams['nepochs'], checkpoint_dir, args)
IPython.display.Audio(data=tr[0], rate=16000)
IPython.display.Audio(data=tr[1], rate=16000)
print("Finished")

In [None]:
IPython.display.Audio(data=tr[0], rate=16000)

In [None]:
IPython.display.Audio(data=tr[1], rate=16000)

In [None]:
IPython.display.Audio(data=tr[2], rate=16000)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from IPython.display import clear_output
from PIL import Image
from matplotlib import cm
from time import perf_counter
from torch.utils.data import DataLoader
from tqdm import tqdm

train_loader = load_data(train_path=train_noisy_paths, num_workers=hparams['num_workers'], batch_size=hparams['batch_size'], shuffle=True)

def train(model: nn.Module) -> float:
    model.train()

    train_loss = 0

    for x, y in tqdm(train_loader, desc='Train'):
        optimizer.zero_grad()

        output = model(x)

        loss = loss_fn(output, y)

        train_loss += loss.item()

        loss.backward()

        optimizer.step()

    train_loss /= len(train_loader)
    
    return train_loss

@torch.inference_mode()
def evaluate(model: nn.Module, loader: DataLoader):
    model.eval()

    total_loss = 0
    total = 0
    correct = 0

    for x, y in tqdm(loader, desc='Evaluation'):
        output = model(x)

        loss = loss_fn(output, y)

        total_loss += loss.item()

        _, y_pred = torch.max(output, 1)
        total += y.size(0)
        correct += (y_pred == y).sum().item()

    total_loss /= len(loader)
    accuracy = correct / total

    return total_loss, accuracy

def plot_stats(
    train_loss,
    valid_loss,
    valid_accuracy,
    title: str
):
    plt.figure(figsize=(16, 8))

    plt.title(title + ' loss')

    plt.plot(train_loss, label='Train loss')
    plt.plot(valid_loss, label='Valid loss')
    plt.legend()
    plt.grid()

    plt.show()

    plt.figure(figsize=(16, 8))

    plt.title(title + ' accuracy')

    plt.plot(valid_accuracy)
    plt.grid()

    plt.show()

model = UNetLikeModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

loss_fn = nn.L1Loss()

In [None]:
num_epochs = 15

train_loss_history, valid_loss_history = [], []
valid_accuracy_history = []

start = perf_counter()
accuracy = 0
valid_loader = test_loader


for epochs in range(num_epochs):
    train_loss = train(model)

    valid_loss, valid_accuracy = evaluate(model, valid_loader)

    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)
    valid_accuracy_history.append(valid_accuracy)
    plot_stats(train_loss_history, valid_loss_history, valid_accuracy_history, 'MLP model')