# Import Required Libraries & Load Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Hyper parameter Set

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 5e-5 
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 3
NOISE_DIM = 100
NUM_EPOCHS = 20
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.02

In [None]:
device

In [None]:
transforms = transforms.Compose([
    transforms.Resize([64,64]),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for i in range(CHANNELS_IMG)], [0.5 for i in range(CHANNELS_IMG)])
])

In [None]:
IMAGE_PATH = "/content/drive/MyDrive/cc/img_align_celeba"
IMAGE_PATH

In [None]:
dataset = datasets.ImageFolder(IMAGE_PATH, transform = transforms)

In [None]:
def split_indices(n, val_per, seed = 0):
    n_val = int(n * val_per)
    np.random.seed(seed)
    idx = np.random.permutation(n)
    return idx[n_val : ], idx[: n_val]

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
val_per = 0.4
rand_seed = 42

train_indices, val_indices = split_indices(len(dataset), val_per, rand_seed)

print(len(train_indices), len(val_indices))

In [None]:
print("Validation Indices: ", val_indices[:20])
print("Training Indices: ", train_indices[:20])

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataloader import DataLoader

In [None]:

train_sampler = SubsetRandomSampler(train_indices)
loader = DataLoader(dataset, BATCH_SIZE, sampler = train_sampler)

In [None]:
len(dataset), len(loader)

# Discriminator Model Building (Critic)

In [None]:
class Discriminator(nn.Module):
    
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, features_d, kernel_size = 4, stride = 2, padding = 1),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(features_d, features_d*2, kernel_size = 4, stride = 2, padding = 1), 
            nn.BatchNorm2d(features_d*2),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(features_d*2, features_d*4, kernel_size = 4, stride = 2, padding = 1), 
            nn.BatchNorm2d(features_d*4),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(features_d*4, features_d*8, kernel_size = 4, stride = 2, padding = 1), 
            nn.BatchNorm2d(features_d*8),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(features_d*8, 1, kernel_size = 4, stride = 2, padding = 0) 
            
        )
        
        
    def forward(self, x):
        return self.disc(x)

# Define Generator Neural Network

In [None]:
class Generator(nn.Module):
    
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, features_g*16, kernel_size = 4, stride = 1, padding = 0),
            nn.BatchNorm2d(features_g*16),
            nn.ReLU(),
            
            nn.ConvTranspose2d(features_g*16, features_g*8, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm2d(features_g*8),
            nn.ReLU(),
            
            nn.ConvTranspose2d(features_g*8, features_g*4, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm2d(features_g*4),
            nn.ReLU(),
            
            nn.ConvTranspose2d(features_g*4, features_g*2, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm2d(features_g*2),
            nn.ReLU(),
            
            nn.ConvTranspose2d(features_g*2, channels_img, kernel_size = 4, stride = 2, padding = 1),
            nn.Tanh()
        )
        
    
    def forward(self, x):
        return self.net(x)

In [None]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
def test():
    N, in_channels, H, W = 8, 3, 64, 64
    noise_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
    gen = Generator(noise_dim, in_channels, 8)
    z = torch.randn((N, noise_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
    print('Success')

In [None]:
test()

# Generator and Discriminator (Critic) Initiation

In [None]:
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(z_dim = NOISE_DIM, channels_img = CHANNELS_IMG, features_g = FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

In [None]:
opt_gen = optim.RMSprop(gen.parameters(), lr = LEARNING_RATE)
opt_disc = optim.RMSprop(disc.parameters(), lr = LEARNING_RATE)

In [None]:
def reset_grad():
    opt_disc.zero_grad()
    opt_gen.zero_grad()

# Training Discriminator (Critic)

In [None]:
def train_discriminator(images):
    for _ in range(CRITIC_ITERATIONS):
        
        disc_real = disc(images).reshape(-1)
        real_score = disc_real
        
        
        z = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake_images = gen(z)
        disc_fake = disc(fake_images).reshape(-1)
        fake_score = disc_fake
        
        loss_disc = - (torch.mean(disc_real) - torch.mean(disc_fake))
        
        reset_grad()
        
        loss_disc.backward()
        
        opt_disc.step()
        
        for p in disc.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
        
        return loss_disc, real_score, fake_score 

# Training Generator

In [None]:
def train_generator():
    z = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
    fake_images = gen(z)
    labels = torch.ones(BATCH_SIZE, 1).to(device)
    output = disc(fake_images).reshape(-1)
    g_loss = - torch.mean(output)
    reset_grad()
    g_loss.backward()
    opt_gen.step()
    return g_loss, fake_images

In [None]:
import os

sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [None]:
def show_img(img, label):
    print('Label: ', label)
    plt.imshow(img.permute(1,2,0), cmap = 'gray')

In [None]:
def denorm(x):
  out = (x + 1) / 2
  return out.clamp(0, 1)

# Image View

In [None]:
from IPython.display import Image
from torchvision.utils import save_image
for images, _ in loader:
    images = images.reshape(images.size(0), 3, 64, 64)
    save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'), nrow=16)
    break
   
Image(os.path.join(sample_dir, 'real_images.png'))

# Image Saving Code

In [None]:
sample_vectors = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)

def save_fake_images(index):
    fake_images = gen(sample_vectors)
    fake_images = fake_images.reshape(fake_images.size(0), 3, 64, 64)
    fake_fname = 'fake_images-{0:0=4d}.png'.format(index)
    print('Saving', fake_fname)
    save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=16)
save_fake_images(0)
Image(os.path.join(sample_dir, 'fake_images-0000.png'))

# Training of Model

In [None]:
%%time


total_step = len(loader)
d_losses, g_losses, real_scores, fake_scores = [], [], [], []

for epoch in range(NUM_EPOCHS):
    for i, (images, _) in enumerate(loader):
        images = images.to(device)
        d_loss, real_score, fake_score = train_discriminator(images)
        g_loss, fake_images = train_generator()
        if (i+1) / 100 == 0:
            d_losses.append(d_loss.item())
            g_losses.append(g_loss.item())
            real_scores.append(real_score.mean().item())
            fake_scores.append(fake_score.mean().item())
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    save_fake_images(epoch+1)

# Save Model

In [None]:

torch.save(gen.state_dict(), 'G.ckpt')
torch.save(disc.state_dict(), 'D.ckpt')

# Verify Images After Every Epoch

In [None]:
Image('./samples/fake_images-0000.png')

In [None]:
Image('./samples/fake_images-0002.png')

In [None]:
Image('./samples/fake_images-0004.png')

In [None]:
Image('./samples/fake_images-0006.png')

In [None]:
Image('./samples/fake_images-0008.png')

In [None]:
Image('./samples/fake_images-0010.png')

In [None]:
Image('./samples/fake_images-0011.png')

In [None]:
Image('./samples/fake_images-0012.png')

In [None]:
Image('./samples/fake_images-0013.png')

In [None]:
Image('./samples/fake_images-0014.png')

In [None]:
Image('./samples/fake_images-0018.png')

In [None]:
Image('./samples/fake_images-0020.png')