# VAE/GAN 구현

### 외부 파일 가져오기 & requirements 설치


In [1]:
from google.colab import drive
drive.mount("/content/drive")
import os
import sys
from datetime import datetime

drive_project_root = '/content/drive/MyDrive/#fastcampus'
sys.path.append(drive_project_root)
!pip install -r '/content/drive/MyDrive/#fastcampus/requirements.txt'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
gpu_info = !nvidia-smi
gpu_info = "\n".join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Select the Runtime > "change runtime type" menu to enable a GPU accelerator, ')
    print('and then re-execurte this cell.')
else:
    print(gpu_info)

Thu Dec 30 08:31:12 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   71C    P8    32W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
from typing import List
from typing import Dict
from typing import Union
from typing import Any
from typing import Optional
from typing import Iterable
from typing import Callable
from abc import abstractmethod
from abc import ABC
from datetime import datetime
from functools import partial
from collections import Counter
from collections import OrderedDict
import random
import math

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl

# data & models
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torchvision.utils as vutils

# For configuration
from omegaconf import DictConfig
from omegaconf import OmegaConf
import hydra
from hydra.core.config_store import ConfigStore

# For logger
from torch.utils.tensorboard import SummaryWriter
import wandb
os.environ["WANDB_START_METHOD"]="thread"

In [4]:
from data_utils import dataset_split
from config_utils import flatten_dict
from config_utils import register_config
from config_utils import configure_optimizers_from_cfg
from config_utils import configure_optimizer_element
from config_utils import get_loggers
from config_utils import get_callbacks
from custom_math import softmax

## Base 모델 정의

In [12]:
# Define model
class BaseGenerativeModel(pl.LightningModule):
    def __init__(self, cfg: DictConfig):
        pl.LightningModule.__init__(self)
        self.cfg = cfg

    @abstractmethod
    def forward(self ,x):
        raise NotImplementedError()

    def sample_generate(self, *args, **kwargs):
        raise NotImplementedError()

    def loss_function(self, *args, **kwargs):
        pass

    def configure_optimizers(self):
        self._optimizers, self._scheduler = configure_optimizers_from_cfg(
            self.cfg, self
        )
        return self._optimizers, self._scheduler

    def training_step(self, batch, batch_idx):
        pass

    def validation_step(self, batch, batch_idx):
        pass

# Generative module 정의
class Generator(nn.Module):
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg
        self.latent_dim = cfg.model.latent_dim

        def mlp_module(
            in_feat: int,
            out_feat: int,
            normalize: bool = True,
            activation: str = ""
        ):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat))
            if activation == "LeakyReLU":
                layers.append(nn.LeakyReLU(inplace=True))
            elif activation == "Tanh":
                layers.append(nn.Tanh())
            elif activation == "Sigmoid":
                layers.append(nn.Sigmoid())
            else:
                raise NotImplementedError()
            return layers
        
        layers: List[nn.Module] = []
        for modules_cfg in cfg.model.generator.mlp_modules:
            layers.extend(mlp_module(**modules_cfg))
        self.model = nn.Sequential(*layers)
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(
            img.size(0),
            self.cfg.data.C,
            self.cfg.data.H,
            self.cfg.data.W,
        )
        return img

# Discriminator module 정의
class Discriminator(nn.Module):
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg

        def mlp_module(
            in_feat: int,
            out_feat: int,
            normalize: bool = True,
            activation: str = ""
        ):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat))
            if activation == "LeakyReLU":
                layers.append(nn.LeakyReLU(inplace=True))
            elif activation == "Tanh":
                layers.append(nn.Tanh())
            elif activation == "Sigmoid":
                layers.append(nn.Sigmoid())
            else:
                raise NotImplementedError()
            return layers

        layers: List[nn.Module] = []
        for modules_cfg in cfg.model.discriminator.mlp_modules:
            layers.extend(mlp_module(**modules_cfg))
        self.model = nn.Sequential(*layers)

    def forward(self, img):
        img_flattened = img.view(img.size(0), -1)
        validity = self.model(img_flattened)
        return validity
        
# VanillaGAN 정의
class VanillaGAN(BaseGenerativeModel):
    def __init__(self, cfg: DictConfig):
        super().__init__(cfg)
        self.latent_dim = cfg.model.latent_dim

        self.generator = Generator(cfg)
        self.discriminator = Discriminator(cfg)

    def forward(self, z):
        return self.generator(z)

    def loss_function(self, y_hat, y):
        # adversarial loss
        # y_hat: from discriminator
        # y: label for discriminator
        return F.binary_cross_entropy(y_hat, y)

    def configure_optimizers(self):
        self._opt_g, self._scheduler_g = configure_optimizer_element(
            self.cfg.opt.generator.optimizer,
            self.cfg.opt.generator.lr_scheduler,
            self.generator
        )

        self._opt_d, self._scheduler_d = configure_optimizer_element(
            self.cfg.opt.discriminator.optimizer,
            self.cfg.opt.discriminator.lr_scheduler,
            self.discriminator
        )  
        optimizers = [self._opt_d, self._opt_g]
        schedulers = [self._scheduler_d, self._scheduler_g]
        schedulers = [sch for sch in schedulers if sch is not None]
        return optimizers, schedulers

    def _step(self, batch, batch_idx, optimizer_idx, mode="train"):
        # optimizer_idx를 넣는이유: loss를 따로 구하기 위해서
        assert mode in ["train", "val", "test"]
        imgs, _ = batch

        # sample noise
        z = torch.randn(imgs.shape[0], self.cfg.model.latent_dim)
        z = z.type_as(imgs)
        
        # ground truth_result
        valid = torch.ones(imgs.size(0), 1) # true
        valid = valid.type_as(imgs).to(self.device)

        # discriminator
        if optimizer_idx == 0:

            # fake loss
            fake = torch.zeros(imgs.size(0), 1)
            fake = fake.type_as(imgs).to(self.device)
            
            fake_loss = self.loss_function(
                self.discriminator(self(z)), fake
            )

            # true loss
            real_loss = self.loss_function(
                self.discriminator(imgs), valid
            )

            # discriminator loss
            d_loss = (real_loss + fake_loss) / 2

            outputs = {
                f"{mode}_loss": d_loss,
                f"{mode}_d_loss":d_loss,
                f"{mode}_real_loss": real_loss,
                f"{mode}_fake_loss": fake_loss,
            }
            if mode == "train":
                outputs["loss"]=d_loss
            self.log_dict(outputs)

        # generator
        if optimizer_idx == 1:
            self.generated_imgs = self(z)
            
            # generative loss
            g_loss = self.loss_function(
                self.discriminator(self(z)), valid
            )

            outputs = {
                f"{mode}_loss": g_loss,
                f"{mode}_g_loss":g_loss,
            }
            if mode == "train":
                outputs["loss"]=g_loss
            else: # val, test
                sample_imgs = self.generated_imgs
                grid_samples = vutils.make_grid(sample_imgs)
                grid_imgs = vutils.make_grid(imgs)

                self.logger.experiment[0].log(
                    {
                        f"{mode}_generated_images": wandb.Image(grid_samples),
                        f"{mode}_orig_images": wandb.Image(grid_imgs)
                    }
                )
                self.logger.experiment[1].add_images(
                    f"{mode}_generated_images",
                    sample_imgs,
                    0,
                )
            self.log_dict(outputs)

        return outputs

    def training_step(self, batch, batch_idx, optimizer_idx):
        return self._step(
            batch, batch_idx, optimizer_idx, mode="train"
        )

    def validation_step(self, batch, batch_idx):
        self._step(
            batch, batch_idx, 0, mode="val"
        )
        self._step(
            batch, batch_idx, 1, mode="val"
        )

    def test_step(self, batch, batch_idx):
        self._step(
            batch, batch_idx, 0, mode="test"
        )
        self._step(
            batch, batch_idx, 1, mode="test"
        )

    def _on_epoch_end(self, mode):
        # random noise test
        assert mode in ["train", "val", "test"]
        self.validation_z = torch.randn(
            cfg.train.test_batch_size, self.latent_dim
        )
        z = self.validation_z.type_as(self.generator.model[0].weight)

        # log sampled images
        sample_imgs = self(z)
        grid_samples = vutils.make_grid(sample_imgs)
        self.logger.experiment[0].log(
            {
                f"{mode}_generated_images": wandb.Image(grid_samples),
            }
        )

    def on_train_epoch_end(self, unused: Optional=None):
        self._on_epoch_end("train")

    def on_validation_epoch_end(self):
        self._on_epoch_end("val")

    def on_test_epoch_end(self,):
        self._on_epoch_end("test")

In [6]:
class VanillaVAE(BaseGenerativeModel):
    def __init__(self, cfg: DictConfig):
        super().__init__(cfg)
        self.latent_dim = cfg.model.latent_dim

        # define posterior (encoder) modules
        posterior_mlp_modules_list = []
        prev_dim = cfg.model.posterior.hidden_dims[0]

        for h_dim in cfg.model.posterior.hidden_dims[1:]:
            posterior_mlp_modules_list.append(nn.Linear(prev_dim, h_dim))
            prev_dim = h_dim

        self.posterior_mlp_modules = nn.Sequential(
            *posterior_mlp_modules_list
            )

        # define latent encode
        self.posterior_mu = nn.Linear(
            cfg.model.posterior.hidden_dims[-1], self.latent_dim
        )
        self.posterior_log_var = nn.Linear(
            cfg.model.posterior.hidden_dims[-1], self.latent_dim
        )

        # define prior (decoder) modules
        prior_mlp_modules_list = []
        prev_dim = self.latent_dim

        for h_dim in cfg.model.prior.hidden_dims:
            prior_mlp_modules_list.append(nn.Linear(prev_dim, h_dim))
            prev_dim = h_dim

        self.prior_mlp_modules = nn.Sequential(
            *prior_mlp_modules_list
            )
    def encode(self, input):
        hidden = self.posterior_mlp_modules(input)
        mu = self.posterior_mu(hidden)
        log_var = self.posterior_log_var(hidden)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)
        return mu + (epsilon * std)

    def decode(self, z):
        result = self.prior_mlp_modules(z)
        return torch.sigmoid(result)

    def forward(self, x):
        input = x.view(-1, self.cfg.model.posterior.hidden_dims[0])
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var # return recons, mu, log_var 

    def sample_generate(
        self,
        num_samples: int = 64, # for training (=batch size) or random sampling
        z: Optional[torch.Tensor] = None, # for manual sample generation
    ):
        if z is None:
            z = torch.randn(num_samples, self.latent_dim)
        else:
            num_samples = z.shape[0]
        assert z.shape[-1] == self.latent_dim
        z = z.to(self.device)
        samples = self.decode(z)
        return samples.view(num_samples, self.cfg.data.C, self.cfg.data.H, self.cfg.data.W)
        

    def loss_function(
        self,
        recons,
        real_img,
        mu,
        log_var,
        kld_weight,
        mode="train"
    ) -> dict:
        assert mode in ["train", "val", "test"]
        
        # reconstruction loss
        recons_loss = F.binary_cross_entropy(
            recons,
            real_img.view(-1, self.cfg.model.prior.hidden_dims[-1]),
            reduction="sum"
        )
        # kld loss
        kld_loss = 0.5 * torch.sum(
            mu.pow(2) + log_var.exp() - log_var - 1
        )

        # summation
        loss = recons_loss + kld_weight * kld_loss
        loss_result = {
            f"{mode}_loss": loss / self.cfg.data[f"num_{mode}_imgs"],
            f"{mode}_reconstruction_loss": recons_loss / self.cfg.data[f"num_{mode}_imgs"],
            f"{mode}_kld_loss": kld_loss / self.cfg.data[f"num_{mode}_imgs"]
        }

        return loss_result

    def training_step(self, batch, batch_idx):
        real_img, labels = batch
        recons, mu, log_var = self.forward(real_img)
        loss_results = self.loss_function(
            recons,
            real_img,
            mu,
            log_var,
            kld_weight=self.cfg.model.kld_weight,
            mode="train"
        )
        loss_results["loss"] = loss_results["train_loss"]
        self.log_dict(loss_results)
        return loss_results

    def validation_step(self, batch, batch_idx):
        real_img, labels = batch
        recons, mu, log_var = self.forward(real_img)
        loss_results = self.loss_function(
            recons,
            real_img,
            mu,
            log_var,
            kld_weight=self.cfg.model.kld_weight,
            mode="val"
        )
        self.log_dict(loss_results)

        # random sample_generate
        sample_gens = self.sample_generate(real_img.shape[0])

        self.logger.experiment[0].log({
            "inputs": wandb.Image(real_img),
            "recons": wandb.Image(recons.view(-1, self.cfg.data.C, self.cfg.data.H, self.cfg.data.W)),
            "sample_gens": wandb.Image(sample_gens)
        })

        return loss_results
    

## Configuration Def

In [7]:
# data configs
data_fashion_mnist_cfg = {
    "name": 'fashion_mnist',
    'data_root': os.path.join(os.getcwd(), "data"),
    "transforms": [
        {
            "name": "ToTensor",
            "kwargs": {}
        }
    ],
    "W": 28,
    "H": 28,
    "C": 1,
    "n_class": 10,
}

# model configs
model_mnist_vanilla_vae_cfg = {
    "name": "VanillaVAE",
    "latent_dim": 4,
    "posterior": {
        "hidden_dims": [28*28, 512, 256],
    }, # encoder
    "prior": {
        "hidden_dims": [256, 512, 28*28]
    }, # decoder
    "kld_weight": 1,
}
model_mnist_vanilla_gan_cfg = {
    "name": "VanillaGAN",
    "latent_dim": 128,
    "generator": {
        "mlp_modules": [
            {
                "in_feat": 128,
                "out_feat": 256,
                "normalize": False,
                "activation": "LeakyReLU"
            },
            {
                "in_feat": 256,
                "out_feat": 512,
                "normalize": True,
                "activation": "LeakyReLU"
            },
            {
                "in_feat": 512,
                "out_feat": 1024,
                "normalize": True,
                "activation": "LeakyReLU"
            },
            {
                "in_feat": 1024,
                "out_feat": 1*28*28, # C x H x W
                "normalize": False,
                "activation": "Tanh"
            },
        ]
    },
    "discriminator": {
        "mlp_modules": [
            {
                "in_feat": 1*28*28, # C x H x W
                "out_feat": 512,
                "normalize": True,
                "activation": "LeakyReLU"
            },
            {
                "in_feat": 512,
                "out_feat": 256,
                "normalize": True,
                "activation": "LeakyReLU"
            },
            {
                "in_feat": 256,
                "out_feat": 1,
                "normalize": False,
                "activation": "Sigmoid"
            },
        ]
    },
}

# opt_config
opt_cfg = {
    "optimizers": [
        {
            "name": "RAdam",
            "kwargs": {
                "lr": 1e-3,
            }
        }
    ],
    "lr_schedulers": [
        {
        "name": None,
        "kwargs": {
            "warmup_end_steps": 1000
        }
    }
    ]
}
gan_opt_cfg = {
    "discriminator": {
        "optimizer": {
            "name": "RAdam",
            "kwargs": {
                "lr": 0.0002,
            }
        },
        "lr_scheduler": {
            "name": None,
        }
    },
    "generator": {
        "optimizer": {
            "name": "RAdam",
            "kwargs": {
                "lr": 0.0002,
            }
        },
        "lr_scheduler": {
            "name": None,
        }
    }
}

_merged_cfg_presets = {
    "vanilla_vae_fashion_mnist": {
        "opt": opt_cfg,
        "data": data_fashion_mnist_cfg,
        "model": model_mnist_vanilla_vae_cfg
    },
    "vanilla_gan_fashion_mnist": {
        "opt": gan_opt_cfg,
        "data": data_fashion_mnist_cfg,
        "model": model_mnist_vanilla_gan_cfg
    }
}

## hydra composition
# clear hydra instance first
hydra.core.global_hydra.GlobalHydra.instance().clear()

# register preset configs
register_config(_merged_cfg_presets)

# initialize & make config
## select mode here ##
# ...................#
hydra.initialize(config_path=None)
# cfg = hydra.compose('vanilla_vae_fashion_mnist')
cfg = hydra.compose('vanilla_gan_fashion_mnist')

# override some cfg
run_name = f"{datetime.now().isoformat(timespec='seconds')}-{cfg.model.name}-{cfg.data.name}"

# Define other train configs & log_configs
# Merge configs inot one & ergister it to Hydra
project_root_dir = os.path.join(
    drive_project_root, "runs", "generative-dnn-tutorial-fashion-mnist-runs"
)
save_dir = os.path.join(project_root_dir, run_name)
run_root_dir = os.path.join(project_root_dir, run_name)

# train configs
train_cfg = {
    "train_batch_size": 256,
    "val_batch_size": 64,
    "test_batch_size": 64,
    "train_val_split": [0.9, 0.1],
    "run_root_dir": run_root_dir,
    "trainer_kwargs": {
        "accelerator": "dp",
        "gpus": "0",
        "max_epochs": 50,
        "val_check_interval": 1.0,
        "log_every_n_steps": 100,
        "flush_logs_every_n_steps": 100,
    }
}

# logger configs
log_cfg = {
    "loggers": {
        "WandbLogger": {
            "project": "fastcampus_generative_fashion_mnist_tutorials",
            'name': run_name,
            "tags": ["fastcampus_generative_fashion_mnist_tutorials"],
            "save_dir": run_root_dir,
        },
        "TensorBoardLogger": {
            "save_dir": project_root_dir,
            "name": run_name,
        }
    },
    "callbacks": {
        "ModelCheckpoint": {
            "save_top_k": 3,
            "monitor": "val_loss",
            "mode": "min",
            "verbose": True,
            "dirpath": os.path.join(run_root_dir, "weights"),
            "filename": "{epoch}-{val_loss:.3f}",
        },
        "EarlyStopping": {
            "monitor": "val_loss",
            "mode": "min",
            "patience": 10,
            "verbose": True
        }
    }
}
# unlock config & set train, log cofig
OmegaConf.set_struct(cfg, False)
cfg.train = train_cfg
cfg.log = log_cfg

# print config
print(OmegaConf.to_yaml(cfg))

opt:
  discriminator:
    optimizer:
      name: RAdam
      kwargs:
        lr: 0.0002
    lr_scheduler:
      name: null
  generator:
    optimizer:
      name: RAdam
      kwargs:
        lr: 0.0002
    lr_scheduler:
      name: null
data:
  name: fashion_mnist
  data_root: /content/data
  transforms:
  - name: ToTensor
    kwargs: {}
  W: 28
  H: 28
  C: 1
  n_class: 10
model:
  name: VanillaGAN
  latent_dim: 128
  generator:
    mlp_modules:
    - in_feat: 128
      out_feat: 256
      normalize: false
      activation: LeakyReLU
    - in_feat: 256
      out_feat: 512
      normalize: true
      activation: LeakyReLU
    - in_feat: 512
      out_feat: 1024
      normalize: true
      activation: LeakyReLU
    - in_feat: 1024
      out_feat: 784
      normalize: false
      activation: Tanh
  discriminator:
    mlp_modules:
    - in_feat: 784
      out_feat: 512
      normalize: true
      activation: LeakyReLU
    - in_feat: 512
      out_feat: 256
      normalize: true
      acti

## Data & dataloader

In [8]:
# get transforms from torch.vision
def get_transforms(cfg: DictConfig):
    transforms_list = []
    for tfm in cfg.data.transforms:
        if hasattr(transforms, tfm.name):
            transforms_list.append(
                getattr(transforms, tfm.name)(**tfm.kwargs)
            )

        else:
            raise ValueError(
                f"Not supported transform {tfm} in torch.vision.transform"
            )
    return transforms.Compose(transforms_list)

transform = get_transforms(cfg)

def get_datasets(
    cfg: DictConfig, download: bool = True
) -> Dict[str, torch.utils.data.Dataset]:
    data_root = cfg.data.data_root
    datasets = {}
    if cfg.data.name == "fashion_mnist":
        fashion_mnist_dataset = FashionMNIST(data_root, download=download, train=True, transform=transform)
        datasets = dataset_split(
            fashion_mnist_dataset, split=cfg.train.train_val_split
        )
        datasets["test"] = FashionMNIST(data_root, download=download, train=False, transform=transforms.ToTensor())
    else:
        raise NotImplementedError("Not supported dataset yet")
    return datasets

datasets = get_datasets(cfg, download=True)   

train_dataset = datasets['train']
val_dataset = datasets['val']
test_dataset = datasets['val']

# save_dataset_N
cfg.data.num_train_imgs = len(datasets["train"])
cfg.data.num_val_imgs = len(datasets["val"])
cfg.data.num_test_imgs = len(datasets["test"])

# define dataloader
train_batch_size = cfg.train.train_batch_size
val_batch_size = cfg.train.val_batch_size
test_batch_size = cfg.train.test_batch_size

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=0
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=0
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0
)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [13]:
# model define
def get_pl_model(cfg: DictConfig, checkpoint_path: Optional[str] = None):
    if cfg.model.name == "VanillaVAE":
        model = VanillaVAE(cfg)
    elif cfg.model.name == "VanillaGAN":
        model = VanillaGAN(cfg)
    else:
        NotImplementedError("not implemented model")

    if checkpoint_path is not None:
        model = model.load_from_checkpoint(cfg=cfg, checkpoint_path=checkpoint_path)
    return model

model = None
model = get_pl_model(cfg)

print(model)


VanillaGAN(
  (generator): Generator(
    (model): Sequential(
      (0): Linear(in_features=128, out_features=256, bias=True)
      (1): LeakyReLU(negative_slope=0.01, inplace=True)
      (2): Linear(in_features=256, out_features=512, bias=True)
      (3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): LeakyReLU(negative_slope=0.01, inplace=True)
      (5): Linear(in_features=512, out_features=1024, bias=True)
      (6): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): LeakyReLU(negative_slope=0.01, inplace=True)
      (8): Linear(in_features=1024, out_features=784, bias=True)
      (9): Tanh()
    )
  )
  (discriminator): Discriminator(
    (model): Sequential(
      (0): Linear(in_features=784, out_features=512, bias=True)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01, inplace=True)
      (3): Linear(in_featur

In [10]:
# pytorch-lightning trainer def
logger = get_loggers(cfg)
callbacks = get_callbacks(cfg)

trainer = pl.Trainer(
    callbacks=callbacks,
    logger=logger,
    default_root_dir=cfg.train.run_root_dir,
    num_sanity_val_steps=2,
    **cfg.train.trainer_kwargs
)

  f"Parsing of the Trainer argument gpus='{s}' (string) will change in the future."
GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [14]:
%load_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/\#fastcampus/runs/dnn-tutorial-fashion-mnist-runs/

trainer.fit(model, train_dataloader, val_dataloader)
# trainer.test(model, test_dataloader)

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 365), started 0:05:48 ago. (Use '!kill 365' to kill it.)

<IPython.core.display.Javascript object>

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 1.5 M 
1 | discriminator | Discriminator | 535 K 
------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.127     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 0.754
Epoch 0, global step 210: val_loss reached 0.75420 (best 0.75420), saving model to "/content/drive/MyDrive/#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/2021-12-30T08:31:15-VanillaGAN-fashion_mnist/weights/epoch=0-val_loss=0.754.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 1, global step 421: val_loss reached 0.82470 (best 0.75420), saving model to "/content/drive/MyDrive/#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/2021-12-30T08:31:15-VanillaGAN-fashion_mnist/weights/epoch=1-val_loss=0.825.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 2, global step 632: val_loss reached 1.22489 (best 0.75420), saving model to "/content/drive/MyDrive/#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/2021-12-30T08:31:15-VanillaGAN-fashion_mnist/weights/epoch=2-val_loss=1.225.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 3, global step 843: val_loss was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 4, global step 1054: val_loss was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 5, global step 1265: val_loss was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 6, global step 1476: val_loss was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 7, global step 1687: val_loss was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 8, global step 1898: val_loss was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 9, global step 2109: val_loss was not in top 3


Validating: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 10 records. Best score: 0.754. Signaling Trainer to stop.
Epoch 10, global step 2320: val_loss was not in top 3


In [None]:
# test
ckpt_path = os.path.join(
    cfg.log.callbacks.ModelCheckpoint.dirpath,
    "epoch=6-val_loss=3.004.ckpt"
)
model = get_pl_model(cfg, ckpt_path).eval()
print(model)

VanillaVAE(
  (posterior_mlp_modules): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=256, bias=True)
  )
  (posterior_mu): Linear(in_features=256, out_features=4, bias=True)
  (posterior_log_var): Linear(in_features=256, out_features=4, bias=True)
  (prior_mlp_modules): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=512, bias=True)
    (2): Linear(in_features=512, out_features=784, bias=True)
  )
)


In [None]:
def create_interpolation_images(
    model,
    axis1=0,
    axis2=1,
    latent_dim=4,
    save_img_path=None,
    range1=np.arange(-2, 2, 0.2),
    range2=np.arange(-2, 2, 0.2)
):
    assert len(range1) == len(range2)
    z = []
    for i in range1:
        for j in range2:
            cur = [0. for _ in range(latent_dim)]
            cur[axis1] = i
            cur[axis2] = j
            z.append(cur)
    z = torch.Tensor(z)
    out = model.sample_generate(z=z)
    out = vutils.make_grid(out, nrow=len(range1))

    if save_img_path is None:
        save_img_path = f"interpolation_results_{axis1}vs{axis2}.png"
    vutils.save_image(out, save_img_path)

create_interpolation_images(model, 0, 1)
create_interpolation_images(model, 0, 2)
create_interpolation_images(model, 0, 3)
create_interpolation_images(model, 1, 2)
create_interpolation_images(model, 1, 3)
create_interpolation_images(model, 2, 3)