### Deep Learning Homework 6

Taking inspiration from the last 2 pictures within the notebook (07-convnets.ipynb), implement a U-Net-style CNN with the following specs:

1. All convolutions must use a 3 x 3 kernel and leave the spatial dimensions (i.e. height, width) of the input untouched.
2. Downsampling in the contracting part is performed via maxpooling with a 2 x 2 kernel and stride of 2.
3. Upsampling is operated by a deconvolution with a 2 x 2 kernel and stride of 2. The PyTorch module that implements the deconvolution is `nn.ConvTranspose2d`
4. The final layer of the expanding part has only 1 channel 

* between how many classes are we discriminating?

Create a network class with (at least) a `__init__` and a `forward` method. Please resort to additional structures (e.g., `nn.Module`s, private methods...) if you believe it helps readability of your code.

Test, at least with random data, that the network is doing the correct tensor operations and that the output has the correct shape (e.g., use `assert`s in your code to see if the byproduct is of the expected shape).

Note: the overall organization of your work can greatly improve readability and understanding of your code by others. Please consider preparing your notebook in an organized fashion so that we can better understand (and correct) your implementation.

In [2]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import pylab as pl
from IPython.display import clear_output

In [9]:
x = torch.zeros((10, ))
a = []
layer = nn.Linear(10, 10)
out = layer(x)
a.append(out)

print(a)
out = layer(out)
print(a)

[tensor([-0.2921, -0.1395, -0.2395,  0.0859, -0.1852, -0.0917, -0.1712, -0.1405,
         0.2245,  0.1828], grad_fn=<AddBackward0>)]
[tensor([-0.2921, -0.1395, -0.2395,  0.0859, -0.1852, -0.0917, -0.1712, -0.1405,
         0.2245,  0.1828], grad_fn=<AddBackward0>)]


In [10]:
out

tensor([-0.2387, -0.2231, -0.2372,  0.1754, -0.2415, -0.0832, -0.1299, -0.2551,
         0.3153,  0.1280], grad_fn=<AddBackward0>)

In [75]:
class VGG_block(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=2, maxpool=False, activation=nn.ReLU):
        super().__init__()
        
        layers = []
        
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        layers.append(activation())
        for i in range(num_layers-1):
            layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
            layers.append(activation())
        
        if maxpool:
            layers.append(nn.MaxPool2d(2))
             
        self.layers = nn.Sequential(*layers)
    
    def forward(self, X):
        return self.layers(X)
    
class upsampling_block(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, num_mid_layers=2, activation=nn.ReLU):
        super().__init__()
        
        self.layers = nn.Sequential(
            VGG_block(in_channels, mid_channels, num_layers=num_mid_layers, activation=activation),
            nn.ConvTranspose2d(mid_channels, out_channels, kernel_size=2, stride=2),
            activation()
        )
        
    def forward(self, X):
        return self.layers(X)

class U_net(nn.Module):
    """
    Implements a Unet
    """

    def __init__(self, h=572, w=572, channels=3, depth=4, num_classes=10):
        super().__init__()
        
        # Downsampling layers
        downsampling_layers = []
        in_channels, out_channels = channels, 64
        for i in range(depth):
            downsampling_layers.append(VGG_block(in_channels, out_channels))
            in_channels = out_channels
            out_channels *= 2
        
        self.downsampling_layers = nn.Sequential(*downsampling_layers)
        self.maxpool = nn.MaxPool2d(2)
        
        # Deepest layer
        mid_channels = out_channels
        out_channels = in_channels
        self.deep_layer = upsampling_block(in_channels, mid_channels, out_channels)
        
        # Upsampling layers
        upsampling_layers = []
        in_channels, mid_channels, out_channels = mid_channels, out_channels, out_channels//2
        for i in range(depth-1):
            upsampling_layers.append(upsampling_block(in_channels, mid_channels, out_channels))
            in_channels = in_channels//2
            mid_channels = mid_channels//2
            out_channels = out_channels//2
            
        self.upsampling_layers = nn.Sequential(*upsampling_layers)
        
        # Clasifier or last layer
        self.classifier = nn.Sequential(
            VGG_block(in_channels, mid_channels),
            nn.Conv2d(mid_channels, num_classes, kernel_size=1)
        )
        
    def forward(self, X):        
        skips = []
        out = X
        for layer in self.downsampling_layers:
            out = layer(out)
            skips.append(out)
            out = self.maxpool(out)
        
        out = self.deep_layer(out)
            
        for i, layer in enumerate(self.upsampling_layers, start=1):
            print(skips[-i].size())
            out = torch.cat((skips[-i], out), dim=1)
            out = layer(out)
        
        print(skips[0].size())
        out = torch.cat((skips[0], out), dim=1)
        return self.classifier(out)

In [78]:
net = U_net(h=64, w=64, channels=3, depth=4, num_classes=10)

x = torch.zeros((4, 3, 64, 64))

print(net(x).size())
from torchsummary import summary

_=summary(U_net())
U_net()

torch.Size([4, 512, 8, 8])
torch.Size([4, 256, 16, 16])
torch.Size([4, 128, 32, 32])
torch.Size([4, 64, 64, 64])
torch.Size([4, 10, 64, 64])
Layer (type:depth-idx)                        Param #
├─Sequential: 1-1                             --
|    └─VGG_block: 2-1                         --
|    |    └─Sequential: 3-1                   38,720
|    └─VGG_block: 2-2                         --
|    |    └─Sequential: 3-2                   221,440
|    └─VGG_block: 2-3                         --
|    |    └─Sequential: 3-3                   885,248
|    └─VGG_block: 2-4                         --
|    |    └─Sequential: 3-4                   3,539,968
├─MaxPool2d: 1-2                              --
├─upsampling_block: 1-3                       --
|    └─Sequential: 2-5                        --
|    |    └─VGG_block: 3-5                    14,157,824
|    |    └─ConvTranspose2d: 3-6              2,097,664
|    |    └─ReLU: 3-7                         --
├─Sequential: 1-4                 

U_net(
  (downsampling_layers): Sequential(
    (0): VGG_block(
      (layers): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (1): VGG_block(
      (layers): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (2): VGG_block(
      (layers): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (3): VGG_block(
      (layers): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(5