In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import numpy as np

from torchvision.utils import save_image

from ipynb.fs.full.definitions import DEVICE, ARCH, IMAGE_SIZE, IMAGE_PATH, MODELS_PATH, OUTPUT_PATH, n_gpu, learning_rate_G, learning_rate_D, \
    beta_adam_1, beta_adam_2, loader_workers, number_channels, gen_feature_map, dis_feature_map, n_epochs, decay_epoch, batch_size, latent_vector
from ipynb.fs.full.image_preprocessing import dataLoader, scaleImages
from ipynb.fs.full.utils import loss_plot, image_grid
from ipynb.fs.full.metrics import compute_metrics
from ipynb.fs.full.train import weights_init, training_loop, LambdaLR
from ipynb.fs.full.dcgan import Generator_256, Discriminator_256, Discriminator_SN_256
from ipynb.fs.full.validation import plot_fake_images, plot_real_images

In [2]:
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(MODELS_PATH, exist_ok=True)
os.makedirs(os.path.join(OUTPUT_PATH, 'synthetics'), exist_ok=True)

In [3]:
def contains_subfolders(directory):
    for item in os.listdir(directory):
        item_path = os.path.join(directory, item)
        if (os.path.exists(item_path)) and (os.path.isdir(item_path)):
            return True
    return False

In [4]:
if not contains_subfolders(IMAGE_PATH):
    scaleImages() # resizing images to the size set up in "definitions"

In [5]:
def loadModel(model_name, quantity, dataloader, img_size):
    netG = Generator_256(ngpu=n_gpu, nz=latent_vector, ngf=gen_feature_map, nc=number_channels).to(DEVICE)

    netG.apply(weights_init)
    netG.load_state_dict(torch.load(os.path.join(MODELS_PATH, model_name)))
        
    plot_real_images(dataloader, _show=False)
        
    plot_fake_images(netG, _show=False)

    for i in range(quantity):
        fixed_noise = torch.randn(img_size, latent_vector, 1, 1, device=DEVICE)
        fakes = netG(fixed_noise)

        for j in range(len(fakes)):
            save_image(fakes[j], os.path.join(OUTPUT_PATH, 'synthetics', str(i) + '_' + str(j) + '_synthetic.png'))


In [None]:
dataloader = dataLoader(path=IMAGE_PATH, image_size=IMAGE_SIZE, batch_size=batch_size,workers=loader_workers)
        
if ARCH == 'DCGAN':
            
    netG = Generator_256(ngpu=n_gpu, nz=latent_vector, ngf=gen_feature_map, nc=number_channels).to(DEVICE)
    
    netD = Discriminator_256(ngpu=n_gpu, nc=number_channels, ndf=dis_feature_map).to(DEVICE)
    
elif ARCH == 'SNGAN':
    
    netG = Generator_256(ngpu=n_gpu, nz=latent_vector, ngf=gen_feature_map, nc=number_channels).to(DEVICE)
    
    netD = Discriminator_SN_256(ngpu=n_gpu, nc=number_channels, ndf=dis_feature_map).to(DEVICE)
    
if (DEVICE.type == 'cuda') and (n_gpu > 1):
    netG = nn.DataParallel(netG, list(range(n_gpu)))
    
if (DEVICE.type == 'cuda') and (n_gpu > 1):
    netD = nn.DataParallel(netD, list(range(n_gpu)))
    
netG.apply(weights_init)
netD.apply(weights_init)
    
# print(netG)
# print(netD)
    
criterion = nn.BCELoss()
    
fixed_noise = torch.randn(IMAGE_SIZE, latent_vector, 1, 1, device=DEVICE)
    
optimizerD = optim.Adam(netD.parameters(), lr=learning_rate_D, betas=(beta_adam_1, beta_adam_2))
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate_G, betas=(beta_adam_1, beta_adam_2))

# Variable learning rate (scheduler)
lr_schedulerG = optim.lr_scheduler.LambdaLR(
    optimizerG, lr_lambda=LambdaLR(n_epochs, 0, decay_epoch).step
)

lr_schedulerD = optim.lr_scheduler.LambdaLR(
    optimizerD, lr_lambda=LambdaLR(n_epochs, 0, decay_epoch).step
)

G_losses, D_losses, img_list, img_list_only = training_loop(num_epochs=n_epochs, dataloader=dataloader,
                                                            netG=netG, netD=netD, device=DEVICE, criterion=criterion, nz=latent_vector,
                                                            optimizerG=optimizerG, optimizerD=optimizerD, schedulerG=lr_schedulerG, schedulerD=lr_schedulerG,
                                                            fixed_noise=fixed_noise, out=OUTPUT_PATH)

# save the trained generator
torch.save(netG.state_dict(), os.path.join(MODELS_PATH, "generator_" + str(IMAGE_SIZE) + 'x' + str(IMAGE_SIZE) + ".pth"))
    
# as well as the trained discriminator
torch.save(netD.state_dict(), os.path.join(MODELS_PATH, "discriminator_" + str(IMAGE_SIZE) + 'x' + str(IMAGE_SIZE) + ".pth"))

# save losses 
np.save(os.path.join(OUTPUT_PATH, "G_loss"), G_losses)

np.save(os.path.join(OUTPUT_PATH, "D_loss"), D_losses)

loss_plot(G_losses=G_losses, D_losses=D_losses, out=OUTPUT_PATH)
    
image_grid(dataloader=dataloader, img_list=img_list, device=DEVICE, out=OUTPUT_PATH)
    
compute_metrics(real=next(iter(dataloader)), fakes=img_list_only, size=IMAGE_SIZE, out=OUTPUT_PATH)