In [None]:
from functools import partial

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision.datasets as dsets
import torchvision.transforms as T
from torchvision.utils import make_grid

from ignite.engine import Engine, Events
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from ignite.metrics import Average

import matplotlib.pyplot as plt

from modules import Generator, Discriminator
from models import DeepConvolutionGAN
from loss import bce_loss

# Data

## Dataloader

In [None]:
train_data = dsets.CIFAR10("./", download=True, transform=T.ToTensor())
train_loader = DataLoader(train_data, batch_size=256, shuffle=True)

## Data arguments

In [None]:
input_size = 32

# Model

## Generator arguments

In [None]:
latent_dim = 100
G_hidden_channel = 128
G_last_act = "sigmoid"

## Discriminator arguments

In [None]:
D_hidden_channel = 128
D_last_act = "sigmoid"
loss_fn = bce_loss

## Make model

In [None]:
generator = Generator(input_size=input_size, latent_dim=latent_dim, hidden_channel=G_hidden_channel, last_act=G_last_act)
discriminator = Discriminator(input_size=input_size, hidden_channel=D_hidden_channel, last_act=D_last_act)

In [None]:
model = DeepConvolutionGAN(
    generator=generator,
    generator_opt=torch.optim.Adam(generator.parameters(), 1e-4),
    discriminator=discriminator,
    discriminator_opt=torch.optim.Adam(discriminator.parameters(), 1e-4),
    loss_fn=loss_fn,
)

In [None]:
if torch.cuda.is_available():
    _ = model.cuda()

# Trainer

## Set Ignite Engine

In [None]:
trainer = Engine(model.fit_batch)

## Set metrics

In [None]:
def output_transform(output, key):
    return output[key]


for key in ["G_loss", "D_loss"]:
    average = Average(output_transform=partial(output_transform, key=key))
    average.attach(trainer, key)

## Save history

In [None]:
trainer.history = []


@trainer.on(Events.EPOCH_COMPLETED)
def log_metric(engine):
    trainer.history += [engine.state.metrics]

## Set Progressbar

In [None]:
@trainer.on(Events.EPOCH_COMPLETED(every=5))
def print_every(engine, num_img=64):
    from math import sqrt

    # print metrics
    state = f"Epoch {engine.state.epoch} - "
    for key, value in engine.state.metrics.items():
        state += f"{key}: {value:.4f}, "
    print(state)

    # print img
    nrow = ncol = sqrt(num_img)
    if nrow * ncol < num_img:
        nrow += 1

    with torch.no_grad():
        model.generator.eval()
        device = next(model.generator.parameters()).device
        z = torch.randn((num_img, model.generator.latent_dim)).to(device)
        fake_data = model.generator(z).cpu()
        fake_data_grid = make_grid(fake_data, int(nrow))
        plt.imshow(fake_data_grid.permute(1, 2, 0))
        plt.show()

In [None]:
trainer.run(train_loader, 50)