In [9]:
from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab, gray2rgb
from skimage.io import imsave
from sklearn.metrics import mean_squared_error
from torchvision import transforms
import torchvision.models as models
#from torchvision.models import inception_v2
from PIL import Image
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
import timm

In [10]:
# Get images
X = []
for filename in os.listdir('Train/'):
     X.append(np.array(Image.open('Train/'+filename).convert('RGB')))
X = np.array(X, dtype=float)
Xtrain = 1.0/255*X

inception_resnet_v2 = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
inception_resnet_v2.eval()

Using cache found in C:\Users\Arohan/.cache\torch\hub\pytorch_vision_v0.10.0


Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stri

In [11]:
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv7 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(512, 256, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        return x

class Fusion(nn.Module):
    def __init__(self):
        super(Fusion, self).__init__()

    def forward(self, x, embed_input):
        embed_input = embed_input.view(-1, 1000, 1, 1)
        embed_input = embed_input.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, embed_input], dim=1)
        x = F.relu(nn.Conv2d(256 + 1000, 256, kernel_size=1)(x))
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.conv1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv3 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(16, 2, kernel_size=3, padding=1)
        self.upsample3 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.upsample1(x)
        x = F.relu(self.conv2(x))
        x = self.upsample2(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.tanh(self.conv5(x))
        x = self.upsample3(x)
        return x

class ColorizationModel(nn.Module):
    def __init__(self):
        super(ColorizationModel, self).__init__()
        self.encoder = Encoder()
        self.fusion = Fusion()
        self.decoder = Decoder()

    def forward(self, encoder_input, embed_input):
        encoder_output = self.encoder(encoder_input)
        fusion_output = self.fusion(encoder_output, embed_input)
        decoder_output = self.decoder(fusion_output)
        return decoder_output


In [12]:
def preprocess_batch(images):
    preprocess = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return preprocess(images)

def create_inception_embedding(grayscaled_rgb):
    grayscaled_rgb_resized = preprocess_batch(grayscaled_rgb_resized)
    input_tensor = torch.tensor(grayscaled_rgb_resized, dtype=torch.float32).permute(0, 3, 1, 2)
    with torch.no_grad():
        output = inception_resnet_v2(input_tensor)
    return output


In [13]:
#inception_resnet_v2 = models.inception_resnet_v2(pretrained=True, aux_logits=False)
#inception_resnet_v2.eval()

# Convert data to PyTorch tensors
lab_train = rgb2lab(Xtrain)
X_train = lab_train[:,:,:,0]
Y_train = lab_train[:,:,:,1:] / 128
X_train = X_train.reshape(10, 256, 256, 1)
Y_train = Y_train.reshape(10, 256, 256, 2)
Xtrain_tensor = torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order
Ytrain_tensor = torch.tensor(Y_train, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order

# Instantiate the model and set up the optimizer and loss function
model = ColorizationModel()
criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters(),lr=0.001, alpha=0.9)


# Data augmentation using torchvision.transforms
transform = transforms.Compose([
    transforms.RandomAffine(degrees=20, shear=[-5,5], scale=(0.8, 1.2)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
])

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, X, embed, Y, transform=None):
        self.X = X
        self.Y = Y
        self.embed = embed
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
      #  x = self.X[idx]
      #  y = self.Y[idx]
      #  embed = self.embed[idx]

      #  if self.transform:
      #      #Convert to PIL Image before applying transformations
      #      x = transforms.ToPILImage()(x)
      #      y = transforms.ToPILImage()(y)
      #      embed = transforms.ToPILImage()(embed)

      #      x = self.transform(x)
      #      y = self.transform(y)
      #      embed = self.transform(embed)

            # Convert back to PyTorch tensors
      #      x = transforms.ToTensor()(x)
      #      y = transforms.ToTensor()(y)
      #      embed = transforms.ToTensor()(embed)

      #  return x, y, embed
        sample = {'image': self.X[idx], 'embed': self.embed[idx], 'target': self.Y[idx]}

        if self.transform:
            sample['image'] = self.transform(sample['image'])
            

        return sample

# Create DataLoader for batch training
batch_size=2
grayscaled_rgb = gray2rgb(rgb2gray(Xtrain))
grayscaled_rgb = torch.tensor(grayscaled_rgb, dtype=torch.float32).permute(0, 3, 1, 2)
transformed_batch = torch.stack([preprocess_batch(image) for image in grayscaled_rgb])
with torch.no_grad():
    embed = inception_resnet_v2(transformed_batch)
dataset = ImageDataset(Xtrain_tensor, embed, Ytrain_tensor, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#print(embed.shape)


# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch in dataloader:
        inputs, targets, em = batch['image'], batch['target'], batch['embed']
        optimizer.zero_grad()
        outputs = model(inputs, em)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    average_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}")

Epoch [1/10], Loss: 0.1882
Epoch [2/10], Loss: 0.0058
Epoch [3/10], Loss: 0.0060
Epoch [4/10], Loss: 0.0057
Epoch [5/10], Loss: 0.0057
Epoch [6/10], Loss: 0.0057
Epoch [7/10], Loss: 0.0057
Epoch [8/10], Loss: 0.0057
Epoch [9/10], Loss: 0.0057
Epoch [10/10], Loss: 0.0058


In [14]:
# Load and preprocess test images
color_me = []
for filename in os.listdir('Test/'):
    img = Image.open('Test/' + filename).convert('RGB')
    img_array = np.array(img, dtype=float)
    color_me.append(img_array)
color_me = np.array(color_me, dtype=float)
x=color_me
color_me = rgb2lab(1.0/255*color_me)[:, :, :, 0]
color_me = color_me.reshape(8, 256, 256, 1)

# Convert numpy array to PyTorch tensor
color_me_tensor = torch.tensor(color_me, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order

grayscaled_rgb_test = torch.tensor(x, dtype=torch.float32).permute(0, 3, 1, 2)
transformed_test = torch.stack([preprocess_batch(image) for image in grayscaled_rgb_test])
with torch.no_grad():
    embed_test = inception_resnet_v2(transformed_test)
# Set the model to evaluation mode

# Test model
model.eval()
with torch.no_grad():
    output = model(color_me_tensor, embed_test)
    output = output.cpu().numpy()
# Convert PyTorch tensor to numpy array

output = output * 128

# Output colorizations
for i in range(len(output)):
    cur = np.zeros((256, 256, 3))
    cur[:, :, 0] = color_me[i][:, :, 0]
    cur[:, :, 1:] = output[i].transpose(1, 2, 0)
    output_img = lab2rgb(cur)
    output_img = (output_img * 255).astype(np.uint8)
    # Save the output image
    imsave("result/img_"+str(i)+".png", output_img)
   