In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
import time

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
  print(torch.cuda.current_device())
  print(torch.cuda.device(0))
  print(torch.cuda.device_count())
  print(torch.cuda.get_device_name(0))
else:
  print("No NVIDIA driver found. Using CPU")

In [None]:
# Load the CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

In [34]:
# Define the colorization model
class ColorizationNet(nn.Module):
    def __init__(self):
        super(ColorizationNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=4, dilation=2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=4, dilation=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=4, dilation=2)
        self.conv4 = nn.Conv2d(128, 3, kernel_size=5, stride=1, padding=4, dilation=2)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.relu(self.conv3(x))
        x = torch.sigmoid(self.conv4(x))
        return x

In [35]:
# Convert RGB image to grayscale
def rgb_to_gray(img):
    return img.mean(dim=1, keepdim=True)


In [36]:
# Loss functions
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        # Use VGG19 as feature extractor
        vgg = torchvision.models.vgg19(pretrained=True).features.to(device)
        self.feature_extractor = nn.Sequential()

        # Use specific layers from VGG19
        self.layers = [0, 5, 10, 19, 28]  # Conv1_1, Conv2_1, Conv3_1, Conv4_1, Conv5_1

        # Freeze the network
        for param in vgg.parameters():
            param.requires_grad = False

        # Extract the needed layers
        i = 0
        for layer in vgg.children():
            if i <= max(self.layers):
                self.feature_extractor.add_module(str(i), layer)
            i += 1

    def extract_features(self, x):
        features = []
        for i, layer in enumerate(self.feature_extractor):
            x = layer(x)
            if i in self.layers:
                features.append(x)
        return features

    def forward(self, pred, target):
        # Normalize inputs to match VGG expected range
        pred = pred * 2 - 1  # Convert from [0,1] to [-1,1]
        target = target * 2 - 1

        pred_features = self.extract_features(pred)
        target_features = self.extract_features(target)

        # Calculate L1 loss for each feature layer
        loss = 0
        for pred_feature, target_feature in zip(pred_features, target_features):
            loss += nn.functional.l1_loss(pred_feature, target_feature)

        return loss


In [42]:
# Main training function
def train_model(loss_type, epochs=30):
    # Initialize model
    model = ColorizationNet().to(device)

    # Set loss function based on type
    if loss_type == 'mse':
        criterion = nn.MSELoss()
        print("Using MSE Loss")
    elif loss_type == 'mae':
        criterion = nn.L1Loss()
        print("Using MAE Loss (L1)")
    elif loss_type == 'perceptual':
        criterion = PerceptualLoss()
        print("Using Perceptual Loss")
    else:
        raise ValueError("Invalid loss type. Choose 'mse', 'mae', or 'perceptual'")

    # Set optimizer
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    start_time = time.time()

    # Training loop
    for epoch in range(epochs):
        epoch_loss = 0.0
        batch_count = 0

        for i, (images, _) in enumerate(train_loader):
            grayscale_images = rgb_to_gray(images).to(device)
            images = images.to(device)

            # Forward pass
            outputs = model(grayscale_images)
            loss = criterion(outputs, images)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Print statistics
            if i % 100 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

    print(f"Finished Training with {loss_type} loss.")

    # Save model
    torch.save(model.state_dict(), f"colorization_model_{loss_type}.pth")

    return model

In [None]:
def test_on_custom_image(models_dict, image_path):
    # Load and preprocess the image
    img = Image.open(image_path)
    gray_img = img.convert("L")

    # Transform to tensor
    transform = transforms.Compose([
          # Resize to match CIFAR-10 dimensions
        transforms.ToTensor(),
    ])

    img_tensor = transform(gray_img).unsqueeze(0).to(device)

    # Dictionary to store results
    results = {}

    # Process with each model
    for loss_type, model in models_dict.items():
        model.eval()
        with torch.no_grad():
            colorized_tensor = model(img_tensor)
            results[loss_type] = colorized_tensor.squeeze(0).cpu()

    # Visualize results
    fig, axes = plt.subplots(1, len(models_dict) + 2, figsize=(4 * (len(models_dict) + 2), 4))

    # Original color image
    axes[0].imshow(img)
    axes[0].set_title("Original Color")
    axes[0].axis('off')

    # Grayscale image
    axes[1].imshow(gray_img, cmap='gray')
    axes[1].set_title("Grayscale")
    axes[1].axis('off')

    # Colorized images
    for i, (loss_type, result) in enumerate(results.items()):
        axes[i + 2].imshow(transforms.ToPILImage()(result))
        axes[i + 2].set_title(f"{loss_type.upper()} Loss")
        axes[i + 2].axis('off')

    plt.tight_layout()
    plt.show()

    # Save colorized images
    for loss_type, result in results.items():
        colorized_img = transforms.ToPILImage()(result)
        colorized_img.save(f"colorized_{loss_type}.jpg")

EPOCHS = 30
loss_types = ['mse', 'mae', 'perceptual']
models = {}

for loss_type in loss_types:
    print(f"\nTraining model with {loss_type} loss")
    model = train_model(loss_type, epochs=EPOCHS)
    models[loss_type] = model


In [43]:
def imshow(img):
    # Convert from Tensor image and display
    img = img / 2 + 0.5  # Unnormalize
    npimg = img.numpy()
    if len(img.shape) == 2:  # grayscale image
        plt.imshow(npimg, cmap='gray')
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

def visualize_all_three(original_images, grayscale_images, colorized_images, n=5):
    """
    Display grayscale, colorized, and original images side by side.
    n: number of images to display from the batch
    """
    fig = plt.figure(figsize=(3*n, 4))
    for i in range(n):
        # Display original image
        ax = plt.subplot(1, 3*n, 3*i + 1)
        imshow(original_images[i])
        ax.set_title("Original")
        ax.axis("off")

        # Display original grayscale image
        ax = plt.subplot(1, 3*n, 3*i + 2)
        imshow(grayscale_images[i])
        ax.set_title("Grayscale")
        ax.axis("off")

        # Display colorized image
        ax = plt.subplot(1, 3*n, 3*i + 3)
        imshow(colorized_images[i])
        ax.set_title("Colorized")
        ax.axis("off")

    plt.tight_layout()
    plt.show()


def torch_rgb_to_hsv(rgb):
    """
    Convert an RGB image tensor to HSV.

    Parameters:
    - rgb: tensor of shape (batch_size, 3, height, width) in RGB format in the range [0, 1].

    Returns:
    - hsv: tensor of same shape in HSV format in the range [0, 1].
    """
    r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]
    max_val, _ = torch.max(rgb, dim=1)
    min_val, _ = torch.min(rgb, dim=1)
    diff = max_val - min_val

    # Compute H
    h = torch.zeros_like(r)
    mask = (max_val == r) & (g >= b)
    h[mask] = (g[mask] - b[mask]) / diff[mask]
    mask = (max_val == r) & (g < b)
    h[mask] = (g[mask] - b[mask]) / diff[mask] + 6.0
    mask = max_val == g
    h[mask] = (b[mask] - r[mask]) / diff[mask] + 2.0
    mask = max_val == b
    h[mask] = (r[mask] - g[mask]) / diff[mask] + 4.0
    h = h / 6.0
    h[diff == 0.0] = 0.0

    # Compute S
    s = torch.zeros_like(r)
    s[diff != 0.0] = diff[diff != 0.0] / max_val[diff != 0.0]

    # V is just max_val
    v = max_val

    return torch.stack([h, s, v], dim=1)


def torch_hsv_to_rgb(hsv):
    """
    Convert an HSV image tensor to RGB.

    Parameters:
    - hsv: tensor of shape (batch_size, 3, height, width) in HSV format in the range [0, 1].

    Returns:
    - rgb: tensor of same shape in RGB format in the range [0, 1].
    """
    h, s, v = hsv[:, 0, :, :], hsv[:, 1, :, :], hsv[:, 2, :, :]
    i = (h * 6.0).floor()
    f = h * 6.0 - i
    p = v * (1.0 - s)
    q = v * (1.0 - s * f)
    t = v * (1.0 - s * (1.0 - f))

    i_mod = i % 6
    r = torch.zeros_like(h)
    g = torch.zeros_like(h)
    b = torch.zeros_like(h)

    r[i_mod == 0.0] = v[i_mod == 0.0]
    g[i_mod == 0.0] = t[i_mod == 0.0]
    b[i_mod == 0.0] = p[i_mod == 0.0]

    r[i_mod == 1.0] = q[i_mod == 1.0]
    g[i_mod == 1.0] = v[i_mod == 1.0]
    b[i_mod == 1.0] = p[i_mod == 1.0]

    r[i_mod == 2.0] = p[i_mod == 2.0]
    g[i_mod == 2.0] = v[i_mod == 2.0]
    b[i_mod == 2.0] = t[i_mod == 2.0]

    r[i_mod == 3.0] = p[i_mod == 3.0]
    g[i_mod == 3.0] = q[i_mod == 3.0]
    b[i_mod == 3.0] = v[i_mod == 3.0]

    r[i_mod == 4.0] = t[i_mod == 4.0]
    g[i_mod == 4.0] = p[i_mod == 4.0]
    b[i_mod == 4.0] = v[i_mod == 4.0]

    r[i_mod == 5.0] = v[i_mod == 5.0]
    g[i_mod == 5.0] = p[i_mod == 5.0]
    b[i_mod == 5.0] = q[i_mod == 5.0]

    return torch.stack([r, g, b], dim=1)

def exaggerate_colors(images, saturation_factor=1.5, value_factor=1.2):
    """
    Exaggerate the colors of RGB images.

    Parameters:
    - images: tensor of shape (batch_size, 3, height, width) in RGB format.
    - saturation_factor: factor by which to increase the saturation. Default is 1.5.
    - value_factor: factor by which to increase the value/brightness. Default is 1.2.

    Returns:
    - color_exaggerated_images: tensor of same shape as input, with exaggerated colors.
    """
    # Convert images to the range [0, 1]
    images = (images + 1) / 2.0

    # Convert RGB images to HSV
    images_hsv = torch_rgb_to_hsv(images)

    # Increase the saturation and value components
    images_hsv[:, 1, :, :] = torch.clamp(images_hsv[:, 1, :, :] * saturation_factor, 0, 1)
    images_hsv[:, 2, :, :] = torch.clamp(images_hsv[:, 2, :, :] * value_factor, 0, 1)

    # Convert the modified HSV images back to RGB
    color_exaggerated_images = torch_hsv_to_rgb(images_hsv)

    # Convert images back to the range [-1, 1]
    color_exaggerated_images = color_exaggerated_images * 2.0 - 1.0

    return color_exaggerated_images



In [50]:
def test_model(model, loss_type, num_batches=5):
    model.eval()
    test_examples = []

    with torch.no_grad():
        for i, (images, _) in enumerate(test_loader):
            if i >= num_batches:
                break

            grayscale_images = rgb_to_gray(images).to(device)
            colorized_images = model(grayscale_images)

            # Convert to CPU for visualization
            grayscale_images_cpu = grayscale_images.cpu().squeeze(1)
            colorized_images_cpu = colorized_images.cpu()
            original_images_cpu = images.cpu()

            # Apply color enhancement
            colorized_images_cpu_enhanced = exaggerate_colors(colorized_images_cpu)

            visualize_all_three(original_images_cpu, grayscale_images_cpu, colorized_images_cpu)
            # Store the example for later visualization
            test_examples.append({
                'original': original_images_cpu,
                'grayscale': grayscale_images_cpu,
                'colorized': colorized_images_cpu,
                'colorized_enhanced': colorized_images_cpu_enhanced
            })

    return test_examples



# Visualize results from all models
def visualize_comparison(examples_dict, sample_index=0, image_index=0):
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))

    # Original image
    original = examples_dict['mse'][sample_index]['original'][image_index]
    grayscale = examples_dict['mse'][sample_index]['grayscale'][image_index]

    # Plot original and grayscale
    axes[0].imshow(np.transpose(original.numpy(), (1, 2, 0)))
    axes[0].set_title("Original")
    axes[0].axis('off')

    # Plot colorized images for each loss function
    for i, loss_type in enumerate(['mse', 'mae', 'perceptual']):
        colorized = examples_dict[loss_type][sample_index]['colorized'][image_index]
        axes[i+1].imshow(np.transpose(colorized.numpy(), (1, 2, 0)))
        axes[i+1].set_title(f"{loss_type.upper()} Loss")
        axes[i+1].axis('off')

    plt.tight_layout()
    plt.show()

In [41]:
import os

# Get the current working directory
current_directory = os.getcwd()

# Print the current working directory
print("Current working directory:", current_directory)


Current working directory: /content


In [None]:
from google.colab import files

# Upload files from your local machine
uploaded = files.upload()

# List the uploaded files
for filename in uploaded.keys():
    print("Uploaded file:", filename)


In [13]:
from PIL import Image

# Open the image. (Keep your image in the current directory. In my case, the image was horse.jpg)
img = Image.open(filename)

# Convert the image to grayscale
gray_img = img.convert("L")

In [None]:
# Test models on the test dataset
test_examples = {}

for loss_type, model in models.items():
    print(f"\nTesting model with {loss_type} loss")
    examples= test_model(model, loss_type)
    test_examples[loss_type] = examples


# Visualize sample results from each model
for i in range(min(3, len(test_examples['mse']))):
    visualize_comparison(test_examples, sample_index=i, image_index=0)

try:
    # You can replace this with your own image path
    custom_image_path = filename
    test_on_custom_image(models, custom_image_path)
except Exception as e:
    print(f"Could not test on custom image: {e}")
    print("Please upload a custom image and specify the correct path to test.")