In [None]:
!nvidia-smi

In [None]:
from google.colab import drive

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import functional
import torchvision.datasets as datasets
import torch.optim as optim
from torch import Tensor, flatten
import torchvision.transforms as transforms
import torchvision.models as models
from skimage.color import rgb2lab, lab2rgb
from torchvision.io import read_image

In [None]:
import numpy as np
import pandas as pd 
import glob
import os
import ntpath
from PIL import Image
import nltk
import io

In [None]:
import matplotlib.pyplot as plt

In [None]:
drive.mount("my-drive")

In [None]:
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip 

In [None]:
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip

In [None]:
!unzip /content/Flickr8k_text.zip  -d /media/

In [None]:
!unzip /content/Flickr8k_Dataset.zip  -d /media/

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Flickr8kDataset(Dataset):
    """
    FlickrDataset
    """
    def __init__(self,root_dir, dist = "train", transformations=None):
        self.root_dir = root_dir
        
        self.transformations = transformations
        
        #Get image and caption colum from the dataframe
        self.imgs = pd.read_csv(root_dir + "/Flickr_8k.trainImages.txt", header = None)
        
        
        #Initialize vocabulary and build vocab
    
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self,idx):
        
        img_name = self.imgs[0][idx]
        img_location = os.path.join(self.root_dir+"/Flicker8k_Dataset",img_name)
        img = Image.open(img_location).convert("RGB")
        
        #apply the transfromation to the image
        if self.transformations is not None:
            img = self.transformations(img)
        
        return img

In [None]:
class rgb2lab_t(object):
  def __init__(self):
    pass
    
  def __call__(self, tensor):
    return torch.tensor(rgb2lab(tensor.permute(1,2,0))).permute(2,0,1)
    
  def __repr__(self):
    return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
train_transformations = transforms.Compose(
    [transforms.Resize((256,256)),
     transforms.ToTensor(),
     rgb2lab_t(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 32

train_set = Flickr8kDataset(root_dir="/media", dist='train')

train_set.transformations = train_transformations

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, sampler=None, pin_memory=False)


In [None]:
class ColorizationNet(nn.Module):
  def __init__(self, input_size=128):
    super(ColorizationNet, self).__init__()

    MIDLEVEL_FEATURE_SIZE = 128
    
    resnet = models.resnet18(num_classes=365) 
    resnet2 = models.resnet18(num_classes=365) 
    resnet3 = models.resnet18(num_classes=365) 
    resnet4 = models.resnet18(num_classes=365)
    

    # Change first conv layer to accept single-channel (grayscale) input
    resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))
    resnet2.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))
    resnet3.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))
    
    resnet4.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3,bias=False)
    
    resnet2.layer1[0].conv1.kernel_size = (5,5)
    resnet2.layer1[0].conv2.kernel_size = (5,5)
    resnet2.layer1[1].conv1.kernel_size = (5,5)
    resnet2.layer1[1].conv2.kernel_size = (5,5)
    resnet2.layer2[0].conv1.kernel_size = (5,5)
    resnet2.layer2[0].conv2.kernel_size = (5,5)
    resnet2.layer2[1].conv1.kernel_size = (5,5)
    resnet2.layer2[1].conv2.kernel_size = (5,5)
    resnet2.layer3[0].conv1.kernel_size = (5,5)
    resnet2.layer3[0].conv2.kernel_size = (5,5)
    resnet2.layer3[1].conv1.kernel_size = (5,5)
    resnet2.layer3[1].conv2.kernel_size = (5,5)
    resnet2.layer4[0].conv1.kernel_size = (5,5)
    resnet2.layer4[0].conv2.kernel_size = (5,5)
    resnet2.layer4[1].conv1.kernel_size = (5,5)
    resnet2.layer4[1].conv2.kernel_size = (5,5)

    resnet3.layer1[0].conv1.kernel_size = (7,7)
    resnet3.layer1[0].conv2.kernel_size = (7,7)
    resnet3.layer1[1].conv1.kernel_size = (7,7)
    resnet3.layer1[1].conv2.kernel_size = (7,7)
    resnet3.layer2[0].conv1.kernel_size = (7,7)
    resnet3.layer2[0].conv2.kernel_size = (7,7)
    resnet3.layer2[1].conv1.kernel_size = (7,7)
    resnet3.layer2[1].conv2.kernel_size = (7,7)
    resnet3.layer3[0].conv1.kernel_size = (7,7)
    resnet3.layer3[0].conv2.kernel_size = (7,7)
    resnet3.layer3[1].conv1.kernel_size = (7,7)
    resnet3.layer3[1].conv2.kernel_size = (7,7)
    resnet3.layer4[0].conv1.kernel_size = (7,7)
    resnet3.layer4[0].conv2.kernel_size = (7,7)
    resnet3.layer4[1].conv1.kernel_size = (7,7)
    resnet3.layer4[1].conv2.kernel_size = (7,7)
    
    # Extract midlevel features from ResNet-gray
    self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6])
    self.midlevel_resnet2 = nn.Sequential(*list(resnet2.children())[0:6])
    self.midlevel_resnet3 = nn.Sequential(*list(resnet3.children())[0:6])
    self.midlevel_resnet4 = nn.Sequential(*list(resnet4.children())[0:6])

    

    self.upsample = nn.Sequential(     
      nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(128),
      nn.ReLU(),
      nn.Upsample(scale_factor=2),
      nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      nn.Upsample(scale_factor=2),
      nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(32),
      nn.ReLU(),
      nn.Upsample(scale_factor=2),
      nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1)

    )

  def forward(self, input):

    # Pass input through ResNet-gray to extract features
    midlevel_features = self.midlevel_resnet(input) + self.midlevel_resnet2(input) + self.midlevel_resnet3(input)
    
    output = self.upsample(midlevel_features)
    
    midlevel_features2 = self.midlevel_resnet4(output)
    # Upsample to get colors
    output = self.upsample(midlevel_features2 + midlevel_features)
    return output

In [None]:
def train(network, epochs, optimizer, criterion, train_loader):
    
    for epoch in range(epochs):  # loop over the dataset multiple times

        running_loss = 0.0
        epoch_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            
            # get the inputs; data is a list of [inputs, labels]
            inputs  = data[:,0:1,:,:].to(device)
            outputs = data[:,1:3,:,:].to(device)

            #outputs = outputs.to(device)
            #outputs = outputs[:,1:3,:,:]
            
            
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            colorized = network(inputs)
            
            loss = criterion(colorized, outputs)
            
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_loss += loss.item()

            if i % 10 == 9:    # print every 200 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.6f}')
                running_loss = 0.0

        epoch_loss /= len(train_loader)
        print('Train Epoch: {} \tTrain Loss: {:.6f}'.format(epoch, epoch_loss))



In [None]:
model = ColorizationNet().to(device)

In [None]:
model.load_state_dict(torch.load('/content/my-drive/MyDrive/ImageColorization/models/model250.pth')["model_state_dict"])

In [None]:
lr = 0.02
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()

In [None]:
train(model.to(device), 10, optimizer, criterion, train_loader)

In [None]:
for i in range(0,26):
  train(model.to(device), 10, optimizer, criterion, train_loader)
  torch.save({
            'epoch': i*10,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': 700.926795
            
            }, "/content/my-drive/MyDrive/ImageColorization/models/model"+str(i*10) + ".pth")

In [None]:
for i, data in enumerate(train_loader,0):
  break


In [None]:
outputs = data
outputs = outputs.to(device)

inputs = transforms.Grayscale()(outputs)

In [None]:
res1 = model(data[:,0:1,:,:].to(device))

In [None]:
x = 18

In [None]:
plt.imshow(lab2rgb(data[x].permute(1,2,0).detach().numpy()))

In [None]:
plt.imshow(lab2rgb(torch.cat((data[x][0:1,:,:], res1[x].cpu()) , dim = 0).permute(1,2,0).detach().numpy()))