In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

from skimage.color import rgb2lab, rgb2gray, lab2rgb
from skimage import io,transform

import os

import numpy as np

use_gpu = torch.cuda.is_available()

In [2]:
class PeopleDataSet(torch.utils.data.Dataset):
    def __init__(self,root_directory,transform=None):
        self.file_names = os.listdir(root_directory)
        self.root_dir = root_directory
        self.transform = transform
        
    def __getitem__(self,index):
        image = io.imread(self.root_dir + self.file_names[index])
            
        image_lab = self.get_lab(image)
        image_ab = self.get_ab(image_lab)
        image_l = self.get_l(image_lab)
        
        return image_l,image_ab
        
    def __len__(self):
        return len(self.file_names)
        
    def get_ab(self,lab_image):
        img_ab = lab_image[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        return img_ab
        
    def get_l(self,image):
        img_l = rgb2gray(image)
        img_l = torch.from_numpy(img_l).unsqueeze(0).float()
        return img_l
        
    def get_lab(self,rgb_image):
        rgb_image = np.asarray(rgb_image)
        img_lab = rgb2lab(rgb_image)
        img_lab = (img_lab + 128) / 255
        
        img_lab_resized = transform.resize(img_lab,(512,512))
        return img_lab_resized
        

In [3]:
train_data = PeopleDataSet(root_directory='/home/andrew/Pictures/Datasets/INRIAPerson/Train/pos/')

train_loader = torch.utils.data.DataLoader(train_data,shuffle=True,batch_size=20)

In [4]:
class ColorizationNet(nn.Module):
    def __init__(self, midlevel_input_size=128, global_input_size=512):
        super(ColorizationNet, self).__init__()
        # Fusion layer to combine midlevel and global features
        self.midlevel_input_size = midlevel_input_size
        self.global_input_size = global_input_size
        self.fusion = nn.Linear(midlevel_input_size + global_input_size, midlevel_input_size)
        self.bn1 = nn.BatchNorm1d(midlevel_input_size)

        # Convolutional layers and upsampling
        self.deconv1_new = nn.ConvTranspose2d(midlevel_input_size, 128, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.Conv2d(midlevel_input_size, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(32)
        self.conv5 = nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1)
#         self.upsample = nn.Upsample(scale_factor=2)

        print('Loaded colorization net.')

    def forward(self, midlevel_input):
        
        x = self.conv1(midlevel_input)
        x = self.bn2(x)
        x = torch.relu(x)
        
        x = F.interpolate(x,scale_factor=2)
        
        x = self.conv2(x)
        x = self.bn3(x)
        x = torch.relu(x)
        
        x = self.conv3(x)
        x = torch.relu(x)
        
        x = F.interpolate(x,scale_factor=2)
        
        x = self.conv4(x)
        x = torch.sigmoid(x)
        
        x = self.conv5(x)
        x = F.interpolate(x,scale_factor=2)
        return x


class ColorNet(nn.Module):
    def __init__(self):
        super(ColorNet, self).__init__()
        
        # Build ResNet and change first conv layer to accept single-channel input
        resnet_gray_model = models.resnet18(num_classes=365)
        resnet_gray_model.conv1.weight = nn.Parameter(resnet_gray_model.conv1.weight.sum(dim=1).unsqueeze(1).data)

        # Extract midlevel and global features from ResNet-gray
        self.midlevel_resnet = nn.Sequential(*list(resnet_gray_model.children())[0:6])
        self.fusion_and_colorization_net = ColorizationNet()

    def forward(self, input_image):

        # Pass input through ResNet-gray to extract features
        midlevel_output = self.midlevel_resnet(input_image)

        # Combine features in fusion layer and upsample
        output = self.fusion_and_colorization_net(midlevel_output) #, global_output)
        return output

In [5]:
net = ColorNet().to('cuda') if use_gpu else ColorNet() 

Loaded colorization net.


In [6]:
optimizer = optim.Adam(net.parameters(),lr=0.01)
criterion = nn.MSELoss()

In [7]:
for epoch in range(5):
    total_loss = 0
    for image_l,image_ab in train_loader:
        if use_gpu:
            image_l,image_ab = image_l.to('cuda'),image_ab.to('cuda')

        optimizer.zero_grad()
        output = net(image_l)
        
        loss = criterion(output, image_ab)
        total_loss += loss
        loss.backward()
        optimizer.step()
    print('Epoch',epoch,'Loss',total_loss.item())

Epoch 0 Loss 4.646118640899658
Epoch 1 Loss 0.1637272983789444
Epoch 2 Loss 0.07120449841022491
Epoch 3 Loss 0.06606699526309967
Epoch 4 Loss 0.06375345587730408
