Colorspace and architecture: https://arxiv.org/abs/2204.02850  
Fully Convolutional Networks: https://arxiv.org/pdf/1411.4038.pdf

Jetson nano: 0.236 TFLOPS fp32

Image colorization on youtube: https://youtu.be/WXyeQeHUxpc?si=jQfcU8Ra4StxFOwT

Youtube code: https://colab.research.google.com/drive/1BsqM7GBTtsyBixy2jsLGJNiSvp1ocpV7?usp=sharing


In [14]:
! pip install torchprofile 1>/dev/null
import copy
import math
import random
import time
from collections import OrderedDict, defaultdict
from typing import Union, List

import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm

from torchprofile import profile_macs # Helps us to obtain mac calculations

import torch.nn.functional as F
import os
from torch.utils.data import Dataset
from PIL import Image

print("Testing if GPU is available.")
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and being used.")
else:
    device = torch.device("cpu")
    print("GPU is not available, Falling back to CPU.")

os.chdir("/home/aaron/git/ImageColorization/")
BASEADDR = os.getcwd()
print(f"Current working directory: {BASEADDR}")

BATCH_SIZE = 32


Testing if GPU is available.
GPU is available and being used.
Current working directory: /home/aaron/git/ImageColorization


# Download and unpack imagenet64 database
imagenet64 is the imagenet database with each image scaled to be 64 by 64 pixels.

In [15]:

checkpoints = f"{BASEADDR}/checkpoints/VGG13based"
content = f"/media/aaron/Storage/imagenet/ILSVRC/Data/CLS-LOC"
if not os.path.exists(content):
    raise Exception("Mount the hard drive moron.")


In [16]:
class ColorizationDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Load color image
        color_img, _ = self.dataset[idx]

        # Convert color image to grayscale
        gray = transforms.Grayscale()
        graytmp = gray(color_img)
        grayscale_img = torch.cat((graytmp,graytmp,graytmp))

        return grayscale_img, color_img


In [17]:
def get_imagenet_data(batchSize):
  ## transformations
  ##transform = transforms.Compose(
    ##  [transforms.Resize(64), transforms.RandomHorizontalFlip(), transforms.ToTensor().to(device)])

  transform=transforms.Compose([
                              transforms.RandomResizedCrop((640,480), scale=(1.0, 1.0), ratio=(1., 1.)),
                              transforms.RandomHorizontalFlip(),
                              transforms.ToTensor(),
                          ])

  ## download and load training dataset
  trainSet = torchvision.datasets.ImageFolder(root=f'{content}/train/', transform=transform)
  colorizationTrainSet = ColorizationDataset(trainSet)
  trainLoader = DataLoader(colorizationTrainSet, batch_size=batchSize, shuffle=True, num_workers=2, pin_memory=True)

  ## download and load testing dataset
  testSet = torchvision.datasets.ImageFolder(root=f'{content}/val/', transform=transform)
  colorizationTestSet = ColorizationDataset(testSet)
  testloader = torch.utils.data.DataLoader(colorizationTestSet, batch_size=batchSize, shuffle=False, num_workers=2, pin_memory=True)
  return {'train': trainLoader, 'test': testloader}

In [18]:
# Adjusted network to match VGG13
class GrayNet(nn.Module):
    def __init__(self):
        super(GrayNet, self).__init__()
        channels = 64
        #block 1
        self.conv1 = nn.Conv2d(3, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

        #block 2
        self.conv3 = nn.Conv2d(channels, channels*2, 3, padding=1)
        self.conv4 = nn.Conv2d(channels*2, channels*2, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(channels*2)
        self.bn4 = nn.BatchNorm2d(channels*2)

        #block 3
        self.conv5 = nn.Conv2d(channels*2, channels*4, 3, padding=1)
        self.conv6 = nn.Conv2d(channels*4, channels*4, 3, padding=1)
        self.bn5 = nn.BatchNorm2d(channels*4)
        self.bn6 = nn.BatchNorm2d(channels*4)

        #block 4
        self.conv7 = nn.Conv2d(channels*4, channels*8, 3, padding=1)
        self.conv8 = nn.Conv2d(channels*8, channels*8, 3, padding=1)
        self.bn7 = nn.BatchNorm2d(channels*8)
        self.bn8 = nn.BatchNorm2d(channels*8)

        #block 5
        self.conv9 = nn.Conv2d(channels*8, channels*8, 3, padding=1)
        self.conv10 = nn.Conv2d(channels*8, channels*8, 3, padding=1)
        self.bn9 = nn.BatchNorm2d(channels*8)
        self.bn10 = nn.BatchNorm2d(channels*8)
        self.conv11 = nn.ConvTranspose2d(channels*8,channels*8, kernel_size=2,stride=2)

        #block 6
        self.conv12 = nn.Conv2d(channels*16, channels*4, 1, padding=0)
        self.conv13 = nn.Conv2d(channels*4, channels*4, 3, padding=1)
        self.conv14 = nn.Conv2d(channels*4, channels*4, 3, padding=1)
        self.bn12 = nn.BatchNorm2d(channels*4)
        self.bn13 = nn.BatchNorm2d(channels*4)
        self.bn14 = nn.BatchNorm2d(channels*4)
        self.conv15 = nn.ConvTranspose2d(channels*4,channels*4, kernel_size=2,stride=2)

        #block 7
        self.conv16 = nn.Conv2d(channels*8, channels*2, 1, padding=0)
        self.conv17 = nn.Conv2d(channels*2, channels*2, 3, padding=1)
        self.conv18 = nn.Conv2d(channels*2, channels*2, 3, padding=1)
        self.bn16 = nn.BatchNorm2d(channels*2)
        self.bn17 = nn.BatchNorm2d(channels*2)
        self.bn18 = nn.BatchNorm2d(channels*2)
        self.conv19 = nn.ConvTranspose2d(channels*2,channels*2, kernel_size=2,stride=2)

        #block 8
        self.conv20 = nn.Conv2d(channels*4, channels, 1, padding=0)
        self.conv21 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv22 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn20 = nn.BatchNorm2d(channels)
        self.bn21 = nn.BatchNorm2d(channels)
        self.bn22 = nn.BatchNorm2d(channels)
        self.conv23 = nn.ConvTranspose2d(channels,channels, kernel_size=2,stride=2)

        #block 9
        self.conv24 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv25 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv26 = nn.Conv2d(channels, 3, 3, padding=1)
        self.bn24 = nn.BatchNorm2d(channels)
        self.bn25 = nn.BatchNorm2d(channels)
        self.bn26 = nn.BatchNorm2d(3)

    def forward(self, x):
        #print(x.size())
        # Input 64x64x3
        #block1
        x = F.relu(self.bn2(self.conv2(F.relu(self.bn1(self.conv1(x)))))) # 32x32xchannels
        block1 = x
        x = F.max_pool2d(x,2)
        
        #block2
        x = F.relu(self.bn4(self.conv4(F.relu(self.bn3(self.conv3(x)))))) # 16x16xchannels*2
        block2 = x
        x = F.max_pool2d(x,2)
        
        #block3
        x = F.relu(self.bn6(self.conv6(F.relu(self.bn5(self.conv5(x)))))) # 8x8xchannels*4
        block3 = x
        x = F.max_pool2d(x,2)
        
        #block4
        x = F.relu(self.bn8(self.conv8(F.relu(self.bn7(self.conv7(x)))))) # 4x4xchannels*8
        block4 = x
        x = F.max_pool2d(x,2)
        
        #block5
        x = self.conv11(F.relu(self.bn10(self.conv10(F.relu(self.bn9(self.conv9(x))))))) # 8x8xchannels*8

        #block6
        x = torch.cat((x,block4),1) #8x8xchannels*16
        x = self.conv15(F.relu(self.bn14(self.conv14(F.relu(self.bn13(self.conv13(F.relu(self.bn12(self.conv12(x)))))))))) #16x16xchannels*4
        
        #block7
        x =torch.cat((x,block3),1)#16x16xchannels*8
        x = self.conv19(F.relu(self.bn18(self.conv18(F.relu(self.bn17(self.conv17(F.relu(self.bn16(self.conv16(x)))))))))) #32x32xchannels*2
        
        #block8
        x =torch.cat((x,block2),1) #32x32xchannels*4
        x = self.conv23(F.relu(self.bn22(self.conv22(F.relu(self.bn21(self.conv21(F.relu(self.bn20(self.conv20(x)))))))))) #64x64xchannels
        
        #block9
        x = F.relu(self.bn26(self.conv26(F.relu(self.bn25(self.conv25(F.relu(self.bn24(self.conv24(x))))))))) #64x64x3 RGB
        return x
testInput = torch.randn(1,3,64,64).to(device)
testNetwork = GrayNet().to(device)
print(testNetwork)
print(testNetwork(testInput))

GrayNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_runni

In [19]:
def trainEvaluate(architecture, data_loader, test_loader, num_epochs=5):
    # Assuming you have access to labeled training data and a loss function
    criterion = nn.L1Loss()#nn.CrossEntropyLoss()
    optimizer = Adam(architecture.parameters(), lr=0.001)

    architecture.to(device, non_blocking=True)
    for epoch in range(num_epochs):
        print("\tepoch: ", epoch+1, "/", num_epochs)
        for inputs, targets in tqdm(data_loader, leave=False):
            optimizer.zero_grad()
            outputs = architecture(inputs.to(device, non_blocking=True))
            loss = criterion(outputs, targets.to(device, non_blocking=True))
            loss.backward()
            optimizer.step()

    # After training, assess the performance on a validation set or the entire dataset
    total_loss = 0.0
    total_samples = 0

    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, leave=False):
            outputs = architecture(inputs.to(device, non_blocking=True))
            loss = criterion(outputs, targets.to(device, non_blocking=True))
            total_loss += loss.item() * len(targets)
            total_samples += len(targets)

    average_loss = total_loss / total_samples
    return -average_loss  # Return the negative loss as a fitness score for maximization

# Initialize model and dataset

In [20]:
data = get_imagenet_data(BATCH_SIZE)
trainLoader = data['train']
testLoader = data['test']

model = GrayNet().to(device)
# Load model from file.
loadPath = f"{BASEADDR}/checkpoints/finaltrainedweights.pt"
model.load_state_dict(torch.load(loadPath))
recover_model = lambda: model.load_state_dict(torch.load(loadPath))
# # Load pretrained weights into the decoder portion of the network
# pre = torchvision.models.vgg13_bn(weights = torchvision.models.VGG13_BN_Weights).to(device)
# params = pre.features[0].state_dict()
# model.conv1.load_state_dict(params)
# model.conv1.requires_grad_ = False  #Freeze layer for transfer learning. May want to unfreeze later
# params = pre.features[3].state_dict()
# model.conv2.load_state_dict(params)
# model.conv2.requires_grad_ = False  #Freeze layer for transfer learning. May want to unfreeze later
# params = pre.features[7].state_dict()
# model.conv3.load_state_dict(params)
# model.conv3.requires_grad_ = False  #Freeze layer for transfer learning. May want to unfreeze later
# params = pre.features[10].state_dict()
# model.conv4.load_state_dict(params)
# model.conv4.requires_grad_ = False  #Freeze layer for transfer learning. May want to unfreeze later
# params = pre.features[14].state_dict()
# model.conv5.load_state_dict(params)
# model.conv5.requires_grad_ = False  #Freeze layer for transfer learning. May want to unfreeze later
# params = pre.features[17].state_dict()
# model.conv6.load_state_dict(params)
# model.conv6.requires_grad_ = False  #Freeze layer for transfer learning. May want to unfreeze later
# params = pre.features[21].state_dict()
# model.conv7.load_state_dict(params)
# model.conv7.requires_grad_ = False  #Freeze layer for transfer learning. May want to unfreeze later
# params = pre.features[24].state_dict()
# model.conv8.load_state_dict(params)
# model.conv8.requires_grad_ = False  #Freeze layer for transfer learning. May want to unfreeze later
# params = pre.features[28].state_dict()
# model.conv9.load_state_dict(params)
# model.conv9.requires_grad_ = False  #Freeze layer for transfer learning. May want to unfreeze later
# params = pre.features[31].state_dict()
# model.conv10.load_state_dict(params)
# model.conv10.requires_grad_ = False  #Freeze layer for transfer learning. May want to unfreeze later

FileNotFoundError: Couldn't find any class folder in /media/aaron/Storage/imagenet/ILSVRC/Data/CLS-LOC/val/.

### Save model to file

In [None]:
# Save current model to file
def saveModel(state, savePath):
    if not os.path.exists(os.path.dirname(savePath)):
        os.makedirs(os.path.dirname(savePath))
    torch.save(state, savePath)

: 

### Load existing model

### Visualize the data fed into the network

In [None]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

: 

In [None]:
def visualizeData(dataLoader: DataLoader):
    inputs, targets = next(iter(dataLoader))

    # show grayscale images.
    imshow(torchvision.utils.make_grid(inputs))
    
    # show original images.
    imshow(torchvision.utils.make_grid(targets))

visualizeData(trainLoader)

: 

In [None]:
def visualizeNetwork(model: nn.Module, dataLoader: DataLoader):
    inputs, targets = next(iter(dataLoader))

    # show grayscale images.
    imshow(torchvision.utils.make_grid(inputs))

    # show colorized images.
    outputs = model(inputs.to(device))
    imshow(torchvision.utils.make_grid(outputs.cpu()))
    
    # show original images.
    imshow(torchvision.utils.make_grid(targets))


: 

### Train the network.

In [None]:
iteration = 0

bestAccuracy = 0
bestState = model.state_dict()

while True:
    recover_model()
    accuracy = trainEvaluate(model, trainLoader, testLoader)
    if accuracy > bestAccuracy:
        bestAccuracy = accuracy
        save_path = f"{checkpoints}/{iteration}.pt"
        saveModel(model.state_dict(), save_path)
        bestState = model.state_dict()
        recover_model = lambda: model.load_state_dict(bestState)
    print("Iteration: ", iteration, " had accuracy of ", accuracy)
    visualizeNetwork(model, trainLoader)


: 

: 

In [None]:
def compareNetworks(model: nn.Module, dataLoader: DataLoader, net1: str, net2: str):
    data = get_imagenet_data(8)
    inputs, targets = next(iter(data['train']))

    # Save original state of model.
    originalState = model.state_dict()

    # get colorized images from networks.
    model.load_state_dict(torch.load(net1))
    outputs1 = model(inputs.to(device))
    model.load_state_dict(torch.load(net2))
    outputs2 = model(inputs.to(device))

    # Display colorized images alongside true image.
    imshow(torchvision.utils.make_grid(torch.cat((outputs1.cpu(), outputs2.cpu(), targets),2)))
    
    # Load original state back into model.
    model.load_state_dict(originalState)

compareNetworks(model, trainLoader, f"{checkpoints}0/0.pt", f"{checkpoints}0/16.pt")

: 