In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision


In [10]:
class Discriminator(nn.Module):
  def __init__(self, img_dim=64, img_ch=3, hidden_dim=[128, 256, 512, 1024]):
    super().__init__()
    self.net = nn.Sequential(
      nn.Conv2d(img_ch, hidden_dim[0], 4, 2, 1),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Conv2d(hidden_dim[0], hidden_dim[1], 4, 2, 1),
      nn.LeakyReLU(0.2, inplace=True),
      nn.BatchNorm2d(hidden_dim[1]),
      nn.Conv2d(hidden_dim[1], hidden_dim[2], 4, 2, 1),
      nn.LeakyReLU(0.2, inplace=True),
      nn.BatchNorm2d(hidden_dim[2]),
      nn.Conv2d(hidden_dim[2], hidden_dim[3], 4, 2, 1),
      nn.LeakyReLU(0.2, inplace=True),
      nn.BatchNorm2d(hidden_dim[3]),
    )
    self.classifier = nn.Sequential(
      nn.Flatten(),
      nn.Linear(hidden_dim[3] * ( img_dim // 2**(len(hidden_dim)) )**2, 1),
      nn.Sigmoid()
    )
  
  def forward(self, x):
    x = self.net(x)
    return self.classifier(x).squeeze()

In [6]:
class Generator(nn.Module):
  def __init__(self, latent_dim=100, img_dim=64, img_ch=3, hidden_dim=[1024, 512, 256, 128]):
    super().__init__()
    self.project = nn.Sequential(
      nn.Linear(latent_dim, hidden_dim[0] * (img_dim // (2**len(hidden_dim)))**2),
      lambda x: x.view(-1, hidden_dim[0], img_dim // (2**len(hidden_dim)), img_dim // (2**len(hidden_dim))),
      nn.BatchNorm2d(hidden_dim[0]),
      nn.ReLU(inplace=True),
    )

    self.net = nn.Sequential(
      nn.ConvTranspose2d(hidden_dim[0], hidden_dim[1], 4, 2, 1),
      nn.BatchNorm2d(hidden_dim[1]),
      nn.ReLU(inplace=True),
      nn.ConvTranspose2d(hidden_dim[1], hidden_dim[2], 4, 2, 1),
      nn.BatchNorm2d(hidden_dim[2]),
      nn.ReLU(inplace=True),
      nn.ConvTranspose2d(hidden_dim[2], hidden_dim[3], 4, 2, 1),
      nn.BatchNorm2d(hidden_dim[3]),
      nn.ReLU(inplace=True),
      nn.ConvTranspose2d(hidden_dim[3], img_ch, 4, 2, 1),
      nn.Tanh()
    )
  
  def forward(self, x):
    x = self.project(x)
    return self.net(x)

In [11]:
class DCGAN(nn.Module):
  def __init__(self, latent_dim=100, img_dim=64, img_ch=3, hidden_dim=[1024, 512, 256, 128]):
    super().__init__()
    self.disc = Discriminator(img_dim, img_ch, hidden_dim)
    self.gen = Generator(latent_dim, img_dim, img_ch, hidden_dim)
    