This notebook provides an minimal working example of the sky augmentation in the preprint paper "Castle in the Sky: Dynamic Sky Replacement and Harmonization in Videos, arXiv:2010.11800"

[Project Page](https://jiupinjia.github.io/skyar/) | [GitHub](https://github.com/jiupinjia/SkyAR) | [Preprint](https://arxiv.org/abs/2010.11800)

<a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc-sa/4.0/88x31.png" /></a><span xmlns:dct="http://purl.org/dc/terms/" property="dct:title">   The project </a> is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/">Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License</a>.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import glob
import argparse
from networks import *
from skyboxengine import *
import utils
import torch

%matplotlib inline

# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Download pretrained sky matting model

In [2]:
# Define some helper functions for downloading pretrained model
# taken from this StackOverflow answer: https://stackoverflow.com/a/39225039
import requests

def download_file_from_google_drive(id, destination):
    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params = { 'id' : id }, stream = True)
    token = get_confirm_token(response)

    if token:
        params = { 'id' : id, 'confirm' : token }
        response = session.get(URL, params = params, stream = True)

    save_response_content(response, destination)    

def get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value

    return None

def save_response_content(response, destination):
    CHUNK_SIZE = 32768

    with open(destination, "wb") as f:
        for chunk in response.iter_content(CHUNK_SIZE):
            if chunk: # filter out keep-alive new chunks
                f.write(chunk)

In [3]:
# download and unzip...
file_id = '1COMROzwR4R_7mym6DL9LXhHQlJmJaV0J'
destination = './checkpoints_G_coord_resnet50.zip'
download_file_from_google_drive(file_id, destination)

Config your model...

In [3]:
args = utils.parse_config(path_to_json='./config/my_config.json')

In [4]:
class PhotoFilter:
  
  def __init__(self, args):

    self.ckptdir = args.ckptdir
    self.output_dir = args.output_dir
    self.in_size_w, self.in_size_h = args.in_size_w, args.in_size_h
    self.out_size_w, self.out_size_h = args.out_size_w, args.out_size_h

    self.skyboxengine = SkyBox(args)

    self.net_G = define_G(netG=args.net_G).to(device)
    self.load_model()

    if os.path.exists(args.output_dir) is False:
        os.mkdir(args.output_dir)


  def load_model(self):
        checkpoint = torch.load(os.path.join(self.ckptdir, 'best_ckpt.pt'), map_location='cpu')
        self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
        self.net_G.to(device)
        self.net_G.eval()


  def synthesize(self, img_HD, img_HD_prev):

        h, w, c = img_HD.shape

        img = cv2.resize(img_HD, (self.in_size_w, self.in_size_h))

        img = np.array(img, dtype=np.float32)
        img = torch.tensor(img).permute([2, 0, 1]).unsqueeze(0)

        
        print(img.shape)
        with torch.no_grad():
            G_pred = self.net_G(img.to(device))
            G_pred = torch.nn.functional.interpolate(G_pred, (h, w), mode='bicubic', align_corners=False)
            G_pred = G_pred[0, :].permute([1, 2, 0])
            G_pred = torch.cat([G_pred, G_pred, G_pred], dim=-1)
            G_pred = np.array(G_pred.detach().cpu())
            G_pred = np.clip(G_pred, a_max=1.0, a_min=0.0)

        skymask = self.skyboxengine.skymask_refinement(G_pred, img_HD)
        syneth = self.skyboxengine.skyblend(img_HD, img_HD_prev, skymask)

        return syneth

    

  def cvtcolor_and_resize(self, img_HD):

        img_HD = cv2.cvtColor(img_HD, cv2.COLOR_BGR2RGB)
        img_HD = np.array(img_HD / 255., dtype=np.float32)
        img_HD = cv2.resize(img_HD, (self.out_size_w, self.out_size_h))

        return img_HD


  def process_img(self, datadir, background='jupiter.jpg', name='edited'):

        img = cv2.imread(datadir)
        x, y, _ = img.shape
        #self.skyboxengine.args.out_size_w = y
        #self.skyboxengine.args.out_size_h = x
        #self.out_size_w, self.out_size_h = y, x
        self.skyboxengine.args.skybox = background

        img_HD = self.cvtcolor_and_resize(img)
        img_HD_prev = img_HD
        
        syneth = self.synthesize(img_HD, img_HD_prev)

        img_res  = np.array(255.0 * syneth[:, :, ::-1], dtype=np.uint8)

        #print(x, y, img_res.shape)
        cv2.imwrite(self.output_dir + '/' + name + '.jpg', cv2.resize(img_res, (y, x)))
    


In [5]:
pf = PhotoFilter(args)

initialize network with normal


In [6]:
pf.process_img("./photos/2.jpg", name='2', background='jupiter.jpg')

torch.Size([1, 3, 384, 384])
initialize skybox...


In [29]:
args.skybox = "jupiter.jpg"
pf.process_img("./photos/3.jpg")

In [7]:
import tensorflow as tf
import onnx
from onnx_tf.backend import prepare

In [7]:
checkpoint = torch.load(os.path.join("./checkpoints_G_coord_resnet50", 'best_ckpt.pt'), map_location='cpu')

net = define_G(netG=args.net_G).to(device)
net.load_state_dict(checkpoint['model_G_state_dict'])

    
dummy_input = torch.rand(1, 3, 384, 384)

torch.onnx.export(net,               # model being run
                  dummy_input,                         # model input (or a tuple for multiple inputs)
                  './model.onnx',            # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'])
    

initialize network with normal


In [13]:
onnx_model = onnx.load('./model.onnx')


onnx.checker.check_model(onnx_model)



# Export model as .pb file
#tf_rep.export_graph('./model.pb')