# Denoising with generative models

## Pytorch tests

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Subpixel(nn.Module):
    def __init__(self):
        super(Subpixel, self).__init__()
           
        
    def forward(self, x):
        y = x # TODO
        return y

In [3]:
class Concat(nn.Module):
    def __init__(self):
        super(Concat, self).__init__()
           
        
    def forward(self, x1, x2):
        y = x # TODO
        return y

In [4]:
class Downsampling(nn.Module):
    
    def __init__(self, in_ch, out_ch, size):
        super(Downsampling, self).__init__()
        
        self.conv = nn.Sequential(
            # We need a padding mode = same to get the same behavior as the paper
            # But it doesn't seem to exist in pytorch
            nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=size, stride=2, padding_mode='zeros'), 
            nn.LeakyReLU(0.2)
        )
        
    def forward(self, x):
        y = self.conv(x)
        return y

In [5]:
class Bottleneck(nn.Module):
    def __init__(self, ch, size):
        super(Bottleneck, self).__init__()
        self.conv = nn.Sequential(
            # We need a padding mode = same to get the same behavior as the paper
            # But it doesn't seem to exist in pytorch
            nn.Conv1d(in_channels=ch, out_channels=ch, kernel_size=size, stride=2, padding_mode='zeros'), 
            nn.Dropout(0.5),
            nn.LeakyReLU(0.2)
        )
        
    def forward(self, x):
        y = self.conv(x)
        return y

In [6]:
class Upsampling(nn.Module):
    def __init__(self, in_ch, out_ch, size):
        super(Upsampling, self).__init__()
        self.conv = nn.Sequential(
            # We need a padding mode = same to get the same behavior as the paper
            # But it doesn't seem to exist in pytorch
            nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=size, stride=1, padding_mode='zeros'),
            nn.Dropout(p=0.5),
            nn.ReLU()
        )
        self.subpixel = Subpixel()
        self.concat = Concat()
        
    def forward(self, x1, x2):
        y = self.conv(x1)
        y = self.subpixel(y)
        y = self.concat(y, x2)
        return y

In [7]:
class LastConv(nn.Module):
    def __init__(self, in_ch, size):
        super(LastConv, self).__init__()
        self.conv = nn.Conv1d(in_channels=in_ch, out_channels=2, kernel_size=9, stride=1, padding_mode='zeros')
        self.subpixel = Subpixel()
           
        
    def forward(self, x1, x2):
        y = self.conv(x1)
        y = self.subpixel(y)
        return y

In [12]:
class Net(nn.Module):

    def __init__(self, depth):
        super(Net, self).__init__()
        
        
        
        # Manual definition
        
        
        B = 4
        n_channels= [128, 256, 512, 512] # max(2^(6+b), 512) 
        size_filters = [65, 33, 17, 9] # min(2^(7−b) + 1, 9)
        
        self.down1 = Downsampling(1, 128, 63)
        self.down2 = Downsampling(128, 256, 33)
        self.down3 = Downsampling(256, 512, 17)
        self.down4 = Downsampling(512, 512, 9)
        self.bottleneck = Bottleneck(512, 9)
        self.up1 = Upsampling(int((512/2)*2), 512*2, 9)
        self.up2 = Upsampling(512*2, 512*2, 17)
        self.up3 = Upsampling(512*2, 256*2, 33)
        self.up4 = Upsampling(256*2, 128*2, 63)
        self.last = LastConv(128*2, 9)

        
        # Automatic definition, seems to work :
        
        B = depth
        n_channels, size_filters = get_sizes_for_layers(B)
        
        # Downsampling
        self.down = []
        for n_ch_in, n_ch_out, size in args_down(n_channels, size_filters):
            self.down.append(Downsampling(n_ch_in, n_ch_out, size))
            
        # Bottlneck
        self.bottleneck = Bottleneck(n_channels[-1], size_filters[-1])
        
        # Upsampling
        self.up = []
        for n_ch_in, n_ch_out, size in args_up(n_channels, size_filters):
            self.up.append(Upsampling(n_ch_in*2, n_ch_out*2, size))
              
        # Final layer
        self.last = LastConv(n_channels[0]*2, 9)
        
        
        

    def forward(self, x):

        # Manual 
        
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        b = self.bottleneck(x4)
        y = self.up1(b, x4)
        y = self.up2(y, x3)
        y = self.up3(y, x2)
        y = self.up4(y, x1)
        y = self.last(y)
        
        # Automatic
        
        # Downsampling
        down_out = []
        for i in range(len(self.down)):
            x = self.down[i](x)
            down_out.append(x)
            
        # Bottleneck
        b = self.bottleneck(x)
        
        # Upsampling
        y = b
        for i in range(len(self.up)):
            y = self.up[i](y, down_out[:-(i+1)])
            
        # Final layer
        y = self.last(y)
       
        return y
    
net = Net(4)
#print(net)

## Utils

### Generate base sizes and count for each level

In [9]:
# this is what they implemented in their git project

def get_sizes_for_layers(B):
    n_channels = []
    size_filters = []
    for b in range(1, B+1):
        n_channels.append(min(2**(6 + b), 512)) # They wrote max in paper, but min in code
        size_filters.append(max(2**(7-b) + 1, 9)) # They wrote min in paper, but max in code
    return n_channels, size_filters

In [48]:
B = 4
n_channels, size_filters = get_sizes_for_layers(B)

In [49]:
n_channels, size_filters

([128, 256, 512, 512], [65, 33, 17, 9])

### Generate correct parameters for up and down

In [10]:
# The input channel count is equal to the the output channel count of the previous layer
# Input will be all the channel counts, shifted to the right with a 1 before
def args_down(n_channels, size_filters):
    return zip([1] + n_channels[:-1], n_channels, size_filters)

# Input filter count is the size of the bottlneck for the first up layer
# And then it will be the count of the previous up layer, which is equal to twice the count of the down layer
# (since we do some stacking with the skip connections)

# Output filter count  will be twice the count of the down layer 
# so that after the subpixel we get the same count as in the down layer
# and we can stack them together
def args_up(n_channels, size_filters):
    return zip([int(n_channels[-1]/2)] + n_channels[::-1][:-1], n_channels[::-1], size_filters[::-1])

In [60]:
list(args_down(n_channels, size_filters))

[(1, 128, 65), (128, 256, 33), (256, 512, 17), (512, 512, 9)]

In [61]:
list(args_up(n_channels, size_filters))

[(256, 512, 9), (512, 512, 17), (512, 256, 33), (256, 128, 65)]

#### And if we use the arguments, we get the same parameters as the manually defined ones

In [69]:
for n_ch_in, n_ch_out, size in args_down(n_channels, size_filters):
    print(n_ch_in, n_ch_out, size)
print(n_channels[-1], size_filters[-1])
            
for n_ch_in, n_ch_out, size in args_up(n_channels, size_filters):
    print(n_ch_in*2, n_ch_out*2, size)
                                            
print(n_channels[0]*2, 9)

1 128 65
128 256 33
256 512 17
512 512 9
512 9
512 1024 9
1024 1024 17
1024 512 33
512 256 65
256 9


In [None]:
# Compar with this
# Manual definition
# self.down1 = Downsampling(1, 128, 63)
# self.down2 = Downsampling(128, 256, 33)
# self.down3 = Downsampling(256, 512, 17)
# self.down4 = Downsampling(512, 512, 9)
# self.bottleneck = Bottleneck(512, 9)
# self.up1 = Upsampling(int((512/2)*2), 512*2, 9)
# self.up2 = Upsampling(512*2, 512*2, 17)
# self.up3 = Upsampling(512*2, 256*2, 33)
# self.up4 = Upsampling(256*2, 128*2, 63)
# self.last = LastConv(128*2, 9)