*NOTE THE MODEL IS TRAINED ONLY FOR URBAN IMAGES!*

IMPORTING THE LIBRARIES

In [112]:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torch import nn

DEVICE

In [113]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = "cuda" # Use this if you have install the cuda version of torch 

COLORIZATION CLASS

In [114]:
class ColorizationAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Define the architecture
        self.down1 = nn.Conv2d(1, 64, 3, stride=2, padding=1)
        self.down2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.down3 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
        self.down4 = nn.Conv2d(256, 512, 3, stride=2, padding=1)

        self.up1 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1)
        self.up2 = nn.ConvTranspose2d(512, 128, 3, stride=2, padding=1, output_padding=1)
        self.up3 = nn.ConvTranspose2d(256, 64, 3, stride=2, padding=1, output_padding=1)
        self.up4 = nn.ConvTranspose2d(128, 3, 3, stride=2, padding=1, output_padding=1)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Downsample
        d1 = self.relu(self.down1(x))
        d2 = self.relu(self.down2(d1))
        d3 = self.relu(self.down3(d2))
        d4 = self.relu(self.down4(d3))
        
        # Upsample
        u1 = self.relu(self.up1(d4))
        u2 = self.relu(self.up2(torch.cat((u1, d3), dim=1)))
        u3 = self.relu(self.up3(torch.cat((u2, d2), dim=1)))
        u4 = self.sigmoid(self.up4(torch.cat((u3, d1), dim=1)))
        return u4

FUNCTION TO LOAD AND PREPROSSES THE GRAYSCALE IMAGE 

In [115]:
def preprocess_image(image_path):
    
    image = Image.open(image_path).convert('L')   # Loading of the grayscale image 
    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  
        transforms.ToTensor(),         
        transforms.Normalize((0.5,), (0.5,))  
    ])
    image = transform(image).unsqueeze(0) 
    return image

FUNCTION TO POST - PROCESS THE MODEL OUTPUT AND DISPLAY THE IMAGE

In [116]:
def postprocess_and_save(output, save_path):
    output = output.squeeze(0).permute(1, 2, 0) 
    output = output.detach().cpu().numpy()  
    output = (output * 255).astype(np.uint8)  
    
    colorized_image = Image.fromarray(output)
    colorized_image.save(save_path)  
    
    plt.imshow(output)
    plt.axis('off')
    plt.show()

CREATING THE INSTANCE OF THE MODEL

In [117]:
model = ColorizationAutoencoder()

LOAD THE MODEL

In [None]:
model.load_state_dict(torch.load(r'path_to_the_save_Model')) # Enter the path where your model is saved

EVALUATION MODE

In [None]:
model.eval()

MOVING MODEL TO GPU/CPU

In [120]:
model = model.to(device)

FUNCTION TO COLOURIZE THE IMAGE

In [121]:
def colorize_image(image_path, save_path):
    grayscale_img = preprocess_image(image_path).to(device)  # Preprocess and send to to CPU or GPU if available
    with torch.no_grad():
        colorized_img = model(grayscale_img)  
    postprocess_and_save(colorized_img, save_path)  

SPECIFYING THE PATH

In [None]:
image_path =r"path_to_your_grayscale_image" # Enter the path of the image which you want to color
save_path = 'colorized_image.png'
colorize_image(image_path, save_path)