# Relighting a real image with user selection
This notebook loads an interactivate demo for our user selective relighting method

In [None]:
import os, sys, inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model

from importlib import reload
from rewrite_utils import zdataset, show, labwidget, paintwidget, renormalize, nethook, imgviz, pbar, smoothing
from PIL import Image
import numpy as np
import torch

We first load the relighting model. The model is downloaded into the appropriate folder by running setup.sh.


In [None]:
#load model
cmd = ('--name selective '
       '--netG modulated ' 
       '--input_nc 3 ' 
       '--label_nc 0 ' 
       '--dataroot /datasets/lsun_bedrooms/ ' 
       '--which_epoch 200 ' 
      ).split()

opt = TestOptions().parse(save=False, cmd=cmd)
model = create_model(opt)

Load an image to be relit

In [None]:
#load image
import torchvision.transforms as transforms
from data.base_dataset import __scale_width
transforms = transforms.Compose([transforms.Lambda(lambda img: __scale_width(img, opt.loadSize)), 
              transforms.ToTensor(),
              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#insert path to your image here
img_path = 'test_images/lsunbedroom.webp'

dims = (256, 256)
image = Image.open(img_path).resize(dims, Image.BILINEAR)
show(image)
baseline = transforms(image.convert('RGB'))

Mask the region containing the light source that you would like to relight

In [None]:
pw = paintwidget.PaintWidget(image=renormalize.as_url(baseline))
def do_reset():
    pw.mask = ''
reset_btn = labwidget.Button('reset').on('click', do_reset)
show([pw, reset_btn])

In [None]:
#get and threshold mask
mask = renormalize.from_url(pw.mask, target='pt', size=baseline.shape[1:])[0]
mask[mask>0.5] =1
mask[mask<=0.5] =0
mask = mask
show(renormalize.as_image(mask[None]))

In [None]:
#blur mask - blurring the mask gives us realistic edits
sigma = float(1024) / 16.0
kernel_size = (int(sigma) * 2 - 1)
blur = smoothing.GaussianSmoothing(1, kernel_size, sigma=sigma)
mask_ = blur(mask[None, None])
show(renormalize.as_image(mask_[0].repeat(3, 1, 1)))

In [None]:
#visualize mask on original image
mask_copy = torch.clone(mask_)
mask_img = mask_copy[None]#.repeat(3, 1, 1)

temp = torch.clone(mask_copy)
temp[temp>0] = torch.max(baseline)
temp3 = torch.zeros(temp.shape)
temp2 = [temp3, temp3, temp]
temp2 = torch.stack(temp2)
show(renormalize.as_image(baseline*(1-mask_img)+ temp2))

Control the relighting intensity (and color :)) using the sliders. Widgets originally implemented in https://github.com/davidbau/rewriting

In [None]:
#widget for controlling selected light source

reload(labwidget)
lamp = labwidget.Range()
im = labwidget.Image()
lc = labwidget.ColorPicker('#ffffff', desc='lamp light color: ')


show([['lamp intensity', lamp, lc, im]])


def get_lit_scene(image, frac):
    generated = model.inference(image.unsqueeze(0), mask_, amount=frac)
    return generated

im.render(renormalize.as_image(baseline))

def readcolor(value):
    try:
        floatcolor = [float(int(value[i:i+2], 16))/255.0  for i in [1,3,5]]
        color = torch.tensor(floatcolor).float()
        #print('color', color)
        if len(color) == 3:
            return color
    except:
        pass
    return torch.tensor([1.0, 1.0, 1.0]).float()

def newimage():
    def norm_value(vals): 
        return np.array([(float(val) * 2 - 100) / 100.0 for val in vals])
        
    lit_lamp = get_lit_scene(baseline, norm_value([lamp.value])).cpu()
    
    lamp_light = lit_lamp - baseline
    
    lamp_color = readcolor(lc.value)[:,None,None]
    
    colored = baseline + (lamp_light * lamp_color) 
    im.render(renormalize.as_image(colored[0]))

lamp.on('value', newimage)
#lamp_color.on('value', newimage)
lc.on('value', newimage)
None