<a href="https://colab.research.google.com/github/Gajanan-Sapsod/Image_Colorization/blob/main/Image_colorization_Using_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

In [None]:
from skimage.color import lab2rgb, rgb2lab, rgb2gray

In [None]:
import torchvision.models as models

In [None]:
import random
import numpy as np

In [None]:
data_dir="flower"
train_dataset_color = torchvision.datasets.Flowers102(data_dir, "train", download=True)

In [None]:
class GrayscaleImageFolder(torchvision.datasets.ImageFolder):
  '''Custom images folder, which converts images to grayscale before loading'''
  def __getitem__(self, index):
    path, target = self.imgs[index]
    img = self.loader(path)
    if self.transform is not None:
      img_original = self.transform(img)
      img_original = np.asarray(img_original)
      img_lab = rgb2lab(img_original)
      img_lab = (img_lab + 128) / 255
      img_ab = img_lab[:, :, 1:3]
      img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
      img_original = rgb2gray(img_original)
      img_original = torch.from_numpy(img_original).unsqueeze(0).float()
    if self.target_transform is not None:
      target = self.target_transform(target)
    return img_original, img_ab, target

In [None]:
def to_rgb(grayscale_input, ab_input):
  '''Show/save rgb image from grayscale and ab channels
     Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
  plt.clf() # clear matplotlib 
  color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
  color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
  color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
  color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
  color_image = lab2rgb(color_image.astype(np.float64))
  grayscale_input = grayscale_input.squeeze().numpy()
  plt.imshow(grayscale_input, cmap='gray')
  plt.imshow(color_image)
  plt.show()

In [None]:
train_transforms = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip()])
train_imagefolder = GrayscaleImageFolder(data_dir,  train_transforms)

train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=128, shuffle=True)
val_loader=torch.utils.data.DataLoader(train_imagefolder, batch_size=128, shuffle=True)

print(len(train_loader))
print(len(val_loader))


64
64


In [None]:
class ColorizationNet(nn.Module):
  def __init__(self, input_size=128):
    super(ColorizationNet, self).__init__()
    MIDLEVEL_FEATURE_SIZE = 128

    ## First half: ResNet
    resnet= torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
 
    # Change first conv layer to accept single-channel (grayscale) input
    resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1)) 
    # Extract midlevel features from ResNet-gray
    self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6])

    ## Second half: Upsampling
    self.upsample = nn.Sequential(     
      nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(128),
      nn.ReLU(),
      nn.Upsample(scale_factor=2),
      nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      nn.Upsample(scale_factor=2),
      nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(32),
      nn.ReLU(),
      nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1),
      nn.Upsample(scale_factor=2)
    )

  def forward(self, input):

    # Pass input through ResNet-gray to extract features
    midlevel_features = self.midlevel_resnet(input)

    # Upsample to get colors
    output = self.upsample(midlevel_features)
    return output

In [None]:
### Set the random seed for reproducible results
torch.manual_seed(0)

Net=ColorizationNet()

for params in Net.midlevel_resnet.parameters():
  params.requires_grad=False


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


In [None]:
### Define the loss function
loss_fn = torch.nn.MSELoss()

### Define an optimizer (both for the encoder and the decoder!)
lr= 0.01
#lr = 0.0008 # Learning rate




optim = torch.optim.Adam(Net.parameters(), lr=lr)
#optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=6e-05)

# Check if the GPU is available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

# Move both the encoder and the decoder to the selected device
#encoder.to(device)
#decoder.to(device)
Net.to(device)

In [None]:
### Training function
def train_epoch(Net, device, train_loader,  loss_fn, optimizer):
   
    Net.train()
    train_loss = []
    i=0
   
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for  i, (input_gray, input_ab, target) in enumerate(train_loader):
        image_batch = input_gray.to(device)
        color_batch = input_ab.to(device)
        
        # Encode data
      #  encoded_data = encoder(image_batch)
      #   Decode data
      #   decoded_data = decoder(encoded_data)
        decoded_data=Net(image_batch)
        # Evaluate loss
        loss = loss_fn(decoded_data, color_batch)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print batch loss
        print('\t partial train loss (single batch): %f' % (loss.data))
        train_loss.append(loss.detach().cpu().numpy())
        print(i+1)
        if(i%10==0):
         for j in range(5):
          plt.imshow(input_gray[j].reshape(224,224), cmap='gray')
          plt.show()
          to_rgb(input_gray[j].cpu(), decoded_data[j].detach().cpu())

        i=i+1

    return np.mean(train_loss)

In [None]:
num_epochs = 50
history={'train_loss':[],'val_loss':[]}
for epoch in range(num_epochs):

   train_loss = train_epoch(Net,device,train_loader, loss_fn,optim)
   print('\n EPOCH {}/{} \t train loss {:.3f} \t '.format(epoch + 1, num_epochs,train_loss))
   history['train_loss'].append(train_loss)
  

In [None]:
val_loss=[]
def validate(val_loader, Net, loss_fn):
  Net.eval()

  for i, (input_gray, input_ab, target) in enumerate(val_loader):
   if( i<1):
  
      image_batch = input_gray.to(device)
      color_batch = input_ab.to(device)
      decoded_data=Net(image_batch) 
      loss = loss_fn(decoded_data, color_batch)
      val_loss.append(loss)
      for j in range (5): # save at most 5 images
            save_path = {'grayscale': 'outputs/gray/', 'colorized': 'outputs/color/'}
            save_name = 'img-{}-epoch-{}.jpg'.format(i * val_loader.batch_size + j,0)
            plt.imshow(input_gray[j].reshape(224,224), cmap='gray')
            plt.show()
            to_rgb(input_gray[j].cpu(), decoded_data[j].detach().cpu())
  return loss

In [None]:
for epoch in range(2):
  # Train for one epoch, then validate
  with torch.no_grad():
    losses = validate(val_loader, Net, loss_fn)
  # Save checkpoint and replace old best model if current model is better
 
    