In [10]:
import torch
from torch import nn
import torch.utils as utills
import numpy as np
import torchsummary
import torch.nn.functional as F
from torchgan.layers import VirtualBatchNorm

In [26]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.layers = nn.ModuleList()
        self.filters = [1, 16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024]
        
        for i in range(10):
            self.layers.append(nn.Sequential(
                nn.Conv1d(
                    in_channels=self.filters[i] * 2, 
                    out_channels=self.filters[i+1] * 2,
                    kernel_size = 32,
                    stride=2,
                    padding=15),
                VirtualBatchNorm(self.filters[i+1] * 2),
                nn.LeakyReLU(0.3)
                 )
             )
                              
        self.flatten = nn.Sequential(
            nn.Conv1d(1024, 1, kernel_size=1, stride=1),
            nn.LeakyReLU(0.3),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
                              
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.flatten(x)
        return x
        
class NoiseCanceler(nn.Module):
    def __init__(self, skip_z):
        super(NoiseCanceler, self).__init__()
            
        self.enc_layers = nn.ModuleList()
        self.dec_layers = nn.ModuleList()
        self.filters = [1, 16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024]
        self.skip_z = skip_z
        
        self.prelu = nn.PReLU()
        
        # For Encoder [Batch x feature map x length]
        for i in range(11):
            self.enc_layers.append(nn.Conv1d(
                    in_channels=self.filters[i], 
                    out_channels=self.filters[i+1],
                    kernel_size = 32,
                    stride=2,
                    padding=15)
               )
               # output: [Batch x 1024 x 8]
       
        # For Decoder
        # Gaussian random variable z, Whether or not to use.
        for i in range(11, 0, -1):
            if i == 11 and skip_z == True:
                 self.dec_layers.append(nn.ConvTranspose1d(
                        in_channels=self.filters[i], 
                        out_channels=self.filters[i-1],
                        kernel_size = 32,
                        stride=2,
                        padding=15)
                    )
                  # output: [Batch x 1 x 16384]
            else:
                self.dec_layers.append(nn.ConvTranspose1d(
                        in_channels=self.filters[i] * 2, 
                        out_channels=self.filters[i-1],
                        kernel_size = 32,
                        stride=2,
                        padding=15)
                    )
                   # output: [Batch x 1 x 16384]  
        self.dec_tanh = nn.Tanh()
            
    def forward(self, x, z):
        values = []
       
        # Encoding
        for enc in self.enc_layers:
            x = self.prelu(enc(x))
            values.append(x)
        
        # Enc out : Batch x 1024 x 8
        values.reverse()
        x = torch.cat((x, z), dim=1)
            
        # Decoding
        for idx, dec in enumerate(self.dec_layers):
            x = dec(x)
            if idx < 10:
                x = torch.cat((x, values[idx + 1]), dim=1)          
                        
        x = self.dec_tanh(x)
        return x
    
    def forward(self, x):
        values = []
       
        # Encoding
        for enc in self.enc_layers:
            x = self.prelu(enc(x))
            values.append(x)
        
        # Enc out : Batch x 1024 x 8
        values.reverse()
        
        # Decoding
        for idx, dec in enumerate(self.dec_layers):
            x = self.prelu(dec(x))
            if idx < 10:
                x = torch.cat((x, values[idx + 1]), dim=1)          
                        
        x = self.dec_tanh(x)
        return x

In [27]:
NC = NoiseCanceler(True)
D = Discriminator()

In [28]:
torchsummary.summary(NC, (1, 16384) , device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1             [-1, 16, 8192]             528
             PReLU-2             [-1, 16, 8192]               1
            Conv1d-3             [-1, 32, 4096]          16,416
             PReLU-4             [-1, 32, 4096]               1
            Conv1d-5             [-1, 32, 2048]          32,800
             PReLU-6             [-1, 32, 2048]               1
            Conv1d-7             [-1, 64, 1024]          65,600
             PReLU-8             [-1, 64, 1024]               1
            Conv1d-9              [-1, 64, 512]         131,136
            PReLU-10              [-1, 64, 512]               1
           Conv1d-11             [-1, 128, 256]         262,272
            PReLU-12             [-1, 128, 256]               1
           Conv1d-13             [-1, 128, 128]         524,416
            PReLU-14             [-1, 1

In [29]:
torchsummary.summary(D, (2, 16384) , device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1             [-1, 32, 8192]           2,080
  VirtualBatchNorm-2             [-1, 32, 8192]              32
         LeakyReLU-3             [-1, 32, 8192]               0
            Conv1d-4             [-1, 64, 4096]          65,600
  VirtualBatchNorm-5             [-1, 64, 4096]              64
         LeakyReLU-6             [-1, 64, 4096]               0
            Conv1d-7             [-1, 64, 2048]         131,136
  VirtualBatchNorm-8             [-1, 64, 2048]              64
         LeakyReLU-9             [-1, 64, 2048]               0
           Conv1d-10            [-1, 128, 1024]         262,272
 VirtualBatchNorm-11            [-1, 128, 1024]             128
        LeakyReLU-12            [-1, 128, 1024]               0
           Conv1d-13             [-1, 128, 512]         524,416
 VirtualBatchNorm-14             [-1, 1