In [None]:
!pip install pytorch_lightning



In [None]:
!pip install comet_ml



In [None]:
# used fo TPU
# Uncomment if you want to use TPU
# import collections
# from datetime import datetime, timedelta
# import os
# import requests
# import threading

# _VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
# VERSION = "xrt==1.15.0"  #@param ["xrt==1.15.0", "torch_xla==nightly"]
# CONFIG = {
#     'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
#     'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
#         (datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
# }[VERSION]
# DIST_BUCKET = 'gs://tpu-pytorch/wheels'
# TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
# TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
# TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)

# # Update TPU XRT version
# def update_server_xrt():
#   print('Updating server-side XRT to {} ...'.format(CONFIG.server))
#   url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
#       TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
#       XRT_VERSION=CONFIG.server,
#   )
#   print('Done updating server-side XRT: {}'.format(requests.post(url)))

# update = threading.Thread(target=update_server_xrt)
#update.start()

In [None]:
import numpy as np
from numpy.random import choice

import os
from pathlib import Path
import shutil

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision  
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl
from pytorch_lightning import loggers

from PIL import Image

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=0.1):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
# randomly flip some labels
def noisy_labels(y, p_flip=0.05):  # # flip labels with 5% probability
	# determine the number of labels to flip
	n_select = int(p_flip * y.shape[0])
	# choose labels to flip
	flip_ix = choice([i for i in range(y.shape[0])], size=n_select)
	# invert the labels in place
	y[flip_ix] = 1 - y[flip_ix]
	return y

In [None]:
def get_valid_labels(img):
  return noisy_labels((0.8 - 1.1) * torch.rand(img.shape[0], 1, 1, 1) + 1.1)  # soft labels

In [None]:
def get_unvalid_labels(img):
  return noisy_labels((0.0 - 0.3) * torch.rand(img.shape[0], 1, 1, 1) + 0.3)  # soft labels

In [None]:
class Discriminator(pl.LightningModule):
  def __init__(self, ndf, nc):
      super().__init__()
      self.ndf = ndf
      self.nc = nc

      # input is (nc) x 128 x 128
      # self.fc1 = nn.Sequential(
      #     nn.Conv2d(in_channels=self.nc, out_channels=self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
      #     nn.LeakyReLU(0.2)
      # )

      # input is (nc) x 64 x 64
      self.fc1 = nn.Sequential(
          nn.Conv2d(in_channels=self.nc, out_channels=self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
           # nn.BatchNorm2d(self.ndf * 2),
          nn.LeakyReLU(0.2)
      )

      # state size. (ndf) x 32 x 32
      self.fc2 = nn.Sequential(
          nn.Conv2d(in_channels=self.ndf, out_channels=self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm2d(self.ndf * 2),
          nn.LeakyReLU(0.2),
          nn.Dropout(0.3)
      )

      # state size. (ndf*2) x 16 x 16
      self.fc3 = nn.Sequential(
          nn.Conv2d(in_channels=self.ndf * 2, out_channels=self.ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm2d(self.ndf * 4),
          nn.LeakyReLU(0.2),
          nn.Dropout(0.3)
      )

      # state size. (ndf*4) x 8 x 8
      self.fc4 = nn.Sequential(
          nn.Conv2d(in_channels=self.ndf * 4, out_channels=self.ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm2d(self.ndf * 8),
          nn.LeakyReLU(0.2),
          nn.Dropout(0.3)
      )

      # state size. (ndf*8) x 4 x 4
      self.fc5 = nn.Sequential(
          nn.Conv2d(in_channels=self.ndf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
          nn.Sigmoid()
      )

      # state size. 1

  def forward(self, x):
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    x = self.fc4(x)
    x = self.fc5(x)
    #x = self.fc6(x)

    return x

In [None]:
class Generator(pl.LightningModule):
  def __init__(self, latent_dim, ngf, nc):
    super().__init__()
    self.ngf = 16
    self.n_features = latent_dim
    self.nc = nc

    # input is Z, going into a convolution
    self.fc1 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=self.n_features, out_channels=self.ngf * 8, kernel_size=4, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(self.ngf * 8),
        nn.LeakyReLU(0.2)
    )

    # state size. (ngf*8) x 4 x 4
    self.fc2 = nn.Sequential(
        nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, False),
        nn.BatchNorm2d(self.ngf * 4),
        nn.LeakyReLU(0.2)
    )

    # state size. (ngf*4) x 8 x 8
    self.fc3 = nn.Sequential(
        nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, False),
        nn.BatchNorm2d(self.ngf * 2),
        nn.LeakyReLU(0.2)
    )

    # state size. (ngf*2) x 16 x 16
    self.fc4 = nn.Sequential(
        nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, False),
        nn.BatchNorm2d(self.ngf),
        nn.LeakyReLU(0.2)
    )

    # state size. (ngf) x 32 x 32
    self.fc5 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=self.ngf, out_channels=self.nc, kernel_size=4, stride=2, padding=1, bias=False),
        #nn.BatchNorm2d(self.ngf),
        #nn.LeakyReLU(0.2)
        nn.Tanh()
    )

    # state size. (ngf) x 64 x 64
    # self.fc6 = nn.Sequential(
    #     nn.ConvTranspose2d(in_channels=self.ngf, out_channels=self.nc, kernel_size=4, stride=2, padding=1, bias=False),
    #     nn.Tanh()
    # )

     # state size. (nc) x 128 x 128

  def forward(self, x):
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    x = self.fc4(x)
    x = self.fc5(x)
    #x = self.fc6(x)

    return x

In [None]:
class DCGAN(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        self.generator = Generator(hparams.latent_dim, hparams.ngf, hparams.nc)
        self.discriminator = Discriminator(hparams.ndf, hparams.nc)
        self.generator.apply(weights_init) # custom weight init
        self.discriminator.apply(weights_init)

        # cache for generated images
        self.generated_imgs = None
        self.last_imgs = None

        # For experience replay
        self.exp_replay = torch.tensor([])

    def binary_cross_entropy(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def forward(self, x):
        return self.generator(x)

    def training_step(self, batch, batch_nb, optimizer_idx):
        std_gaussian = max(0, self.hparams.level_of_noise - ((self.hparams.level_of_noise * 2) * (self.current_epoch / self.hparams.epochs)))
        #AddGaussianNoiseInst = AddGaussianNoise(std=std_gaussian) # the noise decays over time

        img = batch[0]
        #img = AddGaussianNoiseInst(img)
        self.last_imgs = img

        # Train Generator maximize log(D(G(z)))
        if optimizer_idx == 0:
            noise = torch.randn(img.shape[0], self.hparams.latent_dim, 1, 1)

            self.generated_img = self(noise)
            #self.generated_img = AddGaussianNoiseInst(self.generated_img)

            g_loss = self.binary_cross_entropy(self.discriminator(self.generated_img), get_valid_labels(self.generated_img)) # valid labels is right to use here

            tqdm_dict = {'g_loss': g_loss}
            logs = {"g_loss": g_loss, "std_gaussian": std_gaussian}
            return {"loss": g_loss, "log": logs, 'progress_bar': tqdm_dict}

        # Train Discriminator maximize log(D(x)) + log(1 - D(G(z)))
        if optimizer_idx == 1:
            d_loss = 0

            # Experience replay
            perm = torch.randperm(self.generated_img.size(0))  # Shuffeling
            r_idx = perm[:max(1, self.hparams.experience_save_per_batch)]  # Getting the index
            self.exp_replay = torch.cat((self.exp_replay, self.generated_img[r_idx]),
                                        0).detach()  # Add our new example to the replay buffer

            if self.exp_replay.size(0) >= self.hparams.experience_batch_size:  # when we have enough example from the past train on them
                fake_loss = self.binary_cross_entropy(self.discriminator(self.exp_replay.detach()), get_unvalid_labels(self.exp_replay.detach()))

                d_loss = fake_loss
                self.exp_replay = torch.tensor([])

                tqdm_dict = {'d_loss': d_loss}
                logs = {"d_loss": d_loss, "std_gaussian": std_gaussian}
                return {"loss": d_loss, "log": logs, 'progress_bar': tqdm_dict}

            else:
                real_loss = self.binary_cross_entropy(self.discriminator(img), get_valid_labels(img))

                fake_loss = self.binary_cross_entropy(self.discriminator(self.generated_img.detach()), get_unvalid_labels(self.generated_img.detach()))

                d_loss = (real_loss + fake_loss) / 2

                tqdm_dict = {'d_loss': d_loss}
                logs = {"d_exp_loss": d_loss, "std_gaussian": std_gaussian}
                return {"loss": d_loss, "log": logs, 'progress_bar': tqdm_dict}

    def configure_optimizers(self):
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=self.hparams.lr,
                                 betas=(self.hparams.b1, self.hparams.b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.hparams.lr,
                                 betas=(self.hparams.b1, self.hparams.b2))

        return opt_g, opt_d

    def train_dataloader(self):
        transform = transforms.Compose([transforms.Resize((self.hparams.image_size, self.hparams.image_size)),
                                        transforms.ToTensor(),
                                        transforms.Normalize([0.5], [0.5])
                                        ])

        train_dataset = torchvision.datasets.ImageFolder(
            root="./drive/My Drive/datasets/ghibli_dataset_small_overfit/",
            transform=transform
        )
        return DataLoader(train_dataset, num_workers=self.hparams.num_workers, shuffle=True, batch_size=self.hparams.batch_size)
        # transform = transforms.Compose([transforms.Resize((64, 64)),
        #                             transforms.ToTensor(),
        #                             transforms.Normalize([0.5], [0.5]),
        #                             ])
        # dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)

        # rand_idx = torch.randperm(len(dataset)).tolist()
        # rand_idx = rand_idx[:self.hparams.batch_size]

        # subset_loader = torch.utils.data.DataLoader(dataset, batch_size=self.hparams.batch_size, sampler=torch.utils.data.SubsetRandomSampler(rand_idx), num_workers=self.hparams.num_workers)
        # return subset_loader

    def on_epoch_end(self):
        if self.current_epoch % self.hparams.save_train_image_every_epoch == 0:
          perm = torch.randperm(self.last_imgs.size(0))
          idx = perm[:1]
          comet_logger.experiment.log_image(self.last_imgs[idx].squeeze(dim=0).permute(1, 2, 0), f'trained_on_this_image{self.current_epoch}', step=self.current_epoch)

        if self.current_epoch % self.hparams.save_image_every_epoch == 0:
            noise = torch.randn(4, self.hparams.latent_dim, 1, 1)

            sample_img = self.generator(noise)
            sample_img = sample_img.view(-1, self.hparams.nc, self.hparams.image_size, self.hparams.image_size)
            grid = torchvision.utils.make_grid(sample_img, nrow=2)
            comet_logger.experiment.log_image(grid.permute(1, 2, 0), f'generated_images_epoch{self.current_epoch}', step=self.current_epoch)
            torchvision.utils.save_image(grid, f'{experiment_name}/images/generated_images_epoch{self.current_epoch}.png')

        if self.current_epoch % self.hparams.save_model_every_epoch == 0:
            trainer.save_checkpoint(checkpoint_folder + "/" + experiment_name + "_epoch_" + str(self.current_epoch) + ".ckpt")
            comet_logger.experiment.log_asset_folder(checkpoint_folder)

            if dirpath.exists() and dirpath.is_dir():
                shutil.rmtree(dirpath)

            # creating checkpoint folder
            access_rights = 0o755
            os.makedirs(checkpoint_folder, access_rights)

In [None]:
# Parameters
experiment_name = "GHIBLI_DCGAN_OVERFIT_64px"
dataset_name = "GHIBLI_OVERFIT"
checkpoint_folder = "DCGAN/"
tags = ["DCGAN", "GHIBLI", "OVERFIT", "64x64"]
dirpath = Path(checkpoint_folder)

In [None]:
# Hyperparameters
from argparse import Namespace

args = {
    'batch_size': 10,
    'lr': 0.0002,
    'b1': 0.5,
    'b2': 0.999,
    'latent_dim': 125,
    'level_of_noise': 0.1,
    'epochs': 10000,
    'ndf': 10,  # discriminator is slighly worse so that the discriminator doesnt become too good
    'ngf': 16,
    'experience_batch_size': 200,
    'experience_save_per_batch': 2,
    'save_model_every_epoch': 500,
    'save_image_every_epoch': 2,
    'num_workers': 3,
    'nc': 3,
    'image_size': 64,
    'save_train_image_every_epoch': 100
}

hparams = Namespace(**args)

In [None]:
# init logger
comet_logger = loggers.CometLogger(
    api_key="",
    rest_api_key="",
    project_name="ghibli-gan",
    experiment_name=experiment_name,
    #experiment_key="f23d00c0fe3448ee884bfbe3fc3923fd"  # used for resuming trained id can be found in comet.ml
)

INFO:lightning:CometLogger will be initialized in online mode
COMET INFO: old comet version (3.1.4) detected. current: 3.1.5 please update your comet lib with command: `pip install --no-cache-dir --upgrade comet_ml`
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/luposx/ghibli-gan/8ed4bf0b88ab4e9594a001439c653b99



In [None]:
#defining net
net = DCGAN(hparams)

#logging
comet_logger.experiment.set_model_graph(str(net))
comet_logger.experiment.add_tags(tags=tags)
comet_logger.experiment.log_dataset_info(dataset_name)

In [None]:
# deleting the checkpoint folder
if dirpath.exists() and dirpath.is_dir():
    shutil.rmtree(dirpath)

# creating checkpoint folder
access_rights = 0o755
os.makedirs(checkpoint_folder, access_rights)

dirpath2 = Path(f'{experiment_name}/images')
if dirpath2.exists() and dirpath2.is_dir():
                shutil.rmtree(dirpath2)
os.makedirs(dirpath2, access_rights)

In [None]:
trainer = pl.Trainer(#resume_from_checkpoint="GHIBLI_DCGAN_OVERFIT_64px_epoch_6000.ckpt",
                     logger=comet_logger,
                     max_epochs=args["epochs"]
                     )
trainer.fit(net)
trainer.save_checkpoint(checkpoint_folder + "/" + experiment_name + "_epoch_" + str(args["epochs"]) + ".ckpt")
comet_logger.experiment.log_asset_folder(checkpoint_folder)

INFO:lightning:
   | Name                | Type            | Params
----------------------------------------------------
0  | generator           | Generator       | 429 K 
1  | generator.fc1       | Sequential      | 256 K 
2  | generator.fc1.0     | ConvTranspose2d | 256 K 
3  | generator.fc1.1     | BatchNorm2d     | 256   
4  | generator.fc1.2     | LeakyReLU       | 0     
5  | generator.fc2       | Sequential      | 131 K 
6  | generator.fc2.0     | ConvTranspose2d | 131 K 
7  | generator.fc2.1     | BatchNorm2d     | 128   
8  | generator.fc2.2     | LeakyReLU       | 0     
9  | generator.fc3       | Sequential      | 32 K  
10 | generator.fc3.0     | ConvTranspose2d | 32 K  
11 | generator.fc3.1     | BatchNorm2d     | 64    
12 | generator.fc3.2     | LeakyReLU       | 0     
13 | generator.fc4       | Sequential      | 8 K   
14 | generator.fc4.0     | ConvTranspose2d | 8 K   
15 | generator.fc4.1     | BatchNorm2d     | 32    
16 | generator.fc4.2     | LeakyReLU       | 0 

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

COMET ERROR: File could not be uploaded
INFO:lightning:Detected KeyboardInterrupt, attempting graceful shutdown...
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/luposx/ghibli-gan/8ed4bf0b88ab4e9594a001439c653b99
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     d_loss [8790]       : (0.15617361664772034, 1.1096616983413696)
COMET INFO:     g_loss [8790]       : (0.5527005195617676, 3.2761874198913574)
COMET INFO:     std_gaussian [8790] : (0, 0.1)
COMET INFO:   Others:
COMET INFO:     Name         : GHIBLI_DCGAN_OVERFIT_64px
COMET INFO:     dataset_info : GHIBLI_OVERFIT
COMET INFO:   Parameters:
COMET INFO:     b1                           : 0.5
COMET INFO:     b2                           : 0.999
COMET INFO:     batch_size                   : 10
COMET INFO:     epochs                




COMET INFO: Uploading stats to Comet before program termination (may take several seconds)
COMET INFO: old comet version (3.1.4) detected. current: 3.1.5 please update your comet lib with command: `pip install --no-cache-dir --upgrade comet_ml`
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/luposx/ghibli-gan/8ed4bf0b88ab4e9594a001439c653b99



[('GHIBLI_DCGAN_OVERFIT_64px_epoch_10000.ckpt',
  {'api': 'https://www.comet.ml/api/rest/v2/experiment/asset/get-asset?assetId=62325b5f4d0b4cae8be2d915c56654cb&experimentKey=8ed4bf0b88ab4e9594a001439c653b99',
   'assetId': '62325b5f4d0b4cae8be2d915c56654cb',
   'web': 'https://www.comet.ml/api/asset/download?assetId=62325b5f4d0b4cae8be2d915c56654cb&experimentKey=8ed4bf0b88ab4e9594a001439c653b99'})]