In [None]:
import gc

import os
import sys
import glob
import random

import h5py

import librosa
from librosa.core import audio

import numpy as np
from numpy.core.fromnumeric import shape

import torch
import torch.nn as nn
import torch.nn.functional as functional

from torch import FloatTensor, unsqueeze
from torch.utils.data import Dataset

from torch.nn.modules.activation import ReLU, Softplus

from torch.nn.modules import activation, padding
from torch.autograd import Variable

from torch.nn.modules import loss

torch.cuda.get_device_name(0)


In [None]:
from google.colab import drive

drive.mount('/content/drive')

In [None]:
# datafetcher.py

def tensor_size(x):
    return list(x.size())

class DataFetcher:
    def __init__(self):
        self.wavelist_dict = {}
        self.train_dict = {}
        self.test_dict = {}
        self.num_samples = 0

        self.train_tensor_list = []
        self.test_tensor_list = []

        # Replace with path to your train/test data
        self.tensor_train_path = '/content/drive/My Drive/data/processed_train.hdf5'
        self.tensor_test_path = '/content/drive/My Drive/data/processed_test.hdf5'

        self.device = 'gpu'

    def get_train_data(self):
        hf = h5py.File(self.tensor_train_path, 'r')
        return hf['train']

    def get_test_data(self):
        hf = h5py.File(self.tensor_test_path, 'r')
        return hf['test']

    def get_train_test_data(self):
        train_data = self.get_train_data()
        test_data = self.get_test_data()

        X_train = FloatTensor([train_data[i] for i in range(0, len(train_data), 2)])
        y_train = FloatTensor([train_data[i] for i in range(1, len(train_data), 2)])

        X_test = FloatTensor([test_data[i] for i in range(0, len(test_data), 2)])
        y_test = FloatTensor([test_data[i] for i in range(1, len(test_data), 2)])
        return X_train, X_test, y_train, y_test


In [None]:
# utils.py

def tensor_size(x):
    return list(x.size())

class DataGenerator():
    def __init__(self, x, y, batch_size=16, window_length=4096):
        self.x, self.y = x, y
        self.window_length = window_length # 4096
        self.hop_length = window_length//2 # 2048
        self.batch_size = batch_size

    def __len__(self):
        return len(self.x)
            
    def window_time_series(self, x, frame_length, hop_length):
        '''Window time series, overlapping'''
        blocks = []
        n = list(x.size())[0]
        ilo = range(0, n, hop_length)
        ihi = range(frame_length, n+1, hop_length)
        ilohi = zip(ilo, ihi)
        blocks = [x[ilo:ihi] for ilo, ihi in ilohi]
        return torch.stack(blocks, 0)

    def slicing(self, x):
        x = functional.pad(x, (self.window_length//2, self.window_length//2), mode='constant')

        # Window the time series
        return self.window_time_series(x, self.window_length, self.hop_length)

    def ts(self, x):
        return list(x.size())

    def __getitem__(self, idx):
        batch_x = torch.zeros((self.batch_size, self.window_length, 1))
        batch_y = torch.zeros((self.batch_size, self.window_length, 1))

        x_w = self.x[idx].reshape(len(self.x[idx]))
        y_w = self.y[idx].reshape(len(self.y[idx]))

        x_w = self.slicing(x_w)
        y_w = self.slicing(y_w)

        for i in range(self.batch_size):
            batch_x[i] = x_w[i].reshape(self.window_length,1)
            batch_y[i] = y_w[i].reshape(self.window_length,1)

        return batch_x, batch_y


In [None]:
# layers.py

class SAAF(nn.Module):
    def __init__(self, break_points, break_range, magnitude):
        super(SAAF, self).__init__()
        self.break_range = break_range
        self.magnitude = magnitude
        self.break_points = list(torch.linspace(-self.break_range, self.break_range, break_points, dtype=torch.float32, device=torch.device("cuda:0")))
        self.num_segs = int(len(self.break_points) / 2)
    
    def basisf(self, x, s, e):
        cp_start = torch.less_equal(s, x).float()
        cp_end = torch.greater(e, x).float()

        output = self.magnitude * (0.5 * (x - s)**2 * cp_start
                                     * cp_end + ((e - s) * (x - e) + 0.5 * (e - s)**2) * (1 - cp_end))

        return output.type(torch.cuda.FloatTensor)

    def forward(self, x):
        input_shape = list(x.size())
        self.kernel_dim = (self.num_segs + 1, input_shape[2])

        self.kernel = nn.Parameter(data=torch.zeros(self.kernel_dim, device=torch.device("cuda:0")), requires_grad=True)

        output = torch.multiply(x, self.kernel[-1])

        for segment in range(0, self.num_segs):
            output += torch.multiply(self.basisf(x, self.break_points[segment * 2], self.break_points[segment * 2 + 1]), self.kernel[segment])

        return output

class Deconvolution(nn.Module):
    def __init__(self, filters, kernel_size, conv_layer,
                    strides=1,
                    padding='valid'):
        super(Deconvolution, self).__init__()

        self.device = torch.device("cuda:0")

        self.filters = filters
        self.strides = (strides,)
        self.padding = padding
        self.input_dim = None
        self.input_length = None

        self.kernel_size = kernel_size
        self.conv_layer = conv_layer

    def forward(self, x):
        gc.collect()
        input_dim = x.shape[-1]
        self.kernel_shape = (self.kernel_size, input_dim, self.filters)

        x = torch.unsqueeze(x, -1)
        x = x.permute(0, 2, 1, 3)

        W = torch.unsqueeze(self.conv_layer.weight, -1)

        W = W.permute(1, 0, 2, 3)

        bias_data = torch.zeros(self.filters, device=torch.device("cuda:0"))

        conv2 = nn.Conv2d(self.filters, 1, self.kernel_size, padding=self.padding, bias=True, padding_mode='zeros', dtype=None, device=torch.device("cuda:0"))
        conv2.weight =  nn.Parameter(data=W, requires_grad=True)

        output = conv2(x)
        output = torch.squeeze(output, 3)

        return output.permute(0, 2, 1)


class Convolution1D_Locally_Connected(nn.Module):
    def __init__(self, filters, kernel_size,
                    strides=1,
                    padding='valid',
                    dilation_rate=1):
        super(Convolution1D_Locally_Connected, self).__init__()

        self.device = torch.device("cuda:0")
        self.filters = filters
        self.strides = (strides,)
        self.padding = padding
        self.data_format = 'channels_last'
        self.dilation_rate = (dilation_rate,)
        self.activation = 'linear'
        self.input_dim = None
        self.input_length = None

        self.kernel_size = kernel_size
        self.kernel_shape = (filters, 1, kernel_size)

        self.kernel = nn.Parameter(data=torch.zeros(self.kernel_shape, device=torch.device("cuda:0")), requires_grad=True)
        nn.init.xavier_uniform_(self.kernel)

    def forward(self, x):
        x = torch.split(x, 1, dim=1)
        W = torch.split(self.kernel, 1, dim=0)

        out_shape = [self.filters]
        out_shape.extend(list(x[0].shape))
        outputs = torch.zeros(out_shape)

        for i in range(self.filters):
            x_conv1 = nn.Conv1d(1, 1, self.kernel_size, stride=1, padding='same', dilation=1, groups=1, bias=True, padding_mode='zeros', dtype=None, device=torch.device("cuda:0"))
            x_conv1.weight = nn.Parameter(data=W[i], requires_grad=True)
            x_conv1.bias = nn.Parameter(data=torch.zeros(1, device=torch.device("cuda:0")), requires_grad=True)
            outputs[i] = x_conv1(x[i])

        return outputs


class BatchMaxPooling1d(nn.Module):
    def __init__(self, kernel_size):
        super(BatchMaxPooling1d, self).__init__()

        self.device =torch.device("cuda:0")
        self.kernel_size = kernel_size

    def forward(self, x):
        in_shape = list(x.shape)
        out_shape = in_shape[:-1]
        out_shape.extend([in_shape[-1]//self.kernel_size])

        output = torch.zeros(out_shape, device=torch.device("cuda:0"))
        for i in range(in_shape[0]):
            output[i] = nn.MaxPool1d(self.kernel_size)(torch.squeeze(x[i], 0))

        return output


class LatentSpace_DNN_LocallyConnected_Dense(nn.Module):
    def __init__(self, units,
                    activation=None,
                    use_bias=True):
        super(LatentSpace_DNN_LocallyConnected_Dense, self).__init__()

        self.device = torch.device("cuda:0")

        self.units = units
        self.activation = activation
        self.use_bias = use_bias
        self.supports_masking = True

    def forward(self, x):
        input_shape = list(x.size())
        input_dim = input_shape[-1]
        self.split = input_shape[1]

        kernels = [nn.Parameter(data=torch.zeros(input_dim, self.units, device=torch.device("cuda:0")), requires_grad=True) for i in range(self.split)]
        self.kernel = torch.cat(kernels, -1)

        biases = [nn.Parameter(data=torch.zeros(self.units, device=torch.device("cuda:0")), requires_grad=True) for i in range(self.split)]
        self.bias = torch.cat(biases, -1)

        split_input = torch.split(x, 1, dim=1)
        W = torch.split(self.kernel, self.units, dim=1)
        b = torch.split(self.bias, self.units, dim=0)

        outputs = []
        for i,j in enumerate(split_input):
            output = torch.matmul(torch.squeeze(j), W[i]) + b[i]
            if self.activation is not None:
                output = self.activation(output)
            outputs.append(output)

        return_val = torch.cat(outputs, 1)

        return return_val.view(input_shape[0], self.split, input_dim)

class TimeDistributed(nn.Module):
    def __init__(self, layer, batch_first=False, **kwargs):
        super(TimeDistributed, self).__init__()
        self.layer = layer
        self.batch_first = batch_first
        self.kwargs = kwargs

    def forward(self, x):
        ''' x size: (batch_size, channels, 1)
            Dense:  < x, in_features, out_features, activation >'''
        if self.kwargs['layer_name'] == 'Dense':
            in_features = self.kwargs['in_features']
            out_features = self.kwargs['out_features']
            activation = self.kwargs['activation']
            c_out = self.layer(x, in_features, out_features, activation)
            return c_out
        return None

class UpSampling1D(nn.Module):
    def __init__(self, size):
        super(UpSampling1D, self).__init__()
        self.size = size

    def forward(self, x):
        return nn.Upsample(size=self.size, mode='linear')(x)

class Multiply(nn.Module):
    def __init__(self):
        super(Multiply, self).__init__()

    def forward(self, tensors):
        result = torch.ones(tensors[0].size(), device=torch.device("cuda:0"))
        for t in tensors:
            result *= t
        return t


In [None]:
# network.py

class DistortionNetwork(nn.Module):
    def __init__(self, window_length, filters, kernel_size, learning_rate):
        super(DistortionNetwork, self).__init__()
        self.window_length = window_length
        self.filters = filters
        self.kernel_size = kernel_size
        self.learning_rate = learning_rate

        self.conv = nn.Conv1d(1, self.filters, self.kernel_size, stride=1, padding='same', padding_mode='zeros')
        self.convolution1d_locally_connected = Convolution1D_Locally_Connected(self.filters, 2 * self.kernel_size)
        self.dense_layer = LatentSpace_DNN_LocallyConnected_Dense(self.window_length//64, activation=nn.ReLU)
        self.deconvolution = Deconvolution(1, self.kernel_size, self.conv, padding='same')

        self.conv.to(torch.device("cuda:0"))
        self.convolution1d_locally_connected.to(torch.device("cuda:0"))
        self.dense_layer.to(torch.device("cuda:0"))
        self.deconvolution.to(torch.device("cuda:0"))

    def permute_dims(self, x):
        # pytorch initializes layer tensors in column major format.
        return x.permute(0, 2, 1)

    def Dense(self, x, in_features, out_features, activation=None):
        if activation is not None:
            return activation()(nn.Linear(in_features, out_features, device=torch.device("cuda:0"))(x))
        return nn.Linear(in_features, out_features, device=torch.device("cuda:0"))(x)

    def forward(self, x):
        x = self.permute_dims(x)
        x_conv = self.conv(x)
        x_abs = torch.abs(x_conv)

        M = self.convolution1d_locally_connected(x_abs)
        M = nn.Softplus()(M)
        M = M.permute(1, 0, 2, 3)

        P = x_conv

        Z = BatchMaxPooling1d(self.window_length//64)(M)

        Z = LatentSpace_DNN_LocallyConnected_Dense(self.window_length//64)(Z)

        Z = TimeDistributed(layer=self.Dense, batch_first=True,
                            in_features=tensor_size(Z)[-1],
                            out_features=self.window_length//64,
                            activation=nn.Softplus, layer_name='Dense')(Z)

        M_ = UpSampling1D(self.window_length)(Z)
        
        Y_ = Multiply()([P, M_])

        Y_ = self.permute_dims(Y_)
        Y_ = self.Dense(Y_, tensor_size(Y_)[-1], self.filters, nn.ReLU)
        Y_ = self.Dense(Y_, tensor_size(Y_)[-1], self.filters//2, nn.ReLU)
        Y_ = self.Dense(Y_, tensor_size(Y_)[-1], self.filters//2, nn.ReLU)
        Y_ = self.Dense(Y_, tensor_size(Y_)[-1], self.filters//2, nn.ReLU)
        Y_ = self.Dense(Y_, tensor_size(Y_)[-1], self.filters)


        Y_ = SAAF(break_points=25, break_range=0.2, magnitude=100)(Y_)

        output = Deconvolution(self.filters, self.kernel_size, self.conv, padding='same')(Y_)

        return output


In [None]:
# main.py
def save_model_checkpoint(path, epoch, model_state_dict, optimizer_state_dict):
    torch.save({
          'epoch': epoch,
          'model_state_dict': model_state_dict,
          'optimizer_state_dict': optimizer_state_dict,
    }, path)

if __name__ == '__main__':
    # Model Checkpoint Path
    CHK_PATH = '/content/drive/My Drive/checkpoints/distortion_dnn_model.pt'
    BACKUP_CHK_PATH = '/content/drive/My Drive/checkpoints/distortion_dnn_model_backup.pt'

    checkpoint = None
    try:
        checkpoint = torch.load(CHK_PATH)
    except FileNotFoundError:
        print("No model checkpoint found. Starting afresh....")

    # Parameters
    epochs = 2000
    div = 1
    window_length = 4096
    filters = 128
    kernel_size = 64
    learning_rate = 0.0001
    batch_size = 16

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Get train, test data
    df = DataFetcher()
    X_train, X_test, y_train, y_test = df.get_train_test_data()

    # Load into neural network
    distortion_network = DistortionNetwork(window_length, filters, kernel_size, learning_rate)
    distortion_network.to(torch.device("cuda:0"))

    # Loss function and Optimizer
    loss_function = nn.L1Loss() # Mean Absolute Error
    optimizer = torch.optim.SGD(distortion_network.parameters(), lr=learning_rate)

    # If checkpoint exists load from that.

    train_generator = DataGenerator(X_train, y_train, batch_size=batch_size, window_length=window_length)
    test_generator = DataGenerator(X_test, y_test, batch_size=batch_size, window_length=window_length)

    epoch = 0
    if checkpoint is not None:
        epoch = checkpoint['epoch']
        distortion_network.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    while epoch < epochs:
        running_loss = 0.0
        for i, data in enumerate(train_generator):
            input_audio, output_audio = data
            input_audio, output_audio = input_audio.to(device), output_audio.to(device)

            optimizer.zero_grad()

            #print("Begin forward pass....")
            network_output_audio = distortion_network(input_audio)
            #print("End forward pass....")
            mae_loss = loss_function(network_output_audio, output_audio)
            mae_loss.backward()
            optimizer.step()

            div = 10
            running_loss += 20 * mae_loss.item()
            if i != 0 and i % div == 0:
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss))
                running_loss = 0.0
        if epoch % 10 == 0:
            print("Saved backup checkpoint with epochs ", (epoch + 1), " completed.")
            save_model_checkpoint(BACKUP_CHK_PATH, epoch, distortion_network.state_dict(), optimizer.state_dict())
        print("Saved checkpoint with epochs ", (epoch + 1), " completed.")
        save_model_checkpoint(CHK_PATH, epoch, distortion_network.state_dict(), optimizer.state_dict())
        epoch += 1

    print("Finished training....")

    '''If network trains through all epochs the running loss should
    be negligible for every batch
    '''
    for i, data in enumerate(test_generator):
        input_audio, output_audio = data
        input_audio, output_audio = input_audio.to(device), output_audio.to(device)
        network_output_audio = distortion_network(input_audio)

        mae_loss = loss_function(network_output_audio, output_audio)

        running_loss += 1000 * mae_loss.item()

        if i != 0 and i % div == 0:
            print('[%d, %5d] loss on test data: %.3f' %
                  (epoch + 1, i + 1, running_loss / 50))
            running_loss = 0.0

