In [1]:
from lightning_gans.models.resnet import conv_downsample
from lightning_gans.models.discriminators import Discriminator
import timm
import pytorch_lightning as pl
import torch
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
from lightning_gans.data.dataloader import MonetDataset, AbstractArtDataset
from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning.callbacks import ModelCheckpoint
import tqdm

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
#manualSeed = 999
#manualSeed = random.randint(1,if norm:
        layers.append(norms[norm](num_features=out_c, affine=True))
    if dropout:
        layers.append(torch.nn.Dropout2d(p=0.5, inplace=True))
    if activation:
        layers.append(activations[activation])
    return torch.nn.Sequential(*layers) 10000) # use if you want new results
#print("Random Seed: ", manualSeed)
#random.seed(manualSeed)
#torch.manual_seed(manualSeed)

In [2]:
activations = {
    "leaky_relu": torch.nn.LeakyReLU(0.2,inplace=True),
    "relu": torch.nn.ReLU(inplace=True),
    "tanh": torch.nn.Tanh()
}
norms = {"batch": torch.nn.BatchNorm2d, "instance": torch.nn.InstanceNorm2d}


def weights_init(m):
    """Initializes weights with a 0 mean and 0.02 stdev

    Args:
        m (torch.nn.Module): torch module whose weights will be initialized
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    if classname.find('InstanceNorm') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    if classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

        
def conv_upsample(ks, in_c, out_c, st=2, activation="leaky_relu", norm="batch", dropout=False):
    padding = {
        2:[0,0],
        3:[1,1],
        4:[1,0],
        5:[2,1]
    }
    if st==1:
        padding[4][0]=0
    layers = [
        torch.nn.ConvTranspose2d(in_c, out_c, ks, st, padding=padding[ks][0], output_padding=padding[ks][1],
                                 bias=False),
        #torch.nn.UpsamplingBilinear2d(scale_factor=2),
        #torch.nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, bias=False)
    ]
    if norm:
        layers.append(norms[norm](num_features=out_c, affine=True))
    if dropout:
        layers.append(torch.nn.Dropout2d(p=0.5, inplace=True))
    if activation:
        layers.append(activations[activation])
    return torch.nn.Sequential(*layers)

In [13]:
class GAN(pl.LightningModule):
    def __init__(self,latent_size):
        super().__init__()
        self.latent_size = latent_size
        
        self.generator = None
        self.discriminator = None
        self._create_model()
        self.generator.apply(weights_init)
        self.discriminator.apply(weights_init)
        
        self.generator_optimizer=None
        self.discriminator_optimizer=None
        self.automatic_optimization=False
        
        self.loss = torch.nn.BCELoss()
        self.fixed_noise = torch.randn(64, self.latent_size, 1, 1)
        
    def _create_model(self):
        self.generator = torch.nn.Sequential(
            # input = Bx500x1x1
            conv_upsample(4,in_c=self.latent_size,out_c=1024,st=1,dropout=True),
            # shape = Bx1024x4x4
            conv_upsample(4,in_c=1024,out_c=512,dropout=True),
            # shape = Bx512x8x8
            conv_upsample(4,in_c=512,out_c=256,dropout=True),
            # shape = Bx256x16x16
            conv_upsample(4,in_c=256,out_c=128,dropout=True),
            # shape = Bx128x32x32
            conv_upsample(4,in_c=128,out_c=3,activation="tanh",norm=None),
            # shape = Bx3x64x64
        )
        """
        self.discriminator = torch.nn.Sequential(
            timm.create_model("resnet18",num_classes=1,pretrained=True),
            torch.nn.Sigmoid()
        )"""
        nc = 3
        ndf = 64
        self.discriminator = torch.nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
    def configure_optimizers(self):
        self.generator_optimizer = torch.optim.Adam(self.generator.parameters(),
                                                        lr=2e-4,
                                                        betas=(0.5, 0.999))
        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                                        lr=2e-4,
                                                        betas=(0.5, 0.999))
        return [self.generator_optimizer, self.discriminator_optimizer]
    
    def training_step(self, train_batch, batch_idx):
        g_opt, d_opt = self.optimizers()
        real_images,_ = train_batch
        batch_size = real_images.shape[0]
        
        r1, r2 = 0.75, 0.83
        real_label = (r1 - r2) * torch.randn(1) + r2
        real_label = real_label.item()
        fake_label = 0.1
        
        targets = torch.full((batch_size,),fill_value=real_label,dtype=torch.float,device=self.device)
        
        ####### 1 - UPDATE DISCRIMINATOR #######
        d_opt.zero_grad()
        self.discriminator.zero_grad()
        # REAL IMAGE BATCH PASSED THROUGH D
        preds = self.discriminator(real_images)
        preds = preds.squeeze()
        D_x = preds.mean().item()
        
        real_loss = self.loss(preds,targets)
        self.manual_backward(real_loss)
        # FAKE IMAGE BATCH PASSED THROUGH D
        generator_inputs = torch.randn(batch_size,self.latent_size,1,1,device=self.device)
        fake_images = self.generator(generator_inputs)
        # PREDICTIONS ON  FAKE IMAGES
        preds = self.discriminator(fake_images.detach())
        preds = preds.squeeze()
        D_G_z1 = preds.mean().item()
        
        targets.fill_(fake_label)
        fake_loss = self.loss(preds,targets)
        self.manual_backward(fake_loss)
        d_opt.step()
        loss_d = real_loss+fake_loss
        #self.manual_backward(loss_d)
        #d_opt.step()
        
        ####### 2 - UPDATE GENERATOR #######
        g_opt.zero_grad()
        self.generator.zero_grad()
        targets.fill_(real_label)
        # PREDICTIONS ON FAKE IMAGES
        preds = self.discriminator(fake_images)
        preds = preds.squeeze()
        D_G_z2 = preds.mean().item()
        generator_loss = self.loss(preds,targets)
        self.manual_backward(generator_loss)
        g_opt.step()
        
        self.log_dict({
            "loss_d":loss_d.item(),
            "loss_g":generator_loss.item(),
            "D_x":D_x,
            "D_G_z1":D_G_z1,
            "D_G_z2":D_G_z2,
        }, prog_bar=True)
        
    def validation_step(self, batch, batch_idx):
        if batch_idx==0:
            fake = self.generator(self.fixed_noise.to(self.device)).detach().cpu()
            grid = vutils.make_grid(fake, padding=2, normalize=True)
            self.logger.experiment.add_image('fixed_noise_outputs', grid, self.global_step)

In [15]:
batch_size = 128
workers =16
version = "64x64Images_NoNormGenLast_1"
name="DCGAN_Abstract"
logger = TensorBoardLogger(
            "../data/interim/tboard/",
            name=name,
            version=version,
        )
epochs = 1000
checkpoint_callback = ModelCheckpoint(
            save_top_k=3,
            verbose=True,
            monitor="loss_g",
            mode="min",
            filename="model-epoch{epoch:02d}-val_loss_g_{loss_g:.2f}",
            auto_insert_metric_name=False,
            dirpath="../data/interim/tboard/{}/{}/models/".format(name,version),
        )


transforms = A.Compose([
            A.Resize(64, 64),
            A.Normalize(mean=0, std=1, max_pixel_value=255),
            ToTensorV2(),
        ])
#train_dataset = MonetDataset(dataroot="../data/raw/monet/",transforms=transforms,shuffle=True,complete_dataset=False)

train_dataset = AbstractArtDataset(dataroot="../data/raw/art/",
                             transforms=transforms,
                             size=(64,64))

In [16]:
# Create the dataloader
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=workers)
# Create the generator
model = GAN(100)
trainer = pl.Trainer(
    logger=logger,
    max_epochs=epochs,
    log_every_n_steps=1,
    gpus=1,
    callbacks=[checkpoint_callback],
    enable_checkpointing=True,
    enable_progress_bar=True,
    val_check_interval=10,
    #resume_from_checkpoint="../data/interim/tboard/DCGAN/64x64Images_NoNormGenLast/models/model-epoch81-val_loss_g_1.06.ckpt"
)
trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=train_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type       | Params
---------------------------------------------
0 | generator     | Sequential | 12.7 M
1 | discriminator | Sequential | 2.8 M 
2 | loss          | BCELoss    | 0     
---------------------------------------------
15.4 M    Trainable params
0         Non-trainable params
15.4 M    Total params
61.696    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 0, global step 9: loss_g reached 2.43019 (best 2.43019), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch00-val_loss_g_2.43.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 19: loss_g reached 3.11348 (best 2.43019), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch00-val_loss_g_3.11.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 1, global step 32: loss_g reached 3.23749 (best 2.43019), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch01-val_loss_g_3.24.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 1, global step 42: loss_g reached 2.40476 (best 2.40476), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch01-val_loss_g_2.40.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 2, global step 55: loss_g reached 2.68225 (best 2.40476), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch02-val_loss_g_2.68.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 2, global step 65: loss_g reached 1.61375 (best 1.61375), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch02-val_loss_g_1.61.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 3, global step 78: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 3, global step 88: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 4, global step 101: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 4, global step 111: loss_g reached 1.82553 (best 1.61375), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch04-val_loss_g_1.83.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 5, global step 124: loss_g reached 2.19613 (best 1.61375), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch05-val_loss_g_2.20.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 5, global step 134: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 6, global step 147: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 6, global step 157: loss_g reached 1.45907 (best 1.45907), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch06-val_loss_g_1.46.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 7, global step 170: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 7, global step 180: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 8, global step 193: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 8, global step 203: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 9, global step 216: loss_g reached 1.34105 (best 1.34105), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch09-val_loss_g_1.34.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 9, global step 226: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 10, global step 239: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 10, global step 249: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 11, global step 262: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 11, global step 272: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 12, global step 285: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 12, global step 295: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 13, global step 308: loss_g reached 1.39314 (best 1.34105), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch13-val_loss_g_1.39.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 13, global step 318: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 14, global step 331: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 14, global step 341: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 15, global step 354: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 15, global step 364: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 16, global step 377: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 16, global step 387: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 17, global step 400: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 17, global step 410: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 18, global step 423: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 18, global step 433: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 19, global step 446: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 19, global step 456: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 20, global step 469: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 20, global step 479: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 21, global step 492: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 21, global step 502: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 22, global step 515: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 22, global step 525: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 23, global step 538: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 23, global step 548: loss_g reached 1.44023 (best 1.34105), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch23-val_loss_g_1.44.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 24, global step 561: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 24, global step 571: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 25, global step 584: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 25, global step 594: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 26, global step 607: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 26, global step 617: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 27, global step 630: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 27, global step 640: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 28, global step 653: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 28, global step 663: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 29, global step 676: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 29, global step 686: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 30, global step 699: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 30, global step 709: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 31, global step 722: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 31, global step 732: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 32, global step 745: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 32, global step 755: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 33, global step 768: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 33, global step 778: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 34, global step 791: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 34, global step 801: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 35, global step 814: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 35, global step 824: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 36, global step 837: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 36, global step 847: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 37, global step 860: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 37, global step 870: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 38, global step 883: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 38, global step 893: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 39, global step 906: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 39, global step 916: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 40, global step 929: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 40, global step 939: loss_g reached 0.93683 (best 0.93683), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch40-val_loss_g_0.94.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 41, global step 952: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 41, global step 962: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 42, global step 975: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 42, global step 985: loss_g reached 1.20408 (best 0.93683), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch42-val_loss_g_1.20.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 43, global step 998: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 43, global step 1008: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 44, global step 1021: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 44, global step 1031: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 45, global step 1044: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 45, global step 1054: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 46, global step 1067: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 46, global step 1077: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 47, global step 1090: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 47, global step 1100: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 48, global step 1113: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 48, global step 1123: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 49, global step 1136: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 49, global step 1146: loss_g reached 1.25900 (best 0.93683), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch49-val_loss_g_1.26.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 50, global step 1159: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 50, global step 1169: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 51, global step 1182: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 51, global step 1192: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 52, global step 1205: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 52, global step 1215: loss_g reached 0.87925 (best 0.87925), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch52-val_loss_g_0.88.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 53, global step 1228: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 53, global step 1238: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 54, global step 1251: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 54, global step 1261: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 55, global step 1274: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 55, global step 1284: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 56, global step 1297: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 56, global step 1307: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 57, global step 1320: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 57, global step 1330: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 58, global step 1343: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 58, global step 1353: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 59, global step 1366: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 59, global step 1376: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 60, global step 1389: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 60, global step 1399: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 61, global step 1412: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 61, global step 1422: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 62, global step 1435: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 62, global step 1445: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 63, global step 1458: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 63, global step 1468: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 64, global step 1481: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 64, global step 1491: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 65, global step 1504: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 65, global step 1514: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 66, global step 1527: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 66, global step 1537: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 67, global step 1550: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 67, global step 1560: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 68, global step 1573: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 68, global step 1583: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 69, global step 1596: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 69, global step 1606: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 70, global step 1619: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 70, global step 1629: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 71, global step 1642: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 71, global step 1652: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 72, global step 1665: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 72, global step 1675: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 73, global step 1688: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 73, global step 1698: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 74, global step 1711: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 74, global step 1721: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 75, global step 1734: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 75, global step 1744: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 76, global step 1757: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 76, global step 1767: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 77, global step 1780: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 77, global step 1790: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 78, global step 1803: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 78, global step 1813: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 79, global step 1826: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 79, global step 1836: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 80, global step 1849: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 80, global step 1859: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 81, global step 1872: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 81, global step 1882: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 82, global step 1895: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 82, global step 1905: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 83, global step 1918: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 83, global step 1928: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 84, global step 1941: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 84, global step 1951: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 85, global step 1964: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 85, global step 1974: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 86, global step 1987: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 86, global step 1997: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 87, global step 2010: loss_g reached 1.20315 (best 0.87925), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch87-val_loss_g_1.20.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 87, global step 2020: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 88, global step 2033: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 88, global step 2043: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 89, global step 2056: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 89, global step 2066: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 90, global step 2079: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 90, global step 2089: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 91, global step 2102: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 91, global step 2112: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 92, global step 2125: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 92, global step 2135: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 93, global step 2148: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 93, global step 2158: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 94, global step 2171: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 94, global step 2181: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 95, global step 2194: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 95, global step 2204: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 96, global step 2217: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 96, global step 2227: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 97, global step 2240: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 97, global step 2250: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 98, global step 2263: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 98, global step 2273: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 99, global step 2286: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 99, global step 2296: loss_g reached 1.14887 (best 0.87925), saving model to "/home/aahan/git_repos/LightningGan/data/interim/tboard/DCGAN_Abstract/64x64Images_NoNormGenLast_1/models/model-epoch99-val_loss_g_1.15.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 100, global step 2309: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 100, global step 2319: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 101, global step 2332: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 101, global step 2342: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 102, global step 2355: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 102, global step 2365: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 103, global step 2378: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 103, global step 2388: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 104, global step 2401: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 104, global step 2411: loss_g was not in top 3


Validating: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
m = GAN.load_from_checkpoint("../data/interim/tboard/DCGAN_Monet/64x64Images_NoNormGenLast_4/models/model-epoch140-val_loss_g_1.75.ckpt",latent_size=100)

In [None]:
tmp = torch.randn(1,100,1,1,requires_grad=False,device=model.device)
o = m.generator(tmp.detach())
img = o * 127.5 + 127.5
img = img.squeeze().permute(1, 2, 0).cpu().detach().numpy().astype(np.uint8)
plt.imshow(img)
plt.show()

In [None]:
# Root directory for dataset
dataroot = "../data/raw/monet/"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 1000

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5# Create the dataloader

transforms = A.Compose([
            A.Resize(64, 64),
            A.Normalize(mean=0, std=1, max_pixel_value=255),
            ToTensorV2(),
        ])
train_dataset = MonetDataset(dataroot="../data/raw/monet/",
                         transforms=transforms,complete_dataset=False)


dataloader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=workers)

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

In [None]:
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.02.
netG.apply(weights_init)

# Print the model
print(netG)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)
# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)

In [None]:
# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 50 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()