In [1]:
import os

from torch.utils.data import DataLoader, Dataset
from imutils.paths import list_images
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import numpy as np
import torch
import cv2

from misc.model import Generator, Discriminator

In [2]:
SEP        = os.path.sep
ROOT_PATH  = SEP.join(os.getcwd().split(SEP)[:-3])
DATA_PATH  = f'{ROOT_PATH}/Dataset/apple2orange'

DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 8
EPOCHS     = 100
LR         = 2e-4

In [3]:
def preprocess_image(image):
    image = image / 127.5 - 1
    return (image)


def build_dataset(path):
    
    images      = []
    image_paths = list_images(path)
    
    for image_path in image_paths:
        
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = preprocess_image(image)
        
        images.append(image)
        
    return np.array(images, dtype = 'float32')
        
        

In [4]:
train_A = build_dataset(f'{DATA_PATH}/trainA')
train_B = build_dataset(f'{DATA_PATH}/trainB')

test_A  = build_dataset(f'{DATA_PATH}/testA')
test_B  = build_dataset(f'{DATA_PATH}/testB')

In [5]:
class AppleOrangeDataset(Dataset):
    
    def __init__(self, images, dtype = 'train'):
        
        self.images = images
        self.dtype  = dtype
        
        self.transforms          = {}
        self.transforms['train'] = transforms.Compose([
                                        transforms.ToPILImage(),
                                        transforms.Resize([286, 286]),
                                        transforms.RandomCrop(256),
                                        transforms.ToTensor()
                                    ])
        self.transforms['test']  = transforms.Compose([
                                        transforms.ToPILImage(),
                                        transforms.Resize([256, 256]),
                                        transforms.ToTensor()
                                    ])
        
        
    def __getitem__(self, idx):
        
        image = self.images[idx]
        image = self.transforms[self.dtype](image)
        
        return image
    
    
    def __len__(self):
        return len(self.images)

In [6]:
train_A = AppleOrangeDataset(train_A)
train_B = AppleOrangeDataset(train_B)

test_A  = AppleOrangeDataset(test_A, dtype = 'test')
test_B  = AppleOrangeDataset(test_B, dtype = 'test')

In [7]:
train_loader_A = DataLoader(train_A, batch_size = BATCH_SIZE)
train_loader_B = DataLoader(train_A, batch_size = BATCH_SIZE)