In [1]:
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
from skimage import color # For rgb2la & lab2rgb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_colab = None

In [2]:
# !pip install fastai==2.4

In [3]:
DATA_DIR = './data'
BATCH_SIZE = 16
IMG_SIZE = 256
NUM_WORKERS = 4

Although we are using the same dataset and number of training samples, the exact 8000 images that you train your model on may vary (although we are seeding!) because the dataset here has only 20000 images with different ordering while I sampled 10000 images from the complete dataset.

### 1.2- Making Datasets and DataLoaders

I hope the code is self-explanatory. I'm resizing the images and flipping horizontally (flipping only if it is training set) and then I read an RGB image, convert it to Lab color space and separate the first (grayscale) channel and the color channels as my inputs and targets for the models  respectively. Then I'm making the data loaders.

In [4]:
class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((IMG_SIZE, IMG_SIZE),  Image.BICUBIC),
                # (Train-only) data augmentation
                # TODO: Try different variants?
                transforms.RandomHorizontalFlip(),
            ])
        elif split == 'val':
            self.transforms = transforms.Resize((IMG_SIZE, IMG_SIZE),  Image.BICUBIC)

        self.split = split
        self.paths = paths
    
    def __getitem__(self, idx):
        img_path = self.paths[idx]
        img_rgb = Image.open(img_path).convert("RGB")
        img_rgb = self.transforms(img_rgb)
        img_rgb = np.array(img_rgb)
        img_lab = color.rgb2lab(img_rgb).astype("float32")
        img_lab = transforms.ToTensor()(img_lab)
        # TODO: Understand this (why not use a Torch transform?)
        L  = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
        ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
        
        # Store L and ab channels, as well as class
        return { 'L': L, 'ab': ab, 'class': self.map_class(img_path.split('/')[-2]) }
    
    def __len__(self):
        return len(self.paths)
        
    def map_class(self, cls):
        if cls=='sea_anemone':
            return 0
        if cls=='pufferfish':
            return 1
        if cls=='sea_cucumber':
            return 2
        if cls=='sea_snake':
            return 3
        if cls=='lionfish':
            return 4

In [5]:
train_paths = []
val_paths = []

for subdir, dirs, files in os.walk(DATA_DIR + "/train"):
    if len(files) == 0: continue
    for file in files:
        if file.endswith('.JPEG'):
            train_paths.append(subdir + "/" + file)

for subdir, dirs, files in os.walk(DATA_DIR + "/val"):
    if len(files) == 0: continue
    for file in files:
        if file.endswith('.JPEG'):
            val_paths.append(subdir + "/" + file)

train_dataset = ColorizationDataset(train_paths, split='train')
val_dataset   = ColorizationDataset(val_paths,   split='val')
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True, shuffle=True)
val_dataloader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True, shuffle=True)

print("Train:", len(train_dataset), "samples ·", len(train_dataloader), "batches")
print("Val:", len(val_dataset), "samples ·", len(val_dataloader), "batches")

Train: 6500 samples · 407 batches
Val: 250 samples · 16 batches




In [6]:
# Sanity check: looks good!
for batch in train_dataloader:
    print(batch['L'].shape, batch['ab'].shape, batch['class'])
    break

torch.Size([16, 1, 256, 256]) torch.Size([16, 2, 256, 256]) tensor([4, 4, 2, 0, 1, 4, 2, 0, 0, 3, 2, 4, 4, 4, 1, 2])


### 1.4- Discriminator

The architecture of our discriminator is rather straight forward. This code implements a model by stacking blocks of Conv-BatchNorm-LeackyReLU to decide whether the input image is fake or real. Notice that the first and last blocks do not use normalization and the last block has no activation function (it is embedded in the loss function we will use).

In [7]:
class PatchDiscriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
                          for i in range(n_down)] # the 'if' statement is taking care of not using
                                                  # stride of 2 for the last block in this loop
        model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
                                                                                             # activation for the last layer of the model
        self.model = nn.Sequential(*model)

    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]          # it's always helpful to make a separate method for that purpose
        if norm: layers += [nn.BatchNorm2d(nf)]
        if act: layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

Let's take a look at its blocks:

In [8]:
PatchDiscriminator(3)

PatchDiscriminator(
  (model): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (

And its output shape:

In [9]:
discriminator = PatchDiscriminator(3)
dummy_input = torch.randn(16, 3, 256, 256) # batch_size, channels, size, size
out = discriminator(dummy_input)
out.shape

torch.Size([16, 1, 30, 30])

### 1.5- GAN Loss

This is a handy class we can use to calculate the GAN loss of our final model. In the __init__ we decide which kind of loss we're going to use (which will be "vanilla" in our project) and register some constant tensors as the "real" and "fake" labels. Then when we call this module, it makes an appropriate tensor full of zeros or ones (according to what we need at the stage) and computes the loss.

In [10]:
class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()

    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)

    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

### 1.x Model Initialization

In the TowardsDataScince article, I didn't explain this function. Here is our logic to initialize our models. We are going to initialize the weights of our model with a mean of 0.0 and standard deviation of 0.02 which are the proposed hyperparameters in the article:

In [11]:
def init_weights(net, init='norm', gain=0.02):

    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')

            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)

    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net

def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

### 1.6- Putting everything together

This class brings together all the previous parts and implements a few methods to take care of training our complete model. Let's investigate it.

In the __init__ we define our generator and discriminator using the previous functions and classes we defined and we also initialize them with init_model function which I didn't explain here but you can refer to my GitHub repository to see how it works. Then we define our two loss functions and the optimizers of the generator and discriminator.

The whole work is being done in optimize method of this class. First and only once per iteration (batch of training set) we call the module's forward method and store the outputs in fake_color variable of the class.

Then, we first train the discriminator by using backward_D method in which we feed the fake images produced by generator to the discriminator (make sure to detach them from the generator's graph so that they act as a constant to the discriminator, like normal images) and label them as fake. Then we feed a batch of real images from training set to the discriminator and label them as real. We add up the two losses for fake and real and take the average and then call the backward on the final loss.
Now, we can train the generator. In backward_G method we feed the discriminator the fake image and try to fool it by assigning real labels to them and calculating the adversarial loss. As I mentioned earlier, we use L1 loss as well and compute the distance between the predicted two channels and the target two channels and multiply this loss by a coefficient (which is 100 in our case) to balance the two losses and then add this loss to the adversarial loss. Then we call the backward method of the loss.

In [None]:
class ClassifierHead(nn.Module):
  def __init__(self, input_size=512*8*8, output_size=5):
        super().__init__()

        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(input_size, 1000)
        self.fc2 = nn.Linear(1000, output_size)

  def forward(self, x):
      x = self.fc1(x)
      return self.fc2(self.relu(x))

class ColorizationHead(nn.Module):
  def __init__(self, input_c=1, output_c=2, n_down=5, num_filters=512):
        super().__init__()
        self.model_layers = []
        self.model_layers += (self.get_layer(num_filters, num_filters))
        for _ in range(n_down - 5):
            self.model_layers += (self.get_layer(num_filters, num_filters, dropout=True))
        out_filters = num_filters
        for _ in range(3):
            self.model_layers += (self.get_layer(out_filters // 2, out_filters))
            out_filters //= 2
        self.model_layers += (self.get_layer(output_c, out_filters, outermost=True))
        self.model = nn.Sequential(*self.model_layers)

  def forward(self, x):
      return self.model(x)

  def get_layer(self, nf, ni, bias=False, dropout=False, outermost=False):
        layer = []
        conv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
                                              stride=2, padding=1, bias=bias)
        relu = nn.ReLU(False)

        layer.append(relu)
        layer.append(conv)


        if outermost:
          layer.append(nn.Tanh())
        else:
          layer.append(nn.BatchNorm2d(nf))

        if dropout:
          layer.append(nn.Dropout(0.5))

        return layer

class MLTModel(nn.Module):
  def __init__(self, head="multitask"):
          super().__init__()
          self.resnet_body = create_body(resnet18, pretrained=True, n_in=1, cut=-2)
          for param in self.resnet_body.parameters():
              param.requires_grad = False
          self.heads = []
          self.head = head
          if head=="multitask":
            self.classifier_head = ClassifierHead()
            self.colorization_head = ColorizationHead()
          elif head=="classifier":
            self.classifier_head = ClassifierHead()
          elif self.head=="colorization":
            self.colorization_head = ColorizationHead()
          else:
              raise Exception('Invalid head')
  def forward(self, x):
      if self.head=="multitask":
          x = self.resnet_body(x)
          out_classifier = self.classifier_head(x.view(-1, 512*8*8))
          out_colorization_head = self.colorization_head(x)
          return out_classifier, out_colorization_head
      elif self.head=="classifier":
          x = self.resnet_body(x)
          out_classifier = self.classifier_head(x.view(-1, 512*8*8))
          return out_classifier
      elif self.head=="colorization":
          x = self.resnet_body(x)
          out_colorization_head = self.colorization_head(x)
          return out_colorization_head

In [13]:
class MainModel(nn.Module):
    def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1

        if net_G is None:
            self.net_G = init_model(MLTModel(), self.device)
            print(self.net_G.parameters())
        else:
            self.net_G = net_G.to(self.device)
        self.head = self.net_G.head
        self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
        self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1criterion = nn.L1Loss()
        self.class_criterion = nn.CrossEntropyLoss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))

    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)
        self.labels = data['class'].to(self.device)
            
            
    def forward(self):
        g_out = self.net_G(self.L)
        if self.head=="multitask":
            self.fake_color = g_out[1]
            self.class_pred = g_out[0]
        elif self.head=="classifier":
            self.class_pred = g_out
        elif self.head=="colorization":
            self.fake_color = g_out
        # print(f"Class preds: {torch.argmax(self.class_pred, dim=1)}")

    def backward_D(self):
        # print(f"L shape: {self.L.shape}")
        # print(f"self.fake_color shape: {self.fake_color.shape}")
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        losses_list = []
        if self.head in ["multitask",  "colorization"]:
            fake_image = torch.cat([self.L, self.fake_color], dim=1)
            fake_preds = self.net_D(fake_image)
            self.loss_G_GAN = self.GANcriterion(fake_preds, True)
            self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
            losses_list.append(self.loss_G_GAN)
            losses_list.append(self.loss_G_L1)

        if self.head in ["multitask",  "classifier"]:
            self.class_loss = self.class_criterion(self.class_pred, self.labels)
            losses_list.append(self.class_loss)
            max_pred = self.class_pred.argmax(dim=1)
            self.class_accuracy = (max_pred == self.labels).sum()
            # print(max_pred)
            # print(self.labels)
        
        # self.loss_G = self.loss_G_GAN + self.loss_G_L1 + self.class_loss
        self.loss_G = sum(losses_list)
        self.loss_G.backward()

    def optimize(self):
        self.forward()
        if self.head!="classifier":
            self.net_D.train()
            self.set_requires_grad(self.net_D, True)
            self.opt_D.zero_grad()
            self.backward_D()
            self.opt_D.step()

        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

### 1.xx Utility functions

These functions were nor included in the explanations of the TDS article. These are just some utility functions to log the losses of our network and also visualize the results during training. So here you can check them out:

In [20]:
class AverageMeter:
    def __init__(self):
        self.reset()
        self.history_losses = []

    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3

    def update(self, val, count=1, acc=False):
        if acc:
            self.count += count
            self.sum += val
            self.avg = self.sum / self.count
        else:
            self.count += count
            self.sum += count * val
            self.avg = self.sum / self.count
        self.history_losses.append(val)

def clear_losses(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss_meter.reset()
        
def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()
    class_loss = AverageMeter()
    class_accuracy = AverageMeter()

    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G,
            'class_loss': class_loss,
            'class_accuracy': class_accuracy}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        try:
            if loss_name == 'class_accuracy':
                acc = True
            else:
                acc=False
            loss = getattr(model, loss_name)
            loss_meter.update(loss.item(), count=count, acc=acc)
        except AttributeError:
            pass

def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """

    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.show()
    if save:
        fig.savefig(f"colorization_{time.time()}.png")

def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        if loss_name == 'class_accuracy':
            print(f"{loss_name}: {loss_meter.avg:.5f}%")
        else:
            print(f"{loss_name}: {loss_meter.avg:.5f}")

I hope this code is self-explanatory. Every epoch takes about 4 minutes on not a powerful GPU as Nvidia P5000. So if you are using 1080Ti or higher, it will be much faster.

In [21]:
def train_model(model, train_dl, epochs, display_every=200):
    data = next(iter(val_dataloader)) # getting a batch for visualizing the model output after fixed intrvals
    loss_meter_dict = create_loss_meters() # function returing a dictionary of objects to
    for e in range(epochs):
        i = 0                                  # log the losses of the complete network
        clear_losses(loss_meter_dict)
        for data in tqdm(train_dl):
            model.setup_input(data)
            model.optimize()
            update_losses(model, loss_meter_dict, count=data['L'].size(0)) # function updating the log objects
            i += 1
            if i % display_every == 0:
                print(f"\nEpoch {e+1}/{epochs}")
                print(f"Iteration {i}/{len(train_dl)}")
                log_results(loss_meter_dict) # function to print out the losses
                if model.head!="classifier":
                    visualize(model, data, save=False) # function displaying the model's outputs
    return model, loss_meter_dict

model = MainModel(net_G=MLTModel(head="classifier"))

model, loss_meter_dict = train_model(model, train_dataloader, 1, display_every=20)

# Saving
torch.save(net_G.state_dict(), "res18-unet.pt")
torch.save(loss_meter_dict, "./loss_meter_dict.ckpt")


model initialized with norm initialization


  0%|          | 0/407 [00:00<?, ?it/s]


Epoch 1/1
Iteration 20/407
loss_D_fake: 0.00000
loss_D_real: 0.00000
loss_D: 0.00000
loss_G_GAN: 0.00000
loss_G_L1: 0.00000
loss_G: 2.86961
class_loss: 2.86961
class_accuracy: 0.57188%

Epoch 1/1
Iteration 40/407
loss_D_fake: 0.00000
loss_D_real: 0.00000
loss_D: 0.00000
loss_G_GAN: 0.00000
loss_G_L1: 0.00000
loss_G: 1.78711
class_loss: 1.78711
class_accuracy: 0.67031%

Epoch 1/1
Iteration 60/407
loss_D_fake: 0.00000
loss_D_real: 0.00000
loss_D: 0.00000
loss_G_GAN: 0.00000
loss_G_L1: 0.00000
loss_G: 1.39069
class_loss: 1.39069
class_accuracy: 0.71667%

Epoch 1/1
Iteration 80/407
loss_D_fake: 0.00000
loss_D_real: 0.00000
loss_D: 0.00000
loss_G_GAN: 0.00000
loss_G_L1: 0.00000
loss_G: 1.16857
class_loss: 1.16857
class_accuracy: 0.74844%

Epoch 1/1
Iteration 100/407
loss_D_fake: 0.00000
loss_D_real: 0.00000
loss_D: 0.00000
loss_G_GAN: 0.00000
loss_G_L1: 0.00000
loss_G: 1.01923
class_loss: 1.01923
class_accuracy: 0.76750%

Epoch 1/1
Iteration 120/407
loss_D_fake: 0.00000
loss_D_real: 0.0000

NameError: name 'net_G' is not defined

Every epoch takes about 3 to 4 minutes on Colab. After about 20 epochs you should see some reasonable results.