In [None]:
import os
import cv2
import math
import torch
import shutil 
import random
import numpy as np 
import pandas as pd 
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
from torch import cuda
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from skimage.color import rgb2lab, lab2rgb
import torchvision.models as models

In [None]:
!pip install fastai==2.4
!pip install pillow scikit-video

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
import skvideo.io
import numpy as np
from PIL import Image
def make_video(read_path, video_save_path, count, fps):
  writer = skvideo.io.FFmpegWriter(video_save_path, 
                        inputdict={'-r': str(fps)},
                        outputdict={'-r': str(fps), '-c:v': 'libx264', '-preset': 'ultrafast', '-pix_fmt': 'yuv444p'})
  for i in range(0, count):
    image_name = read_path + "%d.png" % i
    image = Image.open(image_name)
    image = np.array(image, dtype=np.uint8)
    writer.writeFrame(image)
  writer.close()

In [None]:
from google.colab import drive
import shutil
def save_model(model):
  drive.mount('/content/drive')
  shutil.copy("/content/" + model, "/content/drive/MyDrive/cv thesis/model/trained_models_colorization_net")
  print("Model Saved")
  drive.flush_and_unmount()

def load_model(model):
  drive.mount('/content/drive')
  shutil.copy("/content/drive/MyDrive/cv thesis/model/trained_models_colorization_net/" + model, '/content/')
  print("Model Loaded")
  drive.flush_and_unmount()

def save_video(vid):
  drive.mount('/content/drive')
  shutil.copy("/content/" + vid, "/content/drive/MyDrive/cv thesis/")
  print("Video Saved")
  drive.flush_and_unmount()

<h3>Colorization Network

In [None]:
!pip install --upgrade --no-cache-dir gdown
!gdown --id 1bxoWFitjFk_eX9laOZhMQE_tjpLMOrDO #charlie

In [None]:
!unzip Charlie.zip

In [None]:
!gdown --id 1V68xX9CyfPPhgqAzwZTmnfaJhXK7DMdd #network.py
!gdown --id 1Mgvfb9BfIjmRH3G9ju0NHWhKOpvbF5mS #loss.py
!gdown --id 1Ta1BoQRk9GP0BoME86wro8kk3o6f86rb #dataset.py

In [None]:
from dataset import *
from network import *
from loss import *

In [None]:
load_model("netG_90.pt")
load_model("netG_150.pt")

Mounted at /content/drive
Model Loaded
Mounted at /content/drive
Model Loaded


In [None]:
size_transform = transforms.Compose([
        ])
transform = transforms.Compose([
            transforms.ToTensor()
        ])

frames = ImageDataset(r"/content/Charlie-1", r"/content/annotation.csv", size_transform, transform)
frames_loader = DataLoader(dataset = frames, batch_size = 1, num_workers = 0, shuffle = False, pin_memory = True, drop_last = False)
print(len(frames_loader), len(frames_loader.dataset))

2064 2064


<h3>After 90 epochs

In [None]:
generator_model = 'netG_90.pt'
net_G = ColorNet('None')  
net_G.to(device)    
print('Loaded model onto GPU.') 
if os.path.exists("/content/" + generator_model):
  checkpoint = torch.load("/content/" + generator_model)
  net_G.load_state_dict(checkpoint)
  print("Pretrained Model loaded")

Loaded colorization net.
Loaded model onto GPU.
Pretrained Model loaded


In [None]:
!rm -rf /content/colored_frames_epoch_90/*
!ls /content/colored_frames_epoch_90/* | wc -l

ls: cannot access '/content/colored_frames_epoch_90/*': No such file or directory
0


In [None]:
!mkdir colored_frames_epoch_90

In [None]:
count = 0
net_G.eval()
save_path = '/content/colored_frames_epoch_90/'  

In [None]:
for (gray, ab) in tqdm(frames_loader):
    L = gray.to(device = device, dtype = torch.float32)
    ab = ab.to(device = device, dtype = torch.float32)
    output = net_G(L) 
        
    fake = torch.cat([L, output], dim = 1).detach().cpu().numpy()
    for i in range(fake.shape[0]):
      color_image = fake[i]
      color_image = color_image.transpose((1, 2, 0))
      color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
      color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
      color_image = lab2rgb(color_image.astype(np.float64))  * 255.0
      # # print(color_image.shape)
      im = Image.fromarray(color_image.astype(np.uint8))
      im.save(save_path +"%d.png" % count)
      # color_image = cv2.cvtColor(color_image.astype(np.uint8),cv2.COLOR_LAB2BGR)
      # cv2.imwrite(save_path +"frame%d.jpeg" % count, color_image)
      count+=1
      plt.axis(False)
      plt.imshow(im)
      plt.show()

In [None]:
read_path = '/content/colored_frames_epoch_90/'
video_save_path = "/content/colorization_net_charlie_epoch_90.mp4"
fps = 25
make_video(read_path, video_save_path, count, fps)

In [None]:
save_video("colorization_net_charlie_epoch_90.mp4")

Mounted at /content/drive
Video Saved


In [None]:
for (gray, ab) in tqdm(frames_loader):
    L = gray.to(device = device, dtype = torch.float32)
    ab = ab.to(device = device, dtype = torch.float32)
    output = net_G(L) 
        
    fake = torch.cat([L, output], dim = 1).detach().cpu().numpy()
    for i in range(fake.shape[0]):
      color_image = fake[i]
      color_image = color_image.transpose((1, 2, 0))
      color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
      color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
      color_image = lab2rgb(color_image.astype(np.float64))  * 255.0
      # # print(color_image.shape)
      im = Image.fromarray(color_image.astype(np.uint8))
      # im.save(save_path +"%d.png" % count)
      # color_image = cv2.cvtColor(color_image.astype(np.uint8),cv2.COLOR_LAB2BGR)
      # cv2.imwrite(save_path +"frame%d.jpeg" % count, color_image)
      # count+=1
      plt.axis(False)
      plt.imshow(im)
      plt.show()

<h3>After 150 epochs

In [None]:
generator_model = 'netG_150.pt'
net_G = ColorNet('None')  
net_G.to(device)    
print('Loaded model onto GPU.') 
if os.path.exists("/content/" + generator_model):
  checkpoint = torch.load("/content/" + generator_model)
  net_G.load_state_dict(checkpoint)
  print("Pretrained Model loaded")

Loaded colorization net.
Loaded model onto GPU.
Pretrained Model loaded


In [None]:
net_G.eval()

ColorNet(
  (midlevel_resnet): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [None]:
for (gray, ab) in tqdm(frames_loader):
    L = gray.to(device = device, dtype = torch.float32)
    ab = ab.to(device = device, dtype = torch.float32)
    output = net_G(L) # throw away class predictions
        
    fake = torch.cat([L, output], dim = 1).detach().cpu().numpy()
    for i in range(fake.shape[0]):
      color_image = fake[i]
      color_image = color_image.transpose((1, 2, 0))
      color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
      color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
      color_image = lab2rgb(color_image.astype(np.float64))  * 255.0
      # # print(color_image.shape)
      im = Image.fromarray(color_image.astype(np.uint8))
      # im.save(save_path +"%d.png" % count)
      # color_image = cv2.cvtColor(color_image.astype(np.uint8),cv2.COLOR_LAB2BGR)
      # cv2.imwrite(save_path +"frame%d.jpeg" % count, color_image)
      # count+=1
      plt.axis(False)
      plt.imshow(im)
      plt.show()

<h1>Colorization with UNet

In [None]:
# !pip install --upgrade --no-cache-dir gdown
!gdown --id 1l_KBj-pF67Zhhh7s1hNsEKP29Piz9dPD  #color_netwrok
# !gdown --id 1bxoWFitjFk_eX9laOZhMQE_tjpLMOrDO #charlie
# !unzip Charlie.zip

Downloading...
From: https://drive.google.com/uc?id=1l_KBj-pF67Zhhh7s1hNsEKP29Piz9dPD
To: /content/color_network.py
100% 1.66k/1.66k [00:00<00:00, 2.90MB/s]


In [None]:
# !gdown --id 1r_2E4DNVnO5puqT3YmzM64rFOCtB69ZT  #final
# !gdown --id 1i_J_XOI8tGcavvB9xcsKK0_Z85oVg-VM #resnet
!gdown --id 1KL1lpiYLWkn5WGKNunJeL4bpBuTQN2JP #net_G final


Downloading...
From: https://drive.google.com/uc?id=1KL1lpiYLWkn5WGKNunJeL4bpBuTQN2JP
To: /content/net_G_unet_final.pt
100% 125M/125M [00:01<00:00, 105MB/s]


In [None]:
from color_network import *

In [None]:
import os
import numpy as np 
import pandas as pd 
from PIL import Image
from torch.utils.data import Dataset
from skimage.color import rgb2lab, rgb2gray


class ImageDataset(Dataset):
    def __init__(self, root, captions_file, color_transform = None, transform = None):
        self.df = pd.read_csv(captions_file, index_col=None)
        self.transform = transform
        self.color_transform = color_transform
        self.images = self.df["image"]        
        self.root = root

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.root, self.images[index])).convert("RGB")
        if self.color_transform:
          img = self.color_transform(img)
        img = np.array(img)
      
        img_lab = rgb2lab(img).astype("float32") 
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
        ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
        
        if self.transform:
            L = self.transform(L)
            ab = self.transform(ab)

        return L , ab
 

In [None]:
net_G = Generator(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load("net_G_unet_final.pt", map_location=device))

<All keys matched successfully>

In [None]:
frames = ImageDataset(r"/content/Charlie-1", r"/content/annotation.csv")
frames_loader = DataLoader(dataset = frames, batch_size = 1, num_workers = 0, shuffle = False, pin_memory = True, drop_last = False)
print(len(frames_loader), len(frames_loader.dataset))

2064 2064


In [None]:
!rm -rf /content/colored_frames_unet/*
!ls /content/colored_frames_unet/* | wc -l

ls: cannot access '/content/colored_frames_unet/*': No such file or directory
0


In [None]:
!mkdir colored_frames_unet

In [None]:
net_G.eval()
count = 0
# save_path = '/content/colored_frames_unet/'  

In [None]:
for (gray, ab) in tqdm(frames_loader):
    L = gray.to(device = device, dtype = torch.float32)
    ab = ab.to(device = device, dtype = torch.float32)
    out_ab = net_G(L) #output
        
    fake = torch.cat([L, out_ab], dim = 1).detach().cpu().numpy()

    for i in range(fake.shape[0]):
      color_image = fake[i]
      color_image = color_image.transpose((1, 2, 0))
      color_image[:, :, 0:1] = (color_image[:, :, 0:1] + 1) * 50
      color_image[:, :, 1:3] = color_image[:, :, 1:3] * 110
      color_image = lab2rgb(color_image.astype(np.float64))  * 255.0
      im = Image.fromarray(color_image.astype(np.uint8))
      # im.save(save_path +"%d.png" % count)
      count+=1
      plt.axis(False)
      plt.imshow(im)
      plt.show()

In [None]:
read_path = '/content/colored_frames_unet/'
video_save_path = "/content/unet_charlie_colored.mp4"
fps = 25
make_video(read_path, video_save_path, count, fps)

In [None]:
save_video("unet_charlie_colored.mp4")

Mounted at /content/drive
Video Saved


In [None]:
for (gray, ab) in tqdm(frames_loader):
    L = gray

        
    for i in range(L.shape[0]):
      gray_image = L[i][0]
      # gray_image = gray_image.transpose((1, 2, 0))
      print(gray_image.shape)
      # color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
      # color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
      # color_image = lab2rgb(color_image.astype(np.float64))  * 255.0
      # # print(color_image.shape)
      # im = Image.fromarray(color_image.astype(np.uint8))
      # im.save(save_path +"%d.png" % count)
      # color_image = cv2.cvtColor(color_image.astype(np.uint8),cv2.COLOR_LAB2BGR)
      # cv2.imwrite(save_path +"frame%d.jpeg" % count, color_image)
      
      plt.axis(False)
      plt.imshow(gray_image, "gray")
      plt.show()

<h1> Colorization network with Wasserstein loss

In [None]:
!pip install --upgrade --no-cache-dir gdown
!gdown --id 1oRV2Y0DZYDNBfTMkcEB6V_GL82u-YzFX #dataset
!gdown --id 1ACRzAIg_v64iZ6fwOO7CQmt4D0vxLczV #utils
!gdown --id 1TeY1q_vvZia4ryIKLtq9aIkG9vGjSEb- #patch_discriminator
!gdown --id 1FsLpLN612SZATZQhnq2h90V5HcKmF5SI #generator

from utils import *
from dataset import *
from discriminator import *
from encoder import *

Downloading...
From: https://drive.google.com/uc?id=1oRV2Y0DZYDNBfTMkcEB6V_GL82u-YzFX
To: /content/dataset.py
100% 818/818 [00:00<00:00, 1.66MB/s]
Downloading...
From: https://drive.google.com/uc?id=1ACRzAIg_v64iZ6fwOO7CQmt4D0vxLczV
To: /content/utils.py
100% 712/712 [00:00<00:00, 1.14MB/s]
Downloading...
From: https://drive.google.com/uc?id=1TeY1q_vvZia4ryIKLtq9aIkG9vGjSEb-
To: /content/discriminator.py
100% 1.44k/1.44k [00:00<00:00, 2.50MB/s]
Downloading...
From: https://drive.google.com/uc?id=1FsLpLN612SZATZQhnq2h90V5HcKmF5SI
To: /content/encoder.py
100% 563/563 [00:00<00:00, 1.06MB/s]


In [None]:
frames = ImageDataset(r"/content/Charlie-1", r"/content/annotation.csv")
frames_loader = DataLoader(dataset = frames, batch_size = 1, num_workers = 0, shuffle = False, pin_memory = True, drop_last = False)
print(len(frames_loader), len(frames_loader.dataset))

2064 2064


In [None]:
load_model("net_G-final.pth")
net_G = Generator(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load("./net_G-final.pth", map_location=device))

Mounted at /content/drive
Model Loaded


<All keys matched successfully>

In [None]:
net_G.eval()

DynamicUnet(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05

In [None]:
for (gray, ab) in tqdm(frames_loader):
    L = gray.to(device = device, dtype = torch.float32)
    ab = ab.to(device = device, dtype = torch.float32)
    out_ab = net_G(L) #output   
    fake = torch.cat([L, out_ab], dim = 1).detach().cpu().numpy()
    for i in range(fake.shape[0]):
      color_image = fake[i]
      color_image = color_image.transpose((1, 2, 0))
      color_image[:, :, 0:1] = (color_image[:, :, 0:1] + 1) * 50
      color_image[:, :, 1:3] = color_image[:, :, 1:3] * 110
      color_image = lab2rgb(color_image.astype(np.float64))  * 255.0
      im = Image.fromarray(color_image.astype(np.uint8))
      plt.axis(False)
      plt.imshow(im)
      plt.show()