<a href="https://colab.research.google.com/github/JoaoVitorSantiagoNogueira/deepLearning2023/blob/main/T3/DL_Task3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Task 3

Test different auto-encoder architectures

## Theory

### references:

[1]
[2]
...

## Code

### Intialization

In [1]:
import numpy as np
import torch
import torch.nn as nn

### model definition

auto encoder used on a [previous work](https://github.com/JoaoVitorSantiagoNogueira/tcc-testes), removing the middle layers as no processing needs to be done. This will serve as our baseline. Steal model concatenation from task2

In [None]:

class BaseNetwork(nn.Module):
    def __init__(self):
        super(BaseNetwork, self).__init__()

    def init_weights(self, init_type='normal', gain=0.02):
        '''
        initialize network's weights
        init_type: normal | xavier | kaiming | orthogonal
        https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
        '''

        def init_func(m):
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, gain)
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(m.weight.data, gain=gain)
                elif init_type == 'kaiming':
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(m.weight.data, gain=gain)

                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)

            elif classname.find('BatchNorm2d') != -1:
                nn.init.normal_(m.weight.data, 1.0, gain)
                nn.init.constant_(m.bias.data, 0.0)

        self.apply(init_func)


In [None]:
class InpaintGenerator(BaseNetwork):
    def __init__(self, residual_blocks=8, init_weights=True):
        super(InpaintGenerator, self).__init__()

        self.encoder = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=4, out_channels=64, kernel_size=7, padding=0),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256, track_running_stats=False),
            nn.ReLU(True)
        )

        #blocks = []
        #for _ in range(residual_blocks):
        #    block = ResnetBlock(256, 2)
        #    blocks.append(block)

        #self.middle = nn.Sequential(*blocks)

        self.decoder = nn.Sequential(
            # [batch_size, 256, 64, 64]
            #nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            #changes to avoid chckered patterns, parameters chosen to keep the same dimensions
            nn.Upsample(scale_factor= 4, mode='bilinear'),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=0),
            # [batch_size, 128, 128, 128]


            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),
            #nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.Upsample(scale_factor= 4, mode='bilinear'),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=0),
            # [batch_size, 64, 256, 256]


            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(3),
            # [batch_size, 64, 262, 262]
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, padding=0),
            # [batch_size, 3, 256, 256]
        )

        if init_weights:
            self.init_weights()

    def forward(self, x):
        x = self.encoder(x)
        #x = self.middle(x)
        x = self.decoder(x)
        x = (torch.tanh(x) + 1) / 2

        return x

###Training

### Usage

### Visualization
