# Impressionist StyleGAN - Training Loop
Run the cells below to train your own StyleGAN on the dataset of impressionist artworks. Make sure have a directory `impressionist` that contains the images from the dataset (you can find the dataset in the GitHub Release called `Impressionist Artworks v1.0`)

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import ImpressionistDataset as dataset
import matplotlib.pyplot as plt
import generator
import discriminator
import globals 
import math
import utils
from tqdm import tqdm
import importlib
from utils_generator import applyExponentialMovingAverage, g_loss_non_saturating
from utils_discriminator import d_loss_non_saturating_r1
from ADA import ADA
from torch.utils.data import ConcatDataset
from torchmetrics.image.fid import FrechetInceptionDistance
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
importlib.reload(generator)
importlib.reload(discriminator)
importlib.reload(globals)
importlib.reload(utils)
importlib.reload(dataset)

<module 'ImpressionistDataset' from 'c:\\Users\\alexa\\Desktop\\AppliedDeepLearning\\model\\ImpressionistDataset.py'>

In [3]:
cluster_ind = 0
datasets = {
    4: dataset.ImpressionistDataset(resolution=4, cluster_ind=cluster_ind),
    8: dataset.ImpressionistDataset(resolution=8, cluster_ind=cluster_ind),
    16: dataset.ImpressionistDataset(resolution=16, cluster_ind=cluster_ind),
    32: dataset.ImpressionistDataset(resolution=32, cluster_ind=cluster_ind),
    64: dataset.ImpressionistDataset(resolution=64, cluster_ind=cluster_ind),
    128: dataset.ImpressionistDataset(resolution=128, cluster_ind=cluster_ind)
}

In [4]:
plt.figure(figsize=(10, 5))

ada = ADA()
fid = FrechetInceptionDistance(feature=2048).to(globals.DEVICE)

G = generator.Generator()
D = discriminator.Discriminator()
G.to(globals.DEVICE), D.to(globals.DEVICE)

# we initialize our EMA Generator. We don't need gradients for it.
G_EMA = generator.Generator()
G_EMA.load_state_dict(G.state_dict())
G_EMA.train(False)
G_EMA.to(globals.DEVICE)

for param in G_EMA.parameters():
    param.requires_grad_(False)

mapping_params, generator_params = utils.get_generator_params(G)

adam_g = torch.optim.AdamW([
    {'params': mapping_params, 'lr': globals.LR_MAPPING_NETWORK, 'name': 'mapping'},
    {'params': generator_params, 'lr': globals.LR_MODEL, 'name': 'generator'}
], betas=(globals.ADAM_BETA1, globals.ADAM_BETA2))


adam_d = torch.optim.AdamW(D.parameters(), lr=globals.LR_MODEL, betas=(globals.ADAM_BETA1, globals.ADAM_BETA2))

res_list = [2**i for i in range(2, int(math.log2(globals.MAX_RES))+1)]

global_img_count = 0
for res in res_list:

    # we update the learning rate for each resolution
    g_lr = globals.LR_MODEL_PER_RES[res]
    d_lr = globals.LR_MODEL_PER_RES[res]
    mapping_lr = globals.LR_MAPPING_NETWORK_PER_RES[res]

    for param_group in adam_g.param_groups:
        if param_group.get('name') == 'mapping':
            param_group['lr'] = mapping_lr
        else:
            param_group['lr'] = g_lr

    for param_group in adam_d.param_groups:
        param_group["lr"] = d_lr

    print(f"RESOLUTION {res}x{res}:")
    if res > 4:
        G.fade_in(res)
        G_EMA.fade_in(res)
        D.fade_in(res)
    
    repeated_dataset = ConcatDataset([datasets[res]] * 3)
    loader = torch.utils.data.DataLoader(repeated_dataset, batch_size=globals.BATCH_SIZES_PER_RES[res], shuffle=True, num_workers=4, pin_memory=True)

    fade_in_imgs = int(globals.IMAGES_PER_RESOLUTION[res] * globals.FADE_IN_PERCENTAGE)
    count_until_fid = 50000
    count_until_grid = 100000
    imgs_this_phase = 0
    discriminator_steps = 0

    while imgs_this_phase < globals.IMAGES_PER_RESOLUTION[res]:
        for real in tqdm(loader):
            real = real.to(globals.DEVICE)

            batch_size = real.size(0)

            if imgs_this_phase < fade_in_imgs:
                layer_opacity = min(1.0, imgs_this_phase / max(1, fade_in_imgs))
                G.set_layer_opacity(layer_opacity)
                G_EMA.set_layer_opacity(layer_opacity)
                D.set_layer_opacity(layer_opacity)
            else:
                G.set_layer_opacity(1.0)
                G_EMA.set_layer_opacity(1.0)
                D.set_layer_opacity(1.0)

            # only one D step for logistic loss with R1
            for i in range(globals.DISCRIMINATOR_STEPS):
                # Discriminator step
                adam_d.zero_grad(set_to_none=True)
                discriminator_steps += 1
                z = torch.randn(batch_size, globals.Z_DIM, device=globals.DEVICE)
                with torch.no_grad():
                    fake = G(z)

                log = imgs_this_phase % len(loader) == 0 and i == 0
                D_loss = d_loss_non_saturating_r1(D, real, fake.detach(), discriminator_steps, ada, log=log)
                D_loss.backward()
                adam_d.step()

            # Generator step
            z = torch.randn(batch_size, globals.Z_DIM, device=globals.DEVICE)
            adam_g.zero_grad(set_to_none=True)

            fake = G(z)
            G_loss = g_loss_non_saturating(D, fake, ada)
            G_loss.backward()
            adam_g.step()
            applyExponentialMovingAverage(G, G_EMA)

            imgs_this_phase += batch_size
            global_img_count += batch_size
            count_until_fid -= batch_size
            count_until_grid -= batch_size
            if count_until_fid <= 0:
                count_until_fid = 50000
                percent_this_phase = 100*imgs_this_phase / (globals.IMAGES_PER_RESOLUTION[res])
                fid_score = utils.compute_fid(G, G_EMA, datasets[res], res, percent_this_phase, fid)

            if count_until_grid <= 0:
                utils.generate_grid_image(G, fid_score["G"], res, "training_imgs")
                count_until_grid = 100000

    G.set_layer_opacity(1.0)
    G_EMA.set_layer_opacity(1.0)
    D.set_layer_opacity(1.0)


RESOLUTION 4x4:


  0%|          | 0/147 [00:00<?, ?it/s]

REAL SCORES: tensor([[-0.2108],
        [ 0.0156],
        [ 0.1335],
        [ 0.3583],
        [ 0.5205],
        [ 0.0171],
        [ 0.2762],
        [ 0.2478],
        [-0.3006],
        [ 0.1438]], device='cuda:0', grad_fn=<SliceBackward0>)
FAKE SCORES: tensor([[-0.2274],
        [ 0.6772],
        [ 0.7035],
        [ 0.4699],
        [ 0.7292],
        [ 0.1080],
        [-0.4680],
        [ 0.2875],
        [ 0.0337],
        [ 0.2886]], device='cuda:0', grad_fn=<SliceBackward0>)
DISCRIMINATOR_LOSS: 1.545292854309082
ADA rt: 0.4609, p: 0.0000


100%|██████████| 147/147 [00:54<00:00,  2.70it/s]
 72%|███████▏  | 106/147 [00:43<00:09,  4.32it/s]

REAL SCORES: tensor([[ 0.6894],
        [ 0.3892],
        [ 0.4608],
        [ 0.5563],
        [-0.0143],
        [ 0.5273],
        [ 0.5628],
        [ 0.0198],
        [ 0.5393],
        [ 0.5047]], device='cuda:0', grad_fn=<SliceBackward0>)
FAKE SCORES: tensor([[-0.6370],
        [-0.8745],
        [-0.8191],
        [-0.5500],
        [-0.6001],
        [-0.7074],
        [-0.8241],
        [-0.6865],
        [-0.7831],
        [-0.7753]], device='cuda:0', grad_fn=<SliceBackward0>)
DISCRIMINATOR_LOSS: 0.9045981168746948
ADA rt: 0.9531, p: 0.0202


100%|██████████| 147/147 [00:53<00:00,  2.72it/s]
 43%|████▎     | 63/147 [00:32<00:17,  4.68it/s]

REAL SCORES: tensor([[ 0.2496],
        [ 0.1407],
        [ 0.0092],
        [ 0.7148],
        [ 0.6358],
        [ 0.5729],
        [-0.0145],
        [ 0.5311],
        [ 0.4615],
        [ 0.5736]], device='cuda:0', grad_fn=<SliceBackward0>)
FAKE SCORES: tensor([[-0.5944],
        [-0.5589],
        [-0.5212],
        [-0.6699],
        [-0.7701],
        [-0.9024],
        [-0.8811],
        [-0.5450],
        [-0.5243],
        [-0.7699]], device='cuda:0', grad_fn=<SliceBackward0>)
DISCRIMINATOR_LOSS: 0.9028184413909912
ADA rt: 0.9219, p: 0.0289


100%|██████████| 147/147 [03:11<00:00,  1.31s/it]
 16%|█▌        | 23/147 [00:22<00:28,  4.33it/s]

REAL SCORES: tensor([[0.6046],
        [0.5317],
        [0.3971],
        [0.5679],
        [0.3104],
        [0.5240],
        [0.4250],
        [0.3817],
        [0.4884],
        [0.4139]], device='cuda:0', grad_fn=<SliceBackward0>)
FAKE SCORES: tensor([[-0.2257],
        [-0.1814],
        [-0.2335],
        [-0.3167],
        [-0.4788],
        [-0.6122],
        [-0.5355],
        [-0.2731],
        [-0.4559],
        [-0.4437]], device='cuda:0', grad_fn=<SliceBackward0>)
DISCRIMINATOR_LOSS: 1.0638139247894287
ADA rt: 0.9766, p: 0.0375


100%|██████████| 147/147 [00:51<00:00,  2.85it/s]
 86%|████████▋ | 127/147 [00:49<00:04,  4.38it/s]

REAL SCORES: tensor([[ 0.1138],
        [ 0.2695],
        [ 0.2428],
        [-0.0033],
        [ 0.2473],
        [ 0.2405],
        [ 0.1553],
        [ 0.2168],
        [ 0.1855],
        [ 0.3418]], device='cuda:0', grad_fn=<SliceBackward0>)
FAKE SCORES: tensor([[-0.0551],
        [-0.0931],
        [ 0.0138],
        [-0.1597],
        [ 0.0835],
        [ 0.2239],
        [ 0.0962],
        [-0.0333],
        [ 0.0457],
        [ 0.0966]], device='cuda:0', grad_fn=<SliceBackward0>)
DISCRIMINATOR_LOSS: 1.3862230777740479
ADA rt: 0.7578, p: 0.0525


100%|██████████| 147/147 [00:56<00:00,  2.61it/s]
 33%|███▎      | 49/147 [01:11<02:22,  1.45s/it] 


KeyboardInterrupt: 

<Figure size 1000x500 with 0 Axes>

# Save Model Weights

Run the below cell to save the current state of the model i.e. the Generator/Discriminator weights and the Optimizer states

In [None]:
torch.save({
        'G_state_dict': G.state_dict(),
        'D_state_dict': D.state_dict(),
        'G_EMA_state_dict': G_EMA.state_dict(),
        'G_optimizer': adam_g.state_dict(),
        'D_optimizer': adam_d.state_dict()
}, "weights/ada_stylegan_64_more_channels.pth")

# Load Saved EMA Generator

Run the below cell to load a pretrained EMA Generator. You can find the weights file of my StyleGAN in GitHub in the `Impressionist Artworks v1.0` Release.
After is has been loaded, we generate some images with it.
The result will be saved in the directory `final_model_imgs`

In [None]:
G_EMA = generator.Generator().to(globals.DEVICE)
FINAL_MODEL_RESOLUTION = 64

checkpoint = torch.load("weights/ada_stylegan_64_more_channels.pth")
G_EMA.load_state_dict(checkpoint['G_EMA_state_dict'])
G_EMA.fade_in(FINAL_MODEL_RESOLUTION)
G_EMA.set_layer_opacity(1.0)

FINAL_MODEL_FID_SCORE = 40.38
utils.generate_grid_image(G_EMA, FINAL_MODEL_FID_SCORE, FINAL_MODEL_RESOLUTION, "final_model_imgs")

  checkpoint = torch.load(
