<a href="https://colab.research.google.com/github/alexsg4/colorize/blob/main/Colorize.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from scipy import io
from tqdm.auto import tqdm

import time
import timeit 
import pickle
import multiprocessing as mp
import gc
from collections import OrderedDict
import os

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms as tvt
from torchvision import utils as tvu
from PIL import Image
import cv2

%matplotlib inline

### Helper

In [None]:
import IPython
# Helper
def ring_bell(message=''):
  """ 
  A helper function that plays a sound and outputs a message. 
  Used for notifying the completion of various steps i.e. model training.
  Should only be run from an IPython environment.
  """
  print(message)
  js_code = '''
  const audio = new Audio("https://www.myinstants.com/media/sounds/epic.mp3");
  audio.pause();
  audio.addEventListener("canplaythrough", function () {
        setTimeout(function(){
            audio.pause();
        },
        4300);
  }, false); 
  audio.play();
  '''
  display(IPython.display.Javascript(js_code))

### Env

In [None]:
HOME = '/usr/local/bin/ml-docker-data'
DATA_PATH_SUN = os.path.join(HOME, 'SUN397')
DATA_PATH_FLICKR = os.path.join(HOME, 'flickr10k-aug')

CPUS = mp.cpu_count()
WORKERS = CPUS-1
CUDA = torch.cuda.is_available()
DEVICE = torch.device('cuda') if CUDA else 'cpu'

VERBOSE = False #@param{type:'boolean'}

print(f'HOME: {HOME}')
print(f'\nSUN397: {DATA_PATH_SUN}')
print(f'flickr: {DATA_PATH_FLICKR}')
print(f'\nCPUS: {CPUS} | WORKERS: {WORKERS}')
print(f'DEVICE: {DEVICE}')

ring_bell('\ntest beep')

# Data gathering



## 1) Gather color info from SUN

### Get the SUN urls file

In [None]:
urls_path = os.path.join(DATA_PATH_SUN, 'SUN397_urls.mat')
if not os.path.isfile(urls_path):
  ! curl -s https://vision.cs.princeton.edu/projects/2010/SUN/urls/SUN397_urls.mat > {urls_path}
SUN = io.loadmat(urls_path)['SUN']

### Choose the categories to keep

In [None]:
DATA_PATH_SUN = os.path.join(HOME, 'SUN397')

kept_cats = []
with open(os.path.join(DATA_PATH_SUN, 'kept_cats-small.txt'), 'r') as fp:
  for line in fp:
    if line.find('#') == -1:
      line = line.replace('\n', '')
      kept_cats.append(line)

print(f'categories to keep: {len(kept_cats)}')

### Query helper

In [None]:
def query_sun_filepaths(data_path, cats_to_keep, max_img_per_cat):
  # build a list of file_paths
  fpaths = []

  # number of samples to keep from every category
  imgs_per_cat = {cat:0 for cat in cats_to_keep}

  with open(os.path.join(data_path, 'fpath.txt'), 'r') as pathfile:
    for line in pathfile:
      img_path = line.replace('.', os.path.join(data_path, 'images'), 1)[:-1]
      
      if not os.path.isfile(img_path):
        continue
      
      kept_cat = None
      for cat in kept_cats:
        if img_path.find(cat) != -1 and imgs_per_cat[cat] < max_img_per_cat:
          kept_cat = cat
          break

      if kept_cat is None:
        continue
      
      fpaths.append(img_path)
      imgs_per_cat[kept_cat] += 1

    # sanity check
    if VERBOSE:
      for cat, num in imgs_per_cat.items():
        print(f'{cat}: {num}')
      print('\n')

    print(f'total images: {len(fpaths)}')
      
  return fpaths

### Gather UV data from the images

#### Helpers

In [None]:
def compute_uvs_for_imgs_local(fpaths, max_imgs, resize, id, disp_every=50):
  start_t = time.time()

  num_imgs = 0

  Us = []
  Vs = []

  # for every file path
  for fpath in fpaths:
    # load the image from disk
    try:
      image = cv2.imread(fpath)
    except FileNotFoundError as e:
      print(f'#{id}: {e} | path: {fpath}')
      continue

    h, w = image.shape[:2]
    short_edge = min(w, h)
    if resize > 0 and short_edge > resize:
      p = resize * 100 / short_edge 
      new_size = (int(w * p / 100), int(h * p / 100))
      image = cv2.resize(image, new_size, interpolation = cv2.INTER_LANCZOS4)
      if VERBOSE:
        print(f'#{id}: resized image to:', image.shape)

    # store it as a numpy array and check it's 3 channels

    if len(image.shape) != 3 or image.shape[-1] != 3:
      if VERBOSE:
        print(f'#{id}: skipping grayscale image:', fpath)
      continue
          
    if VERBOSE:
      print(f'#{id}:\t read image from:', fpath)
      print(f'#{id}:\t num_imgs', num_imgs)
      print(f'#{id}:\t max_imgs', max_imgs)


    # normalize image to 0-1
    image = np.float32(image * 1/255.)

    # convert BGR image to YUV
    image = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb)

    U = image[..., 1]
    V = image[..., 2]

    Us.extend(U.flatten())
    Vs.extend(V.flatten())
    num_imgs += 1

    if num_imgs % disp_every == 0:
      print(f'#{id}:\t computed UVs for {num_imgs} images')

    if num_imgs == max_imgs:
      print(f'#{id}:\t computed UVs for {num_imgs}(max) images')
      duration = time.time() - start_t

      return Us, Vs, id, duration
    
  print(f'#{id}:\t computed uvs for {num_imgs}(done) imgs, exiting...')
  duration = time.time() - start_t

  return Us, Vs, id, duration

#### Get the uv data

In [None]:
VERBOSE = False
MAX_IMG_PER_CAT =  1000#@param{type:'integer'}

fpath = query_sun_filepaths(DATA_PATH_SUN, kept_cats, MAX_IMG_PER_CAT)

# sanity check
#for img in fpath[:3]:
  #print(img)

In [None]:
MAX_IMG =  100#@param{type:'integer'}
MAX_SIZE = 512 #@param{type:'integer'}

print('computing uv pixel data...')
Us, Vs, id, duration = compute_uvs_for_imgs_local(fpath, max_imgs = MAX_IMG, 
                                                  resize = MAX_SIZE, 
                                                  id = 0, disp_every = 50)
print(f'duration: {duration:.2f} sec')

In [None]:
print(np.min(Us), np.max(Us))
print(np.min(Vs), np.max(Vs))

### Compute weights based on the probabilities for each bin

In [None]:
# number of bins to split the space in
NUM_BINS = 10

# edges of the histogram
x_edge = np.linspace(-0.1, 1.1, NUM_BINS + 1)
y_edge = np.linspace(-0.1, 1.1, NUM_BINS + 1)

hist, _, _ = np.histogram2d(Us, Vs, bins = [x_edge, y_edge])
hist /= np.sum(hist)
print('computed normalized histogram')
print(np.min(hist))
print(np.max(hist))

del Us
del Vs
gc.collect()

In [None]:
THRESHOLD = 1e-7
LAMBDA = 0.4
MAX_CATEGORIES = 6

cat2id = {}
id2cat = {}
cat_freq = {}

sorted_values = np.sort(np.ravel(hist))[::-1]
# determine how many categories are above the threshold
NUM_CAT = sorted_values.shape[0] - np.searchsorted(sorted_values[::-1], THRESHOLD)
print('num_cat initial', NUM_CAT)

# adjust threshold to only keep MAX_CATEGORIES categories
if NUM_CAT > MAX_CATEGORIES:
  THRESHOLD = (sorted_values[MAX_CATEGORIES - 1] + sorted_values[MAX_CATEGORIES]) * 0.5

# first pass, assign unique ids to color categories
id = 0
for xcat in range(NUM_BINS):
  for ycat, score in enumerate(hist[xcat, :]):
    if score > THRESHOLD:
      cat2id[(xcat, ycat)] = id
      id2cat[id] = (xcat, ycat)
      cat_freq[id] = score
      id += 1
NUM_CAT = id

# second pass, mapping rare colors to frequent ones; updating frequencies
for xcat in range(NUM_BINS):
  for ycat, score in enumerate(hist[xcat, :]):
    if not score > THRESHOLD:
      closest_class = min(range(NUM_CAT),
                          key=lambda k: (id2cat[k][0] - xcat) ** 2 + 
                          (id2cat[k][1] - ycat) ** 2)
      cat_freq[closest_class] += score
      cat2id[(xcat, ycat)] = closest_class

# compute the weights associated with every pixel class
weights = {k: 1. / (proba * (1. - LAMBDA) + LAMBDA / NUM_CAT) for
                (k, proba) in cat_freq.items()}
# normalize the weights
normalization_factor = sum([weights[k] * cat_freq[k] for k in weights])
weights = {k: weight / normalization_factor for k, weight in weights.items()}

print(f'number of weights: {len(weights)}')

categories_mean_pixels = np.zeros([NUM_CAT, 2], dtype=np.float32)
for index in range(1, NUM_CAT):
  xcat, ycat = id2cat[index]
  categories_mean_pixels[index, :] = [(x_edge[xcat] + x_edge[xcat + 1]) / 2,
                                      (y_edge[ycat] + y_edge[ycat + 1]) / 2]

In [None]:
print('\n =================STATS======================== \n')
print('id id2cat \t cat2mean \t cat_freq \t\t weight: \n')

for i in range(NUM_CAT):
  print(i, 
    id2cat[i], 
    '\t', 
    categories_mean_pixels[i],
    '\t',
    cat_freq[i],
    '\t',
    weights[i],
    )

### Plots

In [None]:
hm = np.copy(hist)
hm[hm < THRESHOLD] = 0
logheatmap = np.log10(hm)
extent = [x_edge[0] - 10, x_edge[-1] + 10, y_edge[0] - 10, y_edge[-1] + 10]

plt.figure(figsize=(15, 10))
plt.subplot(222)
plt.imshow(logheatmap.T, extent=extent, origin='lower')
plt.colorbar()
plt.title("Frequency map (log-scale)", fontsize = 15, color='white')

plt.subplot(224)
weights_matrix = np.zeros([NUM_BINS, NUM_BINS])
for k in weights:
    weights_matrix[id2cat[k]] = weights[k]
logweights_matrix = np.log10(weights_matrix)
plt.imshow(logweights_matrix.T, extent=extent, origin='lower')
plt.colorbar()
plt.title('Weight map (log-scale)', fontsize = 15, color='white')

plt.subplot(221)
color_matrix = np.ones([NUM_BINS, NUM_BINS, 3], dtype=np.float32)
for k in weights:
    yuv = np.zeros([1, 1, 3]) + 0.5
    yuv[..., 1:] = categories_mean_pixels[k]

    if VERBOSE:
      print('yuv')
      # u
      print(np.min(yuv[..., 1]))
      # v
      print(np.min(yuv[..., 2]))
      pass
    
    rgb = cv2.cvtColor(yuv.astype(np.float32), cv2.COLOR_YCrCb2RGB)
    
    if VERBOSE:
      print('rgb')
      # r
      print(np.min(rgb[..., 0]))
      # g
      print(np.min(rgb[..., 1]))
      # b
      print(np.max(rgb[..., 2]))  
      pass

    color_matrix[id2cat[k][1], id2cat[k][0], :] = rgb
plt.imshow(color_matrix, extent=extent, origin='lower')
plt.title("Color map, 50% luminance", fontsize = 15, color='white')

plt.subplot(223)
plt.imshow(-logheatmap.T, extent=extent, origin='lower')
plt.colorbar()
plt.title("Inverse-frequency map (log-scale)", fontsize = 15, color='white')

plt.tight_layout()

### Weights save

In [None]:
# save the weights
weights_np = np.float32(list(weights.values()))
suffix = 'small'#@param{type:'string'}


weights_file = f'sun-w-{MAX_IMG}-{NUM_CAT}-{LAMBDA}.npy'
if len(suffix):
  weights_file = weights_file.replace('.npy', f'-{suffix}.npy')
np.save(os.path.join(HOME, weights_file), weights_np)

# sanity checks
print(weights_np[:5])

print('min:', np.min(weights_np))
print('max:', np.max(weights_np))

assert len(weights_np) == NUM_CAT
print(f'\nsaved weights for {NUM_CAT} classes as: \n{weights_file}')

### Conversion/categorization functions

In [None]:
def categorize_uv_pixels(uv_px, x_edge, y_edge, cat2id):
  u_px = uv_px[:, :, 0]
  u_flat = np.ravel(u_px)

  v_px = uv_px[:, :, 1]
  v_flat = np.ravel(v_px)

  upx_cat = np.searchsorted(x_edge[:-1], u_flat) - 1
  vpx_cat = np.searchsorted(y_edge[:-1], v_flat) - 1

  return np.reshape(np.array([cat2id[xycategories] for xycategories in
                              zip(upx_cat, vpx_cat)]), u_px.shape)
px_to_uvcat = lambda uvs: categorize_uv_pixels(uvs, x_edge, y_edge, cat2id)

In [None]:
 def UVpixels_from_distribution(distribution, temperature, cat_to_mean):
  """
  Returns mean pixels from Npixels distributions over the color categories.
  :param temperature: temperature of the annealed probability distribution.
  :param distribution: matrix of size Npixels * Mpixels * n_categories.
  """
  temp_distribution = np.exp(np.log(distribution + 1e-8) / temperature)
  newshape = list(distribution.shape)
  newshape[-1] = 1
  temp_distribution /= np.sum(temp_distribution, axis=-1).reshape(newshape)

  return np.dot(temp_distribution, cat_to_mean)

uv_px_from_z = lambda z, temp: UVpixels_from_distribution(z, temp, categories_mean_pixels)

## 2) Query image urls

In [None]:
is_within = lambda x, range: x >= range[0] and x <= range[1]

IMG_SIZE = (512, 800)

### 2.1) flickr

#### flickr query helper and params

In [None]:
def query_images(image_urls, params, size, num_photos, page = 1, max_taken = u''):
  params['page'] = str(page)
  if len(max_taken) > 0:
    params['max_taken_date'] = max_taken

  resp = flickr.photos.search(**params)
  assert resp.attrib['stat'] == 'ok'

  kept_photos = 0

  page = next(resp.iter('photos'))
  
  print(page.attrib)
  
  photos = next(page.iter('photos'))
  for p in photos.iter('photo'):
    if len(image_urls) >= num_photos:
      print('got all the necessary photos')
      break
    
    photo_id = p.attrib['id']

    size_info = flickr.photos.getSizes(photo_id = photo_id)
    if size_info.attrib['stat'] != u'ok':
      continue
    size_info = next(size_info.iter('sizes'))
    if size_info.attrib['candownload'] != u'1':
      continue
    for img_size in size_info.iter('size'):
      size_attr = img_size.attrib
      
      if (is_within(int(size_attr['width']), size) and is_within(int(size_attr['height']), size)) \
          and photo_id not in image_urls:
          
          image_urls[photo_id] = size_attr['source']    
          kept_photos += 1
          break

tags = ['landscape', 'beach', 'mountains', 'nature', 'sunset', 'sunrise', 'desert']

MAX_QUERY_SIZE = 4000
PER_PAGE = 500

query_params = {
    'tags' : u','.join(tags),
    'tag_type' : u'any',
    'license' : u'7,9,10', # no copyright, public domain
    'safe_search' : u'1',  # safeSearch on
    'content_type' : u'1', # photos only
    'media' : u'photo',
    'per_page' : str(PER_PAGE),
}

#### gather urls using the query

In [None]:
# init the ficker api
import flickrapi

# TODO actually load from a file 
api_key = u'dc74c8ef9aa757b0f0799d06ddaad303'
api_secret = u'b50f2b82f2a58575'

flickr = flickrapi.FlickrAPI(api_key, api_secret)

In [None]:
PHOTOS_TO_GET =  10000#@param {type:'integer'}

num_photo_urls = 0
id_to_url_flickr = OrderedDict()
initial_len = 0

query_results_fpath = os.path.join(DATA_PATH_FLICKR, 'id2url.pickle')
if os.path.isfile(query_results_fpath):
  with open(query_results_fpath, 'rb') as fp:
    id_to_url_flickr = pickle.load(fp)
    num_photo_urls = len(id_to_url_flickr)
    initial_len = num_photo_urls

    print(f'Already have {num_photo_urls}/{PHOTOS_TO_GET} urls')

page = 1
last_taken_date = u''

while num_photo_urls < PHOTOS_TO_GET:
  prev_len = len(id_to_url_flickr)
  
  query_images(id_to_url_flickr, query_params, IMG_SIZE, PHOTOS_TO_GET, 
               page, last_taken_date)
  
  num_photo_urls = len(id_to_url_flickr)
  print(f'page {page}, got {num_photo_urls - prev_len} urls')
  
  page += 1
  
  # if we queried more image, restrain the query to start from the last image in our set
  if (page - 1) % (MAX_QUERY_SIZE // PER_PAGE) == 0:
    page = 1

    last_img_id = next(reversed(id_to_url_flickr))

    resp = flickr.photos.getInfo(photo_id = last_img_id)
    assert resp.attrib['stat'] == u'ok'

    dates = next(iter(resp.iter('dates')))
    last_taken_date = dates.attrib['taken']

with open(query_results_fpath, 'wb') as fp:
  pickle.dump(id_to_url_flickr, fp)

ring_bell(f'\nGot {len(id_to_url_flickr) - initial_len} additional urls')

### 2.2) SUN-local

#### Gather file-paths based on kept categories

In [None]:
def get_i2u_sun(id_to_url, fpaths, size, num_imgs):
  for fp in tqdm(fpaths):
    if len(id_to_url) >= num_imgs:
      print('got all the necessary photos')
      break
    
    try:
      image = plt.imread(fp)
    except:
      continue
    
    image_size = image.shape[:2]
    if len(image.shape) != 3 or image.shape[-1] != 3 or min(image_size) < size:
      continue
    
    image_name, extension = os.path.split(fp)[-1].split('.')[-2:]
    if extension.lower() not in ['jpg', 'jpeg']:
      continue

    id = image_name.replace('sun_', '')
    assert len(id)
    
    if id not in id_to_url:
      id_to_url[id] = fp


In [None]:
PHOTOS_TO_GET = 2000 #@param{type:'number'}
MIN_SIZE = 260

num_photo_urls = 0
id_to_url_sun = OrderedDict()
initial_len = 0
query_results_fpath = os.path.join(DATA_PATH_SUN, 'id2url.pickle')

if os.path.isfile(query_results_fpath):
  with open(query_results_fpath, 'rb') as fp:
    id_to_url_sun = pickle.load(fp)
    num_photo_urls = len(id_to_url_sun)
    initial_len = num_photo_urls

    print(f'Already have {num_photo_urls}/{PHOTOS_TO_GET} urls')

max_img_per_cat = 2 * PHOTOS_TO_GET // len(kept_cats) + 1

iteration = 1
while num_photo_urls < PHOTOS_TO_GET:
  prev_len = len(id_to_url_sun)
  
  print(f'iteration {iteration}:')
  print(f'max images per category: {max_img_per_cat}')
  
  fpaths = query_sun_filepaths(DATA_PATH_SUN, kept_cats, max_img_per_cat)
  get_i2u_sun(id_to_url_sun, fpaths, MIN_SIZE, PHOTOS_TO_GET)

  num_photo_urls = len(id_to_url_sun)
  print(f'got {num_photo_urls - prev_len} urls')
  
  max_img_per_cat += 3
  iteration += 1

with open(query_results_fpath, 'wb') as fp:
  pickle.dump(id_to_url_sun, fp)
ring_bell(f'\nGot {len(id_to_url_sun) - initial_len} additional urls')

In [None]:
# TODO TEMP DELETE

id_to_url_sun = OrderedDict()
query_results_fpath = os.path.join(DATA_PATH_SUN, 'id2url.pickle')
fpaths = query_sun_filepaths(DATA_PATH_SUN, kept_cats, max_img_per_cat)

MIN_SIZE = 256
get_i2u_sun(id_to_url_sun, fpaths, MIN_SIZE, 1500)

with open(query_results_fpath, 'wb') as fp:
  pickle.dump(id_to_url_sun, fp)

## 3) Download and process images

- Num img to get: in paper was ~14k train / ~3k test

### Helper functions

In [None]:
def download_process_images(image_urls, img_dir, aug_processors, blacklist, download=True):
  print(f'Max images to get: {len(image_urls)}\n')
  
  if download:
    print(f'starting download...')
  else:
    print(f'processing local images...')

  os.makedirs(img_dir, exist_ok=True)

  uid_to_path = OrderedDict()

  MAX_DL_SIZE_B = 5e10
  total_size_b = 0
  num_imgs_dld = 0

  for u in tqdm(image_urls):
    id = u[0]
    url = u[-1]

    if download:
      fpath = os.path.join(img_dir, f'{id}.jpg')
    else:
      fpath = url

    assert fpath is not None

    if fpath in blacklist:
      if VERBOSE:
        print(f'skipping blacklisted file: {fpath}...')
      continue

    if not os.path.isfile(fpath):
      if download:
        ! curl -s {url} > {fpath} 
        if VERBOSE:
          print(f'could not find file {fpath}, downloading...')
      else:
        continue
    
    try: 
      dl_size = os.path.getsize(fpath)

      image = plt.imread(fpath)

      if download:
        num_ch = len(image.shape)
        if num_ch != 3 or (num_ch == 3 and image.shape[-1] != 3):
          blacklist.add(fpath)
          
          if VERBOSE:
            print(f'skipping non rgb image {fpath}...')
            #print(image.shape)
          continue

      if len(aug_processors) == 0:
        uid_to_path[id] = fpath

      for proc in aug_processors:
        n_id, n_path = proc(id, fpath)
        if n_id is not None and n_path is not None:
          uid_to_path[n_id] = n_path

    except OSError:
      continue

    total_size_b += dl_size
    num_imgs_dld += dl_size > 0

    if total_size_b >= MAX_DL_SIZE_B and download:
      print('exceeded max download size')
      break

  id_to_uid = dict(enumerate(uid_to_path.keys()))
  size_factor = max(len(aug_processors), 1)
  approx_size_k = int(total_size_b/1000 * size_factor)
  
  if download:
    print(f'download complete, got ~{approx_size_k}K')

  for img_path in blacklist:
    try:
      os.remove(img_path)
    except OSError:
      continue
  
  return uid_to_path, id_to_uid
  

### Image processors

In [None]:
class ImageProcessor():
  def __init__(self, transform, label=''):
    self.transform = transform
    self.label = label

  def __call__(self, img_id, img_path):
    
    assert img_id is not None and img_path is not None, 'No id or path'
    assert os.path.getsize(img_path) > 0, 'empty file'

    try:
      image = Image.open(img_path)
      if img_path.split('.')[-1].lower() in ('jpg', 'jpeg', 'jpe', 
                                             'jif', 'jfif', 'jfi', 'tif', 'tiff'):
        exif_data = image._getexif()

    except AttributeError as ae:
      #print(f'{img_path}:{ae}')
      return None, None

    except (IOError, ValueError) as ioe:
      #print(ioe)
      return None, None
    
    new_id = img_id
    if len(self.label):
      new_id += f'-{self.label}'
    new_path = img_path.replace(img_id, new_id)
    
    if not os.path.isfile(new_path):
      transformed = self.transform(image)
      transformed.save(new_path)
    
    if os.path.getsize(new_path) > 0:
      return new_id, new_path
    
    return None, None


In [None]:
def add_gaussian_noise(img, mean = 0, std = 1):

  img_np = np.array(img)
  noise = np.random.normal(mean, std, img_np.shape)
  noisy = img + noise
  noisy = np.uint8((noisy / noisy.max()) * 255.)
  
  return Image.fromarray(noisy)

In [None]:
SQ_SIZE = 256

random_crop256 = tvt.Compose([
    tvt.Resize(SQ_SIZE, interpolation=Image.LANCZOS),
    tvt.RandomCrop(size=SQ_SIZE)
])
base_transform = ImageProcessor(random_crop256)
proc_rc1 = ImageProcessor(random_crop256, 'rc1')
proc_rc2 = ImageProcessor(random_crop256, 'rc2')
proc_rc3 = ImageProcessor(random_crop256, 'rc3')

flipX = tvt.Compose([
    random_crop256,
    tvt.RandomHorizontalFlip(p = 1.)
])
proc_flip_x = ImageProcessor(flipX, 'flip-x')

flipY = tvt.Compose([
    random_crop256,
    tvt.RandomVerticalFlip(p = 1.)
])
proc_flip_y = ImageProcessor(flipY, 'flip-y')

noise_l = lambda img: add_gaussian_noise(img, mean=0.1, std=0.05)
add_noise = tvt.Compose([
    random_crop256,
    tvt.Lambda(noise_l)
])
proc_noise = ImageProcessor(add_noise, 'n')

### Train/test/val split the urls and download/process the images

In [None]:
VERBOSE = False
# either use flickr or SUN
USE_FLICKR = False #@param{type:'boolean'}

PHOTOS_TO_USE =  2000#@param{type:'integer'}

id_to_url = id_to_url_flickr if USE_FLICKR else id_to_url_sun 

train_imgs, test_imgs = train_test_split(list(id_to_url.items())[:PHOTOS_TO_USE], 
                                         test_size = 0.1, shuffle=False)
train_imgs, valid_imgs = train_test_split(train_imgs, 
                                          test_size = 0.1, shuffle=False)

print(f'train samples orig.:\t {len(train_imgs)}')
print(f'val. samples orig.:\t {len(valid_imgs)}')
print(f'test samples orig.:\t {len(test_imgs)}')
print('\n')

# ====================================================

data_path = DATA_PATH_FLICKR if USE_FLICKR else DATA_PATH_SUN
IMAGE_PATH = os.path.join(data_path, 'images')

BLACKLIST = set()
blacklist_path = os.path.join(data_path, 'blacklist.pickle')
if os.path.isfile(blacklist_path):
  with open(blacklist_path, 'rb') as fp:
    initial_blacklist = pickle.load(fp)
    BLACKLIST.update(initial_blacklist)

start_time = time.time()
train_procs = [proc_rc1, proc_rc2, proc_flip_x, proc_noise]

u2p_train, i2u_train = download_process_images(
    train_imgs, 
    IMAGE_PATH,
    train_procs,
    BLACKLIST,
    download = USE_FLICKR)

u2p_val, i2u_val = download_process_images(
    valid_imgs, 
    IMAGE_PATH,
    train_procs,
    BLACKLIST,
    download = USE_FLICKR)

test_procs = []

u2p_test, i2u_test = download_process_images(
    test_imgs, 
    IMAGE_PATH,
    test_procs,
    BLACKLIST,
    download = USE_FLICKR)

with open(blacklist_path, 'wb') as fp:
  pickle.dump(BLACKLIST, fp)

print(f'\ntrain samples aug.:\t {len(i2u_train)}')
print(f'val. samples aug.:\t {len(i2u_val)}')
print(f'test samples aug.:\t {len(i2u_test)}')

stat_str = 'local data processed' if not USE_FLICKR else \
  'data downloaded and processed'
stat_str = f'\n{stat_str} \n\ntook {time.time()-start_time:.2f}s\n'

ring_bell(stat_str)

### Sanity check and cleanup

In [None]:
num_bad_images = !find {IMAGE_PATH} -type f -size 0 | wc -l
num_bad_images = int(num_bad_images[0])

assert num_bad_images == 0, f'found {num_bad_images} bad images'

PURGE = True #@param{type:'boolean'}
if PURGE and num_bad_images:
  ! find {IMAGE_PATH} -type f -size 0 -delete
  print('deleted bad images')

### Custom Dataset

In [None]:
class ColorizeDataset(Dataset):
  def __init__(self, uid2path, id2uid, cat_fn = None, is_test = False, resize = None):
    super().__init__()

    assert len(uid2path) == len(id2uid), 'dataset maps length should match'
    
    self.uid2path = uid2path
    self.id2uid = id2uid
    self.cat_fn = cat_fn
    self.is_test = is_test

    self.resize = resize
  
  def __getitem__(self, id):
    if id not in self.id2uid.keys():
      raise IndexError

    path = self.uid2path[self.id2uid[id]]
    image = cv2.imread(path)

    assert image is not None

    if self.resize is not None:
      image = cv2.resize(image, (self.resize, self.resize), interpolation = cv2.INTER_LANCZOS4)

    # normalize image prior to conversion from BGR
    image = np.float32(image * 1./255)
    # convert image to yuv in place, 
    image = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb)

    # input, luminance channel
    luma = torch.from_numpy(image[..., 0]) 
    # output, chroma channels, binned
    chroma = None
    if self.cat_fn is not None:
      chroma = self.cat_fn(image[..., 1:])
      chroma = torch.from_numpy(chroma)

    # we'll only test the grayscale image and need the path for the original file
    if self.is_test:
      return luma, path

    return luma, chroma, path

  def __len__(self):
    return len(self.id2uid)

# Network architecture

implemented from https://arxiv.org/pdf/1811.03120.pdf

### Cells

In [None]:
class DownConvCell(nn.Module):
  def __init__(self, ich, och):
    super().__init__()

    self.conv1 = nn.Conv2d(ich, och, 3, 1, 1)
    self.conv2 = nn.Conv2d(och, och, 3, 1, 1)
    
    self.nl = nn.ReLU()
    
    self.mp = nn.MaxPool2d(2, 2)

  def __call__(self, x):
    x = self.nl(self.conv1(x))
    x = self.nl(self.conv2(x))

    x = self.mp(x)
    
    return F.layer_norm(x, x.shape[1:])

In [None]:
class UpConvCell(nn.Module):
  def __init__(self, ich, och):
    super().__init__()

    self.upconv = nn.ConvTranspose2d(ich, och, 4, 2, 1)
    self.conv1 = nn.Conv2d(och, och, 3, 1, 1)
    self.conv2 = nn.Conv2d(och, och, 3, 1, 1)
    
    self.nl = nn.ReLU()

  def __call__(self, x):
    x = self.nl(self.upconv(x))
    #print(x.shape)

    x = self.nl(self.conv1(x))
    #print(x.shape)
    
    x = self.nl(self.conv2(x))
    #print(x.shape)
    
    return F.layer_norm(x, x.shape[1:])

In [None]:
class OutputCell(nn.Module):
  def __init__(self, ich, och):
    super().__init__()

    self.upconv = nn.ConvTranspose2d(ich, ich, 4, 2, 1)
    self.conv1 = nn.Conv2d(ich, ich, 3, 1, 1)
    self.conv2 = nn.Conv2d(ich, och, 3, 1, 1)
    
    self.nl = nn.ReLU()

  def __call__(self, x):
    x = self.nl(self.upconv(x))
    x = self.nl(self.conv1(x))
    
    return self.conv2(x)

### Model

In [None]:
class ColorUnet(nn.Module):
  def __init__(self, input_ch, num_cls):
    super().__init__()

    self.input_ch = 1
    self.num_cls = num_cls
    
    self.downConv1 = DownConvCell(input_ch, num_cls // 2)
    self.downConv2 = DownConvCell(num_cls // 2, num_cls)
    self.downConv3 = DownConvCell(num_cls, 2 * num_cls)

    self.upConv1 = UpConvCell(2 * num_cls, num_cls)
    self.upConv2 = UpConvCell(2 * num_cls, num_cls // 2)

    self.outCell = OutputCell(num_cls, num_cls)
  
  def __call__(self, x):

    # input shape is (B, H, W)    
    # reshape it as (B, C, H, W)
    x.unsqueeze_(1)
    #print(x.shape)

    # save down conv cell 1 output
    x_1 = self.downConv1(x)
    #print('x1s', x_1.shape)

    # save down conv cell 2 output
    x_2 = self.downConv2(x_1)
    #print('x2s', x_2.shape)

    # apply down conv 3 and up conv 1
    x = self.downConv3(x_2)
    #print('xs dc3', x.shape)
    
    x = self.upConv1(x)
    #print('xs uc1', x.shape)

    # concat channelsand apply up conv 2
    x = torch.cat((x, x_2), axis = 1)
    #print('xs pre-uc2 cat', x.shape)
    
    x = self.upConv2(x)
    #print('xs uc2', x.shape)

    # concat channels and apply output layer
    x = torch.cat((x, x_1), axis = 1)
    #print('xs pre-out cat', x.shape)
    
    x = self.outCell(x)
    #print('xs out', x.shape)

    # output is (B, C, H, W)
    # reshape X to be (B, H, W, C)
    x = x.permute(0, 2, 3, 1)

    return x

# Train/test

### Helpers

In [None]:
def train_step(model, criterion, optimizer, loader, device):
  epoch_loss = 0
  n_iter = 0
  
  model.train()
  for i, data in tqdm(enumerate(loader), desc='train', total=len(loader)):
    luma = data[0].float().to(device)
    chroma_cat = data[1].long().to(device)

    out = model(luma)

    # input / target shape is (B, H, W, C), loss expects (B, C, ...)
    loss = criterion(out.permute(0, 3, 1, 2), chroma_cat.squeeze())

    loss.backward()

    with torch.no_grad():
      optimizer.step()

      epoch_loss += loss.item()
      n_iter += 1
  
  return epoch_loss/n_iter

In [None]:
def valid_step(model, criterion, loader, device):
  epoch_loss = 0
  n_iter = 0
  
  model.eval()
  with torch.no_grad():
    for i, data in tqdm(enumerate(loader), desc='validate', total=len(loader)):
      luma = data[0].float().to(device)
      chroma_cat = data[1].long().to(device)

      out = model(luma)
      
      # input / target shape is (B, H, W, C), loss expects (B, C, ...)
      loss = criterion(out.permute(0, 3, 1, 2), chroma_cat.squeeze())

      epoch_loss += loss.item()
      n_iter += 1
  
  return epoch_loss/n_iter

In [None]:
def gen_model_name(ds_size, num_epochs, prefix='m', suffix=''):
  num_dec = 2
  if ds_size >= 1000:
    num_dec = 0
  
  name = [prefix, f'{num_epochs}e', f'{ds_size/1000:.{num_dec}f}k']
  if len(suffix):
    name += [suffix]
  
  return '-'.join(name)

## Dataloader init

In [None]:
BATCH_SIZE = 48

dl_args = {
    'batch_size': BATCH_SIZE, 
    'shuffle': True,
    'num_workers': 3,
    'pin_memory': CUDA,
    'drop_last': True
}

In [None]:
train_dataset = ColorizeDataset(u2p_train, i2u_train, cat_fn=px_to_uvcat)
train_loader = DataLoader(
    train_dataset,
    **dl_args
)

val_dataset = ColorizeDataset(u2p_val, i2u_val, cat_fn=px_to_uvcat)
val_loader = DataLoader(
    val_dataset,
    **dl_args
)

test_dataset = ColorizeDataset(u2p_test, i2u_test, is_test=True, resize=SQ_SIZE)

## Sanity checks

In [None]:
path = u2p_test[i2u_test[0]]
image = cv2.imread(path)
print(f'read img from {path}')
print(image.shape)
print(type(image))

In [None]:
Ys, UV_cats = next(iter(train_loader))

print(Ys.shape)
print(UV_cats.shape)

#print(Ys[0][:3, :3])
print(torch.min(Ys[0]).item())
print(torch.max(Ys[0]).item())

#print(UV_cats[0][:3, :3])

print(torch.min(UV_cats[0]).item())
print(torch.max(UV_cats[0]).item())

In [None]:
luma, p = test_dataset[0]

print(luma.shape)

print(luma.dtype)
print(luma.min().item())
print(luma.max().item())

print(p)

In [None]:
test_model = ColorUnet(1, 32)
test_model.eval()

luma, _ = test_dataset[0]

luma.unsqueeze_(0)
print(luma.shape)

y = test_model(luma)
print(y.shape)

del test_model

## Training

### Model init

In [None]:
MODELS_PATH = os.path.join(HOME, 'models')
os.makedirs(MODELS_PATH, exist_ok = True)

num_epochs =  25#@param{type:'integer'}
lr =  1e-5 #@param{type:'number'}
log_every = 1 #@param{type:'integer'}

# early stopping
stop_early = True #@param{type:'boolean'}
patience = 2 #@param{type:'integer'}
delta_tol =  1e-3#@param{type:'number'}

# model init
input_ch = 1 #@param{type:'integer'}
model = ColorUnet(input_ch, NUM_CAT).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr)

rebalance = False #@param{type:'boolean'}
# load weights for rebalancing
weights_fp = '100-6-0.4-small'#@param{type:'string'}

if rebalance and len(weights_fp):
  weights = np.load(os.path.join(HOME, f'sun-w-{weights_fp}.npy'), allow_pickle=True)
  weights = torch.from_numpy(weights).float().to(DEVICE)
  criterion = nn.CrossEntropyLoss(weight=weights)
  print(f'using {weights_fp} weights for rebalancing')
else:
  criterion = nn.CrossEntropyLoss()

suffix = 'sun-6cls-no-reb'#@param{type:'string'}
model_name = gen_model_name(len(train_dataset), num_epochs, suffix=suffix)
print(f'model name initial: {model_name}')

### Actual training

In [None]:
train_loss = []
val_loss = []

best_loss = 0
num_bad_epochs = 0
for e in range(num_epochs):
  print(f'epoch {e+1}/{num_epochs}:\n')
  epoch_start = time.time()
  
  t_loss = train_step(model, criterion, optimizer, train_loader, DEVICE)
  v_loss = valid_step(model, criterion, val_loader, DEVICE)

  train_loss.append(t_loss)
  val_loss.append(v_loss)

  found_best_loss = False
  if e == 0:
    best_loss = v_loss
  else:
    if v_loss <= best_loss:
      best_loss = v_loss
      num_bad_epochs = 0
      found_best_loss = True
    elif v_loss - best_loss > delta_tol:
      num_bad_epochs += 1

  if e % log_every == 0:
    print(f'train. loss:\t{t_loss}')
    print(f'valid. loss:\t{v_loss}')
    if found_best_loss:
      print('encountered lowest val. loss')

    print(f'\ntook {(time.time() - epoch_start):.3f}s')
    print('---------------------------------\n')
  
  if num_bad_epochs >= patience and stop_early:
    print(f'early stopping after {patience} bad epochs. trained for {e+1} epochs')
    model_name = model_name.replace(f'{num_epochs}e-', f'{e+1}e-')
    break

model_path = os.path.join(MODELS_PATH, f'{model_name}.pt')
print(f'model will be saved at: {model_path}')
torch.save(model.state_dict(), model_path)

ring_bell('training over')

In [None]:
plt.plot(train_loss, label='train', color='red')
plt.plot(val_loss, label='valid.', color='cyan')
plt.legend()

## Inference

### Optional. Load model state from disk

In [None]:
model_name = ''#@param{type:'string'}
m_path = os.path.join(MODELS_PATH, f'{model_name}.pt')

model_state = None
try:
  model_state = torch.load(m_path)

  model.load_state_dict(model_state)
  print(f'loaded model state for predict: {m_path}')

except FileNotFoundError as e:
  model_state = None
  print(f'could not load model state from {m_path}')

### Prediction for TEST

In [None]:
VERBOSE = False #@param{type:'boolean'}

PRED_PATH = os.path.join(HOME, 'predict', model_name)
MAX_SAVED =  25#@param{type:'integer'}
MAX_SAVED = min(MAX_SAVED, len(train_dataset))

print(f'predicted images will be saved at {PRED_PATH}')
PURGE = True #@param{type:'boolean'}
if PURGE:
  print('deleted prev. predictions')
  ! rm -rf {PRED_PATH}
os.makedirs(PRED_PATH, exist_ok=True)

temps = [0.05, 0.1, 0.3, 0.6, 1]
print('starting prediction...\n')

model.eval()
start_time = time.time()
with torch.no_grad():
  for i, data in tqdm(enumerate(test_dataset), desc='predict', total=MAX_SAVED):
    if i >= MAX_SAVED:
      break

    Y, p = data
    
    Y = Y.float().to(DEVICE).unsqueeze(0)
    Z = torch.softmax(model(Y), -1).cpu().numpy()
    
    if VERBOSE:
      print('\nY')
      #print(Y.shape)
      #print(Y)
      print(Y.min().item())
      print(Y.max().item())

      print('\nZ')
      #print(Z)
      print(np.min(Z))
      print(np.max(Z))
      #print(Z.shape)
      pass

    orig_img = cv2.imread(p)
    assert orig_img is not None, f'Could not find original image at: {p}'

    # save a thumbnail of the ground truth image
    new_size = (orig_img.shape[1]//2, orig_img.shape[0]//2)
    orig_img_thumb = cv2.resize(orig_img, new_size)
    orig_id = os.path.split(p)[-1].replace('.jpg', '-gt-thumb.jpg')
    cv2.imwrite(os.path.join(PRED_PATH, orig_id), orig_img_thumb)

    # normalize gt image prior to conversion from BGR
    orig_img = np.float32(orig_img * 1./255)
    # convert image to yuv in place, 
    orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2YCrCb)

    # save the luminance of the original image
    Y_orig = orig_img[..., 0]

    # save the grayscale image as well
    gray_id = os.path.split(p)[-1].replace('.jpg', '-gray.jpg')
    cv2.imwrite(os.path.join(PRED_PATH, gray_id), Y_orig * 255)

    for t in temps:
      UV = uv_px_from_z(Z, t).squeeze()

      # sanity checks
      if VERBOSE:
        #print(UV)
        print(f'\ntemp: {t}')
        print('-------------')

        print('\nU:')
        print(np.min(UV[..., 0]))
        print(np.max(UV[..., 0]))
        print('\nV:')
        print(np.min(UV[..., 1]))
        print(np.max(UV[..., 1]))

        print('\nUV shape:', UV.shape)
        pass
      
      UV = cv2.resize(UV, Y_orig.shape[:2][::-1], interpolation = cv2.INTER_LANCZOS4)
      
      pred_img = np.concatenate((np.expand_dims(Y_orig, -1), UV), axis=-1)
      
      if VERBOSE:
        print('\nY after resize and concat')
        print(np.min(pred_img[..., 0]))
        print(np.max(pred_img[..., 0]))

        print('\nU after resize and concat')
        print(np.min(pred_img[..., 1]))
        print(np.max(pred_img[..., 1]))

        print('\nV after resize and concat')
        print(np.min(pred_img[..., 2]))
        print(np.max(pred_img[..., 2]))
        pass

      pred_img_cv = cv2.cvtColor(pred_img, cv2.COLOR_YCrCb2RGB)
      pred_img_cv = np.clip(pred_img_cv, 0, 1) * 255

      if VERBOSE:
        print('\n\n')

        print('R after rgb conversion')
        print(np.min(pred_img_cv[..., 0]))
        print(np.max(pred_img_cv[..., 0]))

        print('G after rgb conversion')
        print(np.min(pred_img_cv[..., 1]))
        print(np.max(pred_img_cv[..., 1]))

        print('B after rgb conversion')
        print(np.min(pred_img_cv[..., 2]))
        print(np.max(pred_img_cv[..., 2]))
        pass

      fname = os.path.split(p)[-1].replace('.jpg', f'-pred-{t}.jpg')
      fname = os.path.join(PRED_PATH, f'{fname}')

      cv2.imwrite(fname, pred_img_cv)

ring_bell(f'\nprediction done for {i} images')
print(f'took {(time.time() - start_time):.2f}s')

### Prediction for TRAIN

In [None]:
VERBOSE = False #@param{type:'boolean'}

PRED_PATH = os.path.join(HOME, 'predict-TRAIN', model_name)
MAX_SAVED =  25#@param{type:'integer'}
MAX_SAVED = min(MAX_SAVED, len(train_dataset))

print(f'predicted images will be saved at {PRED_PATH}')
PURGE = True #@param{type:'boolean'}
if PURGE:
  print('deleted prev. predictions')
  ! rm -rf {PRED_PATH}
os.makedirs(PRED_PATH, exist_ok=True)

temps = [0.3, 1]
print('starting prediction...\n')

model.eval()
start_time = time.time()
with torch.no_grad():
  for i, data in tqdm(enumerate(train_dataset), desc='predict-T', total=MAX_SAVED):
    if i >= MAX_SAVED:
      break

    Y, UV, p = data
    orig_img = cv2.imread(p)

    # save the orig image
    orig_id = os.path.split(p)[-1].replace('.jpg', '-gt-thumb.jpg')
    cv2.imwrite(os.path.join(PRED_PATH, orig_id), orig_img)

    # normalize gt image prior to conversion from BGR
    orig_img = np.float32(orig_img * 1./255)
    # convert image to yuv in place, 
    orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2YCrCb)

    # save the luminance of the original image
    Y_orig = orig_img[..., 0]

    # save the grayscale image as well
    gray_id = os.path.split(p)[-1].replace('.jpg', '-gray.jpg')
    cv2.imwrite(os.path.join(PRED_PATH, gray_id), Y_orig * 255)

    Y = Y.float().to(DEVICE).unsqueeze(0)
    Z = torch.softmax(model(Y), -1).cpu().numpy()

    for t in temps:
      UV = uv_px_from_z(Z, t).squeeze()
      
      pred_img = np.concatenate((np.expand_dims(Y_orig, -1), UV), axis=-1)

      pred_img_cv = cv2.cvtColor(pred_img, cv2.COLOR_YCrCb2RGB)
      pred_img_cv = np.clip(pred_img_cv, 0, 1) * 255

      fname = os.path.split(p)[-1].replace('.jpg', f'-pred-{t}.jpg')
      fname = os.path.join(PRED_PATH, f'{fname}')

      cv2.imwrite(fname, pred_img_cv)

ring_bell(f'\nprediction done for {i} images')
print(f'took {(time.time() - start_time):.2f}s')

# Predict FOR VAL

In [None]:
VERBOSE = False #@param{type:'boolean'}

PRED_PATH = os.path.join(HOME, 'predict-VAL', model_name)
MAX_SAVED =  25#@param{type:'integer'}
MAX_SAVED = min(MAX_SAVED, len(train_dataset))

print(f'predicted images will be saved at {PRED_PATH}')
PURGE = True #@param{type:'boolean'}
if PURGE:
  print('deleted prev. predictions')
  ! rm -rf {PRED_PATH}
os.makedirs(PRED_PATH, exist_ok=True)

temps = [0.3, 1]
print('starting prediction...\n')

model.eval()
start_time = time.time()
with torch.no_grad():
  for i, data in tqdm(enumerate(val_dataset), desc='predict-V', total=MAX_SAVED):
    if i >= MAX_SAVED:
      break

    Y, UV, p = data
    orig_img = cv2.imread(p)

    # save the orig image
    orig_id = os.path.split(p)[-1].replace('.jpg', '-gt-thumb.jpg')
    cv2.imwrite(os.path.join(PRED_PATH, orig_id), orig_img)

    # normalize gt image prior to conversion from BGR
    orig_img = np.float32(orig_img * 1./255)
    # convert image to yuv in place, 
    orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2YCrCb)

    # save the luminance of the original image
    Y_orig = orig_img[..., 0]

    # save the grayscale image as well
    gray_id = os.path.split(p)[-1].replace('.jpg', '-gray.jpg')
    cv2.imwrite(os.path.join(PRED_PATH, gray_id), Y_orig * 255)

    Y = Y.float().to(DEVICE).unsqueeze(0)
    Z = torch.softmax(model(Y), -1).cpu().numpy()

    for t in temps:
      UV = uv_px_from_z(Z, t).squeeze()
      
      pred_img = np.concatenate((np.expand_dims(Y_orig, -1), UV), axis=-1)

      pred_img_cv = cv2.cvtColor(pred_img, cv2.COLOR_YCrCb2RGB)
      pred_img_cv = np.clip(pred_img_cv, 0, 1) * 255

      fname = os.path.split(p)[-1].replace('.jpg', f'-pred-{t}.jpg')
      fname = os.path.join(PRED_PATH, f'{fname}')

      cv2.imwrite(fname, pred_img_cv)

ring_bell(f'\nprediction done for {i} images')
print(f'took {(time.time() - start_time):.2f}s')

## Cleanup gpu memory

In [None]:
try:
  del model
  del optimizer
  del criterion
  
except NameError:
  pass

torch.cuda.empty_cache()
gc.collect()

! nvidia-smi | grep MiB