In [1]:
from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab
from skimage.io import imsave
from sklearn.metrics import mean_squared_error
from torchvision import transforms
from PIL import Image
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms

In [2]:
# Get images
X = []
for filename in os.listdir('Train/'):
    X.append(np.array(Image.open('Train/'+filename).convert('RGB')))
X = np.array(X, dtype=float)

split = int(0.95*len(X))
Xtrain = X[:split]
Xtrain = 1.0/255*X

In [3]:
class ColorizationNet_Beta(nn.Module):
    def __init__(self):
        super(ColorizationNet_Beta, self).__init__()
        #fix this
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
        self.conv7 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.conv9 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv10 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv11 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.conv12 = nn.Conv2d(32, 2, kernel_size=3, padding=1)
        self.upsample3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        x = F.relu(self.conv9(x))
        x = self.upsample1(x)
        x = F.relu(self.conv10(x))
        x = self.upsample2(x)
        x = F.relu(self.conv11(x))
        x = torch.tanh(self.conv12(x))
        x = self.upsample3(x)

        return x




In [4]:
lab_train = rgb2lab(Xtrain)
X_train = lab_train[:,:,:,0]
Y_train = lab_train[:,:,:,1:] / 128

In [5]:
# Assuming Xtrain is a NumPy array containing training data

# Convert data to PyTorch tensors
lab_train = rgb2lab(Xtrain)
X_train = lab_train[:,:,:,0]
Y_train = lab_train[:,:,:,1:] / 128
X_train = X_train.reshape(10, 256, 256, 1)
Y_train = Y_train.reshape(10, 256, 256, 2)
Xtrain_tensor = torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order
Ytrain_tensor = torch.tensor(Y_train, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order

# Define the PyTorch model, loss function, and optimizer
model = ColorizationNet_Beta()
criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.9)


# Image transformer (data augmentation)
transform = transforms.Compose([
    transforms.RandomAffine(degrees=20, shear=[-5, 5], scale=(0.8, 1.2)),
    transforms.RandomHorizontalFlip(),
])

# Custom dataset class for PyTorch
class ColorizationDataset(torch.utils.data.Dataset):
    def __init__(self, X, Y, transform=None):
        self.X = X
        self.Y = Y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        sample = {'image': self.X[idx], 'target': self.Y[idx]}

        if self.transform:
            sample['image'] = self.transform(sample['image'])

        return sample

# Create DataLoader for batch training
batch_size=2
dataset = ColorizationDataset(Xtrain_tensor, Ytrain_tensor, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch in dataloader:
        inputs, targets = batch['image'], batch['target']
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    average_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}")


Epoch [1/10], Loss: 0.6697
Epoch [2/10], Loss: 0.0084
Epoch [3/10], Loss: 0.0075
Epoch [4/10], Loss: 0.0061
Epoch [5/10], Loss: 0.0059
Epoch [6/10], Loss: 0.0058
Epoch [7/10], Loss: 0.0058
Epoch [8/10], Loss: 0.0059
Epoch [9/10], Loss: 0.0058
Epoch [10/10], Loss: 0.0056


In [6]:
# Convert data to PyTorch tensors
Xtest = X[split:]
Xtest = 1.0/255*X
lab_test = rgb2lab(Xtest)
X_test = lab_test[:,:,:,0]
Y_test = lab_test[:,:,:,1:] / 128
X_test = X_test.reshape(10, 256, 256, 1)
Y_test = Y_test.reshape(10, 256, 256, 2)
Xtest_tensor = torch.tensor(X_test, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order
Ytest_tensor = torch.tensor(Y_test, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order

# Set the model to evaluation mode
model.eval()

# Make predictions on the test data
with torch.no_grad():
    predictions_tensor = model(Xtest_tensor)

# Flatten tensors
Ytest_flat = torch.reshape(Ytest_tensor, (-1,)).numpy()
predictions_flat = torch.reshape(predictions_tensor, (-1,)).numpy()

# Calculate mean squared error
mse = mean_squared_error(Ytest_flat, predictions_flat)
print(f"Mean Squared Error: {mse:.4f}")

Mean Squared Error: 0.0057


In [9]:
# Get images
color_me = []
for filename in os.listdir('Test/'):
    color_me.append(np.array(Image.open('Test/'+filename).convert('RGB')))
color_me = np.array(color_me, dtype=float)
color_me = 1.0/255*color_me
lab_cm = rgb2lab(color_me)
X_in = lab_cm[:,:,:,0]
X_in = X_in.reshape(8, 256, 256, 1)
Xin_tensor = torch.tensor(X_in, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order

# Set the model to evaluation mode
model.eval()

# Make predictions on the test data
with torch.no_grad():
    Y_tensor = model(Xin_tensor)
    output_numpy = Y_tensor.cpu().numpy()

# Post-process and save images
for i in range(8):
    cur = np.zeros((output_numpy.shape[2], output_numpy.shape[3], 3))
    cur[:,:,0] = X_in[i][:,:,0]
    cur[:,:,1:] = (output_numpy[i] * 128).transpose(1, 2, 0)
    output_rgb = (lab2rgb(cur) * 255).astype(np.uint8)
    #input_rgb = (rgb2gray(lab2rgb(cur)) * 255).astype(np.uint8)
    imsave("result/img_"+str(i)+".png", output_rgb)
