In [None]:
from __future__ import print_function
#%matplotlib inline
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

from tqdm import tqdm

from scipy import linalg
# Set random seed for reproducibility
manualSeed = 69
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
np.random.seed(manualSeed)

# Hyper Parameters

In [None]:
# Root directory for dataset
dataroot = "./data"

# Number of workers for dataloader
workers = 4

# Batch size during training
batch_size = 32

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 128

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 32

# Number of training epochs
num_epochs = 100

# Learning rate for optimizers
disc_lr = 0.0002
gen_lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# keep training
resume = True
#resume = False

# Visualize Dataset

In [None]:
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

# GENERATOR

In [None]:
from models.generator_3_0_2 import * 

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    elif classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# Create the generator
netG = Generator().to(device)

netG.apply(weights_init) 

# Print the model
print(netG)

# DISCRIMINATOR

In [None]:
from models.discriminator_2 import *

In [None]:
# Create the Discriminator
netD = Discriminator().to(device)
# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)
# Print the model
print(netD)

# Loss and Optimizer

In [None]:
# Initialize BCELoss function
criterion = nn.BCELoss()
#criterion = nn.MSELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(batch_size, nz, device=device)
#fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device) #gen1

# Establish convention for real and fake labels during training
## added label smoothing
real_label = 0.9
fake_label = 0.1

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=disc_lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=gen_lr, betas=(beta1, 0.999))

# FID - Illustration2Vec

In [None]:
import i2v
illust2vec = i2v.make_i2v_with_chainer("illust2vec_ver200.caffemodel") # is needed to download the file frome the realeses (read readme)

In [None]:
def calculate_activation_statistics(images, cuda=True):
    feed_to_fv = []
    if cuda:
        batch=images.cuda()
    else:
        batch=images

    transform = transforms.ToPILImage() #defining the transformer
    #for each image in the batch compute the feature vector
    for i in range(len(batch)):
        single_image = batch[i]
        immagine = transform(single_image) #convert to PILImage
        feed_to_fv.append(immagine)
    
    fv = illust2vec.extract_feature(feed_to_fv) #I expect a (27,4096) array
    #print(f'feature vector shape:{fv.shape}')
    #print(f'This is the feature vector: {fv}')

    #compute the statistics of the feature vector, a 4096-dimension array (for each image)
    mu = np.mean(fv, axis=0) #compute the mean (MEDIA) between the i-th element of each internal "list", I need more than 1
    sigma = np.cov(fv, rowvar=False) #and the covariance (COVARIANZA)
    return mu, sigma

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)

def calculate_fretchet(images_real,images_fake):
     mu_1,std_1=calculate_activation_statistics(images_real,cuda=True)
     mu_2,std_2=calculate_activation_statistics(images_fake,cuda=True)
    
     """get fretched distance"""
     fid_value = calculate_frechet_distance(mu_1, std_1, mu_2, std_2)
     return fid_value

# Save & Load to keep training

In [None]:
def save_checkpoint(epoch_checkpoint, netG, netD, optimizerG, optimizerD, G_losses, D_losses):
    torch.save({
                'epoch_checkpoint': epoch_checkpoint,
                'netG_state_dict': netG.state_dict(),
                'netD_state_dict': netD.state_dict(),
                'optimizerG_state_dict': optimizerG.state_dict(),
                'optimizerD_state_dict': optimizerD.state_dict(),
                'G_losses': G_losses,
                'D_losses': D_losses,
                }, f'./saved_models/last_saved_model_{epoch_checkpoint+1}.tar')

In [None]:
#model = TheModelClass(*args, **kwargs)
#optimizer = TheOptimizerClass(*args, **kwargs)

if resume:
    checkpoint = torch.load('./saved_models/last_saved_model_100_good.tar')
    epoch_checkpoint = checkpoint['epoch_checkpoint']
    netG.load_state_dict(checkpoint['netG_state_dict'])
    netD.load_state_dict(checkpoint['netD_state_dict'])
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
    G_losses = checkpoint['G_losses']
    D_losses = checkpoint['D_losses']

else:
    epoch_checkpoint = 0

#model.train()

Move to device

In [None]:
if resume:
    netG.to(device)
    netD.to(device)

# TRAINING

In [None]:
# Training Loop

# Lists to keep track of progress
img_list = []
if resume:
    print('---Resuming Training from saved checkpoint---')
else:
    G_losses = []
    D_losses = []
    fid_values = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(epoch_checkpoint,num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(tqdm(dataloader, total=len(dataloader))):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

        # add some noise to the input to discriminator
        real_cpu=0.9*real_cpu+0.1*torch.randn((real_cpu.size()), device=device)

        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, device=device)
        #noise = torch.randn(b_size, nz, 1, 1, device=device) #gen1
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)

        fake=0.9*fake+0.1*torch.randn((fake.size()), device=device)

        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        #output = netD(fake.detach()) #aggiunto da me
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        #output = netD(fake) #aggiunto da me
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats

#        if i % 500 == 0:
#            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
#                  % (epoch, num_epochs, i, len(dataloader),
#                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later at each iteration
        #G_losses.append(errG.item())
        #D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                #fixed_noise = torch.randn(1, 1, ngf, nz, device=device)
                fake_display = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake_display, padding=2, normalize=True))
            
        iters += 1

    # Save Losses for plotting later at each epoch
    G_losses.append(errG.item())
    D_losses.append(errD.item())
    # Compute FID value for this epoch
    fretchet_dist=calculate_fretchet(real_cpu,fake)
    fid_values.append(fretchet_dist)
    
    print('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tFretchet_Distance: %.4f\t'
                      % (epoch+1, num_epochs,
                         errD.item(), errG.item(),fretchet_dist))

    plt.figure(figsize=(8,8))
    plt.axis("off")
    pictures=vutils.make_grid(fake_display[torch.randint(len(fake_display), (10,))],nrow=5,padding=2, normalize=True)
    plt.imshow(np.transpose(pictures,(1,2,0)))
    plt.show()

    if ((epoch+1) % 10 == 0):
        save_checkpoint(epoch, netG, netD, optimizerG, optimizerD, G_losses, D_losses)
        print(f'model saved at epoch: {epoch+1}')

save_checkpoint(epoch, netG, netD, optimizerG, optimizerD, G_losses, D_losses)
print('Training finished, model saved')

# Plotting the results

## FID plot

In [None]:
plt.figure(figsize=(10,5))
plt.title("FID during training")
plt.plot(fid_values)
plt.xlabel("iterations")
plt.ylabel("FID")
plt.legend()
plt.show()

## Loss plot

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

## Images generated during training

In [None]:
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

## Compare real and generated images

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()