# 0. Setting

In [None]:
from matplotlib import pyplot as plt
import time
from PIL import Image
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
from diffusers import DDIMScheduler
from transformers import pipeline
from Model.triplanelite import triplane_fea
from Dataloader.LLFF import LLFFDataset
from Processing.trainer import *
from Processing.rendering import novel_views_LLFF, creat_video, render_img
from Processing.vis import load_settings, calc_query_emb, calc_feature_dist
from Processing.editing import get_comb_img, save_history_fig, cp, \
    make_inpaint_condition, rgb2canny, mkdir_ifnoexit, context_iter

In [None]:
# 0. general settings
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
scene = 'flower'
load_features = False
show_selection = False

# 1. dataset settings
datadir= os.path.join('./Dataset/nerf_llff_data', scene)
fea_dir = None
LLFF_training = LLFFDataset(datadir, fea_path = fea_dir, split='train',load_features=load_features, downsample=8)
LLFF_test = LLFFDataset(datadir, fea_path = fea_dir, split='test',load_features=load_features, downsample=8)
downsample = 8

# 2. model settings
pretrained_model_path = './pre_trained_models/' + scene + '.pth'
aabb = torch.tensor([-1.7, 1.7])
nerf_model = triplane_fea(aabb = aabb)
nerf_model.load_state_dict(torch.load(pretrained_model_path, map_location=torch.device(device)))
nerf_model.to(device)

# 3. other setting
depth_estimator = pipeline('depth-estimation')

# 1. Query a patch and select in rendering view

In [None]:
# render a frame from the test dataset
rgb_flower, emb_flower, depth_flower, mask_flower = render_img(nerf_model = nerf_model, device= device, Dataset = LLFF_test, img_index = 0, hn = 0, hf = 1, nb_bins = 96, req_others=True)

In [None]:
# select a patch from the rendered image and calculate the feature distance
# modify from N3F official code, https://github.com/dichotomies/N3F
settings = load_settings()[scene]
factor = downsample
r, c = settings['rc']
extent = settings['sz']
r = int(r * 8 / factor)
c = int(c * 8 / factor)
extent = int(extent * 8 / factor)
img_w, img_h = LLFF_test.img_wh
embq, dir_q = calc_query_emb(emb_flower, r, c, extent, rgb=rgb_flower, vis = True)
dist = calc_feature_dist(embq, emb_flower)
plt.figure(figsize=(4,3))
plt.hist(dist.view(-1).cpu().numpy(), bins=20, density=True, alpha=0.5, label='Ditilled Triplanes') 
plt.show()
plt.close()   

In [None]:
# show example: use the threshold dis_thr to get the mask of the selected region
# modify from N3F official code, https://github.com/dichotomies/N3F
rgb_j_fg, emb_j_fg, depth_j_fg, mask_j_fg = render_img(nerf_model = nerf_model, device= device, \
                            Dataset = LLFF_test, img_index = 0, hn = 0, hf = 1, nb_bins = 96,\
                            req_others = True, embq=embq, dis_thr=settings['thr'] + settings['margin'], 
                            foreground=False, show_selection=True)
plt.figure(figsize=(20, 8))
plt.subplot(1, 3, 1)
plt.imshow(rgb_j_fg)
plt.subplot(1, 3, 2)
plt.imshow(depth_j_fg, cmap = 'gray')
plt.subplot(1, 3, 3)
plt.imshow(mask_j_fg, cmap = 'gray')
plt.show()

In [None]:
# if you would like to see selection rendering results (videos) you can set show_selection to True:
if show_selection:
    foldername = time.strftime("%Y%m%d-%H%M%S")
    if not os.path.exists('./render_results'):
        os.mkdir('./render_results')
    folderpath = os.path.join('./render_results', foldername)
    os.mkdir(folderpath)
    
    name = 'wo_selection'
    novel_views_path = novel_views_LLFF(folderpath, name, nerf_model, device, LLFF_test, hn = 0, hf = 1, nb_bins = 96,
                     req_others = True)
    creat_video(novel_views_path, folderpath, name, req_others=False)
    name = 'w_selection'
    novel_views_path = novel_views_LLFF(folderpath, name, nerf_model, device, LLFF_test, hn = 0, hf = 1, nb_bins = 96,
                     req_others = True, dis_thr = settings['thr'] + settings['margin'], 
                     embq = embq, dist_less=False, show_selection = True)
    creat_video(novel_views_path, folderpath, name, req_others=False)

# 2. Preview 3D-aware image context Edit

## 2.1 Prepare 3D-aware image context

In [None]:
# create a folder to save editing results
foldername = time.strftime("%Y%m%d-%H%M%S")
savepath_father = os.path.join("./editing_results", foldername)
mkdir_ifnoexit("./editing_results")
mkdir_ifnoexit(savepath_father)

In [None]:
# get multiple contexts
W, H = LLFF_training.img_wh
n_imgs = len(LLFF_training) // W // H
need_imgs = 8 # as discript in the paper, we need 8 rendered images to edit.
max_iter = (need_imgs // 2 - 1) 
indices = np.random.choice(n_imgs, need_imgs, replace=False)
repeat_indices = indices[:2]
modified_indices = np.zeros(max_iter * 4, dtype=np.int16)
modified_indices[:4] = indices[:4]
for i in range(1, max_iter):
    modified_indices[4*i: 4*i+2] = repeat_indices
    modified_indices[4*i+2: 4*i+4] = indices[4+(2*(i-1)):4+(2*(i))]
indices = modified_indices.astype(np.int16)

mask_4_s_list = []
gt_4_s_list = []
rgb_4_s_list = []
depth_4_s_list = []
ray_4_s_list = []
gt_4_list = []
rgb_4_list = []
depth_4_list = []
ray_4_list = []

for edit_iter in range(max_iter):
    gt_4, rgb_4, depth_4, mask_4, ray_4 = get_comb_img(nerf_model, LLFF_training, device, indices[edit_iter * 4 : (edit_iter + 1) * 4])
    gt_4_s, rgb_4_s, depth_4_s, mask_4_s, ray_4_s = get_comb_img(nerf_model, LLFF_training, device, indices[edit_iter * 4 : (edit_iter + 1) * 4], 
                                                    embq = embq, dis_thr = settings['thr'] + settings['margin'], show_selection = True)
    
    mask_4_s_list.append(mask_4_s)
    gt_4_s_list.append(gt_4_s)
    rgb_4_s_list.append(rgb_4_s)
    depth_4_s_list.append(depth_4_s)
    ray_4_s_list.append(ray_4_s)
    gt_4_list.append(gt_4)
    rgb_4_list.append(rgb_4)
    depth_4_list.append(depth_4)
    ray_4_list.append(ray_4)

## 2.2 Load Generative Image Editing Method - ContolNet

In [None]:
strength = 0.9 # Indicates extent to edit the object in NeRF
wo_edge = False # Indicates extent to edit the object in NeRF, True edit more [W/ Canny edge information the object change 
prompt = "purple flower made of origami paper, high quality, photo realistic" # have fun to change it, we also provide some examples as follows
#--------------------------------------------------------------- prompt examples ---------------------------------------------------------------
# prompt = "flower made of 24k gold, shiny, high quality, photo realistic"
# prompt = "pink zinnia flower made of origami paper, high quality, photo realistic"
# prompt = "flower made of paper, high quality, photo realistic"
# prompt = "pink flower made of origami paper, high quality, photo realistic"
# prompt = "orange zinnia flower with green pistil"
# prompt = "purple flower made of origami paper, high quality, photo realistic"
# prompt = "fire dragon, glowing lava"
# prompt = "paper dragon"
# prompt = "glowing ice dragon, blue"
# prompt = "snow weather"
# prompt = "vibrant green fern"
# prompt = "horns of a dragon made of 24k shiny gold"

In [None]:
# load the generative image editor [ControlNet]
controlnet = [
    ControlNetModel.from_pretrained(
        "lllyasviel/control_v11p_sd15_inpaint",
        torch_dtype=torch.float16,
    ),
    ControlNetModel.from_pretrained(
        "lllyasviel/control_v11f1p_sd15_depth",
        torch_dtype=torch.float16,
    ),
]
if wo_edge == False:
    controlnet.append(
        ControlNetModel.from_pretrained(
            "lllyasviel/control_v11p_sd15_canny",
            torch_dtype=torch.float16,
        ),
    )

pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None,
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
generator = torch.Generator(device=device).manual_seed(1)

In [None]:
# show editing preview, if you are not satisfied with the result, you can change the prompt and run the following code again
rgb_0 = rgb_4_list[0]
mask_0 = mask_4_s_list[0]
depth_0 = depth_4_list[0]
ray_0 = ray_4_list[0]

canny_0 = Image.fromarray(np.uint8(rgb2canny(cp(rgb_0))))
if wo_edge == True:
    control_image = [make_inpaint_condition(cp(rgb_0), cp(mask_0)), cp(depth_0)]
else:
    control_image = [make_inpaint_condition(cp(rgb_0), cp(mask_0)), cp(depth_0), canny_0]

edited_0 = pipe(
    prompt,
    num_inference_steps=20,
    generator=generator,
    eta=1.0,
    strength=strength,
    image=cp(rgb_0),
    mask_image=cp(mask_0),
    control_image=control_image,
).images[0]
edited_0 = edited_0.resize((rgb_0.shape[1],rgb_0.shape[0]))

In [None]:
edited_0

# 3. Edit object in NeRF

In [None]:
# Edit object parameters 
num_iters = 3 # editing iterations, please make sure the need_imgs is larger or equal than 2 * num_iters + 2
lambda_depth = 1e-4 # depth loss weight, larger value reduce more floaters but make the density more sparse
mask_trick = False # Reduce the floaters out of the object, True reduce floaters but increase the time cost (optional). In GTX 4090 Ti GPU, 59s -> 68s 
lr=2e-4 # editing training learning rate
nb_epochs = 1 # editing training epochs
record = True # False spend less time, record the editing history

In [None]:
T1 = time.time()
if mask_trick == True:
    history, edit_nerfmodel = context_iter(nerf_model, savepath_father, LLFF_training, indices, edited_0, canny_0, rgb_4_list, 
                                        mask_4_s_list, depth_4_list, ray_4_list, pipe, prompt, generator, H, W,
                                        depth_estimator, record=record, lr=lr, strength=strength, num_iters=num_iters, 
                                        nb_epochs=nb_epochs, lambda_depth = lambda_depth, mask_trick = mask_trick, original_nerf=nerf_model,
                                       woedge = wo_edge, embq=embq, dis_thr=settings['thr'] + settings['margin'], foreground=True)
else:
    history, edit_nerfmodel = context_iter(nerf_model, savepath_father, LLFF_training, indices, edited_0, canny_0, rgb_4_list, 
                                        mask_4_s_list, depth_4_list, ray_4_list, pipe, prompt, generator, H, W,
                                        depth_estimator, record=record, lr=lr, strength=strength, num_iters=num_iters, 
                                        nb_epochs=nb_epochs, lambda_depth = lambda_depth, woedge = wo_edge)
T2 = time.time()
print('Editing processing cost seconds: ', T2 - T1)

In [None]:
if record == True:
    # show and save edited images, nerf novel views during iterations
    plt.figure(figsize=(20, 10))
    plt.title('photo editing with iterations')
    plt.axis('off')
    for i in range(num_iters):
        plt.subplot(1, num_iters, i + 1)
        plt.imshow(history['photo'][i])
        plt.axis('off')
        plt.title(str(i + 1))
    # get a time
    plt.savefig(os.path.join(savepath_father, prompt + '_photo_iter.png'))
    # plt.close()

    plt.figure(figsize=(20, 8))
    plt.title('nerf novel view with iterations')
    plt.axis('off')
    for i in range(num_iters):
        plt.subplot(1, num_iters, i + 1)
        plt.imshow(history['rendering'][i][1])
        plt.axis('off')
        plt.title(str(i + 1))
    plt.savefig(os.path.join(savepath_father, prompt + '_nerf_iter.png'))


In [None]:
# render novel view results
method = '_W_context'
if record == True:
    save_history_fig(savepath_father, history, prompt, method, num_iters)

novel_views_path = novel_views_LLFF(savepath_father, method, edit_nerfmodel, device, LLFF_test, hn = 0, hf = 1, nb_bins = 96,
                     req_others = False)
creat_video(novel_views_path, savepath_father, method, req_others=False)