In [13]:
import numpy as np
import torch
from skimage.color import rgb2lab, rgb2gray
from torchvision import datasets
import glob
import os
from shutil import copy2
import buildDataset
from Colorize_deep import Colorize_deep
from Constants import Constants
from utils import Utils

import torchvision.transforms as transforms
from torch.utils.data import DataLoader, ConcatDataset

def AugmentImageDataset(ImageFolder,actv_function):
    global img_a, img_b, img_gray
    path, target = imgs[index]
    img = loader(path)
    if transform is not None:
        img = transform(img)

    original_image = np.asarray(img)

    img_lab = rgb2lab(original_image)
    if actv_function=='tanh' or actv_function=='sigmoid':            
        img_lab = img_lab + 128
        img_lab = img_lab / 255
    img_a = img_lab[:, :, 1:2]
    img_a = torch.from_numpy(img_a.transpose((2, 0, 1))).float()  # To match the channel dimensions
    img_b = img_lab[:, :, 2:3]
    img_b = torch.from_numpy(img_b.transpose((2, 0, 1))).float()

    img_gray = rgb2gray(original_image)
    img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()

    return img_gray, img_a, img_b


def load_data():
    image_list = glob.glob('face_images/*.jpg')
    print("Length of given Imge List")
    print(len(image_list))

    #Utils().train_test_split()
    
    os.makedirs('data/train_image/class/', exist_ok=True)
    os.makedirs('data/test_image/class/', exist_ok=True)
    number_of_images = len(next(os.walk('face_images'))[2])
    print("Number of images - ", number_of_images)
    print(len(next(os.walk('face_images'))[2]))
    for i, file in enumerate(os.listdir('face_images')):
        if i < (0.1 * number_of_images):  # first 10% will be val
            copy2('face_images/' + file, 'data/test_image/class/' + file)
            continue
        else:  # others will be train
            copy2('face_images/' + file, 'data/train_image/class/' + file)

        print("Training Set Length : ", len(next(os.walk('data/train_image/class/'))[2]))

        print("Test Set Length : ", len(next(os.walk('data/test_image/class/'))[2]))
    training_image_list = glob.glob('data/train_image/class/*.jpg')
    test_image_list = glob.glob('data/test_image/class/*.jpg')
    print("Length of training Image List", len(training_image_list))
    print("Length of testing Image List", len(test_image_list))

def build_dataset(cuda=False, num_workers=1,
                  activation_function='sigmoid'):
    transform = transforms.Compose([
        transforms.Resize(128),
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(128)
    ])

    train_datasets = []
    train_datasets.append(buildDataset.AugmentImageDataset('data/train_image'))

    for i in range(9):
        train_datasets.append(buildDataset.AugmentImageDataset('data/train_image',transform))


    augmented_dataset = ConcatDataset(train_datasets)
    print("Length of Augmented Dataset", len(augmented_dataset))

    train_loader_args = dict(shuffle=True,
                             batch_size=16,
                             num_workers=num_workers, pin_memory=True) \
        if cuda else dict(shuffle=True, batch_size=32)

    augmented_dataset_batch_train = DataLoader(dataset=augmented_dataset, **train_loader_args)
    augmented_dataset_batch_test = DataLoader(dataset=buildDataset.AugmentImageDataset('data/test_image'))

    return augmented_dataset_batch_train, augmented_dataset_batch_test


def execute_colorizer_tanh():
    activation_function = "tanh"
    save_path = {'grayscale': 'outputs_tanh/gray/', 'colorized': 'outputs_tanh/color/'}
    device, is_cuda_present, num_workers = Utils.get_device()
    model_name = Constants.COLORIZER_SAVED_MODEL_PATH_TANH

    print("Device: {0}".format(device))
    augmented_dataset_batch_train, \
    augmented_dataset_batch_test = build_dataset(is_cuda_present, num_workers,
                                                 activation_function)

    colorizer_deep = Colorize_deep()
    colorizer_deep.train_colorizer(augmented_dataset_batch_train,
                                    activation_function, model_name, device)

    colorizer_deep.test_colorizer(augmented_dataset_batch_test, activation_function,
                                  save_path, model_name, device)

    colorizer_deep.train_regressor(augmented_dataset_batch_train, device)
    colorizer_deep.test_regressor(augmented_dataset_batch_test, device)


def execute_colorizer_sigmoid():
    activation_function = Constants.SIGMOID
    save_path = {'grayscale': 'outputs_sigmoid/gray/', 'colorized': 'outputs_sigmoid/color/'}
    device, is_cuda_present, num_workers = Utils.get_device()
    model_name = Constants.COLORIZER_SAVED_MODEL_PATH_SIGMOID

    print("Device: {0}".format(device))
    augmented_dataset_batch_train, \
    augmented_dataset_batch_test = build_dataset(is_cuda_present, num_workers,
                                                 activation_function)

    colorizer_deep = Colorize_deep()
    colorizer_deep.train_colorizer(augmented_dataset_batch_train,
                                   activation_function, model_name, device)

    colorizer_deep.test_colorizer(augmented_dataset_batch_test, activation_function,
                                  save_path, model_name, device)

    colorizer_deep.train_regressor(augmented_dataset_batch_train, device)
    colorizer_deep.test_regressor(augmented_dataset_batch_test, device)


if __name__ == '__main__':
    load_data()

    print("Normal Credit - Sigmoid")
    execute_colorizer_sigmoid()

    #print("Extra Credit - Tanh")
    #execute_colorizer_tanh()

Length of given Imge List
750
Number of images -  751
751
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Set Length :  675
Test Set Length :  76
Training Se

KeyboardInterrupt: 