In [2]:
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torchaudio as ta
import librosa
import soundfile as sf

DEVICE = T.device('cuda' if T.cuda.is_available() else 'cpu')


class DSP:

    def __init__(self, n_fft=254, hop_len=None):
        """ signal processing utils using torchaudio
        """
        self.n_fft = n_fft
        self.hop_len = n_fft//2 if hop_len is None else hop_len
        self.stft = ta.transforms.Spectrogram(
            n_fft=n_fft,
            hop_length=self.hop_len,
            win_length=n_fft,
            power=None
        )
        self.amplitude_to_db = ta.transforms.AmplitudeToDB()
        self.db_to_amplitude = lambda x: T.pow(T.pow(10.0, 0.1 * x), 1.)

    def sig_to_db_phase(self, sig):
        """ get dB and phase spectrograms of signal
            example usage:
                >>> sig, sr = torchaudio.load('sound.wav')
                >>> db, phase = chvoice.sig_to_db_phase(sig)
        """
        # represent input signal in time-frequency domain
        stft = self.stft(sig)
        # magnitude = amount of power/volume for each phase = frequency
        mag, phase = ta.functional.magphase(stft)
        # put magnitudes on log scale
        db = self.amplitude_to_db(mag)

        return db, phase

    def db_phase_to_sig(self, db, phase):
        """ get wav signal from db and phase spectrograms.
            example usage:
                >>> sig, sr = torchaudio.load('sound.wav')
                >>> db, phase = chvoice.sig_to_db_phase(sig)
                    ... do stuff to db ...
                >>> recovered_sig = chvoice.spec_to_sig(db, phase)
        """
        # go from log scale back to linear
        mag = self.db_to_amplitude(db)
        # recover full fourier transform of signal
        real = mag * T.cos(phase)
        imaginary = mag * T.sin(phase)
        complex = T.stack((real, imaginary), dim=-1)
        # inverse fourier transform to get signal
        sig = complex.istft(
            n_fft=self.n_fft,
            hop_length=self.hop_len
        )

        return sig


### CREDIT : https://github.com/milesial/Pytorch-UNet

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(
                scale_factor=2,
                mode='bilinear',
                align_corners=True
            )
            self.conv = DoubleConv(
                in_channels,
                out_channels,
                in_channels // 2
            )
        else:
            self.up = nn.ConvTranspose2d(
                in_channels,
                in_channels // 2,
                kernel_size=2,
                stride=2
            )
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = T.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


In [34]:



model = UNet(1, 1).to(DEVICE)
model.load_state_dict(T.load('static/model_weights.pkl', map_location=DEVICE))
model = model.to(DEVICE)
dsp = DSP(254)

noisy_sig, _ = librosa.load('vidx/p232_001.wav', sr=16000)
noisy_sig = T.from_numpy(noisy_sig)
db, phase = dsp.sig_to_db_phase(noisy_sig)



In [35]:
db.shape

torch.Size([128, 220])

In [36]:
chunks = db.unsqueeze(0).unfold(2, 128, 128).squeeze(0).movedim(1, 0)


In [37]:
chunks.shape

torch.Size([1, 128, 128])

In [38]:
mean = T.mean(chunks)
std = T.std(chunks)
chunks = (chunks - 32) / 18
chunks = chunks.unsqueeze(1).to(DEVICE)


    

In [39]:
proc = T.empty_like(chunks)

In [40]:
with T.no_grad():
    for idx in range(0, len(chunks), 64):
        print(f'batch {idx}/{len(chunks)}')
        proc[idx:idx+64] = model(chunks[idx:idx+64])

batch 0/1


In [42]:
proc

tensor([[[[ 1.0595,  1.2238,  1.2269,  ...,  1.6178,  1.4083,  1.2772],
          [ 1.0330,  1.3350,  1.3827,  ...,  1.8369,  1.6164,  1.5232],
          [ 0.9786,  1.0791,  1.2370,  ...,  2.0669,  1.8612,  1.6736],
          ...,
          [ 0.4908, -0.6945, -0.3538,  ..., -0.1278, -0.0754, -0.0190],
          [ 0.6045, -0.7678, -0.6551,  ..., -0.2654, -0.3965,  0.0403],
          [ 0.5626, -0.5024, -0.8177,  ..., -0.5210, -0.4075,  0.0448]]]])

In [43]:

proc = (proc * 18) + 32

db_out = T.cat([c for c in proc.squeeze(1)], dim=1)
phase_clipped = phase[:, :db_out.size(1)]
sig = dsp.db_phase_to_sig(db_out, phase_clipped)
sf.write('vidx/testpath.wav', sig, 16000)
