In [1]:
from scipy.io import loadmat
import torch as tr
import torch.nn as nn
from time import time
from matplotlib import pyplot as plt
import os
import cv2
from math import ceil
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [3]:
class ResBlock(nn.Module):
  def __init__(self,channels):
    super(ResBlock,self).__init__()
    self.conv = nn.Conv2d(channels,channels,kernel_size=3,stride=1,padding=1,bias=False,padding_mode='reflect')
    self.bn = nn.BatchNorm2d(channels)
    self.relu = nn.ReLU(inplace=True)
  def forward(self,x):
    conv = self.bn(self.conv(x))
    return self.relu(tr.cat([x,conv],1))

In [10]:
# Generator Model
class Generator(nn.Module):
  def __init__(self):
    super(Generator,self).__init__()
    self.leaky1 = nn.LeakyReLU(0.2)
    self.leaky2 = nn.LeakyReLU(0.2)
    self.leaky3 = nn.LeakyReLU(0.2)
    self.leaky4 = nn.LeakyReLU(0.2)
   

    self.relu1 = nn.ReLU(inplace=True)
    self.relu2 = nn.ReLU(inplace=True)
    self.relu3 = nn.ReLU(inplace=True)
    self.relu4 = nn.ReLU(inplace=True)
    self.relu5 = nn.ReLU(inplace=True)
  
    self.bn1_64 = nn.BatchNorm2d(64)
    self.bn2_64 = nn.BatchNorm2d(64)

    self.bn1_128 = nn.BatchNorm2d(128)
    self.bn2_128 = nn.BatchNorm2d(128)

    self.bn1_256 = nn.BatchNorm2d(256)
    self.bn2_256 = nn.BatchNorm2d(256)

    self.bn1_512 = nn.BatchNorm2d(512)
    self.bn2_512 = nn.BatchNorm2d(512)
    self.bn3_512 = nn.BatchNorm2d(512)
    self.bn4_512 = nn.BatchNorm2d(512)

    self.drop1 = nn.Dropout(0.5)
    self.drop2 = nn.Dropout(0.5)
    

    self.tanh = nn.Tanh()

    self.conv1 = nn.Conv2d(3,64,kernel_size=4,stride=2,padding=1)
    self.conv2 = nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1)
    self.conv3 = nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1)
    self.conv4 = nn.Conv2d(256,512,kernel_size=4,stride=2,padding=1)
   
    
    self.bottleneck = nn.Conv2d(512,512,kernel_size=4,stride=2,padding=1)

    self.tconv1 = nn.ConvTranspose2d(512,512,kernel_size=4,stride=2,padding=1)
    self.tconv2 = nn.ConvTranspose2d(1024,256,kernel_size=4,stride=2,padding=1) # 1,5
    self.tconv3 = nn.ConvTranspose2d(512,128,kernel_size=4,stride=2,padding=1) # 2,4 
    self.tconv4 = nn.ConvTranspose2d(256,64,kernel_size=4,stride=2,padding=1) # 3,3
  

    self.final = nn.ConvTranspose2d(128,3,kernel_size=4,stride=2,padding=1) # 5,1
    
    
  def forward(self,x):
    d1 = self.leaky1(self.conv1(x))
    d2 = self.leaky2(self.bn1_128(self.conv2(d1)))
    d3 = self.leaky3(self.bn1_256(self.conv3(d2)))
    d4 = self.leaky4(self.bn1_512(self.conv4(d3)))
   
    bottleneck = self.drop1(self.relu1(self.bottleneck(d4)))
    
    up1 = self.drop2(self.relu2(self.bn2_512(self.tconv1(bottleneck))))
    up2 = self.relu3(self.bn2_256(self.tconv2(tr.cat([up1,d4],1))))
    up3 = self.relu4(self.bn2_128(self.tconv3(tr.cat([up2,d3],1))))
    up4 = self.relu5(self.bn1_64(self.tconv4(tr.cat([up3,d2],1))))
 
    final = self.tanh(self.final(tr.cat([up4,d1],1)))

    return final


In [11]:
netG = Generator()
_input = tr.randn((1,3,256,256))
writer = SummaryWriter()
writer.add_graph(netG,_input)
writer.close()

In [12]:
with tr.no_grad():
  print(netG(_input).shape)

torch.Size([1, 3, 256, 256])


In [13]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    self.leaky1 = nn.LeakyReLU(0.2)
    self.leaky2 = nn.LeakyReLU(0.2)
    self.leaky3 = nn.LeakyReLU(0.2)
    self.leaky4 = nn.LeakyReLU(0.2)

    self.bn128 = nn.BatchNorm2d(128)
    self.bn256 = nn.BatchNorm2d(256)
    self.bn512 = nn.BatchNorm2d(512)


    self.conv1 = nn.Conv2d(6,64,kernel_size=4,stride=1,padding=1)
    self.conv2 = nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1)
    self.conv3 = nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1)
    self.conv4 = nn.Conv2d(256,512,kernel_size=4,stride=2,padding=1)
    self.final = nn.Conv2d(512,1,kernel_size=4,stride=1,padding=1)

  def forward(self,x):
    d1 = self.leaky1(self.conv1(x))
    d2 = self.leaky2(self.bn128(self.conv2(d1)))
    d3 = self.leaky3(self.bn256(self.conv3(d2)))
    d4 = self.leaky4(self.bn512(self.conv4(d3)))
    final = nn.Sigmoid()(self.final(d4))
    return final

In [14]:
netD = Discriminator()
_input = tr.randn((1,6,256,256))
writer = SummaryWriter()
writer.add_graph(netD,_input)
writer.close()