In [1]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import os
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger

import gans
from gans import WGAN_GP, MNISTDataModule

In [None]:
# define the logger object
logger = TensorBoardLogger("gans_logs", name = "WGAN_GP")

dm = MNISTDataModule(batch_size=128)

model = WGAN_GP()
trainer = Trainer(
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    max_epochs=200,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
    logger=logger,
    check_val_every_n_epoch=5
)
trainer.fit(model, dm)

  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type      | Params
----------------------------------------
0 | generator | generator | 623 K 
1 | critic    | critic    | 134 K 
----------------------------------------
758 K     Trainable params
0         Non-trainable params
758 K     Total params
3.033     Total estimated model params size (MB)


Epoch 3:  74%|██████████████████████████         | 320/430 [04:10<01:26,  1.28it/s, loss=-1.02, v_num=2, loss/g_loss=0.0127, loss/d_loss=-2.12]

In [None]:
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in a uniform grid.
    '''
    image_grid = make_grid(fake[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [None]:
fake = model.on_validation_epoch_end()

In [None]:
show_tensor_images(fake, num_images=25, size=(1, 28, 28))

In [6]:
# Start tensorboard.
os.environ['TENSORBOARD_BINARY'] = '/.../anaconda3/envs/pytorch/bin/tensorboard'
%load_ext tensorboard
%tensorboard --logdir gans_logs/ --port 8889 --bind_all