In [None]:
import torch                                             # Deep learning library
from torch import nn                                     # Neural network functions
from torchvision import transforms                       # Data transformation functions
from torchvision.transforms import InterpolationMode     # Image interpolation methods

import numpy as np                              # Numerical computation library
import matplotlib.pyplot as plt                 # Plotting library
import matplotlib.image as mpimg                # Image operations library
from PIL import Image                           # Image processing library
from skimage.color import rgb2lab, lab2rgb      # Color space conversion functions

In [None]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

# Generator

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu"):
        """
        Block module for the generator network.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            down (bool): Flag indicating whether downsampling should be applied.
            act (str): Activation function to use ("relu" or "leaky").
        """
        super(Block, self).__init__()
        # Define convolutional layers with optional downsampling or upsampling
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2)
        )
        self.down = down

    def forward(self, x):
        """
        Forward pass of the block module.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        x = self.conv(x)
        return x


class Generator(nn.Module):
    def __init__(self, in_channels=1, features=64):
        """
        Generator network for image-to-image translation.

        Args:
            in_channels (int): Number of input channels (default: 1).
            features (int): Number of features in the network (default: 64).
        """
        super().__init__()
        # Initial downsampling layer
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),
        )

        # Residual blocks
        self.res1 = nn.Sequential(
            nn.Conv2d(features, features, 3, 1, 1, bias=False),
            nn.BatchNorm2d(features),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features, features, 3, 1, 1, bias=False),
            nn.BatchNorm2d(features)
        )

        self.res2 = nn.Sequential(
            nn.Conv2d(features, features, 3, 1, 1, bias=False),
            nn.BatchNorm2d(features),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features, features, 3, 1, 1, bias=False),
            nn.BatchNorm2d(features)
        )

        # Downsample 1
        self.res3 = nn.Sequential(
            Block(features, features * 2, down=True, act="leaky"),
            nn.Conv2d(features * 2, features * 2, 3, 1, 1, bias=False),
            nn.BatchNorm2d(features * 2)
        )
        self.downsample1 = nn.Sequential(
            nn.Conv2d(features, features * 2, 1, 2, bias=False),
            nn.BatchNorm2d(features * 2)
        )

        # Residual block
        self.res4 = nn.Sequential(
            nn.Conv2d(features * 2, features * 2, 3, 1, 1, bias=False),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features * 2, features * 2, 3, 1, 1, bias=False),
            nn.BatchNorm2d(features * 2)
        )

        # Downsample 2
        self.down2 = nn.Sequential(
            Block(features * 2, features * 4, down=True, act="leaky"),
            nn.Conv2d(features * 4, features * 4, 3, 1, 1, bias=False),
            nn.BatchNorm2d(features * 4)
        )
        self.downsample2 = nn.Sequential(
            nn.Conv2d(features * 2, features * 4, 1, 2, bias=False),
            nn.BatchNorm2d(features * 4)
        )

        # Residual block
        self.res5 = nn.Sequential(
            nn.Conv2d(features * 4, features * 4, 3, 1, 1, bias=False),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features * 4, features * 4, 3, 1, 1, bias=False),
            nn.BatchNorm2d(features * 4)
        )

        # Downsample 3
        self.downsample3 = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 1, 2, bias=False),
            nn.BatchNorm2d(features * 8)
        )
        self.down3 = Block(features * 4, features * 8, down=True, act="leaky")

        # Downsample 4
        self.downsample4 = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 1, 2, bias=False),
            nn.BatchNorm2d(features * 8)
        )
        self.down4 = Block(features * 8, features * 8, down=True, act="leaky")

        # Downsample 5
        self.downsample5 = nn.Sequential(
            nn.Conv2d(features * 4, features * 8, 1, 2, bias=False),
            nn.BatchNorm2d(features * 8)
        )

        self.down5 = Block(features * 8, features * 8, down=True, act="leaky")
        self.down6 = Block(features * 8, features * 8, down=True, act="leaky")

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1),
            nn.ReLU()
        )

        # Upsampling blocks
        self.up1 = Block(features * 8, features * 8, down=False, act="relu")
        self.up2 = nn.Sequential(
            Block(features * 8 * 2, features * 8, down=False, act="relu"),
            nn.Dropout(0.5)
        )
        self.up3 = nn.Sequential(
            Block(features * 8 * 2, features * 8, down=False, act="relu"),
            nn.Dropout(0.5)
        )
        self.up4 = nn.Sequential(
            Block(features * 8 * 2, features * 8, down=False, act="relu"),
            nn.Dropout(0.5)
        )
        self.up5 = Block(features * 8 * 2, features * 4, down=False, act="relu")
        self.up6 = Block(features * 4 * 2, features * 2, down=False, act="relu")
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu")

        # Final output layer
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, 2, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        """
        Forward pass of the generator network.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        # Downsampling
        d1 = self.initial_down(x)
        d2 = self.res1(d1)
        d3 = self.res2(d2) + d1
        d4 = self.res3(d3) + self.downsample1(d3)  #downsample
        d5 = self.res4(d4) + d4  #downsample
        d6 = self.down2(d5) + self.downsample2(d5)
        d7 = self.res5(d6) + d6
        d8 = self.down3(d7) + self.downsample5(d7)
        d9 = self.down4(d8) + self.downsample3(d8)
        d10 = self.down5(d9)
        d11 = self.down6(d10) + self.downsample4(d10)
        # Bottleneck
        bottleneck = self.bottleneck(d11)
        # Upsampling blocks
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d11], 1))
        up3 = self.up3(torch.cat([up2, d10], 1))
        up4 = self.up4(torch.cat([up3, d9], 1))
        up5 = self.up5(torch.cat([up4, d8], 1))
        up6 = self.up6(torch.cat([up5, d7], 1))
        up7 = self.up7(torch.cat([up6, d5], 1))
        # Final output
        return self.final_up(torch.cat([up7, d3], 1))

# Converting LAB images to RGB

In [None]:
def lab_to_rgb(L, ab):
    """
    Converts a batch of LAB images to RGB format.

    Args:
        L (torch.Tensor): L channel of LAB images with shape (batch_size, 1, height, width).
        ab (torch.Tensor): ab channels of LAB images with shape (batch_size, 2, height, width).

    Returns:
        np.ndarray: Array of RGB images with shape (batch_size, height, width, 3).

    Note:
        - Input L values are scaled from [-1, 1] to [0, 100] range.
        - Input ab values are scaled from [-1, 1] to [-128, 128] range.
        - The LAB images are converted to RGB using the lab2rgb function from the skimage.color module.
    """

    # Scale L values from [-1, 1] to [0, 100] range
    L = (L + 1.) * 50.

    # Scale ab values from [-1, 1] to [-128, 128] range
    ab = ab * 128.

    # Concatenate L and ab channels, permute dimensions, and convert to numpy array
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()

    # Convert LAB images to RGB
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)

    # Stack RGB images into a single numpy array
    return np.stack(rgb_imgs, axis=0)


# Image Plotting

In [None]:
def visualize_side(data, fake_imgs, path = None, save=False):
    """
    Visualize grayscale and colorized images side by side and optionally save the figure.

    Args:
        data (list): List of grayscale image to visualize.
        fake_imgs (list): List of colorized image to visualize.
        path (str): Path to save the figure (if save is True).
        save (bool, optional): Flag indicating whether to save the figure. Defaults to False.
    """

    # Create a new figure with a size of 15x10 inches
    fig = plt.figure(figsize=(15, 10))

    # Iterate over the range 0 to 1 (for two subplots)
    for i in range(2):
        # Add a subplot to the figure at position i+1
        ax = plt.subplot(1, 2, i + 1)

        # Check if the current iteration is for the first subplot
        if i == 0:
            # Display the grayscale image
            ax.imshow(data[0][0].cpu(), cmap='gray')
        else:
            # Display the colorized image
            ax.imshow(fake_imgs[0])

        # Turn off the axis labels
        ax.axis("off")

    # Show the figure
    plt.show()

    # Save the figure if save is True
    if save:
      try:
        # Save the figure as a PNG image with the provided path
        fig.savefig(path + "_side.png")
        print("Figure saved successfully.")
      except Exception as e:
        print("Error occurred while saving the figure: Path doesn't exist")

In [None]:
def visualize(fake_imgs):
    """
    Visualize the Colorized Image and optionally save the figure.

    Args:
        fake_imgs (list): Colorized image to visualize.
        path (str): Path to save the figure (if save is True).
        save (bool, optional): Flag indicating whether to save the figure. Defaults to False.
    """

    # Create a new figure with a size of 16x16 inches
    fig = plt.figure(figsize=(16, 16))

    # Add a subplot to the figure
    ax = plt.subplot(1, 1, 1)

    # Display the colorized image in the subplot
    ax.imshow(fake_imgs[0])

    # Turn off the axis labels
    ax.axis("off")

    # Show the figure
    plt.show()

# Transforming Grayscale Images to Colorized Versions

In [None]:

def predict_image(image_path,model,image_save_path=None,save=False,plot=False,plot_side_by_side=False):

  """
    Predicts the colorized version of an input greyscale image using the provided model.

    Args:
        image_path (str): The path to the input image file.
        model: The generator model.
        image_save_path (str, optional): The path to save the colorized image. Defaults to None.
        save (bool, optional): Whether to save the colorized image. Defaults to False.
        plot (bool, optional): Whether to plot the colorized image. Defaults to False.
        plot_side_by_side (bool, optional): Whether to plot the input image and colorized image side by side.Defaults to False.
  """
  SIZE=512
  torch.cuda.empty_cache()
  # Load content and style images
  img = Image.open(image_path).convert("RGB")   # Open and convert image to RGB mode
  og_size = img.size     # Store original image size
  transform = transforms.Resize((SIZE, SIZE), InterpolationMode.BICUBIC)
  img = transform(img)      # Resize image to SIZE x SIZE using bicubic interpolation
  img = np.array(img)       # Convert image to NumPy array for purpose of converting to LAB
  img_lab = rgb2lab(img).astype("float32")   # Convert RGB image to L*a*b color space
  img_lab = transforms.ToTensor()(img_lab)   # Convert L*a*b image to PyTorch tensor
  L = img_lab[[0], ...] / 50. - 1.      # Normalize L channel to range [-1, 1]
  ab = img_lab[[1, 2], ...] / 128.      # Normalize ab channels to range [-1, 1]
  L_tensor = L.unsqueeze(0).to(device)  # Add batch dimension and move to device
  model.train()       # Set model to training mode as written in the paper
  with torch.no_grad():
    fake_color = model(L_tensor)       # Generate the ab channel using generator

  # Resize L_tensor and fake_color back to original image size
  transform_resized = transforms.Resize((og_size[1],og_size[0]), interpolation=InterpolationMode.BICUBIC)
  L_tensor = transform_resized(L_tensor.squeeze(0))      # Resize L_tensor using bicubic interpolation
  L_tensor = L_tensor.unsqueeze(0)                       # Add back the batch dimension
  fake_color = transform_resized(fake_color.squeeze(0))  # Resize ab channel to original Image size using bicubic interpolation
  fake_color = fake_color.unsqueeze(0)                   # Add back the batch dimension

  fake_imgs = lab_to_rgb(L_tensor, fake_color)        # Combine L_tensor and ab channel to get colorized image
  if plot:
    visualize(fake_imgs)                      # Display colorized image

   # Save the figure if save is True
  if save:
    try:
      # Save the figure as a PNG image with the provided path
      fake_imgs_np = np.squeeze(fake_imgs)              # Remove batch dimension from fake_imgs
      fake_imgs_np = Image.fromarray((fake_imgs_np * 255).astype(np.uint8))     # Convert to PIL image
      fake_imgs_np.save(image_save_path+'.png')            # Save colorized image as PNG
      print("Colorized Image saved successfully.")
    except Exception as e:
      print("Error occurred while saving the figure: Path not provided")

  if plot_side_by_side:
    visualize_side(L_tensor,fake_imgs,path=image_save_path,save=save)       # Display original and colorized images side by side


  torch.cuda.empty_cache()     # Clear CUDA cache


In [None]:
device = get_default_device()  # Checking which device is available for performing inference
device

device(type='cuda')

In [None]:
G = to_device(Generator(features=64),device)      # Loading Generator for image colorization
G.load_state_dict(torch.load('/content/drive/MyDrive/trained/G_1_2000000_gan.pth'))

<All keys matched successfully>

In [None]:
predict_image('/content/The-Beatles-Bruce-McBroom-©-Apple-Corps-Ltd-696x442.jpg',G,
              '/content/sample_data/beatles',
              plot=True,
              plot_side_by_side=True,save=True)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
predict_image('/content/oxcart-trudge-1953--getty_1413458061.jpg',G,
              '/content/sample_data/howrah',
              plot=True,
              plot_side_by_side=True,save=True)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
predict_image('/content/calcutta-bus-stand_1413447185_725x725.jpg',G,
              '/content/sample_data/howrahbus',
              plot=True,
              plot_side_by_side=True,save=True)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
predict_image('/content/chowranghee2-ignouofkolkata_1413451627_725x725.jpg',G,
              '/content/sample_data/chowringhee2',
              plot=True,
              plot_side_by_side=True,save=True)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
predict_image('/content/1593664440_untitled-design-2020-07-02t100133.575.webp',G,
              '/content/sample_data/lonkolbus',
              plot=True,
              plot_side_by_side=True,save=True)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
predict_image('/content/rsz_sttreet_scene_1413447327_725x725.jpg',G,
              '/content/sample_data/kolkatastreet',
              plot=True,
              plot_side_by_side=True,save=True)

Output hidden; open in https://colab.research.google.com to view.