# Image Colorization using Autoencoder
Deep Learning model written in Pytorch to convert gray scaled images into RGB/LAB images.

## Setup
Import Libraries
Convert images to grayscale and sort them for training

In [None]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os
from PIL import Image
import numpy as np
from keras import layers

import wandb
import datetime
import time
import copy

from tqdm import tqdm


In [None]:
# Define Input and Output folder for color transformation
input_folder = './color_train/'
output_folder = './gray_train/'
original_bw_folder = './original_bw_images'
num_img_to_process = 100

## Load Functions

In [None]:
# Function to load and preprocess images
def load_image(image_path, color_mode='rgb', target_size=(160, 160)):
    # Load the image
    image = Image.open(image_path)
    
    # Convert color mode if necessary
    if color_mode == 'rgb':
        image = image.convert('RGB')
    elif color_mode == 'grayscale':
        image = image.convert('L')
    
    # Resize the image
    resize_transform = transforms.Resize(target_size)
    image = resize_transform(image)
    
    # Convert to tensor and normalize
    image = transforms.ToTensor()(image)  # Converts to [0, 1] range
    if color_mode == 'rgb':
        image = transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])(image)  # Normalize
    
    # Ensure grayscale images are converted to RGB format
    if color_mode == 'grayscale':
        image = image.expand(3, -1, -1)
    
    return image

## Defining Classes for sorting the input data

In [None]:
# Defining class of Image DataSet:
class ImageDataset(Dataset):
    def __init__(self, image_paths, color_mode='rgb', target_size=(160, 160)):
        self.image_paths = image_paths
        self.color_mode = color_mode
        self.target_size = target_size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]

        # Open image to check if it matches the expected color mode
        with Image.open(image_path) as img:
            img_mode = img.mode
            
            if self.color_mode == 'grayscale' and img_mode != 'L':
                print(f"Warning: {image_path} is expected to be grayscale (mode 'L'), but it is {img_mode}.")
                warning_flag = True

            elif self.color_mode == 'rgb' and img_mode != 'RGB':
                print(f"Warning: {image_path} is expected to be RGB (mode 'RGB'), but it is {img_mode}.")
                warning_flag = True
            else:
                warning_flag = False
  
        # Load the image using the desired color mode
        image = load_image(image_path, color_mode=self.color_mode, target_size=self.target_size)

        if warning_flag:
            print(f"Warning - Images located incorrectly: {image_path}")

        return image

class PairedDataset(Dataset):
    def __init__(self, grayscale_dataset, color_dataset):
        assert len(grayscale_dataset) == len(color_dataset), "Datasets must have the same length."
        self.grayscale_dataset = grayscale_dataset
        self.color_dataset = color_dataset

    def __len__(self):
        return len(self.grayscale_dataset)

    def __getitem__(self, idx):
        grayscale_image = self.grayscale_dataset[idx]
        color_image = self.color_dataset[idx]
        return grayscale_image, color_image

## Preprocessing
### Loading images
Converting RGB images from input folder to grayscale images and saving them in output folder. 
Additionally, it filters out original grayscale images, saving it into "original_bw_images" folder to evaluate later on. (Ground truth unknown here)

In [None]:

# Create Output Folder if it does not exist yet
os.makedirs(output_folder, exist_ok=True)
os.makedirs(original_bw_folder, exist_ok=True)

# Take num_img_to_process images to convert to grayscale
input_all = os.listdir(input_folder)[:num_images]
num_images = min(num_img_to_process, len(input_all))
image_files = [f for f in input_all if f.endswith(('jpg', 'jpeg', 'png'))]


print(f"Collected {len(image_files)} images in folder {input_folder}.")

# Itera su ogni immagine, convertila in grayscale e salvala
for image_file in image_files:
    input_path = os.path.join(input_folder, image_file)
    output_path = os.path.join(output_folder, image_file)
    originals_path = os.path.join(original_bw_folder, image_file)
    # Apri l'immagine e converti in grayscale
    with Image.open(input_path) as img:
        # Original Grayscale images in Coco Dataset:
        if img.mode == 'L':
            img.save(originals_path)
            os.remove(input_path)
            print(input_path)
        else:
            grayscale_img = img.convert("L")  # Converti in grayscale
            grayscale_img.save(output_path)

print(f"Images converted into grayscale and saved in {output_folder}.")


Collected 100 images in folder ./color_train/.
Images converted into grayscale and saved in ./gray_train/.


## Sorting and pairing dataset

Sorting dataset to train and test on

In [None]:
# Define paths to rgb and gray scale images
color_images_path = input_folder
grayscale_images_path = output_folder

# Get sorted lists of image file paths
color_image_files = sorted([os.path.join(color_images_path, f) for f in os.listdir(color_images_path) if f.endswith(('jpg', 'png', 'jpeg'))])
grayscale_image_files = sorted([os.path.join(grayscale_images_path, f) for f in os.listdir(grayscale_images_path) if f.endswith(('jpg', 'png', 'jpeg'))])

# Ensure the filenames without paths are used for matching
color_filenames = {os.path.basename(f): f for f in color_image_files}
grayscale_filenames = {os.path.basename(f): f for f in grayscale_image_files}

# Match pairs based on filenames
paired_filenames = [(grayscale_filenames[f], color_filenames[f]) for f in grayscale_filenames if f in color_filenames]

# Check if pairs were created correctly
print(f"Number of pairs created: {len(paired_filenames)}")

# Create a dataset from the paired filenames
grayscale_paths, color_paths = zip(*paired_filenames)

grayscale_paths = list(grayscale_paths)
color_paths = list(color_paths)

# Split the dataset into train and test sets (train = 85%, test = 15%)
train_size  = int(len(grayscale_paths)*0.85)
test_size = int(len(grayscale_paths)*0.15)
print(f"train len: {train_size}, test len:{test_size}")

train_grayscale_paths = grayscale_paths[:train_size]
test_grayscale_paths = grayscale_paths[test_size:]

train_color_paths = color_paths[:train_size]
test_color_paths = color_paths[test_size:]

# Creating datasets
## train dataset
train_grayscale_ds = ImageDataset(train_grayscale_paths, color_mode='grayscale')
train_color_ds = ImageDataset(train_color_paths, color_mode='rgb')

## test dataset
test_grayscale_ds = ImageDataset(test_grayscale_paths, color_mode='grayscale')
test_color_ds = ImageDataset(test_color_paths, color_mode='rgb')

# Ensure there are no mix-ups in dataset - rgb/grayscale separation
for idx in range(len(train_grayscale_ds)):
    image = train_grayscale_ds[idx]

for idx in range(len(train_color_ds)):
    image = train_color_ds[idx] 

for idx in range(len(test_grayscale_ds)):
    image = test_grayscale_ds[idx] 

for idx in range(len(test_color_ds)):
    image = test_color_ds[idx]

# Creating data loaders
train_grayscale_loader = DataLoader(train_grayscale_ds, batch_size=32, shuffle=True, num_workers=4)
train_color_loader = DataLoader(train_color_ds, batch_size=32, shuffle=True, num_workers=4)

test_grayscale_loader = DataLoader(test_grayscale_ds, batch_size=32, shuffle=False, num_workers=4)
test_color_loader = DataLoader(test_color_ds, batch_size=32, shuffle=False, num_workers=4)

# Combine grayscale and color datasets into paired datasets
train_dataset = PairedDataset(train_grayscale_ds, train_color_ds)
test_dataset = PairedDataset(test_grayscale_ds, test_color_ds)

# DataLoader configurations
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# # Train the model using the train dataset
# model.fit(train_dataset, epochs=50, validation_data=test_dataset)


Number of pairs created: 3552
train len: 3019, test len:532
