In [0]:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
import cv2
import matplotlib.pyplot as plt

import torch
import torch.nn.parallel as prl
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms,utils,datasets
from torch.autograd import Variable
from torch.utils.data import Dataset,DataLoader

In [0]:
class ResidualCell(nn.Module):
  def __init__(self, n_filters):
    super(ResidualCell, self).__init__()

    self.conv_i = nn.Conv2d(n_filters, n_filters, (3,3), 1, 1) #Why is this layer needed?
    self.conv_m = nn.Conv2d(n_filters, int(n_filters/2), (3,3), 1, 1)
    self.conv_o = nn.Conv2d(int(n_filters/2), n_filters, (3,3), 1, 1)
    
    self.res_net = nn.Sequential(self.conv_i, self.conv_m, self.conv_o)

  def forward(self, x):
    return x+self.res_net(x)

In [0]:
class EncoderCell(nn.Module):
  def __init__(self, n_in, n_res):
    super(EncoderCell, self).__init__()

    self.conv_a = nn.Conv2d(n_in, n_res, (3,3), 2)
    self.res = ResidualCell(n_res)
    self.conv_b = nn.Conv2d(n_res, n_res, (3,3), 1)

    self.encoder = nn.Sequential(self.conv_a, self.res, self.conv_b)

  def forward(self, x):
    return self.encoder(x)

In [0]:
class DecoderCell(nn.Module):
  def __init__(self, n_in, n_res):
    super(DecoderCell, self).__init__()

    self.deconv_a = nn.ConvTranspose2d(n_in, n_res, (3,3), 1)
    self.res = ResidualCell(n_res)
    self.deconv_b = nn.ConvTranspose2d(n_res, n_res, (3,3), 2)

    self.decoder = nn.Sequential(self.deconv_a, self.res, self.deconv_b)

  def forward(self, x):
    return self.decoder(x)

In [0]:
class SubGenerator(nn.Module):
  """
  Input: Zero filling imperfect reconstruction by inverse of undersampled fourier transform
  Output: Full image which (should) belongs to the set of all possible perfect reconstructions
  """
  def __init__(self):
    super(SubGenerator, self).__init__(n)
    #Input channels = 2 

    #ConvEncoder
    self.ec0 = EncoderCell(2,n)
    self.ec1 = EncoderCell(n,2*n)
    self.ec2 = EncoderCell(2*n,4*n)
    self.ec3 = EncoderCell(4*n,8*n)

    #ConvDecoder
    self.dc3 = DecoderCell(8*n,4*n)
    self.dc2 = DecoderCell(4*n,2*n)
    self.dc1 = DecoderCell(2*n,n)
    self.dc0 = DecoderCell(n,n)

    #Output 
    self.out = nn.Conv2d(n,2,(3,3),1,1) #Real and Complex channels as outputs?

  def forward(self, x):
    #With Residual Connections
    
    e0 = self.ec0(x)
    e1 = self.ec1(e0)
    e2 = self.ec2(e1)
    e3 = self.ec3(e2)

    d3 = self.dc3(e3)
    d2 = self.dc2(d3+e2)
    d1 = self.dc1(d2+e1)
    d0 = self.dc0(d1+e0)

    y = nn.Tanh(self.out(d0))

    return y


In [0]:
class Generator(nn.Module):
  def __init__(self, n):
    super(Generator, self).__init__()

    self.reconGAN = SubGenerator(n)
    self.refineGAN = SubGenerator(n)

  def forward(self, x):
    y = x+self.reconGAN(x)
    z = y+self.refineGAN(y) #Check if need to add here too
    return (y,z)

In [0]:
class Discriminator(nn.Module):
  def __init__(self, H, W, n_in, n):
    super(Discriminator, self).__init__()
    
    #Downsample Layers

    self.ec0 = EncoderCell(n_in,n)
    self.ec1 = EncoderCell(n,2*n)
    self.ec2 = EncoderCell(2*n,4*n)
    self.ec3 = EncoderCell(4*n,8*n)

    #Probability Output
    self.downsampler = nn.Sequential(self.ec0, self.ec1, self.ec2, self.ec3)
    self.out = nn.Conv2d(8*n,1,(H,W), 1)
  
  def forward(self, x):
    y = self.downsampler(x)
    y = nn.Sigmoid(self.out(y))
    return y