In [1]:
# Numpy
import numpy as np

# MatLab Plot
import matplotlib.pyplot as plt

# Torch
import torch
import torch.cuda
import torch.utils.data

import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn

from torchsummary import summary

# Color
from skimage import color

# Colorizer Modules
import lab_dataloader
import architecture


In [2]:

# Dataloader for the training set
train_transforms = transforms.Compose([transforms.Resize(256),
    transforms.CenterCrop(256), transforms.ToTensor()])
train_dataset = lab_dataloader.LABImageFolder("PetImages/", transform=train_transforms)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

# Dataloader for the validation set (Currently the same as the training set)
validate_transforms = transforms.Compose([transforms.Resize(256),
    transforms.CenterCrop(256), transforms.ToTensor()])
validate_dataset = lab_dataloader.LABImageFolder("PetImages/", transform=validate_transforms)
validate_dataloader = torch.utils.data.DataLoader(validate_dataset, batch_size=32, shuffle=True)


In [3]:

color_model = architecture.ColorizerModel()
summary(color_model, (1, 256, 256))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 64, 128, 128]        --
|    └─Conv2d: 2-1                       [-1, 64, 256, 256]        640
|    └─ReLU: 2-2                         [-1, 64, 256, 256]        --
|    └─Conv2d: 2-3                       [-1, 64, 128, 128]        36,928
|    └─ReLU: 2-4                         [-1, 64, 128, 128]        --
|    └─BatchNorm2d: 2-5                  [-1, 64, 128, 128]        128
├─Sequential: 1-2                        [-1, 128, 64, 64]         --
|    └─Conv2d: 2-6                       [-1, 128, 128, 128]       73,856
|    └─ReLU: 2-7                         [-1, 128, 128, 128]       --
|    └─Conv2d: 2-8                       [-1, 128, 64, 64]         147,584
|    └─ReLU: 2-9                         [-1, 128, 64, 64]         --
|    └─BatchNorm2d: 2-10                 [-1, 128, 64, 64]         256
├─Sequential: 1-3                        [-1, 256, 32, 32]         --

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 64, 128, 128]        --
|    └─Conv2d: 2-1                       [-1, 64, 256, 256]        640
|    └─ReLU: 2-2                         [-1, 64, 256, 256]        --
|    └─Conv2d: 2-3                       [-1, 64, 128, 128]        36,928
|    └─ReLU: 2-4                         [-1, 64, 128, 128]        --
|    └─BatchNorm2d: 2-5                  [-1, 64, 128, 128]        128
├─Sequential: 1-2                        [-1, 128, 64, 64]         --
|    └─Conv2d: 2-6                       [-1, 128, 128, 128]       73,856
|    └─ReLU: 2-7                         [-1, 128, 128, 128]       --
|    └─Conv2d: 2-8                       [-1, 128, 64, 64]         147,584
|    └─ReLU: 2-9                         [-1, 128, 64, 64]         --
|    └─BatchNorm2d: 2-10                 [-1, 128, 64, 64]         256
├─Sequential: 1-3                        [-1, 256, 32, 32]         --

In [None]:
# Loads each batch from the dataset and shows the first image from each
for images_l, images_ab, label in train_dataloader:
    plt.figure()
    plt.imshow(color.lab2rgb(torch.cat((images_l[0], images_ab[0]), 2)))
    plt.show()