In [None]:
import torch
import os
from matplotlib import pyplot as plt
import time
from PIL import Image
from Model.triplanelite import triplane_fea
from Dataloader.LLFF import LLFFDataset
from Processing.rendering import novel_views_LLFF, creat_video, render_img
from Processing.vis import load_settings, calc_query_emb, calc_feature_dist
from Processing.editing import get_comb_img, mkdir_ifnoexit, save_img, edit_color_list
from Processing.trainer import SimpleSampler

In [None]:
# 0. general settings
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
scene = 'flower'
load_features = False
show_selection = False

# 1. dataset settings
datadir= os.path.join('./Dataset/nerf_llff_data', scene)
fea_dir = None
LLFF_training = LLFFDataset(datadir, fea_path = fea_dir, split='train',load_features=load_features, downsample=8)
LLFF_test = LLFFDataset(datadir, fea_path = fea_dir, split='test',load_features=load_features, downsample=8)
trainingSampler = SimpleSampler(len(LLFF_training), 4096)
downsample = 8

# 2. model settings
pretrained_model_path = './pre_trained_models/' + scene + '.pth'
aabb = torch.tensor([-1.7, 1.7])
nerf_model = triplane_fea(aabb = aabb)
nerf_model.load_state_dict(torch.load(pretrained_model_path, map_location=torch.device(device)))
nerf_model.to(device)

# 1. Query a patch in rendering view

In [None]:
# render a frame from the test dataset
rgb_flower, emb_flower, depth_flower, mask_flower = render_img(nerf_model = nerf_model, device= device, Dataset = LLFF_test, img_index = 0, hn = 0, hf = 1, nb_bins = 96, req_others=True)

In [None]:
# select a patch from the rendered image and calculate the feature distance
# modify from N3F official code, https://github.com/dichotomies/N3F
settings = load_settings()[scene]
factor = downsample
r, c = settings['rc']
extent = settings['sz']
r = int(r * 8 / factor)
c = int(c * 8 / factor)
extent = int(extent * 8 / factor)
img_w, img_h = LLFF_test.img_wh
embq, dir_q = calc_query_emb(emb_flower, r, c, extent, rgb=rgb_flower, vis = True)
dist = calc_feature_dist(embq, emb_flower)
plt.figure(figsize=(4,3))
plt.hist(dist.view(-1).cpu().numpy(), bins=20, density=True, alpha=0.5, label='Ditilled Triplanes') 
plt.show()
plt.close()   

In [None]:
# show example: use the threshold dis_thr to get the mask of the selected region
# modify from N3F official code, https://github.com/dichotomies/N3F
rgb_j_fg, emb_j_fg, depth_j_fg, mask_j_fg = render_img(nerf_model = nerf_model, device= device, \
                            Dataset = LLFF_test, img_index = 0, hn = 0, hf = 1, nb_bins = 96,\
                            req_others = True, embq=embq, dis_thr=settings['thr'] + settings['margin'], 
                            foreground=False, show_selection=True)
plt.figure(figsize=(20, 8))
plt.subplot(1, 3, 1)
plt.imshow(rgb_j_fg)
plt.subplot(1, 3, 2)
plt.imshow(depth_j_fg, cmap = 'gray')
plt.subplot(1, 3, 3)
plt.imshow(mask_j_fg, cmap = 'gray')
plt.show()

In [None]:
# if you would like to see selection rendering results (videos) you can set show_selection to True:
if show_selection:
    foldername = time.strftime("%Y%m%d-%H%M%S")
    if not os.path.exists('./render_results'):
        os.mkdir('./render_results')
    folderpath = os.path.join('./render_results', foldername)
    os.mkdir(folderpath)
    
    name = 'wo_selection'
    novel_views_path = novel_views_LLFF(folderpath, name, nerf_model, device, LLFF_test, hn = 0, hf = 1, nb_bins = 96,
                     req_others = True)
    creat_video(novel_views_path, folderpath, name, req_others=False)
    name = 'w_selection'
    novel_views_path = novel_views_LLFF(folderpath, name, nerf_model, device, LLFF_test, hn = 0, hf = 1, nb_bins = 96,
                     req_others = True, dis_thr = settings['thr'] + settings['margin'], 
                     embq = embq, dist_less=False, show_selection = True)
    creat_video(novel_views_path, folderpath, name, req_others=False)

# 2. Save and edit 3D-aware image context

In [None]:
save_data = False # make it true if you would like to create your own edited image
save_data_path = './Dataset/edited_imgs/' + scene

In [None]:
if save_data == True:
    gt_4, rgb_4, depth_4, mask_4, ray_4 = get_comb_img(nerf_model, LLFF_training, device)
    gt_4_s, rgb_4_s, depth_4_s, mask_4_s, ray_4_s = get_comb_img(nerf_model, LLFF_training, device,
                                                        embq = embq, dis_thr = settings['thr'] + settings['margin'], show_selection = True)
    
    # gt_4_s, rgb_4_s, depth_4_s, mask_4_s
    save_img(os.path.join(save_data_path, 'gt_4_s.png'), gt_4_s)
    save_img(os.path.join(save_data_path, 'rgb_4_s.png'), rgb_4_s)
    save_img(os.path.join(save_data_path, 'depth_4_s.png'), depth_4_s)
    save_img(os.path.join(save_data_path, 'mask_4_s.png'), mask_4_s)
    # gt_4, rgb_4, depth_4
    save_img(os.path.join(save_data_path, 'gt_4.png'), gt_4)
    save_img(os.path.join(save_data_path, 'rgb_4.png'), rgb_4)
    save_img(os.path.join(save_data_path, 'depth_4.png'), depth_4)
    # save ray
    torch.save(ray_4_s, os.path.join(save_data_path, 'ray_4_s.pt'))
    torch.save(ray_4, os.path.join(save_data_path, 'ray_4.pt'))


Then, you can use any image editing tool to edit **rgb_4.png** and save edited image to a folder. Load the edited image as following.

In [None]:
# show one example of the edited image
example_edited_rgb = Image.open(os.path.join(save_data_path, 'color_changed','rgb_4_highblue.png'))

In [None]:
example_edited_rgb

# 3. Edit Appearance 

In [None]:
# create a folder to save editing results
foldername = time.strftime("%Y%m%d-%H%M%S")
savepath_father = os.path.join("./editing_results", foldername)
mkdir_ifnoexit("./editing_results")
mkdir_ifnoexit(savepath_father)

In [None]:
# setting
# --- example path
edit_datapath_mlp1 = os.path.join(save_data_path, 'color_changed')
edit_datapath_mlp2 = os.path.join(save_data_path, 'others')
isf2c_noc2c = True # if True, we use 36KB setting phi_edit; False, we use 4KB setting phi_edit_{c2c}
edit_one_image = 'rgb_4_highblue.png' # set None, if you would like edit each image in the folder

In [None]:
c_model_dict1 = edit_color_list(edit_datapath_mlp1, save_data_path=save_data_path, nerf_model=nerf_model, 
                                savepath_father=savepath_father, device=device,
                                LLFF_test=LLFF_test, settings=settings, embq=embq,
                                isf2c_noc2c = isf2c_noc2c, edit_one_image = edit_one_image)

# 4. Layered Editing

In [None]:
# create a folder to save editing results
foldername = time.strftime("%Y%m%d-%H%M%S")
savepath_father = os.path.join("./editing_results", foldername)
mkdir_ifnoexit("./editing_results")
mkdir_ifnoexit(savepath_father)

In [None]:
# setting
# --- example path
edit_datapath_mlp1 = os.path.join(save_data_path, 'color_changed')
edit_datapath_mlp2 = os.path.join(save_data_path, 'others')
isf2c_noc2c = True # if True, we use 36KB setting phi_edit; False, we use 4KB setting phi_edit_{c2c}
# show example of the edited image
edit_layer1_image = 'rgb_4_orange.png'
edit_layer2_image = 'rgb_4_tone.png'

In [None]:
c_model_dict1 = edit_color_list(edit_datapath_mlp1, save_data_path=save_data_path, nerf_model=nerf_model, 
                                savepath_father=savepath_father, device=device,
                                LLFF_test=LLFF_test, settings=settings, embq=embq,
                                isf2c_noc2c = isf2c_noc2c, edit_one_image = edit_layer1_image)
c_model_dict2 = edit_color_list(edit_datapath_mlp2, save_data_path=save_data_path, nerf_model=nerf_model, 
                                savepath_father=savepath_father, device=device,
                                LLFF_test=LLFF_test, settings=settings, embq=embq,
                                isf2c_noc2c = isf2c_noc2c, edit_one_image = edit_layer2_image)



In [None]:
for name1, c_model1 in c_model_dict1.items():
    for name2, c_model2 in c_model_dict2.items():
        mix_model = [c_model1, c_model2]
        name = name1 + '_+_' + name2
        print(name)
        if isf2c_noc2c == True:
            novel_views_path = novel_views_LLFF(savepath_father, name, nerf_model, device, LLFF_test, hn = 0, hf = 1, nb_bins = 96,
                                dis_thr = settings['thr'] + settings['margin'], embq = embq, dist_less=True,req_others = False, f2c_models = mix_model)
        else:
            novel_views_path = novel_views_LLFF(savepath_father, name, nerf_model, device, LLFF_test, hn = 0, hf = 1, nb_bins = 96,
                                dis_thr = settings['thr'] + settings['margin'], embq = embq, dist_less=True,req_others = False, c2c_models = mix_model)
        creat_video(novel_views_path, savepath_father, name, req_others=False)