In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler
import pandas as pd
import numpy as np

In [2]:
class AudioClassifyDataset(Dataset):
    def __init__(self, root="./UrbanSound8K/"):
        
        self.data = []
        self.label = []
        
        FRAME_SIZE = 512
        HOP_SIZE = 512
        target_length = 88200
        
        meta = pd.read_csv(root+"/metadata/UrbanSound8K.csv")
        
        for fold in os.listdir(root+"audio/"):
            if fold[0] == '.':
                continue
            print(fold)

            for file in tqdm(os.listdir(root+"audio/"+fold+"/")):
                if file[0] == '.':
                    continue

                PATH = root+"audio/"+fold+"/"+file
                sig, sr = librosa.load(PATH)
                sig = librosa.util.fix_length(sig, target_length)

                S_scale = librosa.stft(sig, win_length=HOP_SIZE, n_fft=FRAME_SIZE, hop_length=HOP_SIZE)
                Y_scale = abs(S_scale)**2
                Y_db = librosa.power_to_db(Y_scale)
                
                
                Y_db = np.expand_dims(Y_db, axis=0)
                
                #plot_spectrogram(sig, Y_db, sr, HOP_SIZE)
                
                idx = np.where(meta['slice_file_name'] == file)[0][0]
                label = meta.iloc[[idx]]['classID'][idx]
                
                self.data.append(Y_db)
                self.label.append(label)

    def __len__(self):
        assert(len(self.data) == len(self.label))
        return len(self.label)

    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]


In [3]:
class VarationalCnnAutoEncoder(nn.Module):
    def __init__(self):
        super(VarationalCnnAutoEncoder, self).__init__()
        
        self.conv_1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=7, padding=3)
        self.conv_2 = nn.Conv2d(in_channels=4, out_channels=8, kernel_size=5, padding=2)
        self.conv_3 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        
        self.fc_1 = nn.Linear(16 * 32 * 21, 512)
        
        self.fc2_mu = nn.Linear(512, 10)
        self.fc2_var = nn.Linear(512, 10)
        
        self.fc1_decode = nn.Linear(10, 512)
        self.fc2_decode = nn.Linear(512, 257*173)
        
        self.relu = nn.LeakyReLU(negative_slope=0.05)
        self.pool = nn.MaxPool2d(kernel_size=(2,2))
        
    def encode(self, x):
        x = self.pool(self.conv_1(x))
        x = self.pool(self.conv_2(x))
        x = self.pool(self.conv_3(x))
        
        x = x.view(-1, 16 * 32 *21)
        x = self.fc_1(x)
        
        mu = self.fc2_mu(x)
        var = self.fc2_var(x)
        
        return mu, var
    
    def reparameterize(self, mu, var):
        std = torch.exp(var/2)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        z = self.relu(self.fc1_decode(z))
        z = self.fc2_decode(z)
        return z
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar

In [4]:
def train(model, optimizer, epoch):
    train_loss = 0
    print_every = 100
    model.train()
    epoch_loss = 0
    for idx, (images, label) in enumerate(data_loader):
        if idx == len(data_loader) -1:
            continue
            
        images = images.to(device)
        
        reconstructed, mu, logvar = model(images)
        loss = loss_function(reconstructed_image=reconstructed, original_image=images, mu=mu, logvar=logvar)
        
        #print(loss)
        
        optimizer.zero_grad()
        loss.backward()
        train_loss += loss.item()
        epoch_loss += loss.item()
        optimizer.step()
        
            
    print("===> Epoch {}, Average loss: {:.3f}".format(epoch, epoch_loss/len(data_loader.dataset)))

In [8]:
batch_size = 32

try:
    print(str_data_set)
except:
    print("Data set loaded")
    audio_dataset = torch.load("./audio_dataset")
    str_data_set = "Data set already defined"

data_loader = DataLoader(dataset=audio_dataset, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = VarationalCnnAutoEncoder().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

Data set already defined
cuda


In [6]:
def loss_function(reconstructed_image, original_image, mu, logvar):
    
    lined = original_image.view(-1, 1 * 257 * 173)
    
    #bce = F.binary_cross_entropy(reconstructed_image, lined, reduction = 'sum')
        
    bce = F.mse_loss(reconstructed_image, lined, reduction='sum')
    
    #kld = torch.sum(0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar, 1))
    kld = 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
    
    return bce + kld

In [9]:
epochs = 5
for i in range(epochs):
    train(model=model, epoch=i, optimizer=optimizer)

===> Epoch 0, Average loss: 1415944209634295296.000
===> Epoch 1, Average loss: 3719014063.626
===> Epoch 2, Average loss: 1132870131.247
===> Epoch 3, Average loss: 57443231.853
===> Epoch 4, Average loss: 2563703130.620
