In [17]:
import pytorch_lightning as pl
from skimage import io
from skimage.io import ImageCollection

import torch

from typing import Optional
from torch.utils.data import random_split, DataLoader, Dataset, TensorDataset, ConcatDataset

import torchvision
from torchvision.datasets import ImageFolder

import os
from typing import Any, Callable, List, Optional, Tuple

import torchvision.transforms as transforms
import torchvision.transforms.functional as TVF

from networks import *

In [18]:
%load_ext autoreload
%autoreload 2

In [2]:
pl.seed_everything(42)

Global seed set to 42


42

In [3]:
def show_image(image: torch.Tensor):
    io.imshow(image.permute(1, 2, 0).numpy())

In [4]:
class UnpairedDataSet(Dataset):
    def __init__(self, data_dir: str, dirA: str, dirB: str):
        super().__init__()

        self.imagesA = ImageCollection([f'{data_dir}/{dirA}/*.jpg', f'{data_dir}/{dirA}/*.png'])
        self.imagesB = ImageCollection([f'{data_dir}/{dirB}/*.jpg', f'{data_dir}/{dirB}/*.png'])

        self.transforms = transforms.Compose([
                transforms.ToTensor()
        ])

    def __len__(self):
        return max(len(self.imagesA), len(self.imagesB))

    def __getitem__(self, index: int):
        return { 
            'A': self.transforms(self.imagesA[index % self.__len__()]),
            'B': self.transforms(self.imagesB[index % self.__len__()])
        }

In [5]:
class UnpairedDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = 'data', trainA: str = 'trainA', trainB: str = 'trainB', testA: str = 'testA', testB: str = 'testB', train_frac=0.9, batch_size=4, num_workers=4):
        super().__init__()

        self.data_dir = data_dir

        self.trainA = trainA
        self.trainB = trainB
        self.testA = testA
        self.testB = testB

        self.train_frac = train_frac
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transforms = None

    def setup(self, stage: Optional[str] = None):
        if stage in (None, 'fit'):     
            fit_dataset = UnpairedDataSet(self.data_dir, self.trainA, self.trainB)

            train_size = int(len(fit_dataset) * self.train_frac)
            valid_size = len(fit_dataset) - train_size

            self.train, self.valid = random_split(fit_dataset, [train_size, valid_size])

        if stage in (None, 'test'):
            pass

    def train_dataloader(self):
        return DataLoader(
            self.train, 
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.train, 
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.test, 
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False
        )

In [6]:
dm = UnpairedDataModule()
dm.setup()

In [7]:
data = next(iter(dm.train_dataloader()))

In [9]:
data['A'].shape

torch.Size([4, 3, 256, 256])

In [19]:
self.netG_A = networks.define_G(3, 3, opt.ngf, 'resnet_9blocks', 'instance',
                                not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

In [None]:
class CycleGAN(pl.LightningModule):
    def __init__(self):
        super().__init__()

    def training_step(self, batch, batch_idx):
        raise NotImplementedError

    def validation_step(self, batch, batch_idx):
        raise NotImplementedError

    def test_step(self, batch, batch_idx):
        raise NotImplementedError

    def predict_step(self, batch, batch_idx):
        raise NotImplementedError

    def shared_step(self, batch):
        raise NotImplementedError

