### Imports

In [6]:
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

import os
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from skimage import color # For rgb2la & lab2rgb

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

Device: cuda


### Data loading

In [7]:
DATA_DIR = './data'
BATCH_SIZE = 32
IMG_SIZE = 224
NUM_WORKERS = 4

In [36]:
class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((IMG_SIZE, IMG_SIZE),  Image.BICUBIC),
                # (Train-only) data augmentation
                # TODO: Try different variants? E.g.
                # · random cropping (difficult to classify if subject is out of frame...)
                transforms.RandomHorizontalFlip(),
            ])
        elif split == 'val':
            self.transforms = transforms.Resize((IMG_SIZE, IMG_SIZE),  Image.BICUBIC)

        self.split = split
        self.paths = paths
    
    def __getitem__(self, idx):
        img_path = self.paths[idx]
        img_rgb = Image.open(img_path).convert("RGB")
        img_rgb = self.transforms(img_rgb)
        img_rgb = np.array(img_rgb)
        img_lab = color.rgb2lab(img_rgb).astype("float32")
        img_lab = transforms.ToTensor()(img_lab)
        # TODO: Understand this (why not use a Torch transform?)
        L  = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
        ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
        
        # Store L and ab channels, as well as class
        return { 'L': L, 'ab': ab, 'class': img_path.split('/')[-2] }
    
    def __len__(self):
        return len(self.paths)

In [37]:
train_paths = []
val_paths = []

for subdir, dirs, files in os.walk(DATA_DIR + "/train"):
    if len(files) == 0: continue
    for file in files:
        if file.endswith('.JPEG'):
            train_paths.append(subdir + "/" + file)

for subdir, dirs, files in os.walk(DATA_DIR + "/val"):
    if len(files) == 0: continue
    for file in files:
        if file.endswith('.JPEG'):
            val_paths.append(subdir + "/" + file)

train_dataset = ColorizationDataset(train_paths, split='train')
val_dataset   = ColorizationDataset(val_paths,   split='val')
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True, shuffle=True)
val_dataloader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True, shuffle=True)

print("Train:", len(train_dataset), "samples ·", len(train_dataloader), "batches")
print("Val:", len(val_dataset), "samples ·", len(val_dataloader), "batches")

Train: 6500 samples · 204 batches
Val: 250 samples · 8 batches




In [43]:
# Sanity check: looks good!
for batch in train_dataloader:
    print(batch['L'].shape, batch['ab'].shape, batch['class'])
    break

torch.Size([32, 1, 224, 224]) torch.Size([32, 2, 224, 224]) ['lionfish', 'sea_anemone', 'sea_cucumber', 'sea_anemone', 'pufferfish', 'sea_cucumber', 'lionfish', 'sea_cucumber', 'sea_anemone', 'sea_anemone', 'pufferfish', 'sea_anemone', 'sea_snake', 'pufferfish', 'sea_snake', 'sea_snake', 'sea_cucumber', 'pufferfish', 'pufferfish', 'sea_anemone', 'sea_cucumber', 'sea_anemone', 'pufferfish', 'sea_snake', 'sea_snake', 'sea_cucumber', 'lionfish', 'sea_snake', 'pufferfish', 'pufferfish', 'lionfish', 'lionfish']
