In [1]:
from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab
from skimage.io import imsave
from torchvision import transforms
from PIL import Image
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

In [2]:
## Load and preprocess the image using torchvision
image_path = 'woman.jpg'

# Load the image and apply the transformation
image_pil = Image.open(image_path).convert('RGB')

image = np.array(image_pil)

In [3]:
X = rgb2lab(1.0/255*image)[:,:,0]
Y = rgb2lab(1.0/255*image)[:,:,1:]
Y /= 128
X = X.reshape(1, 400, 400, 1)
Y = Y.reshape(1, 400, 400, 2)

In [4]:
class ColorizationNet(nn.Module):
    def __init__(self):
        super(ColorizationNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
        self.conv5 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv7 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv8 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
        self.upsample3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv9 = nn.Conv2d(16, 2, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = self.upsample1(x)
        x = F.relu(self.conv7(x))
        x = self.upsample2(x)
        x = F.relu(self.conv8(x))
        x = self.upsample3(x)
        x = torch.tanh(self.conv9(x))

        return x




In [5]:
# Instantiate the model
model = ColorizationNet()

# Print the model architecture
print(model)
criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters())

ColorizationNet(
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv5): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (upsample1): Upsample(scale_factor=2.0, mode='bilinear')
  (conv7): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upsample2): Upsample(scale_factor=2.0, mode='bilinear')
  (conv8): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upsample3): Upsample(scale_factor=2.0, mode='bilinear')
  (conv9): Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)


In [13]:
# Assuming your model and data are already defined
# Assuming X and Y are NumPy arrays
# Convert X and Y to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32).permute(0, 3, 1, 2) #TF and PyTorch define differently
Y_tensor = torch.tensor(Y, dtype=torch.float32).permute(0, 3, 1, 2) #
# Optionally, move the tensors to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_tensor = X_tensor.to(device)
Y_tensor = Y_tensor.to(device)

# Create a DataLoader for batch training
dataset = TensorDataset(X_tensor, Y_tensor)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Instantiate the model
model = ColorizationNet()

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.9)

# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    total_loss = 0.0
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(dataloader)
    if(epoch%100==0):
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}")


Epoch [1/1000], Loss: 0.0156
Epoch [101/1000], Loss: 0.0024
Epoch [201/1000], Loss: 0.0018
Epoch [301/1000], Loss: 0.0014
Epoch [401/1000], Loss: 0.0008
Epoch [501/1000], Loss: 0.0006
Epoch [601/1000], Loss: 0.0009
Epoch [701/1000], Loss: 0.0007
Epoch [801/1000], Loss: 0.0004
Epoch [901/1000], Loss: 0.0004


In [14]:
with torch.no_grad():
    model.eval()
    output_tensor = model(X_tensor)
    output_numpy = output_tensor.cpu().numpy()

# Post-process and save images
cur = np.zeros((output_numpy.shape[2], output_numpy.shape[3], 3))
cur[:,:,0] = X[0][:,:,0]
cur[:,:,1:] = (output_numpy[0] * 128).transpose(1, 2, 0)
output_rgb = (lab2rgb(cur) * 255).astype(np.uint8)

input_rgb = (rgb2gray(lab2rgb(cur)) * 255).astype(np.uint8)

imsave("img_result.png", output_rgb)
imsave("img_gray_version.png", input_rgb)