In [89]:
import os
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

In [90]:
#CONSTANTS

CUDA = True
DATA_PATH = 'data'
OUT_PATH = 'out'
BATCH_SIZE = 128
IMAGE_CHANNELS = 1
IMAGE_SIZE = 64
NUM_EPOCHS = 20
REAL_LABEL = 1
FAKE_LABEL = 0
LEARNING_RATE = 0.0002
SEED = 1
RANDOM_DIM = 100
G_HIDDEN = 64
D_HIDDEN = 64

In [91]:
train = dset.MNIST(root='./data', train=True, download=True, transform=None)

In [92]:
#If output folder doesn't exist, create it. Else, delete all files in it.
if not os.path.exists(OUT_PATH):
    os.makedirs(OUT_PATH)
else:
    for f in os.listdir(OUT_PATH):
        os.remove(os.path.join(OUT_PATH, f))
CUDA = CUDA and torch.cuda.is_available()
print("CUDA Available: ", CUDA, ". Pytorch version: ", torch.__version__)
print("Seed:", SEED)
torch.manual_seed(SEED)
if CUDA:
    torch.cuda.manual_seed(SEED)
device = torch.device("cuda:0" if CUDA else "cpu")


CUDA Available:  True . Pytorch version:  1.13.0
Seed: 1


In [93]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        #1st layer
        #input: 100x1x1
        #output: 512x4x4 bc kernel size is 4 and stride is 1,
        # so output size is (input size - 1) * stride + kernel size.
        # (1 - 1) * 1 + 4 = 4
        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(RANDOM_DIM, G_HIDDEN * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 8),
            nn.ReLU(True)
        )
        #2nd layer
        #input: 512x4x4
        #output: 256x8x8
        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(G_HIDDEN * 8, G_HIDDEN * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 4),
            nn.ReLU(True)
        )
        #3rd layer
        #input: 256x8x8
        #output: 128x16x16
        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(G_HIDDEN * 4, G_HIDDEN * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 2),
            nn.ReLU(True)
        )
        #4th layer
        #input: 128x16x16
        #output: 64x32x32
        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(G_HIDDEN * 2, G_HIDDEN, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN),
            nn.ReLU(True)
        )
        #5th layer
        #input: 64x32x32
        #output: 1x64x64
        self.out_layer = nn.Sequential(
            nn.ConvTranspose2d(G_HIDDEN, IMAGE_CHANNELS, 4, 2, 1, bias=False),
            nn.Tanh()
        )


    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.out_layer(out)
        return out

In [94]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        #1st layer
        #input: 64x64x1
        #output: 32x32x64
        self.layer1 = nn.Sequential(
            nn.Conv2d(IMAGE_CHANNELS, D_HIDDEN, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
        #2nd layer
        #input: 32x32x64
        #output: 16x16x128
        self.layer2 = nn.Sequential(
            nn.Conv2d(D_HIDDEN, D_HIDDEN * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        #3rd layer
        #input: 16x16x128
        #output: 8x8x256
        self.layer3 = nn.Sequential(
            nn.Conv2d(D_HIDDEN * 2, D_HIDDEN * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 4),
            nn.LeakyReLU(0.2, inplace=True)
        )
        #4th layer
        #input: 8x8x256
        #output: 4x4x512
        self.layer4 = nn.Sequential(
            nn.Conv2d(D_HIDDEN * 4, D_HIDDEN * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        #out layer
        #input: 4x4x512
        #output: 1x1x1
        self.out_layer = nn.Sequential(
            nn.Conv2d(D_HIDDEN * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.out_layer(out)
        return out

In [95]:
# Create the generator and discriminator objects with the
# loss function and optimizers.
generator = Generator().to(device)
discriminator = Discriminator().to(device)

loss_func = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

In [96]:
# Create the dataset object and dataloader object.
dataset = dset.MNIST(root=DATA_PATH, download=True, transform=transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
]))

assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [98]:
#Train loop

viz_noise = torch.randn(BATCH_SIZE, RANDOM_DIM, 1, 1, device=device)

for epoch in range(NUM_EPOCHS):
    for i, data in enumerate(dataloader):
        x_real = data[0].to(device)
        real_label = torch.full((x_real.size(0),), REAL_LABEL, device=device).float()
        fake_label = torch.full((x_real.size(0),), FAKE_LABEL, device=device).float()

        # Train the discriminator.
        #First train with real data.
        discriminator.zero_grad()
        y_real = discriminator(x_real).view(-1)
        loss_D_real = loss_func(y_real, real_label)
        loss_D_real.backward()

        #Now train with fake data.
        z_noise = torch.randn(x_real.size(0), RANDOM_DIM, 1, 1, device=device)
        x_fake = generator(z_noise)
        y_fake = discriminator(x_fake.detach()).view(-1)
        loss_D_fake = loss_func(y_fake, fake_label)
        loss_D_fake.backward()
        optimizerD.step()

        # Train the generator.
        generator.zero_grad()
        y_fake = discriminator(x_fake).view(-1)
        loss_G = loss_func(y_fake, real_label)
        loss_G.backward()
        optimizerG.step()

        if i % 100 == 0:
            print('Epoch {} [{}/{}] loss_D_real: {:.4f} loss_D_fake:{:.4f} loss_G: {:.4f}'.format(
                epoch, i, len(dataloader),
                loss_D_real.mean().item(),
                loss_D_fake.mean().item(),
                loss_G.mean().item()))
            # Save the generated images from this iteration
            vutils.save_image(x_real, '{}/real_samples.png'.format(OUT_PATH), normalize=True)
            with torch.no_grad():
                viz_sample = generator(viz_noise)
                vutils.save_image(viz_sample, '{}/fake_samples_epoch_{:03d}.png'.format(OUT_PATH, epoch), normalize=True)
                torch.save(generator.state_dict(), '{}/generator_epoch_{:03d}.pth'.format(OUT_PATH, epoch))
                torch.save(discriminator.state_dict(), '{}/discriminator_epoch_{:03d}.pth'.format(OUT_PATH, epoch))
                

Epoch 0 [0/469] loss_D_real: 0.3184 loss_D_fake:0.1684 loss_G: 2.4040
Epoch 0 [100/469] loss_D_real: 0.3656 loss_D_fake:0.0713 loss_G: 1.5529
Epoch 0 [200/469] loss_D_real: 0.1442 loss_D_fake:0.6806 loss_G: 1.6267
Epoch 0 [300/469] loss_D_real: 1.0420 loss_D_fake:0.0517 loss_G: 0.3723
Epoch 0 [400/469] loss_D_real: 0.6705 loss_D_fake:0.0648 loss_G: 0.9549
Epoch 1 [0/469] loss_D_real: 0.1472 loss_D_fake:0.1057 loss_G: 2.2515
Epoch 1 [100/469] loss_D_real: 0.3162 loss_D_fake:0.1620 loss_G: 1.8530
Epoch 1 [200/469] loss_D_real: 0.4851 loss_D_fake:0.2231 loss_G: 1.7614
Epoch 1 [300/469] loss_D_real: 0.0448 loss_D_fake:0.5529 loss_G: 4.5356
Epoch 1 [400/469] loss_D_real: 0.2436 loss_D_fake:0.2128 loss_G: 2.1988
Epoch 2 [0/469] loss_D_real: 0.6014 loss_D_fake:0.1712 loss_G: 1.1329
Epoch 2 [100/469] loss_D_real: 0.0798 loss_D_fake:0.3599 loss_G: 3.2430
Epoch 2 [200/469] loss_D_real: 0.0852 loss_D_fake:0.0249 loss_G: 4.0708
Epoch 2 [300/469] loss_D_real: 0.0848 loss_D_fake:0.2175 loss_G: 3.718

KeyboardInterrupt: 