<a href="https://colab.research.google.com/github/JumanaRahim/Nullclass-Internship/blob/main/Gen_ai_model_for_image_colorization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(torch.cuda.current_device())
else:
    print("No NVIDIA driver found. Using CPU")


No NVIDIA driver found. Using CPU


In [3]:
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=64, shuffle=True, num_workers=2
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=64, shuffle=False, num_workers=2
)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 53.4MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
class ColorizationNet(nn.Module):
    def __init__(self):
        super(ColorizationNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=4, dilation=2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=4, dilation=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=4, dilation=2)
        self.conv4 = nn.Conv2d(128, 3, kernel_size=5, stride=1, padding=4, dilation=2)  # Output 3 channels

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.relu(self.conv3(x))
        x = torch.sigmoid(self.conv4(x))  # Output in range [0, 1]
        return x



In [5]:
# Initialize the model and move it to the selected device (GPU or CPU)
model = ColorizationNet().to(device)

# Define the loss function
criterion = nn.MSELoss()  # Mean Squared Error loss, often used for regression tasks

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer with a learning rate of 0.001

# Convert RGB image to grayscale
def rgb_to_gray(img):
    """
    Converts an RGB image tensor to grayscale by averaging the RGB channels.

    Args:
        img (torch.Tensor): Input tensor with shape (batch_size, 3, height, width)

    Returns:
        torch.Tensor: Grayscale tensor with shape (batch_size, 1, height, width)
    """
    return img.mean(dim=1, keepdim=True)


In [None]:
EPOCHS = 10  # Number of epochs

for epoch in range(EPOCHS):
    total_loss = 0  # Track loss for the epoch

    for step, (images, _) in enumerate(train_loader):
        # Convert RGB images to grayscale and move to the device
        grayscale_images = rgb_to_gray(images).to(device)
        images = images.to(device)  # Original RGB images as targets

        # Forward pass
        outputs = model(grayscale_images)

        # Compute the loss
        loss = criterion(outputs, images)

        # Backward pass and optimization
        optimizer.zero_grad()  # Clear gradients
        loss.backward()        # Compute gradients
        optimizer.step()       # Update model parameters

        # Accumulate the loss for this step
        total_loss += loss.item()

        # Print training status at regular intervals
        if step % 50 == 0:  # Adjust the interval as needed
            print(f"Step [{step}/{len(train_loader)}], Loss: {loss.item():.4f}")

    # Print average loss for the epoch
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{EPOCHS}] completed. Average Loss: {avg_loss:.4f}")


Step [0/782], Loss: 0.0616
Step [50/782], Loss: 0.0156
Step [100/782], Loss: 0.0091
Step [150/782], Loss: 0.0072
Step [200/782], Loss: 0.0071
Step [250/782], Loss: 0.0047
Step [300/782], Loss: 0.0045
Step [350/782], Loss: 0.0065
Step [400/782], Loss: 0.0047
Step [450/782], Loss: 0.0087
Step [500/782], Loss: 0.0052
Step [550/782], Loss: 0.0050
Step [600/782], Loss: 0.0060
Step [650/782], Loss: 0.0051
Step [700/782], Loss: 0.0043
Step [750/782], Loss: 0.0046
Epoch [1/10] completed. Average Loss: 0.0078
Step [0/782], Loss: 0.0061
Step [50/782], Loss: 0.0058
Step [100/782], Loss: 0.0043
Step [150/782], Loss: 0.0050
Step [200/782], Loss: 0.0046


In [None]:
# EPOCHS = 10  # Number of epochs

# for epoch in range(EPOCHS):
#     for step, (images, _) in enumerate(train_loader):  # `images` is the input, `_` ignores labels (if any)
#         # Convert RGB images to grayscale and move to the device
#         grayscale_images = rgb_to_gray(images).to(device)
#         images = images.to(device)  # Original RGB images as targets

#         # Forward pass
#         outputs = model(grayscale_images)

#         # Compute the loss
#         loss = criterion(outputs, images)

#         # Backward pass and optimization
#         optimizer.zero_grad()  # Clear gradients from the previous step
#         loss.backward()        # Compute gradients
#         optimizer.step()       # Update model parameters

#         # Print training status at regular intervals
#         if step % 10 == 0:
#             print(f"Epoch [{epoch+1}/{EPOCHS}], Step [{step}/{len(train_loader)}], Loss: {loss.item():.4f}")


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Function to display a single image
def show(img):
    img = img / 2 + 0.5  # Normalize the image for visualization
    img = img.numpy()    # Convert tensor to numpy array

    if len(img.shape) == 2:
        plt.imshow(img, cmap='gray')  # Grayscale image
    else:
        plt.imshow(np.transpose(img, (1, 2, 0)))  # RGB image (HWC format)
    plt.show()

# Function to display original, grayscale, and colorized images side by side
def visualize_all_three(original_images, grayscale_images, colorized_images, n=5):
    fig = plt.figure(figsize=(15, 5))  # Adjust size based on number of images

    for i in range(n):
        # Original Image
        ax = plt.subplot(n, 3, 3 * i + 1)
        show(original_images[i])
        ax.set_title("Original")

        # Grayscale Image
        ax = plt.subplot(n, 3, 3 * i + 2)
        show(grayscale_images[i])
        ax.set_title("Grayscale")

        # Colorized Image
        ax = plt.subplot(n, 3, 3 * i + 3)
        show(colorized_images[i])
        ax.set_title("Colorized")

    plt.tight_layout()
    plt.show()


In [None]:
import torch

def rgb_to_hsv(image):
    # Split the input RGB image into R, G, and B channels
    r, g, b = image[..., 0], image[..., 1], image[..., 2]

    # Compute max and min values for each pixel
    max_val, _ = torch.max(image, dim=-1)  # Maximum of R, G, B
    min_val, _ = torch.min(image, dim=-1)  # Minimum of R, G, B
    diff = max_val - min_val               # Difference between max and min

    # Initialize Hue (H) to zeros
    h = torch.zeros_like(max_val)

    # Compute Hue (H) based on the max channel
    mask = (max_val == r) & (g >= b)       # Max is R and G >= B
    h[mask] = (g[mask] - b[mask]) / diff[mask]  # Hue formula for R max
    mask = (max_val == r) & (g < b)        # Max is R and G < B
    h[mask] = (g[mask] - b[mask]) / diff[mask] + 6.0

    mask = max_val == g                    # Max is G
    h[mask] = (b[mask] - r[mask]) / diff[mask] + 2.0

    mask = max_val == b                    # Max is B
    h[mask] = (r[mask] - g[mask]) / diff[mask] + 4.0

    h = h / 6.0  # Normalize Hue to the range [0, 1]
    h[diff == 0.0] = 0.0  # Handle divide-by-zero cases

    # Compute Saturation (S)
    s = torch.zeros_like(max_val)
    s[diff != 0.0] = diff[diff != 0.0] / max_val[diff != 0.0]  # S = diff / max
    s[max_val == 0.0] = 0.0  # Avoid divide-by-zero when max_val is 0

    # Value (V) is simply the max value
    v = max_val

    # Stack H, S, V into a single tensor along the last dimension
    return torch.stack((h, s, v), dim=-1)


In [None]:
import torch

def hsv_to_rgb(hsv):
    # Split the HSV tensor into its components
    h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]

    # Scale the hue to [0, 6) for conversion
    h = h * 6.0
    i = torch.floor(h).int()  # Integer part of hue
    f = h - i                # Fractional part of hue

    # Precompute intermediate values
    p = v * (1.0 - s)
    q = v * (1.0 - f * s)
    t = v * (1.0 - (1.0 - f) * s)

    # Wrap around the hue index to fit within 0-5
    i_mod = i % 6

    # Create tensors for R, G, B components
    r = torch.zeros_like(h)
    g = torch.zeros_like(h)
    b = torch.zeros_like(h)

    # Map the intermediate values to RGB channels based on the hue sector
    r[i_mod == 0] = v[i_mod == 0]
    g[i_mod == 0] = t[i_mod == 0]
    b[i_mod == 0] = p[i_mod == 0]

    r[i_mod == 1] = q[i_mod == 1]
    g[i_mod == 1] = v[i_mod == 1]
    b[i_mod == 1] = p[i_mod == 1]

    r[i_mod == 2] = p[i_mod == 2]
    g[i_mod == 2] = v[i_mod == 2]
    b[i_mod == 2] = t[i_mod == 2]

    r[i_mod == 3] = p[i_mod == 3]
    g[i_mod == 3] = q[i_mod == 3]
    b[i_mod == 3] = v[i_mod == 3]

    r[i_mod == 4] = t[i_mod == 4]
    g[i_mod == 4] = p[i_mod == 4]
    b[i_mod == 4] = v[i_mod == 4]

    r[i_mod == 5] = v[i_mod == 5]
    g[i_mod == 5] = p[i_mod == 5]
    b[i_mod == 5] = q[i_mod == 5]

    # Combine R, G, B into a single tensor
    return torch.stack([r, g, b], dim=-1)


In [None]:
def exaggerate_colors(images, saturation_factor=1.5, value_factor=1.2):
    images = (images + 1) / 2.0
    images_hsv = torch_rgb_to_hsv(images)
    images_hsv[:, 1, :, :] = torch.clamp(images_hsv[:, 1, :, :] * saturation_factor, 0, 1)
    images_hsv[:, 2, :, :] = torch.clamp(images_hsv[:, 2, :, :] * value_factor, 0, 2)

    color_exaggerated_images = torch_hsv_to_rgb(images_hsv)
    color_exaggerated_images = color_exaggerated_images * 2.0 - 1.0

    return color_exaggerated_images


In [None]:
with torch.no_grad():
    for i, (images, _) in enumerate(test_loader):
        grayscale_images = rgb_to_gray(images).to(device)
        colorized_images = model(grayscale_images)

        grayscale_images_cpu = grayscale_images.cpu().squeeze(1)
        colorized_images_cpu = colorized_images.cpu()
        original_images_cpu = images.cpu()

        colorized_images_cpu = exaggerate_colors(colorized_images_cpu)

        visualize_all_three(original_images_cpu, grayscale_images_cpu, colorized_images_cpu)

        if i == 10:
            break


In [None]:
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor()])

img_tensor = transform(gray_img).unsqueeze(0)

model.eval()
img_tensor = img_tensor.to(device)

with torch.no_grad():
    colorized_tensor = model(img_tensor)

colorized_img = transforms.ToPILImage()(colorized_tensor.squeeze(0).cpu())
colorized_img.save("Colorized_Image.png")


In [None]:
fig, ax = plt.subplots(1, 3, figsize=(18, 6))

ax[0].imshow(img)
ax[0].set_title("Original Color Image")
ax[0].axis("off")

ax[1].imshow(gray_img, cmap="gray")
ax[1].set_title("Grayscale Image")
ax[1].axis("off")

ax[2].imshow(colorized_img)
ax[2].set_title("Colorized Image")
ax[2].axis("off")  # Hide axes

plt.tight_layout()
