In [None]:
import os
import torch
import sys
import numpy as np

from torch import nn, optim
from PIL import Image
from skimage.color import rgb2lab
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from models import Unet, Discriminator

# Constants that need to be set:

In [None]:
MODEL_SAVE_PATH = 'model/'
# DATASET_PATH = 'dataset/cifar/train/' # Dir with 32x32 images
# DATASET_PATH = 'dataset/small/train/' # Dir with 32x32 images
DATASET_PATH = 'dataset/medium/train/' # Dir with 64x64 images
# DATASET_PATH = 'dataset/large/train/' # Dir with 256x256 images
EPOCHS = 30
BATCH_SIZE = 16

In [None]:
%load_ext tensorboard
%tensorboard --logdir=runs

# Meteres - A handy class from the PyTorch ImageNet tutorial

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

# Dataset that loads images in lab format

In [None]:
class ColorizationDataset(Dataset):
    def __init__(self, dataset_path):
        self.paths = [os.path.join(dataset_path, file) for file in os.listdir(dataset_path)]
    
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32")
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1.
        ab = img_lab[[1, 2], ...] / 110.
        return {'L': L, 'ab': ab}
    
    def __len__(self):
        return len(self.paths)

# Prepare dataset loader

In [None]:
dataset = ColorizationDataset(DATASET_PATH)
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)

# Create models

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Unet().to(device)
discriminator = Discriminator().to(device)
discriminator_true_output = torch.tensor(1.0).to(device)
discriminator_false_output = torch.tensor(0.0).to(device)

# Create loss function

In [None]:
loss_function = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

# Create optimizers

In [None]:
generator_optimizer = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Train

In [None]:
_loss_gen = AverageMeter()
_loss_val = AverageMeter()

for e in range(EPOCHS):
    _loss_gen.reset()
    _loss_val.reset()
    i = 0
    for data in data_loader:
        # Fetch image channels from data
        L = data['L'].to(device)
        ab = data['ab'].to(device)

        # Train discriminator
        fake_color = generator(L)
        discriminator.train()
        for param in discriminator.parameters():
            param.requires_grad = True
        discriminator_optimizer.zero_grad()
        fake_image = torch.cat([L, fake_color], dim=1)
        fake_preds = discriminator(fake_image.detach())
        discriminator_loss_fake = loss_function(fake_preds, discriminator_true_output.expand_as(fake_preds))
        real_image = torch.cat([L, ab], dim=1)
        real_preds = discriminator(real_image)
        discriminator_loss_real = loss_function(real_preds, discriminator_false_output.expand_as(real_preds))
        discriminator_loss = (discriminator_loss_fake + discriminator_loss_real) * 0.5
        discriminator_loss.backward()
        _loss_gen.update(discriminator_loss.item(), L.size(0))
        discriminator_optimizer.step()

        # Train Generator
        generator.train()
        for param in discriminator.parameters():
            param.requires_grad = False
        generator_optimizer.zero_grad()
        fake_image = torch.cat([L, fake_color], dim=1)
        fake_preds = discriminator(fake_image)
        generator_loss = loss_function(fake_preds, discriminator_false_output.expand_as(fake_preds))
        generator_l1_loss = l1_loss(fake_color, ab) * 100
        generator_loss = generator_loss + generator_l1_loss
        generator_loss.backward()
        _loss_val.update(generator_loss.item(), L.size(0))
        generator_optimizer.step()
        
        # Print train state
        i += 1
        if i % 32 != 0:
            continue

        print(f"\nEpoch {e+1}/{EPOCHS}")
        print(f"Iteration {i}/{len(data_loader)}")
        print(f"Loss gen {_loss_gen.val:.4f} ({_loss_gen.avg:.4f})")
        print(f"Loss val {_loss_val.val:.4f} ({_loss_val.avg:.4f})")
    # Log losses
    writer.add_scalar("Loss/generator", _loss_gen.avg, e)
    writer.add_scalar("Loss/validator", _loss_val.avg, e)
    writer.flush()

# Save model

In [None]:
torch.save(generator.state_dict(), os.path.join(MODEL_SAVE_PATH, 'model.pt'))