In [1]:
import os
import json
from functools import reduce, partial
import numpy as np

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.transforms import functional as F

from utils.graphics_utils import focal2fov, getProjectionMatrix
from scene.cameras import MiniCam
from utils.loss_utils import ssim
from utils.image_utils import psnr
from lpipsPyTorch import lpips
from gaussian_renderer import render
from scene import Scene, GaussianModel

item = 'lego'
msg_len = 32
source_dir = 'eval_examples'
save_map = False # you can visualize the views if set this flag as "True"

class Args:
    def __init__(self):
        # evaluation args
        self.msg_len = 32
        self.source_dir = source_dir
        self.item = item
        self.msg_len = msg_len
        self.batch_size = 32
        # basic 3DGS args
        self.sh_degree = 3
        self.white_background = True
        self.data_device = "cuda"
        self.convert_SHs_python = False
        self.compute_cov3D_python = False
        self.debug = False
        self.save_map = save_map

args = Args()

In [2]:
# Utils
def normalize(output):
    output[output > 0.5] = 1
    output[output <= 0.5] = 0
    return output.bool()

def accuarcy(output, target):
    output = normalize(output)
    err = torch.logical_xor(output, target).sum() / target.numel()
    return (1 - err) * 100.

def getSE3(r, t):
    SE3 = np.eye(4)
    SE3[:3, :3] = np.array(r)
    SE3[:3, 3 ] = np.array(t)
    return SE3

def getTestCameras(source_dir, item):
    with open(os.path.join(source_dir, f'cameras-{item}.json')) as jf:
        clist = json.load(jf)

    w2cs = [np.linalg.inv(getSE3(c['rotation'], c['position'])).transpose() for c in clist]
    w2cs = torch.from_numpy(np.array(w2cs)).cuda().float()
    params = {
        'width'  : clist[0]['width'],
        'height' : clist[0]['height'],
        'fovx'   : focal2fov(clist[0]['fx'], clist[0]['width']),
        'fovy'   : focal2fov(clist[0]['fy'], clist[0]['height']),
        'znear'  : 0.01,
        'zfar'   : 100.0
    }

    proj_matrix = getProjectionMatrix(znear=params['znear'], zfar=params['zfar'], fovX=params['fovx'], fovY=params['fovy']).transpose(0,1).cuda()
    projs = w2cs.bmm(proj_matrix.repeat(len(w2cs), 1, 1))
    return [MiniCam(world_view_transform=w2c, full_proj_transform=proj, **params) for w2c, proj in zip(w2cs, projs)]

def extract_rendered_views_and_gts(gaussians, guardsplat, cameras, renderArgs):
    with torch.no_grad():
        pds, gts = [], []
        for viewpoint in cameras:
            pds.append(torch.clamp(render(viewpoint, guardsplat, *renderArgs)["render"], 0.0, 1.0)[None])
            gts.append(torch.clamp(render(viewpoint, gaussians, *renderArgs)["render"], 0.0, 1.0)[None])
            
        return torch.cat(pds), torch.cat(gts)

@torch.no_grad()
def eval_image_similarity(pds, gts, args):
    PSNR, SSIM, LPIPS = [], [], 0
    with torch.no_grad():
        for idx in range((len(gts) + args.batch_size - 1) // args.batch_size):
            PSNR.append(psnr(pds[idx * args.batch_size : (idx + 1) * args.batch_size], gts[idx * args.batch_size : (idx + 1) * args.batch_size]).cpu())
            SSIM.append(ssim(pds[idx * args.batch_size : (idx + 1) * args.batch_size], gts[idx * args.batch_size : (idx + 1) * args.batch_size], size_average=False).cpu())
            LPIPS += lpips(pds[idx * args.batch_size : (idx + 1) * args.batch_size], gts[idx * args.batch_size : (idx + 1) * args.batch_size]).item()
        
    return {
        'PSNR'  : torch.cat(PSNR).mean().item(),
        'SSIM'  : torch.cat(SSIM).mean().item(),
        'LPIPS' : LPIPS / len(pds)
    }

@torch.no_grad()
def eval_bit_accuracy(pds, message, model, args):
    decoded_messages = []
    with torch.no_grad():
        for idx in range((len(pds) + args.batch_size - 1) // args.batch_size):
            decoded_messages.append(model(pds[idx * args.batch_size : (idx + 1) * args.batch_size].half()))
    return accuarcy(torch.cat(decoded_messages), message.repeat(len(pds), 1)).item()

def getSize(model):
    return sum(p.numel() for p in model.parameters())

def save_maps(imglist, method_names, source_dir):
    from PIL import Image
    import cv2
    save_dir = os.path.join(args.source_dir, 'results')
    os.makedirs(save_dir, exist_ok=True)

    N, H, W, C = imglist[0].shape
    (_, th), _ = cv2.getTextSize('f', cv2.FONT_HERSHEY_DUPLEX, 2, 2)
    tws = [cv2.getTextSize(method_name, cv2.FONT_HERSHEY_DUPLEX, 2, 2)[0][0] for method_name in method_names]
    
    for idx, imgs in enumerate(zip(*imglist)):
        L, T = 0, 0
        merge = np.full((H + 2 * th, W * 2, C), 255, dtype=np.uint8)
        for img, method_name, tw in zip(imgs, method_names, tws):
            merge[T : T + H, L : L + W] = img
            merge = cv2.putText(merge, method_name, (L + W // 2 - tw // 2, H + int(1.5 * th)), cv2.FONT_HERSHEY_DUPLEX, 2, (0, 0, 0), 2)
            L += W
        Image.fromarray(np.uint8(merge)).save(os.path.join(save_dir, f'{str(idx).zfill(4)}.png'))

In [3]:
# Loading Everything

# Loading 3DGS models
gaussians = GaussianModel(args.sh_degree)
gaussians.load_ply(os.path.join(args.source_dir, f'original-{args.item}.ply'))

# Loading Watermarked 3DGS models
guardsplat = GaussianModel(args.sh_degree)
guardsplat.load_ply(os.path.join(args.source_dir, f'watermarked-{args.item}.ply'))

bg_color = [1, 1, 1] if args.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

# Loading message decoder
message_decoder = torch.jit.load(os.path.join(args.source_dir, f'CLIP_visual+msg_decoder-{args.msg_len}.pt')).eval().cuda()

# Loading Message
with open(os.path.join(args.source_dir, f'message-{args.item}.txt'), 'r') as txtfile:
    message_text = txtfile.read()
    message = torch.tensor([float(x) for x in message_text]).cuda()

In [5]:
# Evaluation

print (f'The Parameter Size of Different Modules:')
print (f'CLIP Visual Encoder : {getSize(message_decoder.CLIP_visual) / 1024 / 1024:.2f}M')
print (f'Our Msg Decoder     : {getSize(message_decoder.msg_decoder) / 1024 / 1024:.2f}M')

print (f'\nMessage : {message_text}')
cameras = getTestCameras(args.source_dir, args.item)
pds, gts = extract_rendered_views_and_gts(gaussians, guardsplat, cameras, (args, background))

print (f'\nEvaluating {len(cameras)} views on {args.item.capitalize()} scene w.r.t N_L={args.msg_len} bits')

# bit accuracy between the original 3DGS and our GuardSplat
acc_ours = eval_bit_accuracy(pds, message, message_decoder, args)
acc_3dgs = eval_bit_accuracy(gts, message, message_decoder, args)
print (f'\nBit Accuracy between Original and Watermarked Models:')
print (f'[GuardSplat (Ours)] Bit Acc : {acc_ours}')
print (f'[Original 3DGS]     Bit Acc : {acc_3dgs}')

# visual similarity on the original model-rendered views
ans = eval_image_similarity(pds, gts, args)
atext = reduce(lambda x1, x2 : f'{x1} | {x2}', [f'{k} : {v:.4f}' for k, v in ans.items()])
print (f'\nImage Similarity against Original Model:')
print (atext)

if args.save_map:
    save_maps([np.uint8(x.permute(0, 2, 3, 1).cpu().numpy() * 255.) for x in [gts, pds]], ['Original 3DGS', 'GuardSplat (Ours)'], args.source_dir)

The Parameter Size of Different Modules:
CLIP Visual Encoder : 83.78M
Our Msg Decoder     : 0.20M

Message : 11001000001100000110111001010111

Evaluating 200 views on Lego scene w.r.t N_L=32 bits

Bit Accuracy between Original and Watermarked Models:
[GuardSplat (Ours)] Bit Acc : 98.453125
[Original 3DGS]     Bit Acc : 65.40625

Image Similarity against Original Model:
PSNR : 41.3957 | SSIM : 0.9962 | LPIPS : 0.0013
