## Import Libraries

In [9]:
import os
import torch
import torch.nn as nn
import ConditionalGAN as cGAN
from torchvision import datasets, transforms
from tqdm import tqdm
from torchvision.utils import save_image
from PIL import Image, ImageFont, ImageDraw

In [23]:
lr = 0.1
decay = 1.00004
batch_size = 100
num_epoch = 100
dir_name = "CGAN_results"

In [24]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

## Load DataSet

In [25]:
train_data = datasets.MNIST(root = './data/',
                            train=True,
                            download=True,
                            transform=transforms.ToTensor())
test_data = datasets.MNIST(root = './data/',
                            train=False,
                            download=True,
                            transform=transforms.ToTensor())

print('number of training data : ', len(train_data))
print('number of test data : ', len(test_data))

number of training data :  60000
number of test data :  10000


In [26]:
train_loader = torch.utils.data.DataLoader(dataset=train_data,batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset=test_data,batch_size = batch_size, shuffle = True)

## Pre-setting

In [27]:
generator = cGAN.Generator().to(device)
discriminator = cGAN.Discriminator().to(device)

In [28]:
criterion = nn.BCELoss().to(device)
DisOptimizer = torch.optim.SGD(generator.parameters(), lr=lr, momentum=0.5, weight_decay=decay)
GenOptimizer = torch.optim.SGD(generator.parameters(), lr=lr, momentum=0.5, weight_decay=decay)

## Training

In [29]:
for epoch in range(num_epoch):
    bar = tqdm(train_loader)
    train_loss = []
    for i, (images, labels) in enumerate(bar):
        images = images.to(device)
        labels = labels.to(device)
        # Create Real and Fake Label
        real_label = torch.full((batch_size, 1), 1, dtype=torch.float32).to(device)
        fake_label = torch.full((batch_size, 1), 0, dtype=torch.float32).to(device)

        # Flatten MNIST images
        real_images = images.reshape(batch_size, -1).to(device)

        # +---------------------+
        # |   train Generator   |
        # +---------------------+

        # Initialize gradient
        GenOptimizer.zero_grad()
        DisOptimizer.zero_grad()

        # Create Noise Vector z
        z = torch.randn(batch_size, 100).to(device)

        # Training Generator
        fake_images = generator(z, labels)

        # Get Loss From Real Image
        genLoss = criterion(discriminator(fake_images, labels), real_label)

        genLoss.backward()
        GenOptimizer.step()

        # +---------------------+
        # | train Discriminator |
        # +---------------------+

        # Initialize gradient
        GenOptimizer.zero_grad()
        DisOptimizer.zero_grad()

        # Create Noise Vector z
        z = torch.randn(batch_size, 100).to(device)
        fake_images = generator(z, labels)

        fake_loss = criterion(discriminator(fake_images, labels), fake_label)
        real_loss = criterion(discriminator(real_images, labels), real_label)
        disLoss = (fake_loss + real_loss) / 2

        disLoss.backward()
        DisOptimizer.step()

        d_performance = discriminator(fake_images, labels).mean()
        g_performance = discriminator(real_images, labels).mean()

        if (i + 1) % 150 == 0:
            print("Epoch [ {}/{} ]  Step [ {}/{} ]  d_loss : {:.5f}  g_loss : {:.5f}".format(epoch + 1, num_epoch, i+1, len(train_loader), disLoss.item(), genLoss.item()))
    # print discriminator & generator's performance
    print(" Epoch {}'s discriminator performance : {:.2f}  generator performance : {:.2f}".format(epoch + 1, d_performance, g_performance))

    # Save fake images in each epoch
    samples = fake_images.reshape(batch_size, 1, 28, 28)
    save_image(samples, os.path.join(dir_name, 'CGAN_fake_samples{}.png'.format(epoch + 1)))
    # print("label of 'CGAN_fake_samples{}.png' is {}".format(epoch + 1, label))

    # Draw real labels on fake sample images
    # If you got error about this, you can remove lines below
    fake_sample_image = Image.open("{}/CGAN_fake_samples{}.png".format(dir_name, epoch + 1))
    font = ImageFont.truetype("arial.ttf", 17)

    label = label.tolist()
    label = label[:10]
    label = [str(l) for l in label]

    label_text = ", ".join(label)
    label_text = "Conditional GAN -\n" \
                 "first 10 labels in this image :\n" + label_text

    image_edit = ImageDraw.Draw(fake_sample_image)
    image_edit.multiline_text(xy=(15, 300),
                              text=label_text,
                              fill=(0, 255, 255),
                              font=font,
                              stroke_width= 4,
                              stroke_fill=(0, 0, 0))
    fake_sample_image.save("{}/CGAN_fake_samples{}.png".format(dir_name, epoch + 1))

  5%|▍         | 29/600 [00:05<01:42,  5.57it/s]


KeyboardInterrupt: 