In [1]:
# import the packages
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [13]:
# create dataset
from torch.utils.data import Dataset
from PIL import Image
import os

# real pizza dataset
class PizzaData(Dataset):
    def __init__(self, train=True, transform=None):
        self.data_dir = "pizzaGANdata/pizzaGANdata"
        self.image_path = os.path.join(self.data_dir, "images")
        self.image_list = os.listdir(self.image_path)
        self.label_path = os.path.join(self.data_dir, "imageLabels.txt")
        with open(self.label_path,"r") as f:
            self.label_list = f.readlines()
            for i in range(len(self.label_list)):
                self.label_list[i] = self.label_list[i].split()
        self.transform = transform
        # split train and test sets
        if train:
            self.image_list = self.image_list[:8000]
            self.label_list = self.label_list[:8000]
        else:
            self.image_list = self.image_list[8000:]
            self.label_list = self.label_list[8000:]
        
    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        img_item_path = os.path.join(self.image_path, img_name)
        img = Image.open(img_item_path)
        if self.transform:
            img = transform(img)
        label = self.lable_list[idx]
        
        return img, label
    
    def __len__(self):
        assert len(self.image_list) == len(self.label_list)
        return len(self.image_list)

# synthetic pizza dataset
# problem: in synthetic data, only 10 labels are available, while the other has 13
class PizzaSynData(Dataset):
    def __init__(self, train=True, transform=None):
        self.data_dir = "pizzaGANsyntheticdata/syntheticDataset"
        if train:
            self.data_dir = os.path.join(self.data_dir, "train")
            f_label_name = "trainLabels.txt"
        else:
            self.data_dir = os.path.join(self.data_dir, "test")
            f_label_name = "testLabels.txt"
        self.image_path = os.path.join(self.data_dir, "images")
        self.image_list = os.listdir(self.image_path)
        self.label_path = os.path.join(self.data_dir, f_label_name)
        with open(self.label_path,"r") as f:
            self.label_list = f.readlines()
            for i in range(len(self.label_list)):
                self.label_list[i] = self.label_list[i].split()
        self.transform = transform
        
    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        img_item_path = os.path.join(self.image_path, img_name)
        img = Image.open(img_item_path)
        if self.transform:
            img = transform(img)
        label = self.lable_list[idx]
        
        return img, label
    
    def __len__(self):
        assert len(self.image_list) == len(self.label_list)
        return len(self.image_list)

In [14]:
# create image transform to reduce and unify the size
transform_pizza = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32,32)),
    torchvision.transforms.ToTensor()
])
# create dataset instances
pizzatrain = PizzaData(train=True, transform=transform_pizza)
pizzatest = PizzaData(train=False, transform=transform_pizza)
pizzasyntrain = PizzaSynData(train=True, transform=transform_pizza)
pizzasyntest = PizzaSynData(train=False, transform=transform_pizza)

# create data loaders
real = True
if real:
    trainloader = torch.utils.data.DataLoader(pizzatrain, batch_size=64, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(pizzatest, batch_size=64, shuffle=True, num_workers=2)
else:
    trainloader = torch.utils.data.DataLoader(pizzasyntrain, batch_size=64, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(pizzasyntest, batch_size=64, shuffle=True, num_workers=2)

In [None]:
# view some images loaded
import matplotlib.pyplot as plt
import numpy as np
def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

sample = next(iter(trainloader))[0]
show(torchvision.utils.make_grid(sample))
print(sample.shape)  ## 64 c'est le batch
                        ## 1 c'est du gris -- sinon ce serait 3 pour du RGB
                        ## 32x32 c'est pour la taille de l'image (petite ici)