# **Importing Functions**

In [33]:
import torch
import torch.nn as nn
from tqdm.auto import tqdm

from torchvision.utils import make_grid



import numpy as np
import os
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision import models, transforms, datasets

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# **Building Generator**

In [34]:
# Building UNET


class DownSample(nn.Module):
    def __init__(self, Input_Channels, Output_Channels):
        super(DownSample, self).__init__()
        # nn.Conv2d(in_channel, out_channel, kernel, stride, padding)
        self.model = nn.Sequential(
                     nn.Conv2d(Input_Channels, Output_Channels, 4, 2, 1, bias=False),
                     nn.LeakyReLU(0.2)
                     )

    def forward(self, x):
        down = self.model(x)
        print(down.shape)
        return down


class Upsample(nn.Module):
    def __init__(self, Input_Channels, Output_Channels):
        super(Upsample, self).__init__()
        
        self.model = nn.Sequential( 
                     nn.ConvTranspose2d(Input_Channels, Output_Channels, 4, 2, 1, bias=False),
                     nn.InstanceNorm2d(Output_Channels),
                     nn.ReLU(inplace=True),)

    def forward(self, x, skip_input):
        print(f'Input_to_block Shape :   {x.shape}')
        x = self.model(x)
        print(f'Output_to_block Shape:   {x.shape}')
        x = torch.cat((x, skip_input), 1)
        print(f'Skip_input Shape     : {skip_input.shape}')
        print(f'Output Shape         : {x.shape}\n')
        return x


class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(Generator, self).__init__()

        self.down1 = DownSample(in_channels, 64)
        self.down2 = DownSample(64, 128)
        self.down3 = DownSample(128, 256)
        self.down4 = DownSample(256, 512)
        self.down5 = DownSample(512, 512)
        self.down6 = DownSample(512, 512)
        self.down7 = DownSample(512, 512)
        self.down8 = DownSample(512, 512)

        self.up1 = Upsample(512, 512)
        self.up2 = Upsample(1024, 512)
        self.up3 = Upsample(1024, 512)
        self.up4 = Upsample(1024, 512)
        self.up5 = Upsample(1024, 256)
        self.up6 = Upsample(512, 128)
        self.up7 = Upsample(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, 3, 4, padding=1), # out_channels
            nn.Tanh(),
        )

    def forward(self, x):
        # U-Net generator with skip connections from encoder to decoder

        print(f'Input_Shape: {x.shape} \n')
        print('Down Sampling Begins \n')

        print('Block 1')
        d1 = self.down1(x)
        print('\nBlock 2')
        d2 = self.down2(d1)
        print('\nBlock 3')
        d3 = self.down3(d2)
        print('\nBlock 4')
        d4 = self.down4(d3)
        print('\nBlock 5')
        d5 = self.down5(d4)
        print('\nBlock 6')
        d6 = self.down6(d5)
        print('\nBlock 7')
        d7 = self.down7(d6)
        print('\nBlock 8')
        d8 = self.down8(d7)
        print('\n')
        print('Upsampling Begins \n')
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)
        u8 = self.final(u7)
        return u8


# **Building Discriminator - PatchGAN**

In [35]:

# Building Discriminator

class Discriminator(nn.Module):
    def __init__(self, in_channels=3,):
        super(Discriminator, self).__init__()

        
        self.model = nn.Sequential(
                     nn.Conv2d(6, 64, 4, 2, 1, bias=False),
                     nn.LeakyReLU(0.2, inplace=True),

                     nn.Conv2d(64, 128, 4, 2, 1, bias=False),
                     nn.LeakyReLU(0.2, inplace=True),

                     nn.Conv2d(128, 256, 4, 2, 1, bias=False),
                     nn.LeakyReLU(0.2, inplace=True),

                     nn.Conv2d(256, 512, 4, 2, 1, bias=False),
                     nn.LeakyReLU(0.2, inplace=True),

                     nn.ZeroPad2d((1, 0, 1, 0)),
                     nn.Conv2d(512, 1, 4, padding=1, bias=False)
                     
                     )

        
    def forward(self, img_A, img_B):
        # Here we are concatenating the images on their channels
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

# **Checking Generator and Discriminator Models**

In [36]:
image = torch.rand((1,3,256,256))
out_channels = 3
generator = Generator()
k = generator(image)
print(k.shape)

Input_Shape: torch.Size([1, 3, 256, 256]) 

Down Sampling Begins 

Block 1
torch.Size([1, 64, 128, 128])

Block 2
torch.Size([1, 128, 64, 64])

Block 3
torch.Size([1, 256, 32, 32])

Block 4
torch.Size([1, 512, 16, 16])

Block 5
torch.Size([1, 512, 8, 8])

Block 6
torch.Size([1, 512, 4, 4])

Block 7
torch.Size([1, 512, 2, 2])

Block 8
torch.Size([1, 512, 1, 1])


Upsampling Begins 

Input_to_block Shape :   torch.Size([1, 512, 1, 1])
Output_to_block Shape:   torch.Size([1, 512, 2, 2])
Skip_input Shape     : torch.Size([1, 512, 2, 2])
Output Shape         : torch.Size([1, 1024, 2, 2])

Input_to_block Shape :   torch.Size([1, 1024, 2, 2])
Output_to_block Shape:   torch.Size([1, 512, 4, 4])
Skip_input Shape     : torch.Size([1, 512, 4, 4])
Output Shape         : torch.Size([1, 1024, 4, 4])

Input_to_block Shape :   torch.Size([1, 1024, 4, 4])
Output_to_block Shape:   torch.Size([1, 512, 8, 8])
Skip_input Shape     : torch.Size([1, 512, 8, 8])
Output Shape         : torch.Size([1, 1024, 8, 

In [37]:
image1 = torch.rand((1,3,256,256))
image2 = torch.rand((1,3,256,256))

out_channels = 3
discriminator = Discriminator()
k = discriminator(image1,image2)
print(k.shape)

torch.Size([1, 1, 16, 16])
