In [1]:
import os
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torchvision.transforms as transforms

In [None]:
dataset = ''


In [None]:
def weights_init(m):
  if isinstance(m, nn.Conv2d):
    nn.init.normal_(m.weight.data, 0.0, 0.02)
  elif isinstance(m, nn.BatchNorm2d):
    nn.init.normal_(m.weight.data, 1.0, 0.02)
    nn.init.constant_(m.bias.data, 0)


In [33]:
class Generator(nn.Module):
  def __init__(self, latent_size):
    super().__init__()
    self.proj_layer = nn.ConvTranspose2d(in_channels = latent_size, out_channels = 1024, kernel_size=4, stride= 1, padding=0, bias=False)
    self.proj_norm = nn.BatchNorm2d(1024, affine=False)

    self.conv1 = nn.ConvTranspose2d(in_channels = 1024, out_channels =  512, kernel_size=4, stride= 2, padding = 1, bias=False)
    self.norm1 = nn.BatchNorm2d(512, affine=False)

    self.conv2 = nn.ConvTranspose2d(in_channels = 512, out_channels =  256, kernel_size=4, stride= 2, padding = 1, bias=False)
    self.norm2 = nn.BatchNorm2d(256, affine=False)

    self.conv3 = nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size=4, stride= 2, padding = 1, bias=False)
    self.norm3 = nn.BatchNorm2d(128, affine=False)

    self.conv4 = nn.ConvTranspose2d(in_channels = 128, out_channels =  3, kernel_size=4, stride= 2, padding = 1, bias=False)

    self.relu = nn.ReLU(inplace= True)
    self.tanh = nn.Tanh()



  def forward(self,z):
    x = self.proj_norm(self.proj_layer(z))
    x = self.relu(self.norm1(self.conv1(x)))
    x = self.relu(self.norm2(self.conv2(x)))
    x = self.relu(self.norm3(self.conv3(x)))
    x = self.tanh(self.conv4(x))
    return x

    

In [34]:
gen = Generator(100)
z = torch.randn(32, 100, 1, 1)
print(gen(z).shape)

torch.Size([32, 3, 64, 64])


In [31]:
class Discriminator(nn.Module):
  def __init__(self, input_channels):
    super().__init__()

    self.proj_layer = nn.Conv2d(in_channels = input_channels, out_channels = 128, kernel_size=4, stride= 2, padding = 1, bias=False)

    self.conv1 = nn.Conv2d(in_channels = 128, out_channels =  256, kernel_size=4, stride= 2, padding = 1, bias=False)
    self.norm1 = nn.BatchNorm2d(256, affine=False)

    self.conv2 = nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size=4, stride= 2, padding = 1, bias=False)
    self.norm2 = nn.BatchNorm2d(512, affine=False)

    self.conv3 = nn.Conv2d(in_channels = 512, out_channels = 1024, kernel_size=4, stride= 2, padding = 1, bias=False)
    self.norm3 = nn.BatchNorm2d(1024, affine=False)

    self.conv4 = nn.Conv2d(in_channels = 1024, out_channels = 1, kernel_size=4, stride= 1, padding = 0, bias=False)
    
    self.l_relu = nn.LeakyReLU(0.2, inplace= True)
    self.sigmoid = nn.Sigmoid()



  def forward(self,z):
    x = self.l_relu(self.proj_layer(z))
    x = self.l_relu(self.norm1(self.conv1(x)))
    x = self.l_relu(self.norm2(self.conv2(x)))
    x = self.l_relu(self.norm3(self.conv3(x)))
    x = self.sigmoid(self.conv4(x))
    return x
