In [None]:
## Same, but with Wasserstein loss, RMSProp as optimizer, results weren't really better, nor was the training faster in any way.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.02.
netG.apply(weights_init)

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Define the Wasserstein loss for the discriminator
def discriminator_loss(real_output, fake_output):
    return -torch.mean(real_output) + torch.mean(fake_output)

# Define the Wasserstein loss for the generator
def generator_loss(fake_output):
    return -torch.mean(fake_output)

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup RMSProp optimizers for both G and D
optimizerD = optim.RMSprop(netD.parameters(), lr=lr)
optimizerG = optim.RMSprop(netG.parameters(), lr=lr)

# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        # Forward pass real batch through D
        output_real = netD(real_cpu).view(-1)
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        # Classify all fake batch with D
        output_fake = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-real and all-fake batches
        errD = discriminator_loss(output_real, output_fake)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD.backward()
        # Update D
        optimizerD.step()
        # Weight clipping
        for p in netD.parameters():
            p.data.clamp_(-0.01, 0.01)

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = generator_loss(output)
        # Calculate gradients for G
        errG.backward()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), output_real.mean().item(), output_fake.mean().item(), output.mean().item()))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
# Save the model to saved_model_example.pth
torch.save(netG.state_dict(), "saved_model_example3.pth")
torch.save(netD.state_dict(), "saved_model_example_d3.pth")

fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())
# Save the animation to a file and store it to images directory
ani.save('images3/dcgan_ani.gif', writer='imagemagick', fps=4)
# also save the files used to create the animation to the images directory under the name dcgan_ani 
for i in range(len(img_list)):
    vutils.save_image(img_list[i], f'images3/dcgan_ani_{i}.png')


# Use the saved model to generate new images
# Load the model
model = Generator(ngpu).to(device)
model.load_state_dict(torch.load("saved_model_example3.pth"))
# Generate new images
with torch.no_grad():
    fake = model(fixed_noise).detach().cpu()
# Plot the generated images
plt.figure(figsize=(10,10))
plt.axis("off")
plt.title("Generated Images")
plt.imshow(np.transpose(vutils.make_grid(fake, padding=2, normalize=True),(1,2,0)))