# Training Generative Adversarial Network (GAN)

### Imports

In [None]:
import matplotlib.pyplot as plt

from training import utils
from training import data
from training.generator import Generator
from training.discriminator import Discriminator
from training.training import Trainer
from training.settings import *

utils.reset_rand()

### Check GPU

In [None]:
utils.check_gpu()

### Dataset

In [None]:
dataset = data.import_dataset()
print(f'Dataset size: {dataset.size():,}')

In [None]:
example_batch = dataset.next().to(DEVICE, non_blocking = True)
print(f'Batch shape: {tuple(example_batch.shape)}')

plt.figure(figsize = (5, 5))
plt.axis('off')
_ = plt.imshow(utils.create_grid(example_batch))
del example_batch

### Models

In [None]:
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

print('Generator:')
generator.summary()

print('\nDiscriminator:')
discriminator.summary()

### Training

In [None]:
trainer = Trainer(dataset, generator, discriminator)
trainer.find_previous_session()

torch.autograd.set_detect_anomaly(True)
trainer.train()