In [None]:
import logging
import time
import random


In [None]:
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image
import torchvision.utils as vutils
from torch import cuda

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [None]:
from DCGAN import DCGAN
from utils import get_data_loader, generate_images, save_gif
from LeNet import Classifier

In [None]:
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [5]:
#hyperparameter settings
args = object()
args.num_epochs = 10
args.ngpu = 1
args.ndf = 128
args.ngf = 128
args.nz = 100
args.lr = 0.0002
args.beta = 0.5
args.nc = 1
args.batch_size = 64
args.image_size = 64
args.num_test_samples = 64
args.output_path = "./results/"
args.fps = 5
args.use_fixed = True
args.plot = True

In [None]:
# Gather MNIST Dataset
transform=transforms.Compose([
                               transforms.Resize(args.image_size),
                               transforms.CenterCrop(args.image_size),
                               transforms.ToTensor(),
                            #    transforms.Normalize(mean=0.5, std=0.5)
                               transforms.Normalize(mean=(0.1307, ), std=(0.3081, )),
                           ])    
dataset = dset.MNIST(root='./mnist_data/',
                           transform=transform, download=True)
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
                                         shuffle=True)

In [None]:
 # Device configuration
device = torch.device('cuda:0' if (torch.cuda.is_available() and args.ngpu > 0) else 'cpu')
print("Using", cuda.get_device_name(0))

In [None]:
# Plot some training images
    real_batch = next(iter(dataloader))
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
# init dcgan
dcgan = DCGAN(ngpu=args.ngpu, device=device, lr=args.lr, nc=args.nc, ndf=args.ndf, nz=args.nz, ngf=args.ngf, beta1=args.beta)

In [None]:
# initialize other variables
num_batches = len(dataloader)
fixed_noise = torch.randn(args.num_test_samples, args.nz, 1, 1, device=device)

In [None]:
img_list = dcgan.train(dataloader=dataloader, num_epochs=args.num_epochs, plot=args.plot)

In [None]:
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
writergif = animation.PillowWriter(fps=30) 
ani.save(args.output_path+"fake_dcgan.gif", writer=writergif)

In [None]:
#save model
model_path = "./models/"
filename = "dcgan_Q1.pt"
torch.save(dcgan, model_path+filename)

In [None]:
#load model
model_path = "./models/"
filename = "dcgan_Q1.pt"
dcgan = torch.load(model_path+filename)

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()