In [59]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch
from torchsummary import summary


from scripts.models import *
from scripts.utility import *

from pl_bolts.models.gans import GAN
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler

In [2]:
%load_ext autoreload
%autoreload 2

In [108]:
class MNISTDataModule(pl.LightningModule):
    def __init__(self, img_size=(32, 32), data_dir = "./data/mnist", batch_size=64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.has_setup_fit= True
        self.dataset = None
        
    def setup(self, state=None):
        self.dataset = datasets.MNIST(self.data_dir,
                    train=True, download=True, 
                    transform=transforms.Compose([transforms.Resize(self.img_size), transforms.ToTensor(),
                                                 transforms.Normalize([0.5], [0.5])]))
    
    def train_dataloader(self):
        dataloader = torch.utils.data.DataLoader(
            self.dataset, batch_size=self.batch_size, shuffle=True)
        return dataloader
    
    def tets_dataloader(self):
        pass
        

class BasicGAN(GAN):
    """BasicGAN. Input parameters are same as GAN class."""
    def __init__(
        self,
        input_channels: int = 1,
        input_height: int = 32,
        input_width: int = 32,
        latent_dim: int = 100,
        learning_rate: float = 0.0002,
        **kwargs
    ):
        """
        Constructor.
        
        Args:
            input_channels: number of channels of an image
            input_height: image height
            input_width: image width
            latent_dim: emb dim for encoder
            learning_rate: the learning rate
        """
        super(GAN, self).__init__()
        # makes self.hparams under the hood and saves to ckpt
        self.save_hyperparameters()
        self.img_dim = (input_channels, input_height, input_width)
        
        # networks
        self.generator = self.init_generator(self.img_dim).cuda()
        self.discriminator = self.init_discriminator(self.img_dim).cuda()
        
    def init_generator(self, img_dim):
        generator = BasicGenerator(img_dim, self.hparams.latent_dim)
        generator.apply(weights_init_normal)
        return generator
    
    def init_discriminator(self, img_dim):
        discriminator = BasicDiscriminator(img_dim)
        discriminator.apply(weights_init_normal)
        return discriminator
    
    

In [111]:
model = BasicGAN()
mnist = MNISTDataModule()
mnist.setup()

In [112]:
callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5)]
trainer = pl.Trainer(default_root_dir="models", gpus=1, callbacks=callbacks, progress_bar_refresh_rate=20)
trainer.fit(model, datamodule=mnist)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | generator     | BasicGenerator     | 1.0 M 
1 | discriminator | BasicDiscriminator | 97.9 K
-----------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.591     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




RuntimeError: Expected 4-dimensional input for 4-dimensional weight [16, 1, 3, 3], but got 2-dimensional input of size [64, 1024] instead

In [92]:
x, _ = next(iter(mnist.train_dataloader()))

In [93]:
x.shape

torch.Size([64, 1, 32, 32])

In [113]:
model.hparams

"input_channels": 1
"input_height":   32
"input_width":    32
"latent_dim":     100
"learning_rate":  0.0002

In [115]:
model.__dict__.keys()

dict_keys(['training', '_parameters', '_buffers', '_non_persistent_buffers_set', '_backward_hooks', '_forward_hooks', '_forward_pre_hooks', '_state_dict_hooks', '_load_state_dict_pre_hooks', '_modules', '_dtype', '_device', 'exp_save_path', 'loaded_optimizer_states_dict', 'trainer', '_distrib_type', '_device_type', 'use_amp', 'precision', '_example_input_array', '_datamodule', '_results', '_current_fx_name', '_running_manual_backward', '_current_hook_fx_name', '_current_dataloader_idx', 'running_stage', '_automatic_optimization', '_param_requires_grad_state', '_hparams_name', '_hparams', '_hparams_initial', 'img_dim', 'testing', 'train_dataloader'])