# Denoising Auto Encoder



## Import

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(123456)

from scipy.io import wavfile

import os
# from audio_dataset import AudioDataset

## AudioDataset class

In [41]:
sliced_dataset = "short_audio_dataset"
sliced_dataset_lenght = 16050
# sliced_dataset = "shorter_audio_dataset"
# sliced_dataset_lenght = 4013
original_dataset = "audio_dataset"
original_dataset_lenght = 80249

class AudioDataset(Dataset):
    def __init__(self, root_path="./data/", drop_both=False, use_short=False, normalize=False):
        root_folder = root_path + original_dataset if not use_short else root_path + sliced_dataset
        self.max_length = original_dataset_lenght if not use_short else sliced_dataset_lenght
        self.class_map = {"esben" : 0, "peter": 1, "both": 2}
        self.data = []
        self.wavs = []
        self.labels = []
        self.min_val = 10e10
        self.max_val = 0
        print("Start reading files")
        for subdir, dirs, files in os.walk(root_folder):
            for file_name in files:
                if drop_both and "both" in subdir:
                   continue
        
                file_path = os.path.join(subdir, file_name)
                _, wav = wavfile.read(file_path)
                wav = wav.astype(np.float32)
                
                if wav.shape[0] > self.max_length:
                    self.max_length = wav.shape[0]
                    print("Found wav with more length than specified max one, new max is:", wav.shape[0])
                
                wav = np.pad(wav, (0, self.max_length-wav.shape[0]))
                label_str = file_path.split('/')[-3][2:]
                label = (np.int64(self.class_map[label_str]))
                
                self.max_val = np.max(wav) if np.max(wav) > self.max_val else self.max_val
                self.min_val = np.min(wav) if np.min(wav) < self.min_val else self.min_val
                
                self.wavs.append(wav)
                self.labels.append(label)
               
        self.wavs = np.array(self.wavs)
        self.mu  = self.wavs.mean()
        self.std = np.std(self.wavs)
        self.wavs = torch.Tensor(self.wavs)
        if normalize:
            self.wavs = (self.wavs + np.abs(self.min_val)) / (np.abs(self.min_val) + self.max_val)
            # self.wavs = torch.nn.functional.normalize(self.wavs, dim=1)
        
        print("="*40)
        print("Loaded DATABASE from {}\n{} total file\nLongest file is {} long\nMean: {}\nStandard deviation: {}\nNormalization: {}".
              format(root_folder, len(self.wavs), self.max_length, self.mu, self.std, normalize))
        print("="*40)
    
    def __len__(self):
        return len(self.wavs)
    
    def __getitem__(self, idx):
        wav_tensor = self.wavs[idx]
        label = self.labels[idx]
        # wav_tensor = torch.from_numpy(wav)
        label_tensor = torch.tensor(label)
        return wav_tensor, label_tensor


## Loading data for training

In [42]:
audio_dataset = AudioDataset(root_path="../data/", drop_both=True, use_short=True, normalize=True)
dataset_len = len(audio_dataset)
train_size, test_size, valid_size = int(dataset_len * 0.7), int(dataset_len * 0.2), int(dataset_len * 0.1)

dataset_train, dataset_test, dataset_valid = torch.utils.data.random_split(audio_dataset, (train_size, test_size, valid_size))
print(audio_dataset.max_val)
kwargs = {'batch_size': 1, 'num_workers': 2}
loader_train = torch.utils.data.DataLoader(dataset_train, **kwargs, shuffle=True)
loader_test = torch.utils.data.DataLoader(dataset_test, **kwargs, shuffle=True)
loader_valid = torch.utils.data.DataLoader(dataset_valid, **kwargs, shuffle=True)

Start reading files
Loaded DATABASE from ../data/short_audio_dataset
1000 total file
Longest file is 16050 long
Mean: -0.6988561153411865
Standard deviation: 2332.389404296875
Normalization: True
32767.0


In [43]:
max = 0
min = 10e10
for x, _ in loader_train:
    # print(torch.max(x), torch.min(x))
    if torch.max(x) > max:
        max = torch.max(x)
    if torch.min(x) < min:
        min = torch.min(x)
print(max, min)

    # plt.plot(np.arange(audio_dataset.max_length), x.flatten())
    # plt.show()
    

tensor(0.0158) tensor(7.4709e-05)


## Autoencoder class

In [44]:
class AutoEncoder(nn.Module):
    def __init__(self, wav_len):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(wav_len, 6000),
            nn.ReLU(),
            nn.Linear(6000, 1000),
            nn.ReLU(),
            # nn.Linear(4000, 1000),
            # nn.ReLU(),
            # nn.Linear(100, 10),
            # nn.ReLU(),
            # nn.Linear(10, 2),
            # nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(1000, 6000),
            nn.ReLU(),
            # nn.Linear(10, 100),
            # nn.ReLU(),
            # nn.Linear(100, 2000),
            # nn.ReLU(),
            # nn.Linear(1000, 5000),
            # nn.ReLU(),
            nn.Linear(6000, wav_len),
            nn.Sigmoid()
        )

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)
        
    def forward(self, x):
        return self.decode(self.encode(x))

## Training

In [45]:
model = AutoEncoder(audio_dataset.max_length)
opt = torch.optim.Adam(model.parameters())

for epoch in range(25):
    print(f'Epoch {epoch+1}/25', end=' ')
    for x, _ in loader_train:   
        
        x_rec = model(x)
        loss = F.binary_cross_entropy(x_rec, x)

        opt.zero_grad()
        loss.backward()
        opt.step()

        mse = F.mse_loss(x_rec, x)
        mae = F.l1_loss(x_rec, x)
        
    print(f'loss: {loss.item():.4f} - rmse: {np.sqrt(mse.item()):.4f} - mae: {mae.item():.4f}')

Epoch 1/25 loss: 0.0460 - rmse: 0.0009 - mae: 0.0007
Epoch 2/25 loss: 0.0459 - rmse: 0.0009 - mae: 0.0007
Epoch 3/25 loss: 0.0460 - rmse: 0.0008 - mae: 0.0007
Epoch 4/25 loss: 0.0460 - rmse: 0.0010 - mae: 0.0008
Epoch 5/25 loss: 0.0461 - rmse: 0.0007 - mae: 0.0005
Epoch 6/25 loss: 0.0459 - rmse: 0.0010 - mae: 0.0007
Epoch 7/25 

KeyboardInterrupt: 