<a href="https://colab.research.google.com/github/JeppeLL/deep-learning-course/blob/master/SE_FFTnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install torchaudio
!pip install pystoi
!pip install https://github.com/ludlows/python-pesq/archive/master.zip
import numpy as np
import glob
import librosa
import sys
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import Audio
import torch
import torch.nn.functional as F
import torch.nn as nn
#import torch.optim as optim
import torchaudio as ta
if torch.cuda.is_available():
  cuda=True
else:
  cuda=False
#import datetime as dt
import os
from torch.autograd import Variable
from IPython.display import Image, display, clear_output
import datetime
from pystoi.stoi import stoi
from pesq import pesq

from google.colab import drive
drive.mount('/content/drive')

Collecting https://github.com/ludlows/python-pesq/archive/master.zip
[?25l  Downloading https://github.com/ludlows/python-pesq/archive/master.zip
[K     | 399kB 268kB/s
Building wheels for collected packages: pesq
  Building wheel for pesq (setup.py) ... [?25l[?25hdone
  Created wheel for pesq: filename=pesq-0.0.1-cp36-cp36m-linux_x86_64.whl size=162016 sha256=0b661fd7a3918c174fd47f565cda73af0f8325c587791ae897782f4a633ff633
  Stored in directory: /tmp/pip-ephem-wheel-cache-2nd1anyj/wheels/85/91/09/5ae7677a054a05d49111dc8f3b282e886b3852348384893a32
Successfully built pesq
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# DataLoader class:

In [0]:
def get_filepaths(data_directory, dataset):
    """
    Returns a list of file paths for the specified dataset.
    """
    assert dataset in {'clean-speech', 'noise', 'impulse-responses'}
    
    regex = '/'.join((data_directory, dataset, '*.wav'))
    filepaths = glob.glob(regex)
    
    if len(filepaths) == 0:
        raise Exception('No files were found in the specified dataset!')
    
    return filepaths

class DataLoader():
  def __init__(self, data_directory='/content/drive/My Drive/reverberation'
               , sample_rate=16000
               ,GT_range=[0,3121], IR_range=[0,27], N_range=[0,15]
               ,sequence_length_GT=32000, sequence_length_IR=16000):
    
    self.sample_rate=sample_rate
    self.seq_len_GT=sequence_length_GT
    self.seq_len_IR=sequence_length_IR
    self.marc=None
    self.GT_is_padded = False

    if GT_range:
      self.paths_GT=get_filepaths(data_directory, 'clean-speech')[GT_range[0]:GT_range[1]]
    else:
      self.paths_GT=None
    if IR_range:
      self.paths_IR=get_filepaths(data_directory, 'impulse-responses')[IR_range[0]:IR_range[1]]
    else:
      self.paths_IR=None
    if N_range:
      self.paths_N=get_filepaths(data_directory, 'noise')[N_range[0]:N_range[1]]
    else:
      self.paths_N=None
    print(f"Found {len(self.paths_GT)} ground truth files in data set")
    print(f"Found {len(self.paths_IR)} impulse response files in data set")
    print(f"Found {len(self.paths_N)} noise files in data set\n")


    self.use_cuda = torch.cuda.is_available()
    print("Running GPU.") if self.use_cuda else print("No GPU available.")

    ###Load and normalize
    self.GT=list()
    self.IR=list()
    self.N=list()
    
    self.GT_lengths=list()
    self.IR_lengths=list()
    self.N_lengths=list()

    self.GT_filenames=list()
    self.IR_filenames=list()
    self.N_filenames=list()
    
    #lengths=list()
    #filenames=list()
    #maxcount=len(path)

    if GT_range:
      print("GT")
      count=0
      maxcount=GT_range[1]-GT_range[0]
      for file in self.paths_GT:
        count+=1
        if count % 100 == 1:
          print(f"Processing file {count} of {maxcount}...")
        
        ##Load the file into memory
        d = self.load(file)

        #Add to class memory:
        self.GT.append(d)
        self.GT_lengths.append(d.shape[1])
        self.GT_filenames.append(file)

    if IR_range:
      print("IR")
      count=0
      maxcount=IR_range[1]-IR_range[0]
      for file in self.paths_IR:
        count+=1
        if count % 100 == 1:
          print(f"Processing file {count} of {maxcount}...")
        
        ##Load the file into memory
        d = self.load(file)

        #Add to class memory:
        self.IR.append(d)
        self.IR_lengths.append(d.shape[1])
        self.IR_filenames.append(file)


    if N_range:
      print("N")
      count=0
      maxcount=N_range[1]-N_range[0]
      for file in self.paths_N:
        count+=1
        if count % 100 == 1:
          print(f"Processing file {count} of {maxcount}...")
        
        ##Load the file into memory
        d = self.load(file)

        #Add to class memory:
        self.N.append(d)
        self.N_lengths.append(d.shape[1])
        self.N_filenames.append(file)
  


  def load(self, file):
    ##Load the file into memory
    if self.use_cuda:
      d, sr = ta.load(file)
      #sr = sr.cuda()
      if d.shape[0]>1:
        d=d[0,:].unsqueeze(0)
      
      #print(d.shape)
      #print()
      if sr!=self.sample_rate:
        #print("Sample rates not equal")
        d = ta.transforms.Resample(sr, self.sample_rate)(d)
      #d = d.squeeze(0)
      d  = d.cuda()

    else:    
      d, sr = librosa.load(file, sr=self.sample_rate)
      #normalize
      d_minus_mean=d-np.mean(d)
      d = d_minus_mean/np.max(np.abs(d_minus_mean))
    return d
  
  def cropAndPadIR(self,matrix_ops=True):
    #IR=self.IR
    for i in range(len(self.IR)):
      idx = torch.argmax(torch.abs(self.IR[i]))
      self.IR[i]=self.IR[i][:,idx:]
      self.IR[i]=torch.cat((self.IR[i][:,:self.seq_len_IR],
                            torch.zeros(1,max(0,self.seq_len_IR-self.IR[i].shape[1])).cuda()),dim=1)
    if matrix_ops:
      self.IR=torch.stack(self.IR)
      
  
  def cropAndPadGT(self,matrix_ops=True):
    #for gt in self.GT:
    #  gt=torch.cat((gt[:,:self.seq_len_GT],torch.zeros(1,max(0,self.seq_len_GT-gt.shape[1])).cuda()),dim=1) 
    for i in range(len(self.GT)):
      self.GT[i]=torch.cat((self.GT[i][:,:self.seq_len_GT],
                            torch.zeros(1,max(0,self.seq_len_GT-self.GT[i].shape[1])).cuda()),dim=1) 
    self.GT_is_padded=True
    if matrix_ops:
      self.GT=torch.stack(self.GT)
    
  def add_reverb(self,matrix_ops=True):
    ## Takes the ground truth (GT_files), and applies convolutions from IR_files ##
    ## GT_lengths is the amount of samples in each ground truth files ##
    ## IR_lengths is the amount of samples in each impulse response files ##
    if self.GT_is_padded: 
      if self.use_cuda:
        if matrix_ops:
          self.GT_IR = torch.nn.functional.conv1d(self.GT,torch.flip(self.IR,[2]),padding=self.IR.shape[2])[:,:,1:-self.IR.shape[2]]
          self.GT_IR = self.GT_IR.contiguous()
        else:
          self.GT_IR = []
          len_GT = len(self.GT)
          for gt in range(len_GT):
            if gt % 100 == 1:
              print(f"Adding reverb to GT {gt} of {len_GT}...")
            for ir in range(len(self.IR)):
              gt_ir = torch.nn.functional.conv1d(self.GT[gt].unsqueeze(0),torch.flip(self.IR[ir].unsqueeze(0),[2]),padding=self.IR[ir].shape[1])[:,:,1:-self.IR[ir].shape[1]]
              self.GT_IR.append(gt_ir)
          print("stacking GT_IR")
          self.GT_IR = torch.stack(self.GT_IR) 
          print("stacking GT")
          self.GT = torch.stack(self.GT)
          print("stacking IR")
          self.IR = torch.stack(self.IR)

    else:
      print("ERROR: you should pad GT first")
  

  def normalize(self):
    max_abs = torch.max(torch.abs(self.GT_IR),2)[0]
    max_abs = max_abs.unsqueeze(2) #Need to be shape (GT,IR,1)
    eps = 1e-12
    self.GT_IR = self.GT_IR / (max_abs + eps)

  def cpu(self):
    self.IR=self.IR.cpu()
    self.GT=self.GT.cpu()
    self.GT_IR=self.GT_IR.cpu()

  def spectrogram(self,n_fft=160, melScale=False):
    ##Turns raw audio signal into a Spectrogram. Either on the Hz-scale or Mel-scale.
    ##Also converts to logarithmic scale.
    if melScale:
      spect = ta.transforms.MelSpectrogram(sample_rate=SAMPLE_RATE, n_fft=n_fft)
    else:
      spect = ta.transforms.Spectrogram(n_fft=n_fft)
    #self.GT_spectrogram=spect(self.GT.squeeze(1))[:,:,:-1].log2()
    #self.GT_IR_spectrogram=spect(self.GT_IR.view(-1,1,self.GT_IR.shape[2]).squeeze(1))[:,:,:-1].log2()
    self.GT_spectrogram=spect(self.GT.squeeze(1))
    self.GT_IR_spectrogram=spect(self.GT_IR.view(-1,1,self.GT_IR.shape[2]).squeeze(1))
      
  def stft(self, n_fft=160):
    self.GT_stft = torch.stft(self.GT.squeeze(1), n_fft, hop_length=n_fft//2)
    self.GT_IR_stft = torch.stft(self.GT_IR.view(-1,1,self.GT_IR.shape[2]).squeeze(1), n_fft, hop_length=n_fft//2)
    #print(f"GT_stft.shape: {self.GT_stft.shape}")
    #print(f"GT_IR_stft.shape: {self.GT_IR_stft.shape}")
    self.GT_mag,self.GT_phase = ta.functional.magphase(self.GT_stft)
    self.GT_IR_mag,self.GT_IR_phase = ta.functional.magphase(self.GT_IR_stft)
    #print(f"GT_mag.shape: {self.GT_mag.shape}")
    #print(f"GT_IR_mag.shape: {self.GT_IR_mag.shape}")
    #print(f"GT_phase.shape: {self.GT_phase.shape}")
    #print(f"GT_IR_phase.shape: {self.GT_IR_phase.shape}")

# Iterator class:

In [0]:
class Iterator():
  def __init__(self, object, mode='waveform'):
    self.paths_GT = object.paths_GT
    self.paths_IR = object.paths_IR
    self.mode=mode
    if self.mode=='spectrogram':
      self.GT = object.GT_spectrogram.repeat_interleave(len(self.paths_IR),dim=0).unsqueeze(1)
      self.GT_IR = object.GT_IR_spectrogram.unsqueeze(1)
    elif self.mode == 'FFTnet':
      self.GT = object.GT.repeat_interleave(len(self.paths_IR),dim=0)
      self.GT_IR = object.GT_IR
    elif self.mode=='waveform':
      self.GT = object.GT.repeat_interleave(len(self.paths_IR),dim=0)
      self.GT_IR = object.GT_IR
    elif self.mode=='stft':
      self.GT = object.GT_mag.repeat_interleave(len(self.paths_IR),dim=0).unsqueeze(1)
      self.GT_IR = object.GT_IR_mag.unsqueeze(1)
      self.GT_phase = object.GT_phase.repeat_interleave(len(self.paths_IR),dim=0).unsqueeze(1)
      self.GT_IR_phase = object.GT_IR_phase.unsqueeze(1)

  def setChunkSize(self,k):
    self.chunkSize=k
  def setBatchSize(self,k):
    self.batchSize=k

  def chunkify(self):
    if self.mode=='spectrogram':
      self.GT_IR = self.GT_IR.view(-1,1,self.GT_IR.shape[2],self.chunkSize)
      self.GT = self.GT.view(-1,1,self.GT_IR.shape[2],self.chunkSize)
    elif self.mode=='FFTnet':
      pass
      # self.GT_IR = self.GT_IR.view(-1,1,self.GT_IR.shape[2],self.chunkSize)
      self.GT = self.GT.view(-1,self.GT_IR.shape[1],self.GT_IR.shape[2])
    elif self.mode=='waveform':
      self.GT_IR = self.GT_IR.view(-1,1,self.chunkSize)
      self.GT = self.GT.view(-1,1,self.chunkSize)
    elif self.mode=='stft':
      self.GT_IR = self.GT_IR.view(-1,1,self.GT_IR.shape[2],self.chunkSize)
      self.GT = self.GT.view(-1,1,self.GT_IR.shape[2],self.chunkSize)
      self.GT_IR_phase = self.GT_IR_phase.view(-1,1,self.GT_IR_phase.shape[2],self.chunkSize)
      self.GT_phase = self.GT_phase.view(-1,1,self.GT_IR_phase.shape[2],self.chunkSize)
  
  def __iter__(self):
    self.n=0
    return self
  def __next__(self):
    b1=self.n*self.batchSize
    b2=(self.n+1)*self.batchSize

    if self.mode =='FFTnet':
      for i in range(self.batchSize):
        chunk_start_position = self.n%(self.GT_IR.shape[2]-self.chunkSize)
        chunk_number = int(self.n/(self.GT_IR.shape[2]-self.chunkSize))
        GT_IR_file = chunk_number%self.GT_IR.shape[1]
        GT_file =int(chunk_number/self.GT_IR.shape[1])
        GT_position = int((self.chunkSize-1)/2)
        if GT_file <self.GT_IR.shape[0]:
          if i == 0:
            x = self.GT_IR[GT_file,GT_IR_file,chunk_start_position:chunk_start_position+self.chunkSize].view(1,1,-1)
            y = self.GT[GT_file,GT_IR_file,chunk_start_position+GT_position].view(1,1)
            # print(y)
            self.n+=1
          else:
            x = torch.cat((x,self.GT_IR[GT_file,GT_IR_file,chunk_start_position:chunk_start_position+self.chunkSize].view(1,1,-1)),0 )
            y = torch.cat((y,self.GT[GT_file,GT_IR_file,chunk_start_position+GT_position].view(1,1)),0)
            self.n+=1
         
        else:
            raise StopIteration
      return x,y, 0, 0 #last 2 zeros are placeholders, so the iterator is consistent
    if b2<=self.GT_IR.shape[0]:
      if self.mode == 'spectrogram':
        x = self.GT_IR[b1:b2,:,:,:]
        y = self.GT[b1:b2,:,:,:]
        self.n+=1
        return x,y, 0, 0 #last 2 zeros are placeholders, so the iterator is consistent
      elif self.mode == 'waveform':
        x = self.GT_IR[b1:b2,:,:]
        y = self.GT[b1:b2,:,:]
        self.n+=1
        return x,y, 0, 0 #last 2 zeros are placeholders, so the iterator is consistent
      elif self.mode == 'stft':
        x = self.GT_IR[b1:b2,:,:,:]
        y = self.GT[b1:b2,:,:,:]
        x_phase = self.GT_IR_phase[b1:b2,:,:,:]
        y_phase = self.GT_phase[b1:b2,:,:,:]
        self.n+=1
        return x,y, x_phase, y_phase

    else:
      raise StopIteration


# Custom functions (loss, weight init and RAdam)

In [0]:
#init xavier weights
def init_weights(m):
      if type(m) == nn.Conv1d:
          torch.nn.init.xavier_normal_(m.weight)
          m.bias.data.fill_(0)
      elif type(m) == nn.ConvTranspose1d:
          torch.nn.init.xavier_normal_(m.weight)
          m.bias.data.fill_(0)

#Custom sftf loss function for waveform
def stftMAE(y_pred,y_true,n_fft=160):
  pred_stft = torch.stft(y_pred.squeeze(1),n_fft,hop_length=n_fft//2, normalized = True)
  pred_stft = torch.abs(pred_stft)
  true_stft = torch.stft(y_true.squeeze(1),n_fft,hop_length=n_fft//2, normalized = True)
  true_stft = torch.abs(true_stft)
  mae = F.l1_loss(pred_stft,true_stft)
  return mae

def stftHuber(y_pred,y_true,n_fft=160):
  pred_stft = torch.stft(y_pred.squeeze(1),n_fft,hop_length=n_fft//2, normalized = True)
  pred_stft = torch.abs(pred_stft)
  true_stft = torch.stft(y_true.squeeze(1),n_fft,hop_length=n_fft//2, normalized = True)
  true_stft = torch.abs(true_stft)
  mae = F.smooth_l1_loss(pred_stft,true_stft)
  return mae
  
def stftMSE(y_pred,y_true,n_fft=160):
  pred_stft = torch.stft(y_pred,n_fft,hop_length=n_fft//2)
  true_stft = torch.stft(y_true,n_fft,hop_length=n_fft//2)
  mse = F.mse_loss(pred_stft,true_stft)
  return mse

def magMAE(y_pred,y_true,n_fft=160):
  pred_stft = torch.stft(y_pred.squeeze(1),n_fft,hop_length=n_fft//2)
  true_stft = torch.stft(y_true.squeeze(1),n_fft,hop_length=n_fft//2)
  pred_mag,pred_phase = ta.functional.magphase(pred_stft)
  true_mag,true_phase = ta.functional.magphase(true_stft)
  mae = F.l1_loss(pred_mag,true_mag)
  return mae


In [0]:
import math
import torch
from torch.optim.optimizer import Optimizer, required

class RAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        
        self.degenerated_to_sgd = degenerated_to_sgd
        if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
            for param in params:
                if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
                    param['buffer'] = [[None, None, None] for _ in range(10)]
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state['step'] += 1
                buffered = group['buffer'][int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    elif self.degenerated_to_sgd:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    else:
                        step_size = -1
                    buffered[2] = step_size

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
                    p.data.copy_(p_data_fp32)
                elif step_size > 0:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    p_data_fp32.add_(-step_size * group['lr'], exp_avg)
                    p.data.copy_(p_data_fp32)

        return loss

class PlainRAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
                    
        self.degenerated_to_sgd = degenerated_to_sgd
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)

        super(PlainRAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(PlainRAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state['step'] += 1
                beta2_t = beta2 ** state['step']
                N_sma_max = 2 / (1 - beta2) - 1
                N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)


                # more conservative since it's an approximated value
                if N_sma >= 5:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
                    p.data.copy_(p_data_fp32)
                elif self.degenerated_to_sgd:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    step_size = group['lr'] / (1 - beta1 ** state['step'])
                    p_data_fp32.add_(-step_size, exp_avg)
                    p.data.copy_(p_data_fp32)

        return loss

# Define neural network:

## FFTnet

In [0]:
# Create SE-FFTnet Block
        
        
class FFTNet_block(nn.Module):
  def __init__(self, in_channels, out_channels, hid_channels, layer_id, std_f=0.5):
    super(FFTNet_block, self).__init__()
    self.layer_id = layer_id
    self.dialation = 2**layer_id
    self.block_size = 4*(2**(layer_id-1))-1
    self.start_idx = [0,0+self.dialation,0+self.block_size+1]
    self.receptive_field =  4*self.dialation-1
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.hid_channels = hid_channels
    self.conv1_1 = nn.Conv1d(in_channels, hid_channels, 1, stride=1)
    self.conv1_2 = nn.Conv1d(in_channels, hid_channels, 1, stride=1)
    self.conv1_3 = nn.Conv1d(in_channels, hid_channels, 1, stride=1)
    self.conv2 = nn.Conv1d(hid_channels, out_channels, 1)
    self.relu = nn.ReLU()
    self.init_weights(std_f)
    self.buffer = None
    self.cond_buffer = None
    # inference params for linear operations
    self.w1_1 = None
    self.w1_2 = None
    self.w2 = None

  def init_weights(self, std_f):
    std = np.sqrt(std_f / self.in_channels)
    self.conv1_1.weight.data.normal_(mean=0, std=std)
    self.conv1_1.bias.data.zero_()
    self.conv1_2.weight.data.normal_(mean=0, std=std)
    self.conv1_2.bias.data.zero_()
    self.conv1_3.weight.data.normal_(mean=0, std=std)
    self.conv1_3.bias.data.zero_()

  def forward(self, x, cx=None):
    T = x.shape[1]
    if self.receptive_field == 3:
      x1 = x[:, :,0].unsqueeze(2)
      x2 = x[:, :,1].unsqueeze(2)
      x3 = x[:, :,2].unsqueeze(2)
    else:
      x1 = x[:, :,self.start_idx[0]:self.start_idx[0] + self.block_size]
      x2 = x[:, :,self.start_idx[1]:self.start_idx[1] + self.block_size]
      x3 = x[:, :,self.start_idx[2]:self.start_idx[2] + self.block_size]
    z1 = self.conv1_1(x1)
    z2 = self.conv1_2(x2)
    z3 = self.conv1_3(x3)
    z = z1 + z2 + z3

    out = self.relu(z)
    out = self.conv2(out)
    out = self.relu(out)
    return out



class FFTNetModel(nn.Module):
  def __init__(self, hid_channels=256, out_channels=1, n_layers=4,
    cond_channels=None):
    super(FFTNetModel, self).__init__()
    self.cond_channels = cond_channels
    self.hid_channels = hid_channels
    self.out_channels = out_channels
    self.n_layers = n_layers
    self.receptive_field = 4*(2**(n_layers-1))-1

    self.layers = []
    for idx in reversed(range(self.n_layers)):
      layer_id = idx
      if idx == n_layers-1:
       layer = FFTNet_block(1, hid_channels, hid_channels, layer_id=layer_id)
      else:
       layer = FFTNet_block(hid_channels, hid_channels, hid_channels, layer_id=layer_id)
      
      self.layers.append(layer)
    self.layers = nn.ModuleList(self.layers)
    self.fc = nn.Linear(hid_channels, out_channels)
    print('Receptive Field =  ',self.receptive_field, ' samples')
  def forward(self, x, cx=None):
    # FFTNet modules
    # out = x.view(1,1,-1)
    out = x
    for idx, layer in enumerate(self.layers):
      out = layer(out)
    out = out.transpose(1, 2)
    out = self.fc(out)
    return out





In [0]:
def applyToFile(net,x):
  block_size = net.receptive_field
  start = 0
  finish = start + block_size
  input_lenght = (x.shape[2])
  print(input_lenght)
  output = np.zeros((input_lenght,1))
  while finish < input_lenght:
    chunk = x[:,:,start:finish]
    temp = net(chunk.cuda())
    out = temp.cpu().detach().numpy()
    output[start] = out
    start += 1
    finish = start + block_size 
  
  return output

# Pipeline functions
Useful so we can make a loop where different hyper parameters are tested

Hyperparameters we want to test:

* Data type: waveform(Conv1D) or spectrogram(Conv2D) (CNN='1d' or CNN='2d')
* Input size/no of samples (chunk_size)
* Loss function (MAE, MSE, Huber, cosh. Measured on waveform/stft real+imaginary/stft mag/stft mag+phase)
* Regularization
* Optimizer (always Adam? Try AdamW (weight decay))
* Optimizer learning rate (optim.lr_scheduler)
* Batch size ?
* Epochs ?



## Data pipeline

In [0]:
#DataLoader
def PipelineData(SAMPLE_RATE=16000, CNN='2d',
                       batch_size=64, chunk_size=None,
                       train_GT_range=[0,200],
                       train_IR_range=[0,24], train_N_range=[0,1],
                       train_sequence_length_GT=16000*2,
                       train_sequence_length_IR=16000*1,
                       test_GT_range=[4030,4050],
                       test_IR_range=[30,36], test_N_range=[0,1],
                       test_sequence_length_GT=16000*2,
                       test_sequence_length_IR=8000*1,
                       matrix_ops=True):
  
  #Data preprocessing:
  #total GT files: 4162
  #total IR files: 36
  trainLoader=DataLoader(GT_range=train_GT_range,
                         IR_range=train_IR_range,
                         N_range=train_N_range,
                         sequence_length_GT=train_sequence_length_GT,
                         sequence_length_IR=train_sequence_length_IR)

  #trainLoader.printFilenames('GT')
  #trainLoader.printFilenames('IR')

  trainLoader.cropAndPadIR(matrix_ops=matrix_ops)
  trainLoader.cropAndPadGT(matrix_ops=matrix_ops)
  trainLoader.add_reverb(matrix_ops=matrix_ops)
  trainLoader.normalize()
  trainLoader.cpu()
  if CNN=='2d':
    trainLoader.stft()
  
  
  testLoader=DataLoader(GT_range=test_GT_range,
                         IR_range=test_IR_range,
                         N_range=test_N_range,
                         sequence_length_GT=test_sequence_length_GT,
                         sequence_length_IR=test_sequence_length_IR)


  #testLoader.printFilenames('GT')
  #testLoader.printFilenames('IR')
  
  testLoader.cropAndPadIR(matrix_ops=matrix_ops)
  testLoader.cropAndPadGT(matrix_ops=matrix_ops)
  testLoader.add_reverb(matrix_ops=matrix_ops)
  testLoader.normalize()
  testLoader.cpu()
  if CNN=='2d':
    testLoader.stft()
  
  #Iterator:
  if CNN=='2d':
    trainIterator=Iterator(trainLoader,mode='stft')
    testIterator=Iterator(testLoader,mode='stft')
  elif CNN=='FFTnet':
    trainIterator=Iterator(trainLoader,mode='FFTnet')
    testIterator=Iterator(testLoader,mode='FFTnet')    
  else:
    trainIterator=Iterator(trainLoader,mode='waveform')
    testIterator=Iterator(testLoader,mode='waveform')

  trainIterator.setBatchSize(batch_size)
  testIterator.setBatchSize(batch_size)

  if chunk_size:
    trainIterator.setChunkSize(chunk_size)
    testIterator.setChunkSize(chunk_size)
    trainIterator.chunkify()
    testIterator.chunkify()
  
  print(f"Shape of train.GT: {trainIterator.GT.shape}")
  print(f"Shape of test.GT: {testIterator.GT.shape}")
  print(f"Shape of train.GT_IR: {trainIterator.GT_IR.shape}")
  print(f"Shape of test.GT_IR: {testIterator.GT_IR.shape}")

  return trainIterator, testIterator

## Model pipeline

In [0]:

def PipelineModel(CNN='2d', kernel_size=11, optim='Adam',learning_rate=None,loss='Huber',use_cuda=True, layers = 10):

  #Define the model:
  assert CNN in ('2d', 'Baseline', 'Baseline_BN','FFTnet')
  if CNN=='2d':
    net = CNN2d(kernel_size)
    net.apply(init_weights)
  elif CNN=='Baseline_BN':
    net = Baseline_BN(kernel_size)
    net.apply(init_weights)
  elif CNN=='Baseline':
    net = Baseline(kernel_size)
    net.apply(init_weights)
  elif CNN == 'FFTnet':
    net = FFTNetModel(n_layers=layers)
  

  if use_cuda:
    net = net.cuda()
  #Optimizer:
  if optim=='Adam':
    if learning_rate==None:
      learning_rate=0.001
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
  elif optim=='AdamW':
    if learning_rate==None:
      learning_rate=0.001
    optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate)
  elif optim=='RAdam':
    if learning_rate==None:
      learning_rate=0.001
    optimizer = RAdam(net.parameters(), lr=learning_rate)
  #Loss:
  if loss=='L2':
    loss_function = nn.MSELoss()
  elif loss=='L1':
    loss_function = nn.L1Loss()
  elif loss=='Huber':
    loss_function = nn.SmoothL1Loss()
  elif loss=='stftMAE':
    loss_function = stftMAE
  elif loss=='stftHuber':
    loss_function = stftHuber
  elif loss=='magMAE':
    loss_function = magMAE
  else:
    print("Invalid loss function")

  return net, optimizer, loss_function

## Fitting pipeline

In [0]:
def save_checkpoint(state, old_loss, loss, is_best, filename='/content/drive/My Drive/reverberation/Models/test.pth.tar'):
    """Save checkpoint if a new best is achieved"""
    if is_best:
        print (f"=> Saving a new best loss improved from {old_loss} to {loss}")
        torch.save(state, filename)  # save checkpoint
    else:
        print ("=> Validation Accuracy did not improve")

In [0]:
# TO DO: 
## Return results
## Iterators can return 2 or 4 arguments
def PipelineFit(net, optimizer, loss_function,
                trainIterator, testIterator,
                num_epochs, save_name, 
                scheduler_step_size=1,
                scheduler_gamma=0.9,
                plotting=False, return_results=True, 
                use_cuda=True):
  print("Training network...")
  tmp_img = "tmp_ae_out.png"
  train_loss = []
  valid_loss = []
  Scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step_size, gamma=scheduler_gamma)

  for epoch in range(num_epochs):
    batch_loss = []
    net.train()
    # Go through each batch in the training dataset using the loader
    # Note that y is not necessarily known as it is here
    #for ir in range(27):
    trainIter=iter(trainIterator)
    #count=0
    for x,target,x_phase,target_phase in trainIter:
      optimizer.zero_grad()
      x = Variable(x,requires_grad=True)
      if use_cuda:
        x = x.cuda()
        target=target.cuda()
      output = net(x)
      loss = loss_function(target.squeeze(), output.squeeze())
      batch_loss.append(loss.item())
      loss.backward()
      optimizer.step()
    train_loss.append(np.mean(batch_loss))

    # Evaluate, do not propagate gradients
    with torch.no_grad():
      batch_loss = []
      net.eval()
      testIter=iter(testIterator)
      for x,target,x_phase,target_phase in testIter:
        x = Variable(x)
        if use_cuda:
          x = x.cuda()
          target = target.cuda()

        output = net(x)
        loss = loss_function(target.squeeze(), output.squeeze())
        #Ferdi's stuff:
        batch_loss.append(loss.item())
      valid_loss.append(np.mean(batch_loss))
    
    Scheduler.step()

    if epoch+1 == 1:
      best_loss = valid_loss[epoch]

    # Get bool not ByteTensors
    print(valid_loss[epoch])
    print(best_loss)
    is_best = bool(valid_loss[epoch] < best_loss)
    if is_best:
      old_loss = best_loss
    else:
      old_loss = 0
    # Get greater Tensor to keep track best acc
    best_loss = min(valid_loss[epoch], best_loss)
    # Save checkpoint if is a new best

    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict':  net.state_dict(),
        'best_loss': best_loss
    }, old_loss, best_loss, is_best, filename=str(save_name)+'.pth.tar')

    if epoch == 0:
      continue
    
    if plotting:

      # -- Plotting --
      f, ax1 = plt.subplots(figsize=(8,8))
    
      # Loss
      #ax = axarr[0]
      ax1.set_title("Loss")
      ax1.set_xlabel('Epoch')
      ax1.set_ylabel('Loss')
      #ax2.set_title("Validation error")
      #ax2.set_xlabel('Epoch')
      #ax2.set_ylabel('Error')


      ax1.plot(np.arange(epoch+1), train_loss, color="blue", linestyle="-.")
      ax1.plot(np.arange(epoch+1), valid_loss, color="red", linestyle="-.")
      ax1.legend(['Training','Validation'])
      plt.tight_layout()
      plt.savefig(tmp_img)
      plt.close(f)
      display(Image(filename=tmp_img))
      clear_output(wait=True)

      os.remove(tmp_img)

  if return_results:
    return train_loss,valid_loss


## Evaluation pipeline - ONLY for waveform so far

In [0]:
def evaluationMetrics(net,testIterator,mode,cuda=True,cutoff=100):
  assert mode in ('waveform','stft')
  
  ### ONLY waveform IMPLEMENTATION ###

  SAMPLE_RATE=16000
  net.eval()

  #Collect inputs, outputs and targets
  lossPESQ = []
  lossSTOI = []
  dummyPESQ = []
  dummySTOI = []
  Iterator = iter(testIterator)
  count=0
  for x,target,x_phase,target_phase in Iterator:
    if cuda:
      x=x.cuda()
    y=net(x)
    y=y.view(-1).cpu().detach().numpy()
    #outputs=np.append(outputs,y)
    
    x=x.view(-1).cpu().detach().numpy()
    #inputs=np.append(inputs,x)

    target=target.view(-1).cpu().detach().numpy()
    #targets=np.append(targets,target)
    #metrics:
    lossPESQ.append(pesq(SAMPLE_RATE, target, y, 'wb'))
    dummyPESQ.append(pesq(SAMPLE_RATE, target, x, 'wb'))
    # lossSTOI.append(stoi(target, y, SAMPLE_RATE, extended=False))
    # dummySTOI.append(stoi(target, x, SAMPLE_RATE, extended=False))

    count+=1
    if count==cutoff:
      break

  print(count)
  print('The results below is based on the 64 reconstructed sound files')
  #print('Mean loss of MSE: ' + str(lossMSE))
  #print('Variance loss of MSE: '+ str(np.var(lossMSE)))
  print('Mean loss of PESQ: ' + str(np.mean(lossPESQ)))
  print('Variance loss of PESQ: '+ str(np.var(lossPESQ)))
  print('Mean loss of STOI: ' + str(np.mean(lossSTOI)))
  print('Variance loss of STOI: '+ str(np.var(lossSTOI)))


  print('The results below is based on the 64 noisy sound files')
  #print('Mean loss of MSE: ' + str(dummyMSE))
  #print('Variance loss of MSE: '+ str(np.var(lossMSE)))
  print('Mean loss of PESQ: ' + str(np.mean(dummyPESQ)))
  print('Variance loss of PESQ: '+ str(np.var(dummyPESQ)))
  print('Mean loss of STOI: ' + str(np.mean(dummySTOI)))
  print('Variance loss of STOI: '+ str(np.var(dummySTOI)))

# Scheduler

## Model step

In [0]:
net,optimizer,loss_function=PipelineModel(CNN='FFTnet',kernel_size=10,learning_rate=0.001, optim='Adam',loss='L1',use_cuda=cuda, layers = 10)

Receptive Field =   2047  samples


## Data step

In [0]:
 ## Get training and test iterators:
trainIterator, testIterator=PipelineData(SAMPLE_RATE=16000, CNN='FFTnet',
                       train_GT_range=[0,30],
                       train_IR_range=[0,4],
                       test_GT_range=[31,40],
                       test_IR_range=[0,4],
                       batch_size=128, chunk_size=net.receptive_field,
                       matrix_ops=True) #Set matrix_ops to False, if cuda does not have enough memory

Found 30 ground truth files in data set
Found 4 impulse response files in data set
Found 1 noise files in data set

Running GPU.
GT
Processing file 1 of 30...
IR
Processing file 1 of 4...
N
Processing file 1 of 1...
Found 9 ground truth files in data set
Found 4 impulse response files in data set
Found 1 noise files in data set

Running GPU.
GT
Processing file 1 of 9...
IR
Processing file 1 of 4...
N
Processing file 1 of 1...
Shape of train.GT: torch.Size([30, 4, 32000])
Shape of test.GT: torch.Size([9, 4, 32000])
Shape of train.GT_IR: torch.Size([30, 4, 32000])
Shape of test.GT_IR: torch.Size([9, 4, 32000])


In [0]:
cp = torch.load('/content/drive/My Drive/reverberation/Models/SEFFT_net_short.pth.tar')

net.load_state_dict(cp)
net.eval()

<All keys matched successfully>

In [0]:
 ## See how many parameters are in the model (Optional)
pytorch_total_params = sum(p.numel() for p in net.parameters())
pytorch_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)

print(f"Total parameters in the model: {pytorch_total_params}")
print(f"Total trainable parameters in the model: {pytorch_trainable_params}")

NameError: ignored

## Training step

In [0]:
 ## Perform the training loop:
train_loss,valid_loss=PipelineFit(net,optimizer,loss_function,
                                  trainIterator,testIterator,
                                  num_epochs=30,
                                  save_name="FFTNet",
                                  scheduler_step_size=1,
                                  scheduler_gamma=0.9,
                                  plotting=True,return_results=True,use_cuda=cuda)

Training network...


In [0]:
 torch.save(net.state_dict(), '/content/drive/My Drive/reverberation/Models/SEFFT_net_short.pth.tar')

In [0]:
# net.load_state_dict(torch.load('/content/drive/My Drive/reverberation/Models/1d_2000_BN_DO.p'))
net.eval()

FFTNetModel(
  (layers): ModuleList(
    (0): FFTNet_block(
      (conv1_1): Conv1d(1, 256, kernel_size=(1,), stride=(1,))
      (conv1_2): Conv1d(1, 256, kernel_size=(1,), stride=(1,))
      (conv1_3): Conv1d(1, 256, kernel_size=(1,), stride=(1,))
      (conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (relu): ReLU()
    )
    (1): FFTNet_block(
      (conv1_1): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (conv1_2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (conv1_3): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (relu): ReLU()
    )
    (2): FFTNet_block(
      (conv1_1): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (conv1_2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (conv1_3): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (relu): ReLU()
    )
    (3): FFTNet_block(
      (conv1_1): Conv1

## Evaluation step

In [0]:
SAMPLE_RATE = 16000
input_file = testIterator.GT_IR[0,0,:]
GT = testIterator.GT[0,0,:]
result = applyToFile(net,input_file.view(1,1,-1))
target=GT.view(-1).cpu().detach().numpy()
result = np.reshape(result,result.shape[0])


8000


In [0]:
st = stoi(target, result, 16000, extended=False)
print(st)

In [0]:
psq = pesq(SAMPLE_RATE, target, result, 'wb')
print(psq)

1.0542824268341064
