In [3]:
import torch.nn as nn
import torch.nn.functional as F

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [ nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features,3),
                      nn.InstanceNorm2d(in_features), # BN for GANs
                      nn.ReLU(True),
                      nn.ReflectionPad2d(1), # Conserva mejor la distribucion
                      nn.Conv2d(in_features, in_features, 3),
                      nn.InstanceNorm2d(in_features)
                      ]
    
        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return self.conv_block(x) + x

In [6]:
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Bloque convolucional
        model = [nn.ReflectionPad2d(3),
                nn.Conv2d(input_nc, 64, F), # I-7 + 2*3
                nn.InstanceNorm2d(64),
                nn.ReLU(True)
                ]
        in_features = 64
        out_features =in_features*3

        # Encoding
        for _ in range(2):
            model += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), #I/2
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(True)
                      ]
            in_features = out_features
            out_features = in_features*2


        # Transformaciones residuales
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Decoding
        out_features = in_features/2
        for _ in range(2):
            model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), #2I
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(True)]

            in_features = out_features
            out_features = in_features//2

            #salida
            model += [ nn.ReflectionPad2d(3),
                      nn.Conv2d(64, output_nc, 7), #I
                      nn.Tanh()]
    
        self.model = nn.Sequential(*model)

    
    def forward(self, x):
        return self.model(x)