# Impressionist StyleGAN - Training Loop
Run the cells below to train your own StyleGAN on the dataset of impressionist artworks. Make sure have a directory `impressionist` that contains the images from the dataset (you can find the dataset in the GitHub Release called `Impressionist Artworks v1.0`).
The training loop calculates FID scores every 50_000 images shown to the discriminator and saves a grid of generated images every 100_000 images shown to the discriminator.

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import ImpressionistDataset as dataset
import matplotlib.pyplot as plt
import generator
import discriminator
import globals 
import math
import utils
from tqdm import tqdm
import importlib
from utils_generator import apply_exponential_moving_average, g_loss_non_saturating
from utils_discriminator import d_loss_non_saturating_r1
from ADA import ADA
from torch.utils.data import ConcatDataset
from torchmetrics.image.fid import FrechetInceptionDistance
import numpy as np
import os

In [None]:
importlib.reload(generator)
importlib.reload(discriminator)
importlib.reload(globals)
importlib.reload(utils)
importlib.reload(dataset)

In [None]:
cluster_ind = 0
datasets = {
    4: dataset.ImpressionistDataset(resolution=4, cluster_ind=cluster_ind),
    8: dataset.ImpressionistDataset(resolution=8, cluster_ind=cluster_ind),
    16: dataset.ImpressionistDataset(resolution=16, cluster_ind=cluster_ind),
    32: dataset.ImpressionistDataset(resolution=32, cluster_ind=cluster_ind),
    64: dataset.ImpressionistDataset(resolution=64, cluster_ind=cluster_ind),
    128: dataset.ImpressionistDataset(resolution=128, cluster_ind=cluster_ind),
}

In [None]:
plt.figure(figsize=(10, 5))

ada = ADA()
fid = FrechetInceptionDistance(feature=2048).to(globals.DEVICE)

G = generator.Generator()
D = discriminator.Discriminator()
G.to(globals.DEVICE), D.to(globals.DEVICE)

# we initialize our EMA Generator. We don't need gradients for it.
G_EMA = generator.Generator()
G_EMA.load_state_dict(G.state_dict())
G_EMA.train(False)
G_EMA.to(globals.DEVICE)

for param in G_EMA.parameters():
    param.requires_grad_(False)

mapping_params, generator_params = utils.get_generator_params(G)

adam_g = torch.optim.AdamW(
    [
        {"params": mapping_params, "lr": globals.LR_MAPPING_NETWORK, "name": "mapping"},
        {"params": generator_params, "lr": globals.LR_MODEL, "name": "generator"},
    ],
    betas=(globals.ADAM_BETA1, globals.ADAM_BETA2),
)


adam_d = torch.optim.AdamW(
    D.parameters(), lr=globals.LR_MODEL, betas=(globals.ADAM_BETA1, globals.ADAM_BETA2)
)

res_list = [2**i for i in range(2, int(math.log2(globals.MAX_RES)) + 1)]

global_img_count = 0
for res in res_list:

    # we update the learning rate for each resolution
    g_lr = globals.LR_MODEL_PER_RES[res]
    d_lr = globals.LR_MODEL_PER_RES[res]
    mapping_lr = globals.LR_MAPPING_NETWORK_PER_RES[res]

    for param_group in adam_g.param_groups:
        if param_group.get("name") == "mapping":
            param_group["lr"] = mapping_lr
        else:
            param_group["lr"] = g_lr

    for param_group in adam_d.param_groups:
        param_group["lr"] = d_lr

    print(f"RESOLUTION {res}x{res}:")
    if res > 4:
        G.fade_in(res)
        G_EMA.fade_in(res)
        D.fade_in(res)

    repeated_dataset = ConcatDataset([datasets[res]] * 3)
    loader = torch.utils.data.DataLoader(
        repeated_dataset,
        batch_size=globals.BATCH_SIZES_PER_RES[res],
        shuffle=True,
        num_workers=4,
        pin_memory=True,
    )

    fade_in_imgs = int(globals.IMAGES_PER_RESOLUTION[res] * globals.FADE_IN_PERCENTAGE)
    count_until_fid = 50000
    count_until_grid = 100000
    imgs_this_phase = 0
    discriminator_steps = 0

    while imgs_this_phase < globals.IMAGES_PER_RESOLUTION[res]:
        print(
            f"Images this phase: {imgs_this_phase}/{globals.IMAGES_PER_RESOLUTION[res]}"
        )
        for real in tqdm(loader):
            real = real.to(globals.DEVICE)

            batch_size = real.size(0)

            if imgs_this_phase < fade_in_imgs:
                layer_opacity = min(1.0, imgs_this_phase / max(1, fade_in_imgs))
                G.set_layer_opacity(layer_opacity)
                G_EMA.set_layer_opacity(layer_opacity)
                D.set_layer_opacity(layer_opacity)
            else:
                G.set_layer_opacity(1.0)
                G_EMA.set_layer_opacity(1.0)
                D.set_layer_opacity(1.0)

            # only one D step for logistic loss with R1
            for i in range(globals.DISCRIMINATOR_STEPS):
                # Discriminator step
                adam_d.zero_grad(set_to_none=True)
                discriminator_steps += 1
                z = torch.randn(batch_size, globals.Z_DIM, device=globals.DEVICE)
                with torch.no_grad():
                    fake = G(z)

                D_loss = d_loss_non_saturating_r1(
                    D, real, fake.detach(), discriminator_steps, ada
                )
                D_loss.backward()
                adam_d.step()

            # Generator step
            z = torch.randn(batch_size, globals.Z_DIM, device=globals.DEVICE)
            adam_g.zero_grad(set_to_none=True)

            fake = G(z)
            G_loss = g_loss_non_saturating(D, fake, ada)
            G_loss.backward()
            adam_g.step()
            apply_exponential_moving_average(G, G_EMA)

            imgs_this_phase += batch_size
            global_img_count += batch_size
            count_until_fid -= batch_size
            count_until_grid -= batch_size
            if count_until_fid <= 0:
                count_until_fid = 50000
                percent_this_phase = (
                    100 * imgs_this_phase / (globals.IMAGES_PER_RESOLUTION[res])
                )
                print("Calculating FID...")
                fid_score = utils.compute_fid(
                    G, G_EMA, datasets[res], res, percent_this_phase, fid
                )

            if count_until_grid <= 0:
                print("Generating Image Grid...")
                utils.generate_grid_image(G, fid_score["G"], res, "training_imgs")
                count_until_grid = 100000

    G.set_layer_opacity(1.0)
    G_EMA.set_layer_opacity(1.0)
    D.set_layer_opacity(1.0)

# Save Model Weights

Run the below cell to save the current state of the model i.e. the Generator/Discriminator weights and the Optimizer states

In [None]:
YOUR_MODEL_NAME = "Write_a_name_for_your_model_here"

os.makedirs("weights", exist_ok=True)
torch.save(
    {
        "G_state_dict": G.state_dict(),
        "D_state_dict": D.state_dict(),
        "G_EMA_state_dict": G_EMA.state_dict(),
        "G_optimizer": adam_g.state_dict(),
        "D_optimizer": adam_d.state_dict(),
    },
    f"weights/{YOUR_MODEL_NAME}.pth",
)

# Load Saved EMA Generator

Run the below cell to load a pretrained EMA Generator. You can find the weights file of my StyleGAN in GitHub in the `Impressionist Artworks v1.0` Release.
After it has been loaded, we generate some images with it.
The result will be saved in the directory `final_model_imgs`

In [None]:
FINAL_MODEL_RESOLUTION = 64
FINAL_MODEL_FID_SCORE = 40.38

G_EMA = generator.Generator().to(globals.DEVICE)
checkpoint = torch.load("weights/ada_stylegan_64_more_channels.pth")
G_EMA.load_state_dict(checkpoint["G_EMA_state_dict"])
G_EMA.fade_in(FINAL_MODEL_RESOLUTION)
G_EMA.set_layer_opacity(1.0)

utils.generate_grid_image(
    G_EMA, FINAL_MODEL_FID_SCORE, FINAL_MODEL_RESOLUTION, "final_model_imgs"
)