In [69]:
from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab
from skimage.io import imsave
from sklearn.metrics import mean_squared_error
from torchvision import transforms
from PIL import Image
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms

In [115]:
# Get images
X = []
for filename in os.listdir('Train/'):
    X.append(np.array(Image.open('Train/'+filename).convert('RGB')))
X = np.array(X, dtype=float)

# Set up train and test data
split = int(0.95*len(X))
Xtrain = X[:split]
Xtrain = 1.0/255*X

In [116]:
import torch.nn as nn
import torch.nn.functional as F

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 2, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        )

    def forward(self, x):
        return self.model(x)


# Instantiate the model
#model = CustomModel()

In [118]:
# Convert data to PyTorch tensors
lab_train = rgb2lab(Xtrain)
X_train = lab_train[:,:,:,0]
Y_train = lab_train[:,:,:,1:] / 128
X_train = X_train.reshape(10, 256, 256, 1)
Y_train = Y_train.reshape(10, 256, 256, 2)
Xtrain_tensor = torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order
Ytrain_tensor = torch.tensor(Y_train, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order

# Instantiate the model and set up the optimizer and loss function
model = CustomModel()
criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters(),lr=0.001, alpha=0.9)


# Data augmentation using torchvision.transforms
transform = transforms.Compose([
    transforms.RandomAffine(degrees=20, shear=[-5,5], scale=(0.8, 1.2)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
])

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, X, Y, transform=None):
        self.X = X
        self.Y = Y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.Y[idx]

        if self.transform:
            #Convert to PIL Image before applying transformations
            x = transforms.ToPILImage()(x)
            y = transforms.ToPILImage()(y)

            x = self.transform(x)
            y = self.transform(y)

            # Convert back to PyTorch tensors
            x = transforms.ToTensor()(x)
            y = transforms.ToTensor()(y)
        

        return x, y

# Assuming Xtrain and Ytrain are PyTorch tensors
batch_size=2
train_dataset = ImageDataset(Xtrain_tensor, Ytrain_tensor, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


# Set up TensorBoard
#tensorboard_writer = SummaryWriter(log_dir="output/first_run")

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for batch in train_loader:
        inputs, targets = batch
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        # Print the loss every 100 epochs
    if epoch % 1 == 0:
        print(f'Epoch {epoch+1}/{10}, Loss: {loss.item()}')


    # Logging to TensorBoard
    #tensorboard_writer.add_scalar('Loss', loss.item(), epoch)

#tensorboard_writer.close()

Epoch 1/10, Loss: 0.5737428665161133
Epoch 2/10, Loss: 0.45848554372787476
Epoch 3/10, Loss: 0.17949651181697845
Epoch 4/10, Loss: 0.05962227284908295
Epoch 5/10, Loss: 0.14653341472148895
Epoch 6/10, Loss: 0.14656485617160797
Epoch 7/10, Loss: 0.13210293650627136
Epoch 8/10, Loss: 0.12684974074363708
Epoch 9/10, Loss: 0.11299820244312286
Epoch 10/10, Loss: 0.11017961800098419


In [119]:
# Save model
#model_json = model.to_json()
#with open("model.json", "w") as json_file:
#    json_file.write(model_json)
#model.save_weights("model.h5")

In [120]:

# Test images
X_test = rgb2lab(1.0/255*X)[:, :, :, 0]
#X_test = Xtest.reshape(Xtest.shape + (1,))
Y_test = rgb2lab(1.0/255*X)[:, :, :, 1:]
Y_test = Y_test / 128

# Convert numpy arrays to PyTorch tensors
X_test = X_test.reshape(10, 256, 256, 1)
Y_test = Y_test.reshape(10, 256, 256, 2)
Xtest = torch.tensor(X_test, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order
Ytest = torch.tensor(Y_test, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order



# Evaluate the model
model.eval()
with torch.no_grad():
    predictions = model(Xtest)




# Load and preprocess test images
color_me = []
for filename in os.listdir('Test/'):
    img = Image.open('Test/' + filename).convert('RGB')
    img_array = np.array(img, dtype=float)
    color_me.append(img_array)
color_me = np.array(color_me, dtype=float)
color_me = rgb2lab(1.0/255*color_me)[:, :, :, 0]
color_me = color_me.reshape(8, 256, 256, 1)

# Convert numpy array to PyTorch tensor
color_me_tensor = torch.tensor(color_me, dtype=torch.float32).permute(0, 3, 1, 2)  # Adjust dimension order

# Test model
model.eval()
with torch.no_grad():
    output = model(color_me_tensor)
    output = output.cpu().numpy()
# Convert PyTorch tensor to numpy array

output = output * 128

# Output colorizations
for i in range(len(output)):
    cur = np.zeros((256, 256, 3))
    cur[:, :, 0] = color_me[i][:, :, 0]
    cur[:, :, 1:] = output[i].transpose(1, 2, 0)
    output_img = lab2rgb(cur)
    output_img = (output_img * 255).astype(np.uint8)
    # Save the output image
    imsave("result/img_"+str(i)+".png", output_img)
   
