<a href="https://colab.research.google.com/github/annasajkh/DC-GAN/blob/main/DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torch.optim import Adam
import numpy as np
from functools import reduce
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    self.network = nn.Sequential(
      # Block 1:input is Z, going into a convolution
      nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False),
      nn.BatchNorm2d(64 * 8),
      nn.ReLU(True),

      # Block 2: input is (64 * 8) x 4 x 4
      nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
      nn.BatchNorm2d(64 * 4),
      nn.ReLU(True),

      # Block 3: input is (64 * 4) x 8 x 8
      nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
      nn.BatchNorm2d(64 * 2),
      nn.ReLU(True),

      # Block 4: input is (64 * 2) x 16 x 16
      nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(True),

      # Block 5: input is (64) x 32 x 32
      nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
      nn.Tanh()
      # Output: output is (3) x 64 x 64
	    )

  def forward(self, x):
    return self.network(x)

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    self.network = nn.Sequential(
        # Block 1: input is (3) x 64 x 64
        nn.Conv2d(3, 64, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        
        # Block 2: input is (64) x 32 x 32
        nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(64 * 2),
        nn.LeakyReLU(0.2, inplace=True),

        # Block 3: input is (64*2) x 16 x 16
        nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(64 * 4),
        nn.LeakyReLU(0.2, inplace=True),

        # Block 4: input is (64*4) x 8 x 8
        nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(64 * 8),
        nn.LeakyReLU(0.2, inplace=True),
        
        # Block 5: input is (64*8) x 4 x 4
        nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
        nn.Sigmoid(),
        nn.Flatten()
        # Output: 1
    )

  def forward(self, x):
    return self.network(x)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
generator = Generator()
discriminator = Discriminator()

generator.apply(weights_init)
discriminator.apply(weights_init)

generator_optimizer = Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

loss_function = nn.BCELoss()


if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()

    loss_function = loss_function.cuda()

In [None]:
import glob
from PIL import Image
import matplotlib.pyplot as plt

real_data = []

for path in tqdm(glob.glob("drive/MyDrive/dataset/niko/*")):
  img = Image.open(path).resize((64, 64))
  img = (np.array(img) / 255) * 2 - 1
  img = img.reshape(3, 64, 64)
  real_data.append(img)

real_data = torch.from_numpy(np.array(real_data))

if torch.cuda.is_available():
  real_data = real_data.cuda()

print(real_data.shape)

100%|██████████| 1194/1194 [04:09<00:00,  4.79it/s]

torch.Size([1194, 3, 64, 64])





In [None]:
real_data = real_data[:1120]
real_data.shape

torch.Size([1120, 3, 64, 64])

In [None]:
real_data = real_data.float()

In [None]:
def efficient_zero_grad(model):
  """ 
    Apply zero_grad more efficiently
    Source: https://betterprogramming.pub/how-to-make-your-pytorch-code-run-faster-93079f3c1f7b
  """
  for param in model.parameters():
    param.grad = None

In [None]:
batch_size = 16

epoch = 200

for e in range(epoch):
  print(f"Epoch: {e}")

  generator.train()
  for i in tqdm(range(1, int(len(real_data) / batch_size))):
    #discriminator training
    real = real_data[(i-1)*batch_size:i*batch_size]

    efficient_zero_grad(discriminator)

    real_prediction = discriminator(real)

    real_loss = loss_function(real_prediction, torch.ones(*real_prediction.shape))
    real_loss.backward()
    discriminator_optimizer.step()

    noise = torch.randn(batch_size, 100, 1, 1).float()

    if torch.cuda.is_available():
      noise = noise.cuda()
    
    efficient_zero_grad(discriminator)
    
    fake = generator(noise).detach()
    fake_prediction = discriminator(fake)

    fake_loss = loss_function(fake_prediction, torch.zeros(*fake_prediction.shape))
    fake_loss.backward()
    discriminator_optimizer.step()

    #generator training
    efficient_zero_grad(generator)
    discriminator_fake_out = discriminator(fake)
    generator_error = loss_function(discriminator_fake_out, torch.ones(*discriminator_fake_out.shape))
    generator_error.backward()
    generator_optimizer.step()
  
  if e % 10 == 0:
    generator.eval()
    
    if torch.cuda.is_available():
      noise = noise.cuda()

    gen_img = generator.forward(noise).double()
    gen_img = (((gen_img + 1) / 2).detach().numpy().reshape(64, 64, 3) * 255).astype(np.uint8)
    Image.fromarray(gen_img).resize((200,200)).save(f"{e}.png")


