In [1]:
from functools import partial

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision.datasets as dsets
import torchvision.transforms as T

from ignite.engine import Engine, Events
from ignite.metrics import Average

import matplotlib.pyplot as plt

from modules import Generator, AutoEncoder
from models import EnergyBasedGAN

ModuleNotFoundError: No module named 'models.energy_based_gam'

# Data

## Dataloader

In [None]:
train_data = dsets.CIFAR10("./", download=True, transform=T.ToTensor())
train_loader = DataLoader(train_data, batch_size=256, shuffle=True)

## Data arguments

In [None]:
input_size = 32

# Model

## Model arguments

In [None]:
latent_dim = 100
hidden_channel = 128
G_last_act = "sigmoid"

## Make model

In [None]:
generator = Generator(input_size=input_size, latent_dim=latent_dim, hidden_channel=hidden_channel, last_act=G_last_act)
discriminator = AutoEncoder(input_size=input_size, hidden_channel=hidden_channel, latent_dim=latent_dim)

In [None]:
model = EnergyBasedGAN(
    generator=generator,
    generator_opt=torch.optim.Adam(generator.parameters(), 2e-4),
    discriminator=discriminator,
    discriminator_opt=torch.optim.Adam(discriminator.parameters(), 2e-4),
    margin = 20
)

In [None]:
if torch.cuda.is_available():
    _ = model.cuda(1)

# Trainer

## Set Ignite Engine

In [None]:
trainer = Engine(model.fit_batch)

## Set metrics

In [None]:
from utils import output_transform

In [None]:
for key in ["G_loss", "D_loss"]:
    average = Average(output_transform=partial(output_transform, key=key))
    average.attach(trainer, key)

## Add event handlers

In [None]:
from utils.event_handlers import log_metric, print_img, print_metric

In [None]:
trainer.add_event_handler(Events.EPOCH_COMPLETED, log_metric)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), print_metric)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), print_img, model)

## Run

In [None]:
trainer.run(train_loader, 50)