In [1]:
import os
from typing import Optional

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt

import pytorch_lightning as pl

from skimage import io

In [2]:
DATA_DIR = 'data\Abstract_gallery'
IMAGE_SIZE = 64
BATCH_SIZE = 32
IMAGE_CHANNELS = 3
LATENT_SIZE = 256
EPOCHS = 20

random_seed = 42
torch.manual_seed(random_seed)

<torch._C.Generator at 0x1d0b29646d0>

In [3]:
class AbstractGalleryDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len([image for image in os.listdir(self.root_dir) if os.path.isfile(os.path.join(self.root_dir, image))])

    def __getitem__(self, index):
        image_path = os.listdir(self.root_dir)[index]
        image = io.imread(image_path)

        if self.transform:
            image = self.transform

        return image

In [4]:
class AbstractGalleryDataModule(pl.LightningDataModule):
    def __init__(self, data_dir=DATA_DIR, batch_size=BATCH_SIZE):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [
                transforms.Resize(IMAGE_SIZE),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]
        )

    def setup(self, stage: Optional[str] = None):
        df = AbstractGalleryDataset(DATA_DIR, transform=self.transform)
        self.train, self.test = train_test_split(df, test_size=0.3 ,random_state=42)
        self.test, self.val = train_test_split(self.test, test_size=0.5, random_state=42)

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

In [5]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.apply(weights_init(self))

        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 64 * 8, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 8, 64 * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 4, 64 * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 4, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, IMAGE_CHANNELS, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)