In [None]:
import os
import random
import torch
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.utils as vutils
from torchvision import models
from torchsummary import summary
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import image
from IPython.display import HTML
import csv

# save images
from torchvision.utils import save_image
img_save_path = './CAN/plots/'
os.makedirs(img_save_path, exist_ok=True)

from CAN.parameters import *
from CAN.dataloader_wikiart import *
from CAN.model_CAN_16_9 import * 

In [None]:
# Set random seed for reproducibility
manualSeed = 3
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

# Create the dataloader
dataloader = get_dataset()


# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,6))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device), padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
# Create the generator
netG = Generator(ngpu).to(device)

# # Handle multi-gpu if desired
# if (device.type == 'cuda') and (ngpu > 1):
#     netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

In [None]:
# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# # Handle multi-gpu if desired
# if (device.type == 'cuda') and (ngpu > 1):
#     netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

In [None]:
# Overview Generator
summary(netG, (100,1,1))

In [None]:
# Overview Discriminator
summary(netD, (3,144,256))

In [None]:
# load generator from training

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Path to model parameters
PATH_GEN = './CAN/models/GEN_16_9.pth'

# load checkpoint
checkpoint = torch.load(PATH_GEN, map_location=torch.device('cpu'))
netG.load_state_dict(checkpoint['model_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizer_state_dict'])

# set model to evaluation
netG.eval()

In [None]:
# load results from training
import csv
results = open('./CAN/results.csv', 'r')
 
# creating dictreader object
results = csv.DictReader(results)
 
# creating empty lists
G_loss = []
D_loss = []
Entropy = []
 
# iterating over each row and append
# values to empty list
for col in results:
    G_loss.append(col['G_losses'])
    D_loss.append(col['D_losses'])
    Entropy.append(col['entropies'])

G_loss = list(np.float_(G_loss))
D_loss = list(np.float_(D_loss))
Entropy = list(np.float_(Entropy))

In [None]:
# Plot Losses
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_loss,label="G")
plt.plot(D_loss,label="D")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# plot entropy

plt.figure(figsize=(10,5))
plt.title("Entropy of Discriminator's Classifications of Fake Artwork")
plt.plot(Entropy)
plt.xlabel("Epochs")
plt.ylabel("Entropy")
plt.show()

In [None]:
# load images from training
img_list = []

for i in range(100):
    img = image.imread('./ripper/plots_final/%d.png' % (i))
    img_list.append(img)


In [None]:
# Visualize G's Progress
import matplotlib
matplotlib.rcParams['animation.embed_limit'] = 2**128
fig = plt.figure(figsize=(8,6))
plt.axis("off")
ims = [[plt.imshow(i, animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [None]:
# Some example images from the generator 
genes = torch.randn(64, nz, 1, 1, device=device)

with torch.no_grad():
    images = netG(genes).detach().cpu()

# Plot some training images
plt.figure(figsize=(12,14))
plt.axis("off")
plt.title("Generated Images")
plt.imshow(np.transpose(vutils.make_grid(images.to(device), padding=2, normalize=True).cpu(),(1,2,0)))