# Denoising with generative models

## Pytorch tests

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [10]:
class Downsampling(nn.Module):
    
    def __init__(self, in_ch, out_ch, size):
        super(Downsampling, self).__init__()
        
        self.conv = nn.Sequential(
            # We need a padding mode = same to get the same behavior as the paper
            # But it doesn't seem to exist in pytorch
            nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=size, stride=2, padding_mode='zeros'), 
            nn.LeakyReLU(0.2)
        )
        
    def forward(self, x):
        y = self.conv(x)
        return y

In [16]:
class Bottleneck(nn.Module):
    def __init__(self, ch, size):
        super(Bottleneck, self).__init__()
        self.conv = nn.Sequential(
            # We need a padding mode = same to get the same behavior as the paper
            # But it doesn't seem to exist in pytorch
            nn.Conv1d(in_channels=ch, out_channels=ch, kernel_size=size, stride=2, padding_mode='zeros'), 
            nn.Dropout(0.5),
            nn.LeakyReLU(0.2)
        )
        
    def forward(self, x):
        y = self.conv(x)
        return y

In [17]:
class Upsampling(nn.Module):
    def __init__(self, in_ch, out_ch, size):
        super(Upsampling, self).__init__()
        self.conv = nn.Sequential(
            # We need a padding mode = same to get the same behavior as the paper
            # But it doesn't seem to exist in pytorch
            nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=size, stride=1, padding_mode='zeros'),
            nn.Dropout(p=0.5),
            nn.ReLU()
        )
    def forward(self, x1, x2):
        y = self.conv(x1)
        y = subpixel(y)
        y = concat(y, x2)
        return y

In [18]:
class LastConv(nn.Module):
    def __init__(self, in_ch, size):
        super(LastConv, self).__init__()
        self.conv = nn.Conv1d(in_channels=in_ch, out_channels=2, kernel_size=9, stride=1, padding_mode='zeros')
           
        
    def forward(self, x1, x2):
        y = self.conv(x1)
        y = subpixel(y)
        return y

In [19]:
class Net(nn.Module):

    def __init__(self, depth):
        super(Net, self).__init__()
        # input = 1024
        B = 4
        n_channels= [128, 256, 512, 512] # max(2^(6+b), 512) 
        size_filters = [65, 33, 17, 9] # min(2^(7−b) + 1, 9)
        
        
        # Manual definition
        self.down1 = Downsampling(1, 128, 63)
        self.down2 = Downsampling(128, 256, 33)
        self.down3 = Downsampling(256, 512, 17)
        self.down4 = Downsampling(512, 512, 9)
        self.bottleneck = Bottleneck(512, 9)
        self.up1 = Upsampling(512, 512*2, 9)
        self.up2 = Upsampling(512*2, 512*2, 17)
        self.up3 = Upsampling(512*2, 256*2, 33)
        self.up4 = Upsampling(256*2, 128*2, 63)
        self.last = LastConv(128*2, 9)

        
        # Automatic definition, seems to work :
        
        B = depth
        
        n_channels, size_filters = get_sizes_for_layers(B)
        
        self.down = []
        # The input channel count is equal to the the output channel count of the previous layer
        # Input will be all the channel counts, shifted to the right with a 1 before
        for n_ch_in, n_ch_out, size in zip([1] + n_channels[:-1], n_channels, size_filters):
            self.down.append(Downsampling(n_ch_in, n_ch_out, size))
            
        self.bottleneck = Bottleneckleneck(n_channels[-1], size_filters[-1])
        
        self.up = []
        # Input filter count is the size of the bottlneck for the first up layer
        # And then it will be the count of the previous up layer, which is equal to twice the count of the down layer
        # (since we do some stacking with the skip connections)
        
        # Output filter count  will be twice the count of the down layer 
        # so that after the subpixel we get the same count as in the down layer
        # and we can stack them together
        for n_ch_in, n_ch_out, size in zip([n_channels[-1]/2] + n_channels[::-1][:-1], size_filters[::-1]):
            self.up.append(Upsampling(n_ch_in*2, n_ch_out*2, size))
                                            
        self.out = LastConv(n_channels[0]*2, 9)
        
        
        

    def forward(self, x):

        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        b = bottleneck(x4)
        y = self.up1(b, x4)
        y = self.up2(y, x3)
        y = self.up3(y, x2)
        y = self.up4(y, x1)
        y = self.last(y)
        return y
    
net = Net(4)
#print(net)

NameError: name 'get_sizes_for_layers' is not defined

In [40]:
# this is what they implemented in their git project
B = 4
n_channels = []
size_filters = []
for b in range(1, B+1):
    n_channels.append(min(2**(6 + b), 512))
    size_filters.append(max(2**(7-b) + 1, 9))

In [41]:
n_channels

[128, 256, 512, 512]

In [42]:
size_filters

[65, 33, 17, 9]

In [43]:
# We can check here that we indeed generate the correct values in our nn

for n_ch_in, n_ch_out, size in zip([1] + n_channels[:-1], n_channels, size_filters):
    print(n_ch_in, n_ch_out, size)
print(n_channels[-1], size_filters[-1])
            
for n_ch_in, n_ch_out, size in zip([n_channels[-1]/2] + n_channels[::-1][:-1], n_channels[::-1], size_filters[::-1]):
    print(n_ch_in*2, n_ch_out*2, size)
                                            
print(n_channels[0]*2, 9)

1 128 65
128 256 33
256 512 17
512 512 9
512 9
512.0 1024 9
1024 1024 17
1024 512 33
512 256 65
256 9


In [20]:
# this is what they did in the paper (inverse min and max). It looks wrong
B = 8
n_channels = []
size_filters = []
for b in range(1, B+1):
    n_channels.append(max(2**(6 + b), 512))
    size_filters.append(min(2**(7-b) + 1, 9))

In [21]:
n_channels

[512, 512, 512, 1024, 2048, 4096, 8192, 16384]

In [46]:
size_filters

[9, 9, 9, 9, 5, 3, 2, 1.5]

In [47]:
# Let's put the corect one as a function
def get_sizes_for_layers(B):
    n_channels = []
    size_filters = []
    for b in range(1, B+1):
        n_channels.append(min(2**(6 + b), 512))
        size_filters.append(max(2**(7-b) + 1, 9))
    return n_channels, size_filters
    