In [1]:
import sys,os,imageio,lpips
root = '/mnt/new_disk2/anpei/code/MVS-NeRF'
os.chdir(root)
sys.path.append(root)

from opt import config_parser
from data import dataset_dict
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


# models
from models import *
from renderer import *
from data.ray_utils import get_rays

from tqdm import tqdm


from skimage.metrics import structural_similarity

# pytorch-lightning
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningModule, Trainer, loggers


from data.ray_utils import ray_marcher

%load_ext autoreload
%autoreload 2

torch.cuda.set_device(2)
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [2]:
def decode_batch(batch):
    rays = batch['rays']  # (B, 8)
    rgbs = batch['rgbs']  # (B, 3)
    return rays, rgbs

def unpreprocess(data, shape=(1,1,3,1,1)):
    # to unnormalize image for visualization
    # data N V C H W
    device = data.device
    mean = torch.tensor([-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225]).view(*shape).to(device)
    std = torch.tensor([1 / 0.229, 1 / 0.224, 1 / 0.225]).view(*shape).to(device)

    return (data - mean) / std

loss_fn_vgg = lpips.LPIPS(net='vgg') 
mse2psnr = lambda x : -10. * np.log(x) / np.log(10.)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: /home/anpei/anaconda3/lib/python3.7/site-packages/lpips/weights/v0.1/vgg.pth


# llff no fine tuning

## rendering novel views with nearest 3 source views

In [None]:
psnr_all,ssim_all,LPIPS_vgg_all = [],[],[]
for i_scene, scene in enumerate(['room']):#'fortress','flower','orchids', 'room','leaves','horns','trex','fern'
    psnr,ssim,LPIPS_vgg = [],[],[]
    cmd = f'--datadir /mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/{scene}  \
     --dataset_name llff \
     --ckpt ./ckpts/mvsnerf-v0.tar \
     --net_type v0 --netwidth 128 --netdepth 6'

    args = config_parser(cmd.split())
    args.use_viewdirs = True

    args.N_samples = 128
    args.feat_dim = 8+3*4 

    # create models
    if 0==i_scene:
        render_kwargs_train, render_kwargs_test, start, grad_vars = create_nerf_mvs(args, use_mvs=True, dir_embedder=False, pts_embedder=True)
        filter_keys(render_kwargs_train)

        MVSNet = render_kwargs_train['network_mvs']
        render_kwargs_train.pop('network_mvs')


    datadir = args.datadir
    datatype = 'val'
    pad = 24
    args.chunk = 5120


    print('============> rendering dataset <===================')
    dataset = dataset_dict[args.dataset_name](args, split=datatype)
    val_idx = dataset.img_idx
    
    save_as_image = True
    save_dir = f'results/test3'
    os.makedirs(save_dir, exist_ok=True)
    MVSNet.train()
    MVSNet = MVSNet.cuda()
    
    with torch.no_grad():
        
        try:
            tqdm._instances.clear() 
        except Exception:     
            pass
        
        for i, batch in enumerate(tqdm(dataset)):
            torch.cuda.empty_cache()
            
            rays, img = decode_batch(batch)
            rays = rays.squeeze().to(device)  # (H*W, 3)
            img = img.squeeze().cpu().numpy()  # (H, W, 3)
        
            # find nearest image idx
            positions = dataset.poses[:,:3,3]
            dis = np.sum(np.abs(positions - dataset.poses[val_idx[i],:3,3]), axis=-1)
            pair_idx = np.argsort(dis)[1:4]
            
            imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(pair_idx=pair_idx,device=device)
            volume_feature, _, _ = MVSNet(imgs_source, proj_mats, near_far_source, pad=pad)
            imgs_source = unpreprocess(imgs_source)
        
            N_rays_all = rays.shape[0]
            rgb_rays, depth_rays_preds = [],[]
            for chunk_idx in range(N_rays_all//args.chunk + int(N_rays_all%args.chunk>0)):

                xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(rays[chunk_idx*args.chunk:(chunk_idx+1)*args.chunk],
                                                    N_samples=args.N_samples)

                # Converting world coordinate to ndc coordinate
                H, W = img.shape[:2]
                inv_scale = torch.tensor([W - 1, H - 1]).to(device)
                w2c_ref, intrinsic_ref = pose_source['w2cs'][0], pose_source['intrinsics'][0].clone()
                xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale,
                                             near=near_far_source[0], far=near_far_source[1], pad=pad*args.imgScale_test)


                # rendering
                rgb, disp, acc, depth_pred, alpha, extras = rendering(args, pose_source, xyz_coarse_sampled,
                                                                       xyz_NDC, z_vals, rays_o, rays_d,
                                                                       volume_feature,imgs_source, **render_kwargs_train)
    
                
                rgb, depth_pred = torch.clamp(rgb.cpu(),0,1.0).numpy(), depth_pred.cpu().numpy()
                rgb_rays.append(rgb)
                depth_rays_preds.append(depth_pred)


            depth_rays_preds = np.concatenate(depth_rays_preds).reshape(H, W)
            depth_rays_preds, _ = visualize_depth_numpy(depth_rays_preds, near_far_source)
            
            rgb_rays = np.concatenate(rgb_rays).reshape(H, W, 3)
            img_vis = np.concatenate((img*255,rgb_rays*255,depth_rays_preds),axis=1)
            
            if save_as_image:
                imageio.imwrite(f'{save_dir}/{scene}_{val_idx[i]:03d}.png', img_vis.astype('uint8'))
            else:
                rgbs.append(img_vis.astype('uint8'))
                
            # quantity
            # center crop 0.8 ratio
            H_crop, W_crop = np.array(rgb_rays.shape[:2])//10
            img = img[H_crop:-H_crop,W_crop:-W_crop]
            rgb_rays = rgb_rays[H_crop:-H_crop,W_crop:-W_crop]
            
            psnr.append( mse2psnr(np.mean((rgb_rays-img)**2)))
            ssim.append( structural_similarity(rgb_rays, img, multichannel=True))
            
            img_tensor = torch.from_numpy(rgb_rays)[None].permute(0,3,1,2).float()*2-1.0 # image should be RGB, IMPORTANT: normalized to [-1,1]
            img_gt_tensor = torch.from_numpy(img)[None].permute(0,3,1,2).float()*2-1.0
            LPIPS_vgg.append( loss_fn_vgg(img_tensor, img_gt_tensor).item())
            
        print(f'=====> scene: {scene} mean psnr {np.mean(psnr)} ssim: {np.mean(ssim)} lpips: {np.mean(LPIPS_vgg)}')   
        psnr_all.append(psnr);ssim_all.append(ssim);LPIPS_vgg_all.append(LPIPS_vgg)
    
    if not save_as_image:
        imageio.mimwrite(f'{save_dir}/{scene}_spiral.mp4', np.stack(rgbs), fps=20, quality=10)
print(f'=====> all mean psnr {np.mean(psnr_all)} ssim: {np.mean(ssim_all)} lpips: {np.mean(LPIPS_vgg_all)}') 

## rendering novel views with fixed 3 source views

In [28]:
psnr_all,ssim_all,LPIPS_vgg_all = [],[],[]
for i_scene, scene in enumerate(['room']):#'flower','orchids', 'room','leaves','fern','horns','trex','fortress'
    psnr,ssim,LPIPS_vgg = [],[],[]
    cmd = f'--datadir /mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/{scene}  \
     --dataset_name llff \
     --ckpt ./ckpts/mvsnerf-v0.tar  \
     --net_type v0 --netwidth 128 --netdepth 6'

    args = config_parser(cmd.split())
    args.use_viewdirs = True

    args.N_samples = 128
    args.feat_dim =  8+3*4

    # create models
    if 0==i_scene:
        render_kwargs_train, render_kwargs_test, start, grad_vars = create_nerf_mvs(args, use_mvs=True, dir_embedder=False, pts_embedder=True)
        filter_keys(render_kwargs_train)

        MVSNet = render_kwargs_train['network_mvs']
        render_kwargs_train.pop('network_mvs')


    datadir = args.datadir
    datatype = 'val'
    pad = 24
    args.chunk = 5120


    print('============> rendering dataset <===================')
    dataset = dataset_dict[args.dataset_name](args, split=datatype)
    val_idx = dataset.img_idx
    
    save_as_image = True
    save_dir = f'results/test3'
    os.makedirs(save_dir, exist_ok=True)
    MVSNet.train()
    MVSNet = MVSNet.cuda()
    
    with torch.no_grad():

        imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)
        volume_feature, _, _ = MVSNet(imgs_source, proj_mats, near_far_source, pad=pad)
        imgs_source = unpreprocess(imgs_source)

        try:
            tqdm._instances.clear() 
        except Exception:     
            pass
        
        for i, batch in enumerate(tqdm(dataset)):
            torch.cuda.empty_cache()
            
            rays, img = decode_batch(batch)
            rays = rays.squeeze().to(device)  # (H*W, 3)
            img = img.squeeze().cpu().numpy()  # (H, W, 3)
        
            N_rays_all = rays.shape[0]
            rgb_rays, depth_rays_preds = [],[]
            for chunk_idx in range(N_rays_all//args.chunk + int(N_rays_all%args.chunk>0)):

                xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(rays[chunk_idx*args.chunk:(chunk_idx+1)*args.chunk],
                                                    N_samples=args.N_samples)

                # Converting world coordinate to ndc coordinate
                H, W = img.shape[:2]
                inv_scale = torch.tensor([W - 1, H - 1]).to(device)
                w2c_ref, intrinsic_ref = pose_source['w2cs'][0], pose_source['intrinsics'][0].clone()
                xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale,
                                             near=near_far_source[0], far=near_far_source[1], pad=pad*args.imgScale_test)


                # rendering
                rgb, disp, acc, depth_pred, alpha, extras = rendering(args, pose_source, xyz_coarse_sampled,
                                                                       xyz_NDC, z_vals, rays_o, rays_d,
                                                                       volume_feature,imgs_source, **render_kwargs_train)
    
                
                rgb, depth_pred = torch.clamp(rgb.cpu(),0,1.0).numpy(), depth_pred.cpu().numpy()
                rgb_rays.append(rgb)
                depth_rays_preds.append(depth_pred)

            
            depth_rays_preds = np.concatenate(depth_rays_preds).reshape(H, W)
            depth_rays_preds, _ = visualize_depth_numpy(depth_rays_preds, near_far_source)
            
            rgb_rays = np.concatenate(rgb_rays).reshape(H, W, 3)
            img_vis = np.concatenate((img*255,rgb_rays*255,depth_rays_preds),axis=1)
            
            if save_as_image:
                imageio.imwrite(f'{save_dir}/{scene}_{val_idx[i]:03d}.png', img_vis.astype('uint8'))
            else:
                rgbs.append(img_vis.astype('uint8'))
                
            # quantity
            # center crop 0.8 ratio
            H_crop, W_crop = np.array(rgb_rays.shape[:2])//10
            img = img[H_crop:-H_crop,W_crop:-W_crop]
            rgb_rays = rgb_rays[H_crop:-H_crop,W_crop:-W_crop]
            
            psnr.append( mse2psnr(np.mean((rgb_rays-img)**2)))
            ssim.append( structural_similarity(rgb_rays, img, multichannel=True))
            
            img_tensor = torch.from_numpy(rgb_rays)[None].permute(0,3,1,2).float()*2-1.0 # image should be RGB, IMPORTANT: normalized to [-1,1]
            img_gt_tensor = torch.from_numpy(img)[None].permute(0,3,1,2).float()*2-1.0
            LPIPS_vgg.append( loss_fn_vgg(img_tensor, img_gt_tensor).item())
        
        print(f'=====> scene: {scene} mean psnr {np.mean(psnr)} ssim: {np.mean(ssim)} lpips: {np.mean(LPIPS_vgg)}')   
        psnr_all.append(psnr);ssim_all.append(ssim);LPIPS_vgg_all.append(LPIPS_vgg)
        
    if not save_as_image:
        imageio.mimwrite(f'{save_dir}/{scene}_spiral.mp4', np.stack(rgbs), fps=20, quality=10)
print(f'=====> all mean psnr {np.mean(psnr_all)} ssim: {np.mean(ssim_all)} lpips: {np.mean(LPIPS_vgg_all)}') 

Found ckpts ['/mnt/new_disk2/anpei/code/MVS-NeRF/runs_new/mvs-nerf-net-v0-no-color-skip/ckpts/139999.tar']
Reloading from /mnt/new_disk2/anpei/code/MVS-NeRF/runs_new/mvs-nerf-net-v0-no-color-skip/ckpts/139999.tar
41 41 /mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/room
===> valing index: [35 15 38 21]
====> ref idx: [14 39 34]


 25%|██▌       | 1/4 [00:49<02:29, 50.00s/it]
100%|██████████| 4/4 [00:53<00:00, 13.35s/it]

=====> scene: room mean psnr 23.934366149970703 ssim: 0.9214743734999353 lpips: 0.2191867232322693
=====> all mean psnr 23.934366149970703 ssim: 0.9214743734999353 lpips: 0.2191867232322693





# nerf no fine tuning

## rendering novel views with nearest 3 views

In [5]:
psnr_all,ssim_all,LPIPS_vgg_all = [],[],[]
for i_scene, scene in enumerate(['lego']):#'ship','mic','chair','lego','drums','ficus','materials','hotdog'
    psnr,ssim,LPIPS_vgg = [],[],[]
    cmd = f'--datadir /mnt/new_disk_2/anpei/Dataset/nerf_synthetic/{scene}  \
     --dataset_name blender --white_bkgd \
    --ckpt ./ckpts//mvsnerf-v0.tar'

    args = config_parser(cmd.split())
    args.use_viewdirs = True

    args.N_samples = 128
    args.feat_dim =  8+3*4

    # create models
    if 0==i_scene:
        render_kwargs_train, render_kwargs_test, start, grad_vars = create_nerf_mvs(args, use_mvs=True, dir_embedder=False, pts_embedder=True)
        filter_keys(render_kwargs_train)

        MVSNet = render_kwargs_train['network_mvs']
        render_kwargs_train.pop('network_mvs')


    datadir = args.datadir
    datatype = 'train'
    pad = 0
    args.chunk = 5120


    print('============> rendering dataset <===================')
    dataset_train = dataset_dict[args.dataset_name](args, split='train')
    dataset_val = dataset_dict[args.dataset_name](args, split='val')
    val_idx = dataset_val.img_idx
    
    save_as_image = True
    save_dir = f'results/test3'
    os.makedirs(save_dir, exist_ok=True)
    MVSNet.train()
    MVSNet = MVSNet.cuda()
    
    with torch.no_grad():

        try:
            tqdm._instances.clear() 
        except Exception:     
            pass
        
        for i, batch in enumerate(tqdm(dataset_val)):
            torch.cuda.empty_cache()

            
            rays, img = decode_batch(batch)
            rays = rays.squeeze().to(device)  # (H*W, 3)
            img = img.squeeze().cpu().numpy()  # (H, W, 3)
        
            # find nearest image idx from training views
            positions = dataset_train.poses[:,:3,3]
            dis = np.sum(np.abs(positions - dataset_val.poses[[i],:3,3]), axis=-1)
            pair_idx = np.argsort(dis)[:3]
            pair_idx = [dataset_train.img_idx[item] for item in pair_idx]
            
            imgs_source, proj_mats, near_far_source, pose_source = dataset_train.read_source_views(pair_idx=pair_idx,device=device)
            volume_feature, _, _ = MVSNet(imgs_source, proj_mats, near_far_source, pad=pad)
            imgs_source = unpreprocess(imgs_source)
        
            N_rays_all = rays.shape[0]
            rgb_rays, depth_rays_preds = [],[]
            for chunk_idx in range(N_rays_all//args.chunk + int(N_rays_all%args.chunk>0)):

                xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(rays[chunk_idx*args.chunk:(chunk_idx+1)*args.chunk],
                                                    N_samples=args.N_samples)

                # Converting world coordinate to ndc coordinate
                H, W = img.shape[:2]
                inv_scale = torch.tensor([W - 1, H - 1]).to(device)
                w2c_ref, intrinsic_ref = pose_source['w2cs'][0], pose_source['intrinsics'][0].clone()
                intrinsic_ref[:2] *= args.imgScale_test/args.imgScale_train
                xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale,
                                             near=near_far_source[0], far=near_far_source[1], pad=pad*args.imgScale_test)


                # rendering
                rgb, disp, acc, depth_pred, alpha, extras = rendering(args, pose_source, xyz_coarse_sampled,
                                                                       xyz_NDC, z_vals, rays_o, rays_d,
                                                                       volume_feature,imgs_source, **render_kwargs_train)
    
                
                rgb, depth_pred = torch.clamp(rgb.cpu(),0,1.0).numpy(), depth_pred.cpu().numpy()
                rgb_rays.append(rgb)
                depth_rays_preds.append(depth_pred)

            
            depth_rays_preds = np.concatenate(depth_rays_preds).reshape(H, W)
            depth_rays_preds, _ = visualize_depth_numpy(depth_rays_preds, near_far_source)
            
            rgb_rays = np.concatenate(rgb_rays).reshape(H, W, 3)
            img_vis = np.concatenate((img*255,rgb_rays*255,depth_rays_preds),axis=1)
            
            img_vis = np.concatenate((torch.cat(torch.split(imgs_source*255, [1,1,1], dim=1),-1).squeeze().permute(1,2,0).cpu().numpy(),img_vis),axis=1)
            
            if save_as_image:
                imageio.imwrite(f'{save_dir}/{scene}_{val_idx[i]:03d}.png', img_vis.astype('uint8'))
            else:
                rgbs.append(img_vis.astype('uint8'))
                
            # quantity
            psnr.append( mse2psnr(np.mean((rgb_rays-img)**2)))
            ssim.append( structural_similarity(rgb_rays, img, multichannel=True))
            
            img_tensor = torch.from_numpy(rgb_rays)[None].permute(0,3,1,2).float()*2-1.0 # image should be RGB, IMPORTANT: normalized to [-1,1]
            img_gt_tensor = torch.from_numpy(img)[None].permute(0,3,1,2).float()*2-1.0
            LPIPS_vgg.append( loss_fn_vgg(img_tensor, img_gt_tensor).item())

        print(f'=====> scene: {scene} mean psnr {np.mean(psnr)} ssim: {np.mean(ssim)} lpips: {np.mean(LPIPS_vgg)}')   
        psnr_all.append(psnr);ssim_all.append(ssim);LPIPS_vgg_all.append(LPIPS_vgg)

    if not save_as_image:
        imageio.mimwrite(f'{save_dir}/{scene}_spiral.mp4', np.stack(rgbs), fps=20, quality=10)

print(f'=====> all mean psnr {np.mean(psnr_all)} ssim: {np.mean(ssim_all)} lpips: {np.mean(LPIPS_vgg_all)}') 

Found ckpts ['/mnt/new_disk_2/anpei/code/MVS-NeRF/runs_new/mvs-nerf-color-fusion-attention-128-sofmax-angle/ckpts//latest.tar']
Reloading from /mnt/new_disk_2/anpei/code/MVS-NeRF/runs_new/mvs-nerf-color-fusion-attention-128-sofmax-angle/ckpts//latest.tar
Not ndc!
100 [6, 43, 33, 13, 17, 19, 20, 25, 30, 37, 46, 48, 49, 55, 59, 65]
===> training index: [6, 43, 33, 13, 17, 19, 20, 25, 30, 37, 46, 48, 49, 55, 59, 65]
100 [28, 63, 70, 18]
===> valing index: [28, 63, 70, 18]


100%|██████████| 4/4 [01:12<00:00, 18.20s/it]

=====> scene: lego mean psnr 25.38069234029279 ssim: 0.9428600022362317 lpips: 0.1353445127606392
=====> all mean psnr 25.38069234029279 ssim: 0.9428600022362317 lpips: 0.1353445127606392





## rendering novel views with fixed 3 source views

In [6]:
psnr_all,ssim_all,LPIPS_vgg_all = [],[],[]
for i_scene, scene in enumerate(['lego']):#'ship','mic','chair','lego','drums','ficus','materials','hotdog'
    psnr,ssim,LPIPS_vgg = [],[],[]
    cmd = f'--datadir /mnt/new_disk_2/anpei/Dataset/nerf_synthetic/{scene}  \
     --dataset_name blender --white_bkgd \
    --ckpt ./ckpts//mvsnerf-v0.tar'

    args = config_parser(cmd.split())
    args.use_viewdirs = True

    args.N_samples = 128
    args.feat_dim = 8+3*4

    # create models
    if 0==i_scene:
        render_kwargs_train, render_kwargs_test, start, grad_vars = create_nerf_mvs(args, use_mvs=True, dir_embedder=False, pts_embedder=True)
        filter_keys(render_kwargs_train)

        MVSNet = render_kwargs_train['network_mvs']
        render_kwargs_train.pop('network_mvs')


    datadir = args.datadir
    datatype = 'val'
    pad = 16
    args.chunk = 5120


    print('============> rendering dataset <===================')
    dataset = dataset_dict[args.dataset_name](args, split=datatype)
    val_idx = dataset.img_idx
    
    save_as_image = True
    save_dir = f'results/test3'
    os.makedirs(save_dir, exist_ok=True)
    MVSNet.train()
    MVSNet = MVSNet.cuda()
    
    with torch.no_grad():

        imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)
        volume_feature, _, _ = MVSNet(imgs_source, proj_mats, near_far_source, pad=pad)
        imgs_source = unpreprocess(imgs_source)

        try:
            tqdm._instances.clear() 
        except Exception:     
            pass
        
        for i, batch in enumerate(tqdm(dataset)):
            torch.cuda.empty_cache()
            
            rays, img = decode_batch(batch)
            rays = rays.squeeze().to(device)  # (H*W, 3)
            img = img.squeeze().cpu().numpy()  # (H, W, 3)
        
            N_rays_all = rays.shape[0]
            rgb_rays, depth_rays_preds = [],[]
            for chunk_idx in range(N_rays_all//args.chunk + int(N_rays_all%args.chunk>0)):

                xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(rays[chunk_idx*args.chunk:(chunk_idx+1)*args.chunk],
                                                    N_samples=args.N_samples)

                # Converting world coordinate to ndc coordinate
                H, W = img.shape[:2]
                inv_scale = torch.tensor([W - 1, H - 1]).to(device)
                w2c_ref, intrinsic_ref = pose_source['w2cs'][0], pose_source['intrinsics'][0].clone()
                xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale,
                                             near=near_far_source[0], far=near_far_source[1], pad=pad*args.imgScale_test)


                # rendering
                rgb, disp, acc, depth_pred, alpha, extras = rendering(args, pose_source, xyz_coarse_sampled,
                                                                       xyz_NDC, z_vals, rays_o, rays_d,
                                                                       volume_feature,imgs_source, **render_kwargs_train)
    
                
                rgb, depth_pred = torch.clamp(rgb.cpu(),0,1.0).numpy(), depth_pred.cpu().numpy()
                rgb_rays.append(rgb)
                depth_rays_preds.append(depth_pred)

            
            depth_rays_preds = np.concatenate(depth_rays_preds).reshape(H, W)
            depth_rays_preds, _ = visualize_depth_numpy(depth_rays_preds, near_far_source)
            
            rgb_rays = np.concatenate(rgb_rays).reshape(H, W, 3)
            img_vis = np.concatenate((img*255,rgb_rays*255,depth_rays_preds),axis=1)
            
            if save_as_image:
                imageio.imwrite(f'{save_dir}/{scene}_{val_idx[i]:03d}.png', img_vis.astype('uint8'))
            else:
                rgbs.append(img_vis.astype('uint8'))
                
            # quantity
            psnr.append( mse2psnr(np.mean((rgb_rays-img)**2)))
            ssim.append( structural_similarity(rgb_rays, img, multichannel=True))
            
            img_tensor = torch.from_numpy(rgb_rays)[None].permute(0,3,1,2).float()*2-1.0 # image should be RGB, IMPORTANT: normalized to [-1,1]
            img_gt_tensor = torch.from_numpy(img)[None].permute(0,3,1,2).float()*2-1.0
            LPIPS_vgg.append( loss_fn_vgg(img_tensor, img_gt_tensor).item())

        print(f'=====> scene: {scene} mean psnr {np.mean(psnr)} ssim: {np.mean(ssim)} lpips: {np.mean(LPIPS_vgg)}')   
        psnr_all.append(psnr);ssim_all.append(ssim);LPIPS_vgg_all.append(LPIPS_vgg)

    if not save_as_image:
        imageio.mimwrite(f'{save_dir}/{scene}_depth_spiral.mp4', np.stack(depths_vis), fps=10, quality=10)
        imageio.mimwrite(f'{save_dir}/{scene}_spiral.mp4', np.stack(rgbs), fps=20, quality=10)
print(f'=====> all mean psnr {np.mean(psnr_all)} ssim: {np.mean(ssim_all)} lpips: {np.mean(LPIPS_vgg_all)}') 

Found ckpts ['/mnt/new_disk_2/anpei/code/MVS-NeRF/runs_new/mvs-nerf-color-fusion-attention-128-sofmax-angle/ckpts//latest.tar']
Reloading from /mnt/new_disk_2/anpei/code/MVS-NeRF/runs_new/mvs-nerf-color-fusion-attention-128-sofmax-angle/ckpts//latest.tar
Not ndc!
100 [28, 63, 70, 18]
===> valing index: [28, 63, 70, 18]
====> ref idx: [6, 43, 33]


100%|██████████| 4/4 [01:08<00:00, 17.25s/it]

=====> scene: lego mean psnr 17.98688916232244 ssim: 0.8662780820428587 lpips: 0.26200321689248085
=====> all mean psnr 17.98688916232244 ssim: 0.8662780820428587 lpips: 0.26200321689248085





# DTU no fine tuning

## rendering novel views with nearest 3 views

In [3]:
psnr_all,ssim_all,LPIPS_vgg_all = [],[],[]
for i_scene, scene in enumerate([1]):#,8,21,103,114
    psnr,ssim,LPIPS_vgg = [],[],[]
    cmd = f'--datadir /mnt/data/new_disk/sungx/data/mvs_dataset/DTU/mvs_training/dtu/scan{scene}  \
     --dataset_name dtu_ft  \
    --ckpt ./ckpts//mvsnerf-v0.tar'

    args = config_parser(cmd.split())
    args.use_viewdirs = True

    args.N_samples = 128
    args.feat_dim =  8+3*4

    # create models
    if 0==i_scene:
        render_kwargs_train, render_kwargs_test, start, grad_vars = create_nerf_mvs(args, use_mvs=True, dir_embedder=False, pts_embedder=True)
        filter_keys(render_kwargs_train)

        MVSNet = render_kwargs_train['network_mvs']
        render_kwargs_train.pop('network_mvs')


    datadir = args.datadir
    datatype = 'train'
    pad = 24
    args.chunk = 5120


    print('============> rendering dataset <===================')
    dataset_train = dataset_dict[args.dataset_name](args, split='train')
    dataset_val = dataset_dict[args.dataset_name](args, split='val')
    val_idx = dataset_val.img_idx
    
    save_as_image = True
    save_dir = f'results/test3'
    os.makedirs(save_dir, exist_ok=True)
    MVSNet.train()
    MVSNet = MVSNet.cuda()
    
    with torch.no_grad():

        try:
            tqdm._instances.clear() 
        except Exception:     
            pass
        
        for i, batch in enumerate(tqdm(dataset_val)):
            torch.cuda.empty_cache()
            
            rays, img = decode_batch(batch)
            rays = rays.squeeze().to(device)  # (H*W, 3)
            img = img.squeeze().cpu().numpy()  # (H, W, 3)
            depth = batch['depth'].squeeze().numpy()  # (H, W)
        
            # find nearest image idx from training views
            positions = dataset_train.poses[:,:3,3]
            dis = np.sum(np.abs(positions - dataset_val.poses[[i],:3,3]), axis=-1)
            pair_idx = np.argsort(dis)[:3]
            pair_idx = [dataset_train.img_idx[item] for item in pair_idx]
            
            imgs_source, proj_mats, near_far_source, pose_source = dataset_train.read_source_views(pair_idx=pair_idx,device=device)
            volume_feature, _, _ = MVSNet(imgs_source, proj_mats, near_far_source, pad=pad)
            imgs_source = unpreprocess(imgs_source)
        
            N_rays_all = rays.shape[0]
            rgb_rays, depth_rays_preds = [],[]
            for chunk_idx in range(N_rays_all//args.chunk + int(N_rays_all%args.chunk>0)):

                xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(rays[chunk_idx*args.chunk:(chunk_idx+1)*args.chunk],
                                                    N_samples=args.N_samples)

                # Converting world coordinate to ndc coordinate
                H, W = img.shape[:2]
                inv_scale = torch.tensor([W - 1, H - 1]).to(device)
                w2c_ref, intrinsic_ref = pose_source['w2cs'][0], pose_source['intrinsics'][0].clone()
                xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale,
                                             near=near_far_source[0], far=near_far_source[1], pad=pad*args.imgScale_test)


                # rendering
                rgb, disp, acc, depth_pred, alpha, extras = rendering(args, pose_source, xyz_coarse_sampled,
                                                                       xyz_NDC, z_vals, rays_o, rays_d,
                                                                       volume_feature,imgs_source, **render_kwargs_train)
    
                
                rgb, depth_pred = torch.clamp(rgb.cpu(),0,1.0).numpy(), depth_pred.cpu().numpy()
                rgb_rays.append(rgb)
                depth_rays_preds.append(depth_pred)

            
            depth_rays_preds = np.concatenate(depth_rays_preds).reshape(H, W)
            depth_rays_preds, _ = visualize_depth_numpy(depth_rays_preds, near_far_source)
            
            rgb_rays = np.concatenate(rgb_rays).reshape(H, W, 3)
            img_vis = np.concatenate((img*255,rgb_rays*255,depth_rays_preds),axis=1)
            
            if save_as_image:
                imageio.imwrite(f'{save_dir}/scan{scene}_{val_idx[i]:03d}.png', img_vis.astype('uint8'))
            else:
                rgbs.append(img_vis.astype('uint8'))
                
            # quantity
            # mask background since they are outside the far boundle
            mask = depth==0
            rgb_rays[mask],img[mask] = 0.0,0.0
            psnr.append( mse2psnr(np.mean((rgb_rays[~mask]-img[~mask])**2)))
            ssim.append( structural_similarity(rgb_rays, img, multichannel=True))
            
            img_tensor = torch.from_numpy(rgb_rays)[None].permute(0,3,1,2).float()*2-1.0 # image should be RGB, IMPORTANT: normalized to [-1,1]
            img_gt_tensor = torch.from_numpy(img)[None].permute(0,3,1,2).float()*2-1.0
            LPIPS_vgg.append( loss_fn_vgg(img_tensor, img_gt_tensor).item())

        print(f'=====> scene: {scene} mean psnr {np.mean(psnr)} ssim: {np.mean(ssim)} lpips: {np.mean(LPIPS_vgg)}')   
        psnr_all.append(psnr);ssim_all.append(ssim);LPIPS_vgg_all.append(LPIPS_vgg)

    if not save_as_image:
        imageio.mimwrite(f'{save_dir}/{scene}_spiral.mp4', np.stack(rgbs), fps=20, quality=10)

print(f'=====> all mean psnr {np.mean(psnr_all)} ssim: {np.mean(ssim_all)} lpips: {np.mean(LPIPS_vgg_all)}') 

Found ckpts ['./ckpts//mvsnerf-v0.tar']
Reloading from ./ckpts//mvsnerf-v0.tar
==> image down scale: 1.0
===> training index: [25, 21, 33, 22, 14, 15, 26, 30, 31, 35, 34, 43, 46, 29, 16, 36]
==> image down scale: 1.0
===> valing index: [32, 24, 23, 44]


100%|██████████| 4/4 [00:28<00:00,  7.02s/it]

=====> scene: 1 mean psnr 26.860625765567647 ssim: 0.936532216389585 lpips: 0.1556858941912651
=====> all mean psnr 26.860625765567647 ssim: 0.936532216389585 lpips: 0.1556858941912651





## rendering novel views with fixed 3 source views

In [11]:
psnr_all,ssim_all,LPIPS_vgg_all = [],[],[]
for i_scene, scene in enumerate([1]):#,8,21,103,114
    psnr,ssim,LPIPS_vgg = [],[],[]
    cmd = f'--datadir /mnt/data/new_disk/sungx/data/mvs_dataset/DTU/mvs_training/dtu/scan{scene}  \
    --dataset_name dtu_ft  \
    --ckpt ./ckpts//mvsnerf-v0.tar'

    args = config_parser(cmd.split())
    args.use_viewdirs = True

    args.N_samples = 128
    args.feat_dim = else 8+3*4

    # create models
    if 0==i_scene:
        render_kwargs_train, render_kwargs_test, start, grad_vars = create_nerf_mvs(args, use_mvs=True, dir_embedder=False, pts_embedder=True)
        filter_keys(render_kwargs_train)

        MVSNet = render_kwargs_train['network_mvs']
        render_kwargs_train.pop('network_mvs')


    datadir = args.datadir
    datatype = 'val'
    pad = 24
    args.chunk = 5120


    print('============> rendering dataset <===================')
    dataset = dataset_dict[args.dataset_name](args, split=datatype)
    val_idx = dataset.img_idx
    
    save_as_image = True
    save_dir = f'results/test3'
    os.makedirs(save_dir, exist_ok=True)
    MVSNet.train()
    MVSNet = MVSNet.cuda()
    
    with torch.no_grad():

        imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)
        volume_feature, _, _ = MVSNet(imgs_source, proj_mats, near_far_source, pad=pad)
        imgs_source = unpreprocess(imgs_source)

        try:
            tqdm._instances.clear() 
        except Exception:     
            pass
        
        for i, batch in enumerate(tqdm(dataset)):
            torch.cuda.empty_cache()
            
            rays, img = decode_batch(batch)
            rays = rays.squeeze().to(device)  # (H*W, 3)
            img = img.squeeze().cpu().numpy()  # (H, W, 3)
            depth = batch['depth'].squeeze().numpy()  # (H, W)
        
            N_rays_all = rays.shape[0]
            rgb_rays, depth_rays_preds = [],[]
            for chunk_idx in range(N_rays_all//args.chunk + int(N_rays_all%args.chunk>0)):

                xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(rays[chunk_idx*args.chunk:(chunk_idx+1)*args.chunk],
                                                    N_samples=args.N_samples)

                # Converting world coordinate to ndc coordinate
                H, W = img.shape[:2]
                inv_scale = torch.tensor([W - 1, H - 1]).to(device)
                w2c_ref, intrinsic_ref = pose_source['w2cs'][0], pose_source['intrinsics'][0].clone()
                xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale,
                                             near=near_far_source[0], far=near_far_source[1], pad=pad*args.imgScale_test)


                # rendering
                rgb, disp, acc, depth_pred, alpha, extras = rendering(args, pose_source, xyz_coarse_sampled,
                                                                       xyz_NDC, z_vals, rays_o, rays_d,
                                                                       volume_feature,imgs_source, **render_kwargs_train)
    
                
                rgb, depth_pred = torch.clamp(rgb.cpu(),0,1.0).numpy(), depth_pred.cpu().numpy()
                rgb_rays.append(rgb)
                depth_rays_preds.append(depth_pred)

            
            depth_rays_preds = np.concatenate(depth_rays_preds).reshape(H, W)
            depth_rays_preds, _ = visualize_depth_numpy(depth_rays_preds, near_far_source)
            
            rgb_rays = np.concatenate(rgb_rays).reshape(H, W, 3)
            img_vis = np.concatenate((img*255,rgb_rays*255,depth_rays_preds),axis=1)
            
            if save_as_image:
                imageio.imwrite(f'{save_dir}/scan{scene}_{val_idx[i]:03d}.png', img_vis.astype('uint8'))
            else:
                rgbs.append(img_vis.astype('uint8'))
                
            # quantity
            # mask background since they are outside the far boundle
            mask = depth==0
            rgb_rays[mask],img[mask] = 0.0,0.0
            psnr.append( mse2psnr(np.mean((rgb_rays[~mask]-img[~mask])**2)))
            ssim.append( structural_similarity(rgb_rays, img, multichannel=True))
            
            img_tensor = torch.from_numpy(rgb_rays)[None].permute(0,3,1,2).float()*2-1.0 # image should be RGB, IMPORTANT: normalized to [-1,1]
            img_gt_tensor = torch.from_numpy(img)[None].permute(0,3,1,2).float()*2-1.0
            LPIPS_vgg.append( loss_fn_vgg(img_tensor, img_gt_tensor).item())

        print(f'=====> scene: {scene} mean psnr {np.mean(psnr)} ssim: {np.mean(ssim)} lpips: {np.mean(LPIPS_vgg)}')   
        psnr_all.append(psnr);ssim_all.append(ssim);LPIPS_vgg_all.append(LPIPS_vgg)

    if not save_as_image:
        imageio.mimwrite(f'{save_dir}/{scene}_spiral.mp4', np.stack(rgbs), fps=20, quality=10)
print(f'=====> all mean psnr {np.mean(psnr_all)} ssim: {np.mean(ssim_all)} lpips: {np.mean(LPIPS_vgg_all)}') 

Found ckpts ['/mnt/new_disk_2/anpei/code/MVS-NeRF/runs_new/mvs-nerf-color-fusion-iccv-no-travel_all/ckpts//latest.tar']
Reloading from /mnt/new_disk_2/anpei/code/MVS-NeRF/runs_new/mvs-nerf-color-fusion-iccv-no-travel_all/ckpts//latest.tar
Not ndc!
==> image down scale: 1.0
===> valing index: [32, 24, 23, 44]
====> ref idx: [25, 21, 33]




  0%|          | 0/4 [00:00<?, ?it/s][A[A

 25%|██▌       | 1/4 [00:11<00:35, 11.84s/it][A[A

 50%|█████     | 2/4 [00:23<00:23, 11.80s/it][A[A

 75%|███████▌  | 3/4 [00:35<00:11, 11.79s/it][A[A

KeyboardInterrupt: 

# Pairs generation

In [1]:
import json,torch
import sys,os
import numpy as np
root = '/mnt/new_disk2/anpei/code/MVS-NeRF'
os.chdir(root)
sys.path.append(root)
pairs = torch.load('./configs/pairs.th')

# llff
root_dir = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/'
for scene in ['horns','leaves', 'room', 'fortress', 'trex', 'orchids','fern']:#
    poses_bounds = np.load(os.path.join(root_dir, scene, 'poses_bounds.npy'))  # (N_images, 17)
    poses = poses_bounds[:, :15].reshape(-1, 3, 5)  # (N_images, 3, 5)
    poses = np.concatenate([poses[..., 1:2], - poses[..., :1], poses[..., 2:4]], -1)

    ref_position = np.mean(poses[..., 3],axis=0, keepdims=True)
    dist = np.sum(np.abs(poses[..., 3] - ref_position), axis=-1)
    pair_idx = np.argsort(dist)[:20]
#     pair_idx = torch.randperm(len(poses))[:20].tolist()

    pairs[f'{scene}_test'] = pair_idx[::6]
    pairs[f'{scene}_val'] = pair_idx[::6]
    pairs[f'{scene}_train'] = np.delete(pair_idx, range(0,20,6))


# nerf
center_view = {'lego':6,'ship':2,'drums':22,'mic':20,'chair':8,'materials':36,'hotdog':61,'ficus':8}
blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
for i, scene in enumerate(['ship','drums','mic','chair','materials','lego','hotdog','ficus']):
    with open(f'/mnt/new_disk_2/anpei/Dataset/nerf_synthetic/{scene}/transforms_train.json', 'r') as f:
        meta = json.load(f)
    
    poses = []
    ref_idx = torch.randint(0,len(meta['frames']),(1,))
    ref_idx = center_view[f'{scene}']
    for frame in meta['frames']:
        pose = np.array(frame['transform_matrix']) @ blender2opencv
        poses += [pose]
    poses = np.stack(poses)
    
    # find nearest image idx
    viewing_dir = poses[:,:3,2]
    dis = np.sum(viewing_dir * poses[[ref_idx],:3,2], axis=-1)
    pair_idx = np.argsort(dis)[::-1][:20]
    
    pairs[f'{scene}_train'] = list(set(pair_idx) - set(pair_idx[::6]))
    pairs[f'{scene}_test'] = pair_idx[::6]
    pairs[f'{scene}_val'] = pair_idx[::6]



# dtu
#      0-4
#    10 - 5 
#   11  -  18
#  27  xx   19
# 28    x    38
#48     -     39
# pairs[f'dtu_train'] = [25,21,33,22,14,15,26,30,31,35,34,43,46,29,16,36]
# pairs[f'dtu_val'] = [32,24,23,44]
# pairs[f'dtu_test'] = [32,24,23,44]

torch.save(pairs,'/mnt/new_disk_2/anpei/code/MVS-NeRF/configs/pairs.th')