In [1]:
# Importing Libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tkinter import Tk, filedialog, Label, Button, Radiobutton, IntVar
from PIL import Image, ImageTk
import cv2
import os

In [2]:
# Defining the Colorization Model
class ColorizationNet(nn.Module):
    def __init__(self):
        super(ColorizationNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=4, dilation=2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=4, dilation=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=4, dilation=2)
        self.conv4 = nn.Conv2d(128, 3, kernel_size=5, stride=1, padding=4, dilation=2)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.sigmoid(self.conv4(x))
        return x


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
transform = transforms.Compose([
    transforms.ToTensor()
])

In [5]:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

Files already downloaded and verified


In [6]:
# Convert RGB Image to Grayscale
def rgb_to_gray(img):
    return img.mean(dim=1, keepdim=True)

# Convert RGB to Sketch using Canny Edge Detection
def rgb_to_sketch(img):
    img_np = np.transpose(img.numpy(), (1, 2, 0)) * 255  # Convert tensor to numpy
    img_np = img_np.astype(np.uint8)
    gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
    sketch = cv2.Canny(gray, 100, 200)
    sketch_tensor = torch.from_numpy(sketch).unsqueeze(0).float() / 255  # Convert back to tensor
    return sketch_tensor

# Convert RGB to Infrared (Simulated by Converting RGB Channels)
def rgb_to_infrared(img):
    img_np = np.transpose(img.numpy(), (1, 2, 0)) * 255  # Convert tensor to numpy
    infrared = img_np[:, :, 0]  # Use the red channel as a proxy for infrared
    infrared_tensor = torch.from_numpy(infrared).unsqueeze(0).float() / 255  # Convert back to tensor
    return infrared_tensor

In [7]:
model = ColorizationNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [12]:
EPOCHS = 30
for epoch in range(EPOCHS):
    for i, (images, _) in enumerate(train_loader):
        grayscale_images = rgb_to_gray(images).to(device)
        images = images.to(device)

        # Forward Pass
        outputs = model(grayscale_images)
        loss = criterion(outputs, images)

        # Backward Pass and Optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item():.4f}")


Epoch [1/30], Loss: 0.0049
Epoch [2/30], Loss: 0.0044
Epoch [3/30], Loss: 0.0049
Epoch [4/30], Loss: 0.0031
Epoch [5/30], Loss: 0.0046
Epoch [6/30], Loss: 0.0036
Epoch [7/30], Loss: 0.0043
Epoch [8/30], Loss: 0.0038
Epoch [9/30], Loss: 0.0045
Epoch [10/30], Loss: 0.0050
Epoch [11/30], Loss: 0.0044
Epoch [12/30], Loss: 0.0043
Epoch [13/30], Loss: 0.0057
Epoch [14/30], Loss: 0.0046
Epoch [15/30], Loss: 0.0035
Epoch [16/30], Loss: 0.0057
Epoch [17/30], Loss: 0.0027
Epoch [18/30], Loss: 0.0040
Epoch [19/30], Loss: 0.0042
Epoch [20/30], Loss: 0.0029
Epoch [21/30], Loss: 0.0047
Epoch [22/30], Loss: 0.0033
Epoch [23/30], Loss: 0.0022
Epoch [24/30], Loss: 0.0042
Epoch [25/30], Loss: 0.0025
Epoch [26/30], Loss: 0.0042
Epoch [27/30], Loss: 0.0055
Epoch [28/30], Loss: 0.0052
Epoch [29/30], Loss: 0.0036
Epoch [30/30], Loss: 0.0060


In [13]:
torch.save(model.state_dict(), 'colorization_model.pth')

In [14]:
class ImageColorizer:
    def __init__(self, root):
        self.root = root
        self.root.title("Cross-Domain Image Colorization")
        self.root.geometry("800x500")

        # Buttons and Labels
        self.label = Label(root, text="Upload an Image")
        self.label.pack()

        self.domain = IntVar()
        Radiobutton(root, text="RGB", variable=self.domain, value=1).pack(anchor="w")
        Radiobutton(root, text="Sketch", variable=self.domain, value=2).pack(anchor="w")
        Radiobutton(root, text="Infrared", variable=self.domain, value=3).pack(anchor="w")

        self.upload_button = Button(root, text="Upload Image", command=self.upload_image)
        self.upload_button.pack()
        self.colorize_button = Button(root, text="Colorize Image", command=self.colorize_image)
        self.colorize_button.pack()
        self.reset_button = Button(root, text="Reset", command=self.reset_image)
        self.reset_button.pack()

        self.original_image = None
        self.processed_image = None
        self.colorized_image = None

    def upload_image(self):
        file_path = filedialog.askopenfilename()
        img = Image.open(file_path)
        self.original_image = img

        # Convert to Grayscale/Sketch/Infrared based on selection
        domain_choice = self.domain.get()
        img_tensor = transforms.ToTensor()(img)
        if domain_choice == 1:
            processed_img = rgb_to_gray(img_tensor)
        elif domain_choice == 2:
            processed_img = rgb_to_sketch(img_tensor)
        else:
            processed_img = rgb_to_infrared(img_tensor)

        self.processed_image = transforms.ToPILImage()(processed_img)

        # Display Original and Processed Images
        self.display_images(self.original_image, self.processed_image)

    def colorize_image(self):
        if self.processed_image:
            # Transform Processed Image to Tensor
            img_tensor = transforms.ToTensor()(self.processed_image).unsqueeze(0).to(device)

            # Model Prediction
            model.eval()
            with torch.no_grad():
                colorized_tensor = model(img_tensor)
            
            # Convert to PIL Image
            self.colorized_image = transforms.ToPILImage()(colorized_tensor.squeeze(0).cpu())

            # Display Colorized Image
            self.display_images(self.original_image, self.colorized_image)

    def display_images(self, original, processed):
        original.thumbnail((200, 200))
        processed.thumbnail((200, 200))

        # Display Original Image
        original_photo = ImageTk.PhotoImage(original)
        original_label = Label(self.root, image=original_photo)
        original_label.image = original_photo
        original_label.pack(side="left")

        # Display Processed Image
        processed_photo = ImageTk.PhotoImage(processed)
        processed_label = Label(self.root, image=processed_photo)
        processed_label.image = processed_photo
        processed_label.pack(side="right")

    def reset_image(self):
        for widget in self.root.pack_slaves():
            widget.destroy()
        self.__init__(self.root)

# Running the GUI
if __name__ == "__main__":
    root = Tk()
    app = ImageColorizer(root)
    root.mainloop()

Exception in Tkinter callback
Traceback (most recent call last):
  File "c:\Program Files\Python310\lib\tkinter\__init__.py", line 1921, in __call__
    return self.func(*args)
  File "C:\Users\Rohit\AppData\Local\Temp\ipykernel_14724\1046018220.py", line 55, in colorize_image
    colorized_tensor = model(img_tensor)
  File "c:\Program Files\Python310\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "c:\Program Files\Python310\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\Rohit\AppData\Local\Temp\ipykernel_14724\3503885128.py", line 11, in forward
    x = torch.relu(self.conv1(x))
  File "c:\Program Files\Python310\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "c:\Program Files\Python310\lib\site-packages\torch\nn\modules\module.py", line 

In [None]:
"""
So, above code executed successfully. And the output images can been seen in report and as well as folder named images .
Thank you
"""