<a href="https://colab.research.google.com/github/advaitkumar3107/Speech-Denoising-Using-Deep-Learning/blob/master/VGG-16_with_wavelet_transform.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import torch.nn as nn
import torchvision.datasets as datasets
import os
import glob
import sys
import scipy
import random
from PIL import Image
from torch.nn import init
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
import cv2
import librosa
import librosa.display
from tqdm import tqdm_notebook
from scipy import signal
from scipy.io.wavfile import read, write
from numpy.fft import fft, ifft
from google.colab import drive
from torch.autograd import Variable
from IPython.display import Audio
import pywt

drive.mount('/content/gdrive')
%cd /content/gdrive/My\ Drive/sample_audio_dataset
torch.cuda.manual_seed(7)
torch.manual_seed(7)
np.random.seed(7)
torch.cuda.empty_cache()

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Conv2d):
        init.normal(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.Linear):
        init.normal(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)


def weights_init_xavier(m):
  classname = m.__class__.__name__
  #print(classname)
  if isinstance(m, nn.Conv2d):
      init.xavier_normal(m.weight.data, gain=1)
  elif isinstance(m, nn.Linear):
      init.xavier_normal(m.weight.data, gain=1)
  elif isinstance(m, nn.BatchNorm2d):
      init.normal(m.weight.data, 1.0, 0.02)
      init.constant(m.bias.data, 0.0)

def weights_init_kaiming(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif isinstance(m, nn.Linear) != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif isinstance(m, nn.BatchNorm2d) != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_orthogonal(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.orthogonal(m.weight.data, gain=1)
    elif classname.find('Linear') != -1:
        init.orthogonal(m.weight.data, gain=1)
    elif classname.find('BatchNorm') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)


def init_weights(net, init_type='normal'):
    #print('initialization method [%s]' % init_type)
    if init_type == 'normal':
        net.apply(weights_init_normal)
    elif init_type == 'xavier':
        net.apply(weights_init_xavier)
    elif init_type == 'kaiming':
        net.apply(weights_init_kaiming)
    elif init_type == 'orthogonal':
        net.apply(weights_init_orthogonal)
    else:
        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)


class _GridAttentionBlockND(nn.Module):
    def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3,
                 sub_sample_factor=(2,2,2)):
        super(_GridAttentionBlockND, self).__init__()

        assert dimension in [2, 3]

        # Downsampling rate for the input featuremap
        if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor
        elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor)
        else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension

        # Default parameter set
        self.dimension = dimension
        self.sub_sample_kernel_size = self.sub_sample_factor

        # Number of channels (pixel dimensions)
        self.in_channels = in_channels
        self.gating_channels = gating_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            bn = nn.BatchNorm3d
            self.upsample_mode = 'trilinear'
        elif dimension == 2:
            conv_nd = nn.Conv2d
            bn = nn.BatchNorm2d
            self.upsample_mode = 'bilinear'
        else:
            raise NotImplemented

        # Output transform
        self.W = nn.Sequential(
            conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0),
            bn(self.in_channels),
        )

        # Theta^T * x_ij + Phi^T * gating_signal + bias
        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False)
        self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0, bias=True)
        self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)

        # Initialise weights
        for m in self.children():
            init_weights(m, init_type='xavier')

        # Define the operation
        self.operation_function = self._concatenation



    def forward(self, x, g):
        '''
        :param x: (b, c, t, h, w)
        :param g: (b, g_d)
        :return:
        '''
        output = self.operation_function(x, g)
        return output

    def _concatenation(self, x, g):
        input_size = x.size()
        batch_size = input_size[0]
        assert batch_size == g.size(0)

        # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)
        # phi   => (b, g_d) -> (b, i_c)
        theta_x = self.theta(x)
        theta_x_size = theta_x.size()

        # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
        #  Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)
        phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)
        f = F.relu(theta_x + phi_g, inplace=True)

        #  psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
        sigm_psi_f = F.sigmoid(self.psi(f))

        # upsample the attentions and multiply
        sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)
        y = sigm_psi_f.expand_as(x) * x
        W_y = self.W(y)

        return W_y, sigm_psi_f


class GridAttentionBlock2D(_GridAttentionBlockND):
    def __init__(self, in_channels, gating_channels, inter_channels=None,
                 sub_sample_factor=(2,2)):
        super(GridAttentionBlock2D, self).__init__(in_channels,
                                                   inter_channels=inter_channels,
                                                   gating_channels=gating_channels,
                                                   dimension=2,
                                                   sub_sample_factor=sub_sample_factor,
                                                   )


In [None]:
### Defining UNET Architecture ###
class unet(nn.Module):

  def contracting_block(self, in_channels, out_channels):
    block = nn.Sequential(nn.Conv2d(in_channels,out_channels, 3, padding = 1), nn.ReLU(), nn.BatchNorm2d(out_channels),
                          nn.Conv2d(out_channels, out_channels, 3, padding  = 1), nn.ReLU(), nn.BatchNorm2d(out_channels))
    
    return block

  def expansive_block(self, in_channels, mid_channel, out_channels):
    block = nn.Sequential(nn.Conv2d(in_channels, mid_channel, 3, padding = 1), nn.ReLU(), nn.BatchNorm2d(mid_channel),
                          nn.Conv2d(mid_channel, mid_channel, 3, padding = 1), nn.ReLU(), nn.BatchNorm2d(mid_channel),
                          nn.ConvTranspose2d(mid_channel, out_channels, 3, 2,padding = 1,output_padding = 1))
    return block

  def final_block(self, in_channels, mid_channel, out_channels):
    block = nn.Sequential(nn.Conv2d(in_channels, mid_channel, 3, padding = 1), nn.ReLU(), nn.BatchNorm2d(mid_channel),
                          nn.Conv2d(mid_channel, mid_channel, 3, padding = 1), nn.ReLU(), nn.BatchNorm2d(mid_channel), 
                          nn.Conv2d(mid_channel, out_channels, 3, padding =1), nn.Sigmoid())
    return block



  def __init__(self, in_channel, out_channel):
    super(unet, self).__init__()

    self.encode1 = self.contracting_block(in_channel, 64)
    self.maxpool1 = nn.MaxPool2d(kernel_size = 2)
    self.encode2 = self.contracting_block(64,128)
    self.maxpool2 = nn.MaxPool2d(2)
    self.encode3 = self.contracting_block(128,256)
    self.maxpool3 = nn.MaxPool2d(2)

    self.bottleneck = self.expansive_block(256,512,256)

    self.decode3 = self.expansive_block(512,256,128)
    self.decode2 = self.expansive_block(256,128,64)

    self.ag1 = GridAttentionBlock2D(256,256)
    self.ag2 = GridAttentionBlock2D(128,128)
    self.ag3 = GridAttentionBlock2D(64,64)
    
    self.final_layer = self.final_block(128,64,out_channel)

  
  def crop_and_concat(self, upsampled, bypass):
      c = (bypass.size()[2] - upsampled.size()[2]) // 2
      bypass = F.pad(bypass, (-c, -c, -c, -c))
      bypass = F.upsample(bypass, (upsampled.size(2), upsampled.size(3)), mode = 'bilinear')
      return torch.cat((upsampled, bypass), 1)


  def forward(self,x):
    encode1 = self.encode1(x)
    maxpool1 = self.maxpool1(encode1)
    encode2 = self.encode2(maxpool1)
    maxpool2 = self.maxpool2(encode2)
    encode3 = self.encode3(maxpool2)
    maxpool3 = self.maxpool3(encode3)

    bottleneck = self.bottleneck(maxpool3)

    gate1 = self.ag1(encode3, bottleneck)
    encode3 = encode3 + gate1[0]

    decode3 = self.crop_and_concat(bottleneck, encode3)
    cat_layer2 = self.decode3(decode3)
    
    gate2 = self.ag2(encode2, cat_layer2)
    encode2 = encode2 + gate2[0]

    decode2 = self.crop_and_concat(cat_layer2, encode2)
    cat_layer1 = self.decode2(decode2)
    
    gate3 = self.ag3(encode1, cat_layer1)
    encode1 = encode1 + gate3[0]

    decode1 = self.crop_and_concat(cat_layer1, encode1)
    final = self.final_layer(decode1)

    return final

In [None]:
class AdaptiveBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
        super(AdaptiveBatchNorm2d, self).__init__()
        self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine)
        self.a = nn.Parameter(torch.FloatTensor(1, 1, 1, 1))
        self.b = nn.Parameter(torch.FloatTensor(1, 1, 1, 1))

    def forward(self, x):
        return self.a * x + self.b * self.bn(x)


class vgg_type(nn.Module):
  def __init__(self):
    super(vgg_type, self).__init__()
    self.conv1 = nn.Conv2d(1, 64, [1,3], padding=[0,1], bias = False)
    self.norm1 = AdaptiveBatchNorm2d(64)
    self.relu = nn.ReLU(inplace = True)
    self.lrelu = nn.LeakyReLU(0.2, inplace = True)

    self.conv2 = nn.Conv2d(64,64,[1,3], padding = [0,1], dilation = 1, bias = False)
    self.conv3 = nn.Conv2d(64,64,[1,3], padding = [0,2], dilation = 2, bias = False)
    self.conv4 = nn.Conv2d(64,64,[1,3], padding = [0,4], dilation = 4, bias = False)
    self.conv5 = nn.Conv2d(64,64,[1,3], padding = [0,8], dilation = 8, bias = False)
    self.conv6 = nn.Conv2d(64,64,[1,3], padding = [0,16], dilation = 16, bias = False)
    self.conv7 = nn.Conv2d(64,64,[1,3], padding = [0,32], dilation = 32, bias = False)
    self.conv8 = nn.Conv2d(64,64,[1,3], padding = [0,64], dilation = 64, bias = False)
    self.conv9 = nn.Conv2d(64,64,[1,3], padding = [0,128], dilation = 128, bias = False)
    self.conv10 = nn.Conv2d(64,64,[1,3], padding = [0,256], dilation = 256, bias = False)
    self.conv11 = nn.Conv2d(64,64,[1,3], padding = [0,512], dilation = 512, bias = False)
    self.conv12 = nn.Conv2d(64,64,[1,3], padding = [0,1024], dilation = 1024, bias = False)
    self.conv13 = nn.Conv2d(64,64,[1,3], padding = [0,2048], dilation = 2048, bias = False)
    self.conv14 = nn.Conv2d(64,64,[1,3], padding = [0,1], bias = False)
 
    self.norm2 = AdaptiveBatchNorm2d(64)
    self.norm3 = AdaptiveBatchNorm2d(64)
    self.norm4 = AdaptiveBatchNorm2d(64)
    self.norm5 = AdaptiveBatchNorm2d(64)
    self.norm6 = AdaptiveBatchNorm2d(64)
    self.norm7 = AdaptiveBatchNorm2d(64)
    self.norm8 = AdaptiveBatchNorm2d(64)
    self.norm9 = AdaptiveBatchNorm2d(64)
    self.norm10 = AdaptiveBatchNorm2d(64)
    self.norm11 = AdaptiveBatchNorm2d(64)
    self.norm12 = AdaptiveBatchNorm2d(64)
    self.norm13 = AdaptiveBatchNorm2d(64)
    self.norm14 = AdaptiveBatchNorm2d(64)

    self.final = nn.Conv2d(64,1, [1,1])

  def forward(self, x):
    x = self.conv1(x)
    x = self.norm1(x)
    x = self.relu(x)

    x = self.conv2(x)
    x = self.norm2(x)
    x = self.lrelu(x)

    x = self.conv3(x)
    x = self.norm3(x)
    x = self.lrelu(x)

    x = self.conv4(x)
    x = self.norm4(x)
    x = self.lrelu(x)

    x = self.conv5(x)
    x = self.norm5(x)
    x = self.lrelu(x)

    x = self.conv6(x)
    x = self.norm6(x)
    x = self.lrelu(x)

    x = self.conv7(x)
    x = self.norm7(x)
    x = self.lrelu(x)

    x = self.conv8(x)
    x = self.norm8(x)
    x = self.lrelu(x)

    x = self.conv9(x)
    x = self.norm9(x)
    x = self.lrelu(x)

    x = self.conv10(x)
    x = self.norm10(x)
    x = self.lrelu(x)

    x = self.conv11(x)
    x = self.norm11(x)
    x = self.lrelu(x)

    x = self.conv12(x)
    x = self.norm12(x)
    x = self.lrelu(x)

    x = self.conv13(x)
    x = self.norm13(x)
    x = self.lrelu(x)

    x = self.conv14(x)
    x = self.norm14(x)
    x = self.lrelu(x)

    x = self.final(x)
    return x

In [None]:
class AudioDataset(torch.utils.data.Dataset):
  def __init__(self,ids):
    self.ids = ids
    self.approx_inputs = []
    self.approx_targets = []
    self.detailed_inputs = []
    self.detailed_targets = []
    self.length = len(self.ids) // 16

    self.random_ids = random.sample(self.ids, self.length)

 #   self.mean = mean
 #   self.std = std
 #   self.mean_target = mean_target
 #   self.std_target = std_target
    
    for id_ in self.random_ids:
      input_location = 'noisy_dataset/noisy_trainset_56spk_wav/' + id_
      target_location = 'clean_dataset/' + id_
      
      y, sr = librosa.load(input_location)
      approx_input, detailed_input = pywt.dwt(y, 'db1')
      approx_input = torch.from_numpy(approx_input)
      detailed_input = torch.from_numpy(detailed_input)
      approx_input = approx_input.unsqueeze_(0)
      detailed_input = detailed_input.unsqueeze_(0)
      self.approx_inputs.append(approx_input)
      self.detailed_inputs.append(detailed_input)

      y, sr = librosa.load(target_location)
      approx_target, detailed_target = pywt.dwt(y, 'db1')
      approx_target = torch.from_numpy(approx_target)
      detailed_target = torch.from_numpy(detailed_target)
      approx_target = approx_target.unsqueeze_(0)
      detailed_target = detailed_target.unsqueeze_(0)
      self.approx_targets.append(approx_target)
      self.detailed_targets.append(detailed_target)

  def __len__(self):
    return self.length

  def __getitem__(self,index):
    approx_input = self.approx_inputs[index]
    approx_target = self.approx_targets[index]

    detailed_input = self.detailed_inputs[index]
    detailed_target = self.detailed_targets[index]

    return approx_input, detailed_input, approx_target, detailed_target

In [None]:
ids = os.listdir('clean_dataset')
dataset = AudioDataset(ids)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 1, shuffle = True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 1, shuffle = False)
torch.cuda.empty_cache()

In [None]:
### Training Functions ###

def train(dataloader, approx_model, detailed_model, approx_optimizer, detailed_optimizer, criterion):
  detailed_model.train()
  approx_model.train()
  approx_train_losses.append(0)
  detailed_train_losses.append(0)
  progbar = tqdm_notebook(total = len(dataloader), desc = 'Approx_train')
  progbar1 = tqdm_notebook(total = len(dataloader), desc = 'Detailed_train')

  for i, (approx_input, detailed_input, approx_target, detailed_target) in enumerate(dataloader):
    detailed_optimizer.zero_grad()
    approx_optimizer.zero_grad()

    approx_input, detailed_input, approx_target, detailed_target = Variable(approx_input.unsqueeze_(0).cuda()), Variable(detailed_input.unsqueeze_(0).cuda()), Variable(approx_target.unsqueeze_(0).cuda()), Variable(detailed_target.unsqueeze_(0).cuda())

    approx_output = approx_model(approx_input)
    detailed_output = detailed_model(detailed_input)

    approx_error = criterion(approx_output, approx_target)
    detailed_error = criterion(detailed_output, detailed_target)

    approx_error.backward()
    detailed_error.backward()

    approx_optimizer.step()
    detailed_optimizer.step()

    approx_train_losses[-1] = approx_train_losses[-1] + approx_error.data
    detailed_train_losses[-1] = detailed_train_losses[-1] + detailed_error.data

    progbar.set_description('Approx Train (loss=%.4f)' % (approx_train_losses[-1]/(i+1)))
    progbar.update(1)

    progbar1.set_description('Detailed Train (loss = %.4f)' % (detailed_train_losses[-1]/(i+1)))
    progbar1.update(1)

  approx_train_losses[-1] = approx_train_losses[-1]/len(dataloader)
  detailed_train_losses[-1] = detailed_train_losses[-1]/len(dataloader)

def val(dataloader, approx_model, detailed_model, criterion):
  
  global approx_best_loss, detailed_best_loss
  progbar = tqdm_notebook(total = len(dataloader), desc = 'approx val')
  progbar1 = tqdm_notebook(total = len(dataloader), desc = 'detailed val')

  approx_model.eval()
  detailed_model.eval()

  approx_val_losses.append(0)
  detailed_val_losses.append(0)

  for i, (approx_input, detailed_input, approx_target, detailed_target) in enumerate(dataloader):
    approx_input, detailed_input, approx_target, detailed_target = Variable(approx_input.unsqueeze_(0).cuda()), Variable(detailed_input.unsqueeze_(0).cuda()), Variable(approx_target.unsqueeze_(0).cuda()), Variable(detailed_target.unsqueeze_(0).cuda())

    approx_output = approx_model(approx_input)
    approx_error = criterion(approx_output, approx_target)

    detailed_output = detailed_model(detailed_input)
    detailed_error = criterion(detailed_output, detailed_target)

    approx_val_losses[-1] = approx_val_losses[-1] + approx_error.data
    detailed_val_losses[-1] = detailed_val_losses[-1] + detailed_error.data

    progbar.set_description('Approx Val (loss = %.4f)' % (approx_val_losses[-1] /(i+1)))
    progbar.update(1)

    progbar1.set_description('Detailed Val (loss = %.4f)' % (detailed_val_losses[-1] /(i+1)))
    progbar1.update(1)

  approx_val_losses[-1] = approx_val_losses[-1]/(len(dataloader))
  detailed_val_losses[-1] = detailed_val_losses[-1]/(len(dataloader))

  if approx_best_loss > approx_val_losses[-1]:
    approx_best_loss = approx_val_losses[-1]
    print('Approx Model SAVING....')
    state = {'model' : approx_model}

    torch.save(state, 'dwt_approx_model_best' + '.ckpt.t7')

  if detailed_best_loss > detailed_val_losses[-1]:
    detailed_best_loss = detailed_val_losses[-1]
    print('Detailed Model SAVING....')
    state = {'model' : detailed_model}

    torch.save(state, 'dwt_detailed_model_best' + '.ckpt.t7')

In [None]:
#approx_model = vgg_type().cuda()
#detailed_model = vgg_type().cuda()
checkpoints = torch.load('audio_load.ckpt.t7')
approx_model = checkpoints['approx_model']
detailed_model = checkpoints['detailed_model']
num = checkpoints['epoch']
criterion = torch.nn.L1Loss(reduction = 'sum')

approx_train_losses = []
approx_val_losses = []

detailed_train_losses = []
detailed_val_losses = []

epochs = 1816 - num

lrs = [1e-4, 1e-4, 1e-4, 1e-4, 1e-4]

In [None]:
approx_best_loss = checkpoints['approx_best_loss']
detailed_best_loss = checkpoints['detailed_best_loss']

approx_optimizer = torch.optim.Adam(approx_model.parameters(), lr = lrs[4])
detailed_optimizer = torch.optim.Adam(detailed_model.parameters(), lr = lrs[4])

for epoch in range(epochs):
  train(train_loader, approx_model, detailed_model, approx_optimizer, detailed_optimizer, criterion)
  val(val_loader, approx_model, detailed_model, criterion)
  checkpoints = {'approx_model' : approx_model, 'detailed_model' : detailed_model, 'epoch' : epoch, 'approx_best_loss' : approx_best_loss, 'detailed_best_loss' : detailed_best_loss}
  torch.save(checkpoints, 'audio_load.ckpt.t7')

  print('Epoch : %d/%d' % (epoch+1, epochs))
