# Fastpages Notebook Blog Post
> A tutorial of fastpages for Jupyter notebooks.

- toc: true 
- badges: true
- comments: true
- categories: [jupyter]
- image: images/chart-preview.png

# Import Package

## install

In [None]:
!pip install pytorch-lightning
!pip install wandb
!pip install omegaconf

Collecting pytorch-lightning
  Downloading pytorch_lightning-1.4.1-py3-none-any.whl (915 kB)
[K     |████████████████████████████████| 915 kB 7.7 MB/s 
Collecting torchmetrics>=0.4.0
  Downloading torchmetrics-0.4.1-py3-none-any.whl (234 kB)
[K     |████████████████████████████████| 234 kB 52.5 MB/s 
[?25hCollecting PyYAML>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |████████████████████████████████| 636 kB 48.3 MB/s 
Collecting pyDeprecate==0.3.1
  Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB)
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2021.7.0-py3-none-any.whl (118 kB)
[K     |████████████████████████████████| 118 kB 53.0 MB/s 
Collecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 43.3 MB/s 
[?25hCollecting tensorboard!=2.5.0,>=2.2.0
  Downloading tensorboard-2.4.1-py3-none-any.whl (10.6 MB)
[K     |████████████████████████████████| 10.6 MB 2

## import 

In [None]:
import wandb
import numpy as np
import torch
import torchvision
import pytorch_lightning as pl

from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from omegaconf import OmegaConf
from datetime import datetime

# Set Params

In [None]:
params = OmegaConf.create({

    'dataset': {
        'batch_size': 256
    },
    
    'lt' : {
        'latent_dim': 100,
        'lr': 0.0002,
        'b1': 0.5,
        'b2': 0.999,

    },

    'trainer': {
        'gpus' : -1,
        'max_epochs': 3,
    }
})

# Data Module

In [None]:
class MNIST_DataModule(pl.LightningDataModule):

    def __init__(self, data_dir='./', batch_size=128):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([transforms.ToTensor(),
                                             transforms.Normalize((0.5,), (0.5,))])
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        self.train_dataset = MNIST(self.data_dir, train=True, transform=self.transform)
        self.val_dataset = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
       return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=144, shuffle=False)

# LT Module

## Component

In [1]:
class Generator(nn.Module):

    def __init__(self, latent_dim, x_shape):
        super().__init__()
        self.x_shape = x_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.net = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(x_shape))),
            nn.Tanh(),
        )
    
    def forward(self, z):
        img = self.net(z)
        img = img.view(img.size(0), *self.x_shape)
        return img

NameError: ignored

In [None]:
class Discriminator(nn.Module):

    def __init__(self, x_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(x_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        validity = self.model(x)
        return validity

## Module

In [None]:
class GAN_Module(pl.LightningModule):

    def __init__(self, data_shape, params):
        super().__init__()

        self.save_hyperparameters()
        self.hp = params.lt

        self.generator = Generator(latent_dim=self.hp.latent_dim, x_shape=data_shape)
        self.discriminator = Discriminator(x_shape=data_shape)
        
        self.validation_z = torch.randn(25, self.hp.latent_dim)

    
    def forward(self, z):
        return self.generator(z)


    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)


    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, _ = batch

        # sample noise
        z = torch.randn(imgs.shape[0], self.hp.latent_dim).type_as(imgs)

        # train generator
        if optimizer_idx == 0:

            # generate images
            self.generated_imgs = self(z)

            # required result ( discreminator misjudge )
            valid = torch.ones(imgs.size(0), 1).type_as(imgs)
            
            # generator loss
            g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
            self.log('train/g_loss', g_loss)
            
            return {'loss': g_loss}

        # train discriminator
        if optimizer_idx == 1:

            # real image
            valid = torch.ones(imgs.size(0), 1).type_as(imgs)
            real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

            # generated image
            fake = torch.zeros(imgs.size(0), 1).type_as(imgs)
            fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)

            # discriminator loss
            d_loss = (real_loss + fake_loss) / 2
            self.log('train/d_loss', d_loss)
            
            return {'loss': d_loss}


    def configure_optimizers(self):
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=self.hp.lr, betas=(self.hp.b1, self.hp.b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.hp.lr, betas=(self.hp.b1, self.hp.b2))
        return [opt_g, opt_d], []


    def on_epoch_end(self):
        
        # log sampled images
        sample_imgs = self.generated_imgs[:25]
        grid = torchvision.utils.make_grid(sample_imgs, nrow=5)
        self.log('test_generated_images', wandb.Image(grid))

        z = self.validation_z.type_as(self.generator.net[0].weight)
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs, nrow=5)
        self.log('val_generated_images', wandb.Image(grid))

# RUN

In [None]:
# ログ設定
wandb.login()
wandb_logger = pl.loggers.WandbLogger(name = datetime.now().strftime('%y%m%d-%H%M%S'),
                                      project = 'Basic GAN',
                                      tags=['gan', 'notebook'])

# データセット
dm = MNIST_DataModule(**params.dataset)

# モデル
model = GAN_Module(dm.size(), params)

# 学習設定
trainer = pl.Trainer(logger=wandb_logger, deterministic=True, **params.trainer)

# 学習実行
trainer.fit(model, dm)

# logger を閉じる
wandb.finish()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
  rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop')


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw




  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 1.5 M 
1 | discriminator | Discriminator | 533 K 
------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.174     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




VBox(children=(Label(value=' 0.33MB of 0.33MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train/g_loss,1.12359
train/d_loss,0.54438
epoch,2.0
trainer/global_step,704.0
_runtime,33.0
_timestamp,1626235989.0
_step,16.0


0,1
train/g_loss,▁▁▂▃▃▅▃▃▃▄▄█▁▄
train/d_loss,▅█▆▅▃▄▁▄▃▃▂▂▆▃
epoch,▁▁▁▁▁▅▅▅▅▅▅██████
trainer/global_step,▁▂▂▃▃▃▄▄▅▅▅▆▆▇▇██
_runtime,▁▁▂▃▃▃▄▄▅▅▅▆▆▇▇██
_timestamp,▁▁▂▃▃▃▄▄▅▅▅▆▆▇▇██
_step,▁▁▂▂▃▃▄▄▅▅▅▆▆▇▇██
