In [11]:
from math import log2

import torch
import config
import torch.optim as optim
import ignite.distributed as idist

from train import get_loader
from model import Generator, Discriminator
from utils import gradient_penalty
from ignite.engine import Engine
from ignite.metrics import FID, InceptionScore, RunningAverage, SSIM

In [3]:
gen = Generator(
    config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(idist.device())

critic = Discriminator(
    config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(idist.device())

# initialize optimizers and scalers for FP16 training
opt_gen = optim.Adam(
    gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
)
opt_critic = optim.Adam(
    critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
)

In [4]:
scaler_critic = torch.cuda.amp.GradScaler()
scaler_gen = torch.cuda.amp.GradScaler()



In [22]:
def initialize_fn(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


# @trainer.on(Events.STARTED)
def init_weights():
    gen.apply(initialize_fn)
    critic.apply(initialize_fn)

G_losses = []
D_losses = []


@trainer.on(Events.ITERATION_COMPLETED)
def store_losses(engine):
    o = engine.state.output
    G_losses.append(o["Loss_G"])
    D_losses.append(o["Loss_D"])


img_list = []


@trainer.on(Events.EPOCH_COMPLETED)
def store_images(engine):
    with torch.no_grad():
        fake = gen(fixed_noise).cpu()
    img_list.append(fake)

    plt.title("Fake Images")
    plt.imshow(
        np.transpose(
            vutils.make_grid(img_list[-1], padding=2, normalize=True).cpu(),
            (1, 2, 0),
        )
    )
    plt.show()

In [19]:
fid_metric = FID(device=idist.device())
is_metric = InceptionScore(
    device=idist.device(), output_transform=lambda x: x[0]
)
ssim_metric = SSIM(data_range=1.0, device=idist.device())


def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transforms.ToPILImage()(img)
        resized_img = pil_img.resize((299, 299), Image.BILINEAR)
        arr.append(transforms.ToTensor()(resized_img))
    return torch.stack(arr)


def evaluation_step(engine, batch):
    global gReal, gFake
    gReal, _ = batch
    gReal = gReal.to(idist.device())
    with torch.no_grad():
        noise = torch.randn(batch_size, latent_dim, 1, 1, device=idist.device())
        gFake = netG(noise)
        fake = interpolate(gFake)
        real = interpolate(gReal)
        return fake, real


evaluator = Engine(evaluation_step)
fid_metric.attach(evaluator, "fid")
is_metric.attach(evaluator, "is")
# ssim_metric.attach(evaluator, "ssim")

previous_model = None

fid_values = []
is_values = []
ssim_values = []

# @trainer.on(Events.EPOCH_COMPLETED)
def save_model(engine):
    global previous_model, MODEL_PATH
    if engine.state.epoch < save_threshold:
        return
    if fid_values and (fid_values[-1] == min(fid_values[-(save_threshold-1):])):
        print("Saving new model")
        MODEL_PATH = f"./models/{model_name}_n_epoch_{engine.state.epoch}_G_losses_{G_losses[-1]:4f}_D_losses_{D_losses[-1]:4f}_fid_{fid_values[-1]:4f}_is_{is_values[-1]:4f}_ssim_{ssim_values[-1]:4f}.pth"
        torch.save(netG, MODEL_PATH)
        if previous_model:
            print("Removing previous model")
            os.remove(previous_model)
        previous_model = MODEL_PATH


# @trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    global gReal, gFake
    evaluator.run(test_dataloader, max_epochs=1)
    metrics = evaluator.state.metrics
    fid_score = metrics["fid"]
    is_score = metrics["is"]
    fid_values.append(fid_score)
    is_values.append(is_score)
    
    ssim_metric = SSIM(data_range=1.0, device=idist.device())
    ssim_metric.update((gFake, gReal))
    ssim_score = float(ssim_metric.compute())
    ssim_values.append(ssim_score)
    
    print(f"Epoch [{engine.state.epoch}/{n_epoch}] Metric Scores")
    print(f"*       FID : {fid_score:4f}")
    print(f"*        IS : {is_score:4f}")
    print(f"*      SSIM : {ssim_score:4f}")
    print(f"*  G_losses : {G_losses[-1]:4f}")
    print(f"*  D_losses : {D_losses[-1]:4f}")
    

Downloading: "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth" to /Users/mac/.cache/torch/hub/checkpoints/inception_v3_google-1a9a5a14.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

In [20]:
RunningAverage(output_transform=lambda x: x["Loss_G"]).attach(trainer, "Loss_G")
RunningAverage(output_transform=lambda x: x["Loss_D"]).attach(trainer, "Loss_D")

ProgressBar().attach(trainer, metric_names=["Loss_G", "Loss_D"])
ProgressBar().attach(evaluator)

In [23]:
def training_step(engine, batch):
    global alpha, step, dataset
    gen.train()
    critic.train()
    
    real, _ = batch
    
    real = real.to(idist.device())
    noise = torch.randn(batch_size, config.Z_DIM, 1, 1, device=idist.device())
    
    with torch.cuda.amp.autocast():
        fake = gen(noise, alpha, step)
        critic_real = critic(real, alpha, step)
        critic_fake = critic(fake.detach(), alpha, step)
        gp = gradient_penalty(
            critic, real, fake, alpha, step, device=idist.device()
        )
        loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + config.LAMBDA_GP * gp
            + (0.001 * torch.mean(critic_real ** 2))
        )

    opt_critic.zero_grad()
    scaler_critic.scale(loss_critic).backward()
    scaler_critic.step(opt_critic)
    scaler_critic.update()

    # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
    with torch.cuda.amp.autocast():
        gen_fake = critic(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

    opt_gen.zero_grad()
    scaler_gen.scale(loss_gen).backward()
    scaler_gen.step(opt_gen)
    scaler_gen.update()

    # Update alpha and ensure less than 1
    alpha += batch_size / (
        (config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
    )
    alpha = min(alpha, 1)

    
trainer = Engine(training_step)

def training(*args):
    global alpha, step, dataset
    start_step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
    for step in range(start_step, 7):
        alpha = 1e-5
        image_size = 4 * 2 ** step
        loader, dataset = get_loader(image_size)
        trainer.run(loader, max_epochs=config.PROGRESSIVE_EPOCHS[step])

with idist.Parallel(backend="nccl") as parallel:
    parallel.run(training)

ValueError: Unknown backend 'nccl'. Available backends: ('gloo',)