In [2]:
import os
import math
import pathlib

import numpy as np
import matplotlib.pyplot as plt
import librosa as lb

import torch

plt.style.use('default')

  from collections import Mapping, defaultdict


In [12]:
class SubLinear(torch.nn.Linear):
    #
    def __init__(self, in_features, out_features, sr=44100, bias=True):
        super(SubLinear, self).__init__(in_features, out_features, bias=bias)
        self.in_features = in_features
        self.out_features = out_features
        self.sample_rate = sr
        if (out_features > in_features):
            mfb = lb.filters.mel(n_fft=((out_features-1)*2),
                                 n_mels=in_features,
                                 sr=self.sample_rate)
            self.filterbank = lb.util.nnls(mfb, np.eye(in_features))
        else:
            self.filterbank = lb.filters.mel(n_fft=((in_features-1)*2),
                                             n_mels=out_features,
                                             sr=self.sample_rate)
        self.filterbank = torch.from_numpy(self.filterbank)

In [13]:
class Denoiser(torch.nn.Module):
    #
    def __init__(self):
        super(Denoiser, self).__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Linear(513, 40),
            torch.nn.ReLU(),
            torch.nn.Linear(40, 40),
            torch.nn.ReLU(),
            torch.nn.Linear(40, 514),
            torch.nn.ReLU(),
        )
    #
    def forward(self, x):
        x = self.network(x)
        return x        

In [14]:
class LogarithmicDenoiser(torch.nn.Module):
    #
    def __init__(self):
        super(LogarithmicDenoiser, self).__init__()
        self.network = torch.nn.Sequential(
            SubLinear(513, 40),
            torch.nn.ReLU(),
            torch.nn.Linear(40, 40),
            torch.nn.ReLU(),
            SubLinear(40, 514),
            torch.nn.ReLU(),
        )
    #
    def forward(self, x):
        x = self.network(x)
        return x        

In [37]:
net1 = Denoiser()
net2 = LogarithmicDenoiser()

def cross_entropy(o, t):
    eps = 1e-30
    return ((-t*torch.log(o+eps))-((1-t)*torch.log(1-o+eps))).sum()

def calculate_snr(s, r):
    eps = 1e-30
    return (10*np.log10(np.sum(s**2)/(np.sum((s-r)**2)+eps)+eps))

def regularize(network):
    f = torch.nn.SmoothL1Loss(reduction='mean')
    r = []
    for layer in network.modules():
        if isinstance(layer, SubLinear):
            x = layer.weight
            m = layer.filterbank
            x_w = x.sum(axis=0)/(x.sum(axis=0)).max()
            m_w = m.sum(axis=0)/(m.sum(axis=0)).max()
            loss = f(x_w, m_w)
            r.append(loss)
    return sum(r)

In [None]:
trS = np.load(str(pathlib.Path('~/Datasets/timit/trS.npy').expanduser()))

In [None]:
trS.shape