# models

In [1]:
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F


class ColorizationNet1(nn.Module):
    def __init__(self, input_channels=1):
        super(ColorizationNet1, self).__init__()

        # ResNet-18
        resnet = models.resnet18(num_classes=1000)
        resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))
        self.midlevel_resnet = nn.Sequential(*list(resnet.children())[:6])

        # upsampling layers
        self.conv1 = nn.Conv2d(128, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 32, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(32)
        self.conv4 = nn.Conv2d(32, 16, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(16)
        self.conv5 = nn.Conv2d(16, 2, 3, padding=1)

    def forward(self, input):
        # Mid-level features from ResNet
        midlevel_features = self.midlevel_resnet(input)

        # upsample block
        x = F.relu(self.bn1(self.conv1(midlevel_features)))
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.conv5(x)
        x = F.interpolate(x, scale_factor=2, mode='nearest')

        return x


In [71]:
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F


class ColorizationNet2(nn.Module):
    def __init__(self, input_channels=1):
        super(ColorizationNet2, self).__init__()

        resnet = models.resnet18(num_classes=1000)
        resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))
        self.midlevel_resnet = nn.Sequential(*list(resnet.children())[:6])

        self.conv1 = nn.Conv2d(128, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 32, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(32)
        self.conv4 = nn.Conv2d(32, 16, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(16)
        self.conv5 = nn.Conv2d(16, 2, 3, padding=1)

        self.upsample1 = nn.Upsample(scale_factor=2)
        self.upsample2 = nn.Upsample(scale_factor=2)
        self.upsample3 = nn.Upsample(scale_factor=2)

    def forward(self, input):
        midlevel_features = self.midlevel_resnet(input)

        x = F.relu(self.bn1(self.conv1(midlevel_features)))
        x = self.upsample1(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.upsample2(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.conv5(x)
        x = self.upsample3(x)

        return x



In [80]:
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F


class ColorizationNet3(nn.Module):
    def __init__(self, input_channels=1):
        super(ColorizationNet3, self).__init__()

        resnet = models.resnet18(num_classes=1000)
        resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))
        self.midlevel_resnet = nn.Sequential(*list(resnet.children())[:6])

        self.conv1 = nn.Conv2d(128, 128, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 32, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(32)
        self.conv5 = nn.Conv2d(32, 32, 3, padding=1)
        self.bn5 = nn.BatchNorm2d(32)
        self.conv6 = nn.Conv2d(32, 16, 3, padding=1)
        self.bn6 = nn.BatchNorm2d(16)
        self.conv7 = nn.Conv2d(16, 2, 3, padding=1)

        self.upsample1 = nn.Upsample(scale_factor=2)
        self.upsample2 = nn.Upsample(scale_factor=2)
        self.upsample3 = nn.Upsample(scale_factor=2)

    def forward(self, input):
 
        midlevel_features = self.midlevel_resnet(input)

        x = F.relu(self.bn1(self.conv1(midlevel_features)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.upsample1(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.upsample2(x)
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = self.conv7(x)
        x = self.upsample3(x)

        return x


# colorize data 

In [2]:
import torchvision.transforms as T
import torch
import numpy as np
from skimage.color import rgb2lab, rgb2gray, lab2rgb
from torchvision import datasets
from torchvision.datasets.folder import default_loader

class ColorizeData(datasets.ImageFolder):

    def __init__(self, root, transform=None, target_transform=None, loader=default_loader, is_train=True):
        super(ColorizeData, self).__init__(root, transform=transform, target_transform=target_transform, loader=loader)
        self.is_train = is_train

    def __getitem__(self, index):
        path, _ = self.imgs[index]
        input = self.loader(path)
        input = self.transform(input)
        input = np.asarray(input)
        input = np.transpose(input, (1, 2, 0))  # Add this line to change the dimensions
        img_lab = rgb2lab(input)
        img_lab = (img_lab + 128) / 255
        img_ab = img_lab[:, :, 1:3]
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
        input = rgb2gray(input)
        input = torch.from_numpy(input).unsqueeze(0).float()
        return input, img_ab


def get_colorize_data_transforms(is_train=True):
    if is_train:
        return T.Compose([
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
        ])
    else:
        return T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
        ])

# train 

In [3]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import lab2rgb
import time
import torch.nn as nn
import argparse
import torchvision.transforms as T
import cv2
import glob
from torch.utils.tensorboard import SummaryWriter

class AverageMeter(object):
  def __init__(self):
    self.reset()
  def reset(self):
    self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

class Trainer:
    def __init__(self, writer):
        self.writer = writer

    def to_rgb(self, grayscale_input, ab_input, save_path=None, save_name=None):
        # Show/save rgb image from grayscale and ab channels
        plt.clf()  # clear matplotlib
        color_image = torch.cat((grayscale_input, ab_input), 0).numpy()  # combine channels
        color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
        color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
        color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
        color_image = lab2rgb(color_image.astype(np.float64))

        grayscale_input = grayscale_input.squeeze().numpy()
        if save_path is not None and save_name is not None:
            plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
            plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

    def train(self, train_loader, epoch, model, criterion, optimizer, scheduler):
        print(f'Starting training epoch {epoch + 1}')
        model.train()
        batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
        end = time.time()

        for i, (input_gray, input_ab) in enumerate(train_loader):

            input_gray, input_ab = input_gray.cuda(), input_ab.cuda()
            data_time.update(time.time() - end)  

            # forward pass
            output_ab = model(input_gray)
            loss = criterion(output_ab, input_ab)
            losses.update(loss.item(), input_gray.size(0))

            # gradients and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            self.writer.add_scalar("Training Loss", losses.val, epoch * len(train_loader) + i)

            if i % 2 == 0:
                print(f'Epoch: [{epoch + 1}][{i}/{len(train_loader)}]\t'
                      f'Loss {losses.val:.6f} ({losses.avg:.6f})\t')

        print(f'Finished training epoch {epoch + 1}')

    def validate(self, val_loader, epoch, save_images, model, criterion):
        model.eval()
        batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
        end = time.time()
        already_saved_images = False

        for i, (input_gray, input_ab) in enumerate(val_loader):
            data_time.update(time.time() - end)
            input_gray, input_ab = input_gray.cuda(), input_ab.cuda()

            # Run model and record loss
            output_ab = model(input_gray)
            loss = criterion(output_ab, input_ab)
            losses.update(loss.item(), input_gray.size(0))

            # Save images to file
            if save_images and not already_saved_images:
                already_saved_images = True
                for j in range(min(len(output_ab), 1)):  # save 1 image each epoch
                    save_path = {
                        'grayscale': '/content/drive/MyDrive/585_project_f04/outputs/gray/',
                        'colorized': '/content/drive/MyDrive/585_project_f04/outputs/color/',
                        'ground_truth': '/content/drive/MyDrive/585_project_f04/outputs/ground_truth/'
                    }
                    save_name = f'img-{i * val_loader.batch_size + j}-epoch-{epoch + 1}.jpg'
                    self.to_rgb(input_gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

                    # Saving ground truth images
                    gt_image = torch.cat((input_gray[j].cpu(), input_ab[j].cpu()), 0).numpy()
                    gt_image = gt_image.transpose((1, 2, 0))
                    gt_image[:, :, 0:1] = gt_image[:, :, 0:1] * 100
                    gt_image[:, :, 1:3] = gt_image[:, :, 1:3] * 255 - 128
                    gt_image = lab2rgb(gt_image.astype(np.float64))
                    plt.imsave(arr=gt_image, fname='{}{}'.format(save_path['ground_truth'], save_name))

            # Record time to do forward passes and save images
            batch_time.update(time.time() - end)
            end = time.time()

            self.writer.add_scalar("Validation Loss", losses.val, epoch * len(val_loader) + i)

            if i % 2 == 0:
                print(f'Validate: [{i}/{len(val_loader)}]\t'
                      f'Loss {losses.val:.6f} ({losses.avg:.6f})\t')

        print('Finished validation.')
        return losses.avg


In [4]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [5]:
# change depending on size of dataset, 20 if 80 in train, 1200 if 3800 in train.
n_val = 20


In [6]:
epochs = 10
save_images = True
lr = 1e-3
weight_decay = 1e-4
save_model = True
loss = 'mse'
batch_size = 16

In [None]:
# args = argparse.Namespace(image_dir=image_dir, n_val=n_val, epochs=epochs, save_images=save_images, lr=lr, weight_decay=weight_decay, save_model=save_model, loss=loss, batch_size=batch_size)


In [7]:
args = argparse.Namespace( n_val=n_val, epochs=epochs, save_images=save_images, lr=lr, weight_decay=weight_decay, save_model=save_model, loss=loss, batch_size=batch_size)


In [8]:
files = glob.glob('/content/drive/MyDrive/585_project_f04/outputs/color/*')
for f in files:
    os.remove(f)
files2 = glob.glob('/content/drive/MyDrive/585_project_f04/outputs/gray/*')
for f in files2:
    os.remove(f)
files2 = glob.glob('/content/drive/MyDrive/585_project_f04/outputs/ground_truth/*')
for f in files2:
    os.remove(f)


# model selction

In [None]:
# run one of these each time to pick models

In [9]:
# res18v1
model = ColorizationNet1().cuda()

In [77]:
# res18v2
model = ColorizationNet2().cuda()

In [81]:
# res18v3
model = ColorizationNet3().cuda()

In [10]:
# new
writer = SummaryWriter()
if args.loss == 'mse':  # Initialize loss according to choice
    criterion = nn.MSELoss().cuda()
else:
    criterion = nn.L1Loss().cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5, verbose=False)

# Training
train_transforms = get_colorize_data_transforms(is_train=True)
train_imagefolder = ColorizeData('/content/drive/MyDrive/585_project_f04/images_vs/train', transform=train_transforms, is_train=True)
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=args.batch_size, shuffle=True)

# Validation
val_transforms = get_colorize_data_transforms(is_train=False)
val_imagefolder = ColorizeData('/content/drive/MyDrive/585_project_f04/images_vs/val', transform=val_transforms, is_train=False)
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=args.batch_size, shuffle=False)

print("Image preprocessing completed!")

# Train model
trainer = Trainer(writer=writer)
for epoch in range(args.epochs):
    # Train for one epoch, then validate
    trainer.train(train_loader, epoch, model, criterion, optimizer, scheduler)
    scheduler.step()
    with torch.no_grad():
        trainer.validate(val_loader, epoch, args.save_images, model, criterion)


Image preprocessing completed!
Starting training epoch 1
Epoch: [1][0/5]	Loss 0.276008 (0.276008)	
Epoch: [1][2/5]	Loss 0.123127 (0.201784)	
Epoch: [1][4/5]	Loss 0.048558 (0.145692)	
Finished training epoch 1
Validate: [0/2]	Loss 0.197326 (0.197326)	
Finished validation.
Starting training epoch 2
Epoch: [2][0/5]	Loss 0.022842 (0.022842)	
Epoch: [2][2/5]	Loss 0.011467 (0.015421)	
Epoch: [2][4/5]	Loss 0.014301 (0.015612)	
Finished training epoch 2
Validate: [0/2]	Loss 0.063808 (0.063808)	
Finished validation.
Starting training epoch 3
Epoch: [3][0/5]	Loss 0.016651 (0.016651)	
Epoch: [3][2/5]	Loss 0.022872 (0.021561)	
Epoch: [3][4/5]	Loss 0.015829 (0.019667)	
Finished training epoch 3
Validate: [0/2]	Loss 0.040937 (0.040937)	
Finished validation.
Starting training epoch 4
Epoch: [4][0/5]	Loss 0.010962 (0.010962)	
Epoch: [4][2/5]	Loss 0.008174 (0.011111)	
Epoch: [4][4/5]	Loss 0.006730 (0.009870)	
Finished training epoch 4
Validate: [0/2]	Loss 0.017059 (0.017059)	
Finished validation.
Start

<Figure size 640x480 with 0 Axes>

In [None]:
if args.save_model:  # Saving final model
    torch.save(model, '/content/drive/MyDrive/585_project_f04/Models/saved_model_2.pth')

# inference

In [None]:
import argparse
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from skimage.color import lab2rgb, rgb2gray
import numpy as np
import os
import cv2

def to_rgb(grayscale_input, ab_input):
  # Show/save rgb image from grayscale and ab channels
  plt.clf() # clear matplotlib 
  color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
  color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
  color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
  color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
  color_image = lab2rgb(color_image.astype(np.float64))
  plt.imsave(arr=color_image, fname='/content/drive/MyDrive/585_project_f04/inference/inference_output.jpg')
