In [180]:
import os
from typing import Optional

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split

import pytorch_lightning as pl

from skimage import io
from PIL import Image

In [181]:
DATA_DIR = 'data\Abstract_gallery'
CHECKPOINTS_DIR ='checkpoints'
IMAGE_SIZE = (64, 64)
BATCH_SIZE = 64
IMAGE_CHANNELS = 3
LATENT_SIZE = 256
EPOCHS = 20
NUM_WORKERS = 0
random_seed = 42
torch.manual_seed(random_seed)

<torch._C.Generator at 0x1716a66e3f0>

In [182]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [183]:
class AbstractGalleryDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len([image for image in os.listdir(self.root_dir) if os.path.isfile(os.path.join(self.root_dir, image))])

    def __getitem__(self, index):
        image_path = os.path.join(self.root_dir, os.listdir(self.root_dir)[index])
        image = Image.open(image_path)
        # image = io.imread(image_path)

        if self.transform:
            image = self.transform(image)

        return image

In [184]:
class AbstractGalleryDataModule(pl.LightningDataModule):
    def __init__(self, data_dir=DATA_DIR, batch_size=BATCH_SIZE):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [
                transforms.Resize(IMAGE_SIZE),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]
        )

    def setup(self, stage: Optional[str] = None):
        df = AbstractGalleryDataset(DATA_DIR, transform=self.transform)
        self.train, self.test = random_split(df, [1950, 832])
        # self.test, self.val = train_test_split(self.test, test_size=0.5, random_state=42)

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, num_workers=NUM_WORKERS)

    # def val_dataloader(self):
        # return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size, num_workers=NUM_WORKERS)

In [185]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()

        self.model = nn.Sequential(
            # n x latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, 64 * 8, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(True),
            # n x 64*8 x 4 x 4
            nn.ConvTranspose2d(64 * 8, 64 * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            # n x 64*4 x 8 x 8
            nn.ConvTranspose2d(64 * 4, 64 * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            # n x 64*2 x 16 x 16
            nn.ConvTranspose2d(64 * 2, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # n x 64 x 32 x 32
            nn.ConvTranspose2d(64, IMAGE_CHANNELS, kernel_size=4, stride=2, padding=1),
            # n x IMAGE_CHANNELS x 64 x 64
            nn.Tanh()
        )

        self.apply(self._init_weights)


    def _init_weights(self, module):
        if isinstance(module, nn.Conv2d):
            nn.init.normal_(module.weight.data, 0.0, 0.02)
        if isinstance(module, nn.BatchNorm2d):
            nn.init.normal_(module.weight.data, 1.0, 0.02)
            nn.init.constant_(module.bias.data, 0)


    def forward(self, x):
        return self.model(x)

In [186]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            # n x IMAGE_CHANNELS x 64 x 64
            nn.Conv2d(IMAGE_CHANNELS, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            # n x 64 x 32 x 32
            nn.Conv2d(64, 64 * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2),
            # n x 64*2 x 16 x 16
            nn.Conv2d(64 * 2, 64 * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2),
            # n x 64*4 x 8 x 8
            nn.Conv2d(64 * 4, 64 * 8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2),
            # n x 64*8 x 4 x 4
            nn.Conv2d(64 * 8, 1, kernel_size=4, stride=1, padding=0),
            # n x 1 x 1 x 1
            nn.Sigmoid()
        )

        self.apply(self._init_weights)


    def _init_weights(self, module):
        if isinstance(module, nn.Conv2d):
            nn.init.normal_(module.weight.data, 0.0, 0.02)
        if isinstance(module, nn.BatchNorm2d):
            nn.init.normal_(module.weight.data, 1.0, 0.02)
            nn.init.constant_(module.bias.data, 0)

    def forward(self, x):
        return self.model(x)

In [187]:
class DCGAN(pl.LightningModule):
    def __init__(self, lr=0.0002, latent_dim=100, b1=0.5, b2=0.999):
        super().__init__()
        self.save_hyperparameters()

        self.generator = Generator(self.hparams.latent_dim)
        self.discriminator = Discriminator()

        self.validation_z = torch.randn(8, self.hparams.latent_dim)

    def forward(self, z):
        return  self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        real_imgs = batch

        # sample noise
        z = torch.randn(real_imgs.shape[0], self.hparams.latent_dim, 1, 1)
        z = z.type_as(real_imgs)

        # train generator max log(D(G(z)))
        if optimizer_idx == 0:
            fake_imgs = self(z)
            y_hat = self.discriminator(fake_imgs).reshape(-1)

            sample_imgs = fake_imgs[:6]
            grid = torchvision.utils.make_grid(sample_imgs)
            self.logger.experiment.add_image("generated_images", grid, 0)


            # y = torch.ones(real_imgs.size(0), 1)
            y = torch.ones_like(y_hat)
            y = y.type_as(real_imgs)

            generator_loss = self.adversarial_loss(y_hat, y)

            # log_dict = {'generator_loss': generator_loss}
            self.log('generator_loss', generator_loss, prog_bar=True, sync_dist=True)
            return generator_loss

        # train discriminator max log(D(x)) + log(1 - D(G(z)))
        if optimizer_idx == 1:
            y_hat_real = self.discriminator(real_imgs).reshape(-1)

            # y_real = torch.ones(real_imgs.size(0), 1)
            y_real = torch.ones_like(y_hat_real)
            y_real = y_real.type_as(real_imgs)

            real_loss = self.adversarial_loss(y_hat_real, y_real)

            y_hat_fake = self.discriminator(self(z).detach()).reshape(-1)

            # y_fake = torch.zeros(real_imgs.size(0), 1)
            y_fake = torch.zeros_like(y_hat_fake)
            y_fake = y_fake.type_as(real_imgs)


            fake_loss = self.adversarial_loss(y_hat_fake, y_fake)

            discriminator_loss = (real_loss + fake_loss) / 2

            # log_dict = {'discriminator_loss': discriminator_loss}
            self.log('discriminator_loss', discriminator_loss, prog_bar=True, sync_dist=True)
            return discriminator_loss

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        generator_optimizer = optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))

        return [generator_optimizer, discriminator_optimizer], []

    def on_epoch_end(self):
        z = self.validation_z.to(self.generator.model[0].weight)

        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

In [188]:
dm = AbstractGalleryDataModule()
model = DCGAN()

In [189]:
trainer = pl.Trainer(max_epochs=EPOCHS,
                     accelerator='gpu',
                     devices=1,
                     default_root_dir=CHECKPOINTS_DIR)
trainer.fit(model, dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 3.6 M 
1 | discriminator | Discriminator | 2.8 M 
------------------------------------------------
6.3 M     Trainable params
0         Non-trainable params
6.3 M     Total params
25.377    Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv_transpose2d, but got input of size: [8, 100]