In [28]:
from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab, gray2rgb
from skimage.io import imsave
from sklearn.metrics import mean_squared_error
from torchvision import transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms

In [59]:
# Get images
X = []
for filename in os.listdir('Train/'):
    X.append(np.array(Image.open('Train/'+filename).convert('RGB')))
X = np.array(X, dtype=float)
Xtrain = 1.0/255*X
print(Xtrain.shape)
model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
model.eval()

(10, 256, 256, 3)


Using cache found in C:\Users\samar/.cache\torch\hub\pytorch_vision_v0.10.0


Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stri

In [53]:
# Encoder
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        return x

# Fusion
class Fusion(nn.Module):
    def __init__(self):
        super(Fusion, self).__init__()

    def forward(self, x, embed):
        embed = embed.view(-1, 1000, 1, 1)
        embed = embed.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat((x, embed), dim=1)
        x = F.relu(nn.Conv2d(256+1000, 256, kernel_size=1, stride=1, padding=0)(x))
        return x

# Decoder
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.conv1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(16, 2, kernel_size=3, stride=1, padding=1)
        self.up3 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.up1(x)
        x = F.relu(self.conv2(x))
        x = self.up2(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.tanh(self.conv5(x))
        x = self.up3(x)
        return x

# Model
class ColorizationModel_Full(nn.Module):
    def __init__(self):
        super(ColorizationModel_Full, self).__init__()
        self.encoder = Encoder()
        self.fusion = Fusion()
        self.decoder = Decoder()

    def forward(self, x, embed):
        x = self.encoder(x)
        x = self.fusion(x, embed)
        x = self.decoder(x)
        return x


In [39]:
def preprocess_batch(images):
    preprocess = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return preprocess(images)

def create_inception_embedding(grayscaled_rgb):
    grayscaled_rgb_resized = preprocess_batch(grayscaled_rgb_resized)
    input_tensor = torch.tensor(grayscaled_rgb_resized, dtype=torch.float32).permute(0, 3, 1, 2)
    with torch.no_grad():
        output = inception_resnet_v2(input_tensor)
    return output



In [61]:
# Assuming Xtrain is a NumPy array containing training data

# Convert data to PyTorch tensors
lab_train = rgb2lab(Xtrain)
X_train = lab_train[:,:,:,0]
Y_train = lab_train[:,:,:,1:] / 128
X_train = X_train.reshape(10, 256, 256, 1)
Y_train = Y_train.reshape(10, 256, 256, 2)
Xtrain_tensor = torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order
Ytrain_tensor = torch.tensor(Y_train, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order

# Define the PyTorch model, loss function, and optimizer
model = ColorizationModel_Full()
criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.9)


# Image transformer (data augmentation)
transform = transforms.Compose([
    transforms.RandomAffine(degrees=20, shear=[-5, 5], scale=(0.8, 1.2)),
    transforms.RandomHorizontalFlip(),
])

# Custom dataset class for PyTorch
class ColorizationDataset(torch.utils.data.Dataset):
    def __init__(self, X, embed, Y, transform=None):
        self.X = X
        self.Y = Y
        self.embed = embed
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        sample = {'image': self.X[idx], 'embed': self.embed[idx], 'target': self.Y[idx]}

        if self.transform:
            sample['image'] = self.transform(sample['image'])

        return sample

# Create DataLoader for batch training
batch_size=2
grayscaled_rgb = gray2rgb(rgb2gray(Xtrain))
grayscaled_rgb = torch.tensor(grayscaled_rgb, dtype=torch.float32).permute(0, 3, 1, 2)
transformed_batch = torch.stack([preprocess_batch(image) for image in grayscaled_rgb])
with torch.no_grad():
    embed = inception_resnet_v2(transformed_batch)
dataset = ColorizationDataset(Xtrain_tensor, embed, Ytrain_tensor, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch in dataloader:
        inputs, em, targets = batch['image'], batch['embed'], batch['target']
        optimizer.zero_grad()
        outputs = model(inputs, em)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    average_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}")


torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
Epoch [1/10], Loss: 0.3095
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
Epoch [2/10], Loss: 0.0077
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
Epoch [3/10], Loss: 0.0218
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
torch.Size([2, 1, 256, 256])
torch.Size([2, 1000])
t

In [64]:
# Get images
color_me = []
for filename in os.listdir('Test/'):
    color_me.append(np.array(Image.open('Test/'+filename).convert('RGB')))
color_me = np.array(color_me, dtype=float)
color_me = 1.0/255*color_me
lab_cm = rgb2lab(color_me)
X_in = lab_cm[:,:,:,0]
X_in = X_in.reshape(8, 256, 256, 1)
Xin_tensor = torch.tensor(X_in, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order
grayscaled_rgb_test = torch.tensor(color_me, dtype=torch.float32).permute(0, 3, 1, 2)
transformed_test = torch.stack([preprocess_batch(image) for image in grayscaled_rgb_test])
with torch.no_grad():
    embed_test = inception_resnet_v2(transformed_test)
# Set the model to evaluation mode
model.eval()

# Make predictions on the test data
with torch.no_grad():
    Y_tensor = model(Xin_tensor, embed_test)
    output_numpy = Y_tensor.cpu().numpy()

# Post-process and save images
for i in range(8):
    cur = np.zeros((output_numpy.shape[2], output_numpy.shape[3], 3))
    cur[:,:,0] = X_in[i][:,:,0]
    cur[:,:,1:] = (output_numpy[i] * 128).transpose(1, 2, 0)
    output_rgb = (lab2rgb(cur) * 255).astype(np.uint8)
    #input_rgb = (rgb2gray(lab2rgb(cur)) * 255).astype(np.uint8)
    imsave("result/img_"+str(i)+".png", output_rgb)