# B&W Colorization

This notebook is an experiment to see how well a model can re-colorize historical B&W images and videos. I have seen some advanced GAN examples and wondered if a simple Unet based on a standard resnet could be enough to give ok-ish results

In [None]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import fastai
import fastai.vision
from fastai.vision.models import DynamicUnet
from fastai.vision.models import resnet34
from tqdm import tqdm
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

In [None]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

## Preparing the data

In [None]:
data_path = "/home/jupyter/.fastai/data/imagenette2-320"

In [None]:
batch_size=32

In [None]:
_data_transform = transforms.Compose([
        #transforms.RandomResizedCrop(224),
        #transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
])

_train_ds = datasets.ImageFolder(root=data_path + '/train', transform=_data_transform)
_val_ds   = datasets.ImageFolder(root=data_path + '/val', transform=_data_transform)

# This should be "None" if the entire dataset should be used
subset_size = 1#batch_size # None
_train_ds = torch.utils.data.Subset(_train_ds, list(range(0, subset_size if subset_size is not None else len(_train_ds))))
_val_ds = torch.utils.data.Subset(_train_ds, list(range(0, subset_size if subset_size is not None else len(_train_ds))))

training_dataset_loader   = torch.utils.data.DataLoader(_train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
validation_dataset_loader = torch.utils.data.DataLoader(_val_ds, batch_size=batch_size, shuffle=True, num_workers=4)




## Setting up the model
We will use a simple pretrained resnet34 as a base for a unet by chopping off the head and using the fastai DynamicUnet model

In [None]:
_resnet = resnet34(pretrained=True)
_resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
_resnet = nn.Sequential(*list(_resnet.children())[:-2])

model = DynamicUnet(_resnet, img_size=(224,224), n_classes=3); model # We want three-channel images as output

## Checking the data

In [None]:
img = next(iter(training_dataset_loader))[0][0,:,:,:]
plt.imshow(img.permute(1,2,0))

In [None]:
def grayscale(x):
    return transforms.Compose([
        transforms.Normalize(
            mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
            std=[1/0.229, 1/0.224, 1/0.255]
        ),
        transforms.ToPILImage(),
        transforms.Grayscale(),
        transforms.ToTensor(),
    ])(x)

grayscale_img = grayscale(img)
plt.imshow(grayscale_img.repeat(3,1,1).permute(1,2,0).detach().cpu())

In [None]:
output = model(grayscale_img.unsqueeze(0))
plt.imshow(output.squeeze().permute(1,2,0).detach().cpu())

## Training

In [None]:
from datetime import datetime

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def fit(model=model, training_dataset_loader=training_dataset_loader, validation_dataset_loader=validation_dataset_loader, epochs=10, max_learning_rate=1e-6):
    optimizer = torch.optim.Adam(model.parameters(), lr=max_learning_rate)
    #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss(reduction='mean')
    model = model.to(device)
    for e in range(epochs):
        training_loss = 0
        validation_loss = 0
        
        # Train
        for i, batch in enumerate(tqdm(training_dataset_loader)):
            color_imgs = batch[0].detach().to(device)
            grayscale_imgs = torch.stack([grayscale(img) for img in batch[0]]).to(device)
            outputs = model(grayscale_imgs)
            loss = criterion(outputs, color_imgs)
            loss.backward()
            optimizer.step()
            #scheduler.step()
            training_loss += loss.detach().cpu().item()
            optimizer.zero_grad()
            
        # Validate
        for i, batch in enumerate(tqdm(validation_dataset_loader)):
            with torch.no_grad():
                color_imgs = batch[0].detach().to(device)
                grayscale_imgs = torch.stack([grayscale(img) for img in batch[0]]).to(device)
                outputs = model(grayscale_imgs)
                loss = criterion(outputs, color_imgs)
                validation_loss += loss.detach().cpu().item()
            
        training_loss /= len(training_dataset_loader)
        validation_loss /= len(validation_dataset_loader)
            
        print(f"Epoch {e+1} finished. \t Training loss: {training_loss} \t validation loss: {validation_loss}")
        # torch.save(model, f'./models/colorizer_model_{datetime.now()}-lr_{max_learning_rate}_e_{epochs}')

In [None]:
fit(epochs=50, max_learning_rate=3e-2)

# Test the model

First, let's test the model on an image from the dataset

In [None]:
from PIL import Image
import numpy as np

img = next(iter(training_dataset_loader))[0][0,:,:,:]
output = model(grayscale(img).unsqueeze(0).cuda()).detach().cpu().squeeze()

#print(img.mean(), img.std())
#print(output.mean(), output.std())

output -= output.mean()
output /= output.std() + 1e-3
output *= .15
output = np.clip(output, 0, 1)
def reset(x):
    return transforms.Compose([
        transforms.Normalize(
            mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
            std=[1/0.229, 1/0.224, 1/0.255]
        ),
    ])(x)
print(output)


f, axarr = plt.subplots(1,2)
axarr[0].imshow(img.squeeze().permute(1,2,0))
axarr[1].imshow(reset(output).squeeze().permute(1,2,0))

In [None]:
test_img = Image.open("./test_img.jpg").resize((224,224))
test_img_tensor = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )])(test_img)
plt.imshow(test_img_tensor.permute(1,2,0).detach().cpu())

In [None]:
#img = next(iter(dataset_loader))[0][0,:,:,:]
#plt.imshow(img.permute(1,2,0))

output_img = model(test_img_tensor.unsqueeze(0).to(device))
plt.imshow(output_img.squeeze().permute(1,2,0).detach().cpu())