In [None]:
!pip install pytorch_lightning



In [None]:
!pip install comet_ml



In [None]:
# used fo TPU
import collections
from datetime import datetime, timedelta
import os
import requests
import threading

_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
VERSION = "xrt==1.15.0"  #@param ["xrt==1.15.0", "torch_xla==nightly"]
CONFIG = {
    'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
    'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
        (datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
}[VERSION]
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)

# Update TPU XRT version
def update_server_xrt():
  print('Updating server-side XRT to {} ...'.format(CONFIG.server))
  url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
      TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
      XRT_VERSION=CONFIG.server,
  )
  print('Done updating server-side XRT: {}'.format(requests.post(url)))

update = threading.Thread(target=update_server_xrt)
update.start()

Updating server-side XRT to 1.15.0 ...


Exception in thread Thread-4:
Traceback (most recent call last):
  File "/usr/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-3-947a1487c49e>", line 23, in update_server_xrt
    TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
  File "/usr/lib/python3.6/os.py", line 669, in __getitem__
    raise KeyError(key) from None
KeyError: 'COLAB_TPU_ADDR'



In [None]:
import numpy as np
import os
from pathlib import Path
import shutil

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl
from pytorch_lightning import loggers

In [None]:
class Discriminator(pl.LightningModule):
  def __init__(self):   
      super().__init__()

      self.fc1 = nn.Sequential(
          nn.Linear(28 * 28, 28 * 28 * 2),
          nn.LeakyReLU(0.2),
          nn.Dropout(0.3)
      )
      self.fc2 = nn.Sequential(
          nn.Linear(28 * 28 * 2, 28 * 28 * 5),
          nn.LeakyReLU(0.2),
          nn.Dropout(0.3)
      )   
      self.fc3 = nn.Sequential(
          nn.Linear(28 * 28 * 5, 28 * 28),
          nn.LeakyReLU(0.2),
          nn.Dropout(0.3)
      )      
      self.fc4 = nn.Sequential(
          nn.Linear(28 * 28, 1),
          torch.nn.Sigmoid()
      )

  def forward(self, x):
    x = torch.flatten(x, start_dim=1, end_dim=-1)

    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    x = self.fc4(x)

    return x

In [None]:
class Generator(pl.LightningModule):
    def __init__(self):
      super().__init__()
      
      self.fc1 = nn.Sequential(
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2)
        )
      self.fc2 = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
      self.fc3 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
      self.fc4 = nn.Sequential(
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )
    
    def forward(self, x):
      x = torch.flatten(x, start_dim=1, end_dim=-1)

      x = self.fc1(x)
      x = self.fc2(x)
      x = self.fc3(x)
      x = self.fc4(x)

      return x

In [None]:
class GAN(pl.LightningModule):
  def __init__(self, hparams):
    super().__init__()
    self.hparams = hparams

    self.generator = Generator()
    self.discriminator = Discriminator()

    # cache for generated images
    self.generated_imgs = None
    self.last_imgs = None

    # For experience replay
    self.exp_replay = torch.tensor([])
     
  def binary_cross_entropy(self, y_hat, y):
    return F.binary_cross_entropy(y_hat, y)

  def forward(self, x):
    return self.generator(x)
  
  def training_step(self, batch, batch_nb, optimizer_idx):
    img, _ = batch
    self.last_imgs = img

    # Train Generator
    if optimizer_idx == 0:
      noise = torch.randn(img.shape[0], self.hparams.latent_dim)

      self.generated_img = self(noise)
      self.generated_img = self.generated_img + (self.hparams.level_of_noise) * torch.randn(self.generated_img.shape) # add noise to the input

      valid_lbl = (0.8 - 1.1) * torch.randn(img.shape[0], 1) + 1.1  # soft labels

      g_loss = self.binary_cross_entropy(self.discriminator(self.generated_img), valid_lbl)

      tqdm_dict = {'g_loss': g_loss}
      logs = {"g_loss": g_loss}
      return {"loss": g_loss, "log": logs, 'progress_bar': tqdm_dict}

    # Train Discriminator
    if optimizer_idx == 1:
      valid_lbl = (0.8 - 1.1) * torch.randn(img.shape[0], 1) + 1.1  # soft labels
      unvalid_lbl = (0.0 - 0.3) * torch.randn(img.shape[0], 1) + 0.3  # soft labels

      d_loss = 0

      # Experience replay
      perm = torch.randperm(self.generated_img.size(0)) # Shuffeling
      r_idx = perm[:max(1, int(img.shape[0] / self.hparams.r_frequent))] # Getting the index
      self.exp_replay = torch.cat((self.exp_replay, self.generated_img[r_idx]), 0).detach() # Add our new example to the replay buffer

      if self.exp_replay.size(0) >= img.shape[0]: # when we have enough example from the past train on them
         unvalid_lbl = (0.0 - 0.3) * torch.randn(self.exp_replay.shape[0], 1) + 0.3  # soft labels
         fake_loss = self.binary_cross_entropy(self.discriminator(self.exp_replay), unvalid_lbl)
         
         d_loss = fake_loss
         self.exp_replay = torch.tensor([])

      else:
        real_loss = self.binary_cross_entropy(self.discriminator(img), valid_lbl)

        fake_loss = self.binary_cross_entropy(self.discriminator(self.generated_img.detach()), unvalid_lbl)
        
        d_loss = (real_loss + fake_loss) / 2

      tqdm_dict = {'d_loss': d_loss}
      logs = {"d_loss": d_loss}
      return {"loss": d_loss, "log": logs, 'progress_bar': tqdm_dict}

  def configure_optimizers(self):
    opt_g = torch.optim.Adam(self.generator.parameters(), lr=self.hparams.lr, betas=(self.hparams.b1, self.hparams.b2))
    opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.hparams.lr, betas=(self.hparams.b1, self.hparams.b2))

    return opt_g, opt_d
    
  def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize([0.5], [0.5])])
    dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)
    return DataLoader(dataset, batch_size=self.hparams.batch_size)
    
  def on_epoch_end(self):
      noise = torch.randn(4, self.hparams.latent_dim)

      sample_img = self.generator(noise)
      sample_img = sample_img.view(-1, 1, 28, 28)
      grid = torchvision.utils.make_grid(sample_img, nrow=2).permute(1, 2, 0)
      comet_logger.experiment.log_image(grid, f'generated_images_epoch{self.current_epoch}', step=self.current_epoch)

      if self.current_epoch % 5 == 0:
        trainer.save_checkpoint(checkpoint_folder + "/" + experiment_name + "_epoch_" + str(self.current_epoch) + ".ckpt")
        comet_logger.experiment.log_asset_folder(checkpoint_folder)

        dirpath = Path("GAN" )
        if dirpath.exists() and dirpath.is_dir():
            shutil.rmtree(dirpath)

        # creating checkpoint folder
        access_rights = 0o755
        os.makedirs(checkpoint_folder, access_rights)

In [None]:
# Parameters
experiment_name = "MNIST_GAN_V2"
dataset_name = "MNIST"
checkpoint_folder = "GAN/" + experiment_name + "/checkpoints"
tags = ["GAN", "MNIST"]

In [None]:
# Hyperparameters
from argparse import Namespace
args = {
    'batch_size': 64,
    'lr': 0.0002,
    'b1': 0.5,
    'b2': 0.999,
    'latent_dim': 128,
    'level_of_noise': 0.1,
    'epochs': 100,
    'r_frequent': 10
}

hparams = Namespace(**args)

In [None]:
# init logger
comet_logger = loggers.CometLogger(
    api_key="",
    rest_api_key="",
    project_name="GAN",
    experiment_name=experiment_name,
)

#defining net
net = GAN(hparams)

#logging
comet_logger.experiment.log_parameters(args)
comet_logger.experiment.set_model_graph(str(net))
comet_logger.experiment.add_tags(tags=tags)
comet_logger.experiment.log_dataset_info(dataset_name)

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/luposx/gan/a73cc45ce37144b09139472616c6497c



In [None]:
# deleting the checkpoint folder
dirpath = Path("GAN" )
if dirpath.exists() and dirpath.is_dir():
    shutil.rmtree(dirpath)
 
# creating checkpoint folder
access_rights = 0o755
os.makedirs(checkpoint_folder, access_rights)

In [None]:
# Start training
trainer = pl.Trainer(resume_from_checkpoint="MNIST_GAN_V2_epoch_15.ckpt",
                     logger=comet_logger,
                     max_epochs=args["epochs"]
                     )    
trainer.fit(net)
trainer.save_checkpoint(checkpoint_folder + "/" + experiment_name + "_epoch_" + str(args["epochs"]) + ".ckpt")
comet_logger.experiment.log_asset_folder(checkpoint_folder)

HBox(children=(IntProgress(value=1, bar_style='info', layout=Layout(flex='2'), max=1), HTML(value='')), layout…