In [None]:
import os
print(f"当前目录: {os.getcwd()}")
if 'arch-gaussian' in os.listdir():
    JUPYTER_ROOT = os.getcwd()
    os.chdir('arch-gaussian')
    print(f"更改后的目录: {os.getcwd()}")

In [None]:
import config
# 基础设置
config.scene_name = "Hongkong\ShenShuiBu\ShenShuiBuComposite" # 输入文件夹名称
config.output_name = "Hongkong\ShenShuiBu\ShenShuiBuComposite"  # 输出文件夹名称
config.sh_degree = 3  # 0~3， 改为0可以缩小文件大小，但是无法被现有的unity工具识别， 推荐使用默认3
config.epochs = 7000  # 训练轮次， 3000， 7000， 15000， 30000
config.first_iter = 0
config.loaded_iter = 7000
# 高级设置
config.resolution = 1 # 是否压缩图像，1为原尺寸，-1为自动压缩到不大于1600像素， 支持2的次方的数字
config.densify_until_iter = 15000  # 致密化结束的轮次
config.densify_grad_threshold = 0.0002  # 致密化的阈值，越小增加的越快

# 配置结束，更新配置文件
config.update_colmap_args()
config.update_args()

# 打印修改过后的参数
config.print_updated_colmap_args()
config.print_updated_args()


In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("./src")
import os
import torch
import numpy as np
np.set_printoptions(suppress=True)
from config import args  # 正式导入args



## 创建scene info 与修复

In [None]:
# 创建scene info
from manager.scene_manager import load_and_fix_scene

scene_info = load_and_fix_scene(args)


## 导入相机

In [None]:
from manager.camera_manager import CameraManager
cm = CameraManager()
cm.create_cameras(args, scene_info)

# optional
cm.remove_last_camera()

In [None]:
print(len(cm.sorted_cameras[1.0]))

## 创建gaussian对象

In [None]:
from manager.gaussian_manager import GaussianManager
gm = GaussianManager(args, scene_info)
print(f"num points: {gm.gaussians.get_xyz.shape[0]}")

## 使用封装好的Gaussian Manager管理gaussian

In [None]:
from manager.gaussian_manager import GaussianManager
gm = GaussianManager(args, gm.gaussians) # 从gaussians创建GaussianManager
cam = cm.pick_camera(0)
image = gm.render(cam, convert_to_pil=True)
image

In [None]:
position_range = (torch.tensor([-0.5,0,-100]), torch.tensor([0.5,0.5,100]))
with gm.virtual():
    mask = gm.position_mask(*position_range)
    gm.paint_by_mask(mask)
    image = gm.render(cam, convert_to_pil= True)
image

In [None]:
mask = gm.mask_from_json(r"D:\FengYiheng\Projects\arch-gaussian\output\Hongkong\ShenShuiBu\ShenShuiBuComposite\point_cloud\iteration_7000\ShenShuiBuComposite-point_cloud-iteration_7000-point_cloud.json")
with gm.virtual():
    gm.set_alpha(-100, mask)
    image = gm.render(cam, convert_to_pil= True)
image

In [None]:
from tqdm.auto import tqdm
from manager.train_manager import init_snapshot, take_snapshot, SnapshotCameraMode, SnapshotFilenameMode
init_snapshot(106)
with gm.virtual():
    for i in tqdm(range(150)):
        take_snapshot(cm, gm,
                  _camera_mode=SnapshotCameraMode.SLOW_ROTATE,
                  _slow_ratio=5,
                  _filename_mode=SnapshotFilenameMode.BY_SNAP_COUNT,
                  _folder_name="snapshots2",
                  _iteration_gap=1,
                  _first_period_iteration_gap=1,
                  _first_period_end=0,
                  args=args,
                  iteration=i,
                  image=None)


In [None]:
init_snapshot(106 + 150)
with gm.virtual():
    bboxes = gm.bboxes_from_json(r"D:\FengYiheng\Projects\arch-gaussian\output\Hongkong\ShenShuiBu\ShenShuiBuComposite\point_cloud\iteration_7000\ShenShuiBuComposite-point_cloud-iteration_7000-point_cloud.json",-2, 0)
    for region in bboxes:
        for bbox in region:
            mask = gm.position_mask(*bbox)
            gm.clear_features_rest(mask)
            gm.set_color(0.5, mask)
            gm.noise_position(mask, bbox)

    for i in tqdm(range(10)):
        take_snapshot(cm, gm,
                  _camera_mode=SnapshotCameraMode.SLOW_ROTATE,
                  _slow_ratio=5,
                  _filename_mode=SnapshotFilenameMode.BY_SNAP_COUNT,
                  _folder_name="snapshots3",
                  _iteration_gap=1,
                  _first_period_iteration_gap=1,
                  _first_period_end=0,
                  args=args,
                  iteration=i,
                  image=None)


In [None]:
init_snapshot(106 + 150 + 150)
with gm.virtual():
    gm.set_alpha(-100, mask)
    for i in tqdm(range(150)):
        take_snapshot(cm, gm,
                  _camera_mode=SnapshotCameraMode.SLOW_ROTATE,
                  _slow_ratio=5,
                  _filename_mode=SnapshotFilenameMode.BY_SNAP_COUNT,
                  _folder_name="snapshots4",
                  _iteration_gap=1,
                  _first_period_iteration_gap=1,
                  _first_period_end=0,
                  args=args,
                  iteration=i,
                  image=None)


In [None]:
# 离开virtual环境后，会恢复原有的数据
# image = gm.render(cam, convert_to_pil= True)
# image

In [None]:
from utils.image_utils import get_pil_image, save_pil_image

import torchvision.transforms.functional as TF

def gt_socket(**kwargs):
    """
    ground truth socket
    作用为替换真实的gt image
    """
    iteration = kwargs['iteration']
    camera = kwargs['viewpoint_cam']

    gt_image = camera.original_image
    if iteration % 100 == 0:
        pil_gt_image = get_pil_image(gt_image)
        save_pil_image(pil_gt_image, os.path.join(args.model_path, "snap_shots", f"{iteration:05d}_gt_org.jpg"))
    new_gt_image = TF.adjust_hue(gt_image, 0.4)
    return new_gt_image

In [None]:
def loss_socket(**kwargs):
    """
    对真实计算出的loss进行更改
    """
    global mask
    iteration = kwargs['iteration']
    if mask is None or gm.gaussians._xyz.shape[0] != mask.shape[0]:
        print(f"mask updated at iter {iteration}")
        mask = gm.position_mask(*position_range)
    gm.clear_grads(~mask)



In [None]:
from utils.image_utils import save_pil_image


def post_socket(**kwargs):
    """
    完成每一轮训练后的后处理内容
    """
    args = kwargs['args']
    iteration = kwargs['iteration']
    image = kwargs['image']
    gt_image = kwargs['gt_image']
    gaussians = kwargs['gaussians']


    if iteration % 100 == 0:
        pil_image = get_pil_image(image)
        save_path = os.path.join(args.model_path, "snap_shots", f"{iteration:05d}.jpg")
        save_pil_image(pil_image, save_path)

        pil_gt_iamge = get_pil_image(gt_image)
        gt_save_path = os.path.join(args.model_path, "snap_shots", f"{iteration:05d}_gt.jpg")
        save_pil_image(pil_gt_iamge, gt_save_path)

    if iteration % 1000 == 0:
        print(gaussians.get_xyz.shape)

In [None]:
from manager.train_manager import train, init_output_folder

init_output_folder(args, scene_info)

train(args,scene_info, gm.gaussians, cm.train_cameras, gt_socket=gt_socket,loss_socket=loss_socket, post_socket=post_socket)

# train(args,scene_info, gm.gaussians, cm.train_cameras, post_socket=post_socket)

## 在模型上叠加绘图

In [None]:
import matplotlib.pyplot as plt
position_range = (np.array([-0.5,0,-2]), np.array([0.5,0.5,-1.9]))
# 定义立方体的八个点坐标
a, b = position_range[0], position_range[1]

cube_points = np.array([
    [a[0], a[1], a[2]],  # 左下后
    [a[0], b[1], a[2]],   # 左上后
    [b[0], a[1], a[2]],   # 右下后
    [b[0], b[1], a[2]],    # 右上后

    [a[0], a[1], b[2]],   # 左下前
    [a[0], b[1], b[2]],    # 左上前
    [b[0], a[1], b[2]],    # 右下前
    [b[0], b[1], b[2]],      # 右上前
])


edges = [
    (0, 1), (1, 3), (3, 2), (2, 0),  # 左边
    (4, 5), (5, 7), (7, 6), (6, 4),  # 右边
    (0, 4), (1, 5), (2, 6), (3, 7)   # 连接前后面
]


print(cube_points)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from manager.display_manager import Geometry, Drawer

box = Geometry(cube_points,edges)
drawer = Drawer()
drawer.add_geometry(box)

with gm.virtual():
    for cam in train_cameras[1]:
        mask = gm.position_mask(a, b)
        gm.paint_by_mask(mask)
        image = gm.render(cam, convert_to_pil=True)

        drawer.draw(cam, image)

        save_path = os.path.join(os.getcwd(),"cache",f"{cam.image_name}.jpg")
        os.makedirs(os.path.dirname(save_path), exist_ok=True)

        image.save(save_path)

## View Online

In [None]:
import os
import numpy as np
from plyfile import PlyData

print(f"当前目录: {os.getcwd()}")
if 'arch-gaussian' in os.listdir():
    JUPYTER_ROOT = os.getcwd()
    os.chdir('arch-gaussian')
    print(f"更改后的目录: {os.getcwd()}")

In [None]:
point_cloud_path = os.path.join(args.model_path, "point_cloud/iteration_{}".format(args.iterations), "point_cloud.ply")
if not os.path.exists(point_cloud_path):
    print("file not found")

In [None]:
def splat2np(_path):
    with open(_path) as f:
        _b = f.read()
    _dt = np.dtype([
        ('position', np.float32, 3),
        ('scale', np.float32, 3),
        ('RGBA', np.uint8, 4),
        ('IJKL', np.uint8, 4)
    ])
    return np.frombuffer(_b, _dt)

def np2splat(_data, _save_path):
    _data.tofile(_save_path)
    print(f"data saved to {_save_path}")

def ply2np(_ply_path):
    _dt = np.dtype([
        ('position', np.float32, 3),
        ('scale', np.float32, 3),
        ('RGBA', np.uint8, 4),
        ('IJKL', np.uint8, 4)
    ])

    plydata = PlyData.read(_ply_path)
    x = np.array(plydata.elements[0]['x'])
    y = np.array(plydata.elements[0]['y'])
    z = np.array(plydata.elements[0]['z'])
    scale_0 = np.array(plydata.elements[0]['scale_0'])
    scale_1 = np.array(plydata.elements[0]['scale_1'])
    scale_2 = np.array(plydata.elements[0]['scale_2'])
    rot_0 = np.array(plydata.elements[0]['rot_0'])
    rot_1 = np.array(plydata.elements[0]['rot_1'])
    rot_2 = np.array(plydata.elements[0]['rot_2'])
    rot_3 = np.array(plydata.elements[0]['rot_3'])
    r = np.array(plydata.elements[0]['f_dc_0'])
    g = np.array(plydata.elements[0]['f_dc_1'])
    b = np.array(plydata.elements[0]['f_dc_2'])
    a = np.array(plydata.elements[0]['opacity'])

    position = np.stack((x,y,z), axis=1)
    scales= np.stack((scale_0, scale_1, scale_2), axis=1)
    rots = np.stack((rot_0, rot_1, rot_2, rot_3), axis=1)
    rgba = np.stack((r,g,b,a), axis = 1)

    qlen = np.square(rots).sum(axis=1)
    rots = rots / qlen[:, np.newaxis] * 128 + 128
    rots = np.clip(rots, 0, 255)

    scales = np.exp(scales)

    SH_C0 = 0.28209479177387814
    rgba[:,0:3] = (0.5 + SH_C0 * rgba[:,0:3]) * 255
    rgba[:, 3] = (1 / (1 + np.exp(-rgba[:, 3]))) * 255
    rgba = np.clip(rgba, 0, 255)

    rots = rots.astype(np.uint8)
    rgba = rgba.astype(np.uint8)

    merged = np.empty((position.shape[0]),dtype=_dt)

    merged['position'] = position
    merged['scale'] = scales
    merged['RGBA'] = rgba
    merged['IJKL'] = rots

    return merged


In [None]:
data = ply2np(point_cloud_path)
print(data[0])
save_path = f"{os.path.dirname(point_cloud_path)}/output.splat"
np2splat(data, save_path)

In [None]:
import http.server
import socketserver

# 指定根目录的位置
root_directory = './web/WebGLViewer'

# 设置服务器的端口号
port = 8000

# 创建一个简单的 HTTP 请求处理器类
Handler = http.server.SimpleHTTPRequestHandler

# 指定根目录的位置
Handler.directory = root_directory

# 创建一个服务器，监听指定的端口
with socketserver.TCPServer(("", port), Handler) as httpd:
    print(f"Serving at port {port}")
    # 启动服务器
    httpd.serve_forever()
