Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precalc some splat data each frame #6

Merged
merged 3 commits into from
Sep 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 91 additions & 15 deletions Assets/GaussianSplatting/Scripts/GaussianSplatRenderer.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using Unity.Collections.LowLevel.Unsafe;
using Unity.Profiling;
using Unity.Profiling.LowLevel;
using UnityEngine;
using UnityEngine.Experimental.Rendering;
using UnityEngine.Rendering;
Expand Down Expand Up @@ -47,7 +49,7 @@ public enum DisplayDataMode
public RenderMode m_RenderMode = RenderMode.Splats;
[Range(1.0f,15.0f)] public float m_PointDisplaySize = 3.0f;
public DisplayDataMode m_DisplayData = DisplayDataMode.None;
[Range(1, 8)] public int m_DisplayDataScale = 1;
[Range(1, 8)] public int m_DisplayDataScale = 2;
public bool m_RenderInSceneView = true;
[Tooltip("Use AMD FidelityFX sorting when available, instead of the slower bitonic sort")]
public bool m_PreferFfxSort = true; // use AMD FidelityFX sort if available (currently: DX12, Vulkan, Metal, but *not* DX11)
Expand All @@ -70,6 +72,7 @@ public enum DisplayDataMode
GraphicsBuffer m_GpuSortDistances;
GraphicsBuffer m_GpuSortKeys;
GraphicsBuffer m_GpuChunks;
GraphicsBuffer m_GpuView;

IslandGPUSort m_SorterIsland;
IslandGPUSort.Args m_SorterIslandArgs;
Expand All @@ -89,6 +92,11 @@ public enum DisplayDataMode
GaussianSplatAsset m_PrevAsset;
Hash128 m_PrevHash;

static ProfilerMarker s_ProfSort = new ProfilerMarker(ProfilerCategory.Render, "GaussianSplat.Sort", MarkerFlags.SampleGPU);
static ProfilerMarker s_ProfView = new ProfilerMarker(ProfilerCategory.Render, "GaussianSplat.View", MarkerFlags.SampleGPU);
static ProfilerMarker s_ProfDraw = new ProfilerMarker(ProfilerCategory.Render, "GaussianSplat.Draw", MarkerFlags.SampleGPU);
static ProfilerMarker s_ProfCompose = new ProfilerMarker(ProfilerCategory.Render, "GaussianSplat.Compose", MarkerFlags.SampleGPU);

public GaussianSplatAsset asset => m_Asset;

public bool HasValidAsset => m_Asset != null && m_Asset.m_SplatCount > 0;
Expand All @@ -101,6 +109,8 @@ void CreateResourcesForAsset()
m_GpuChunks = new GraphicsBuffer(GraphicsBuffer.Target.Structured, asset.m_Chunks.Length, UnsafeUtility.SizeOf<GaussianSplatAsset.ChunkInfo>()) { name = "GaussianChunkData" };
m_GpuChunks.SetData(asset.m_Chunks);

m_GpuView = new GraphicsBuffer(GraphicsBuffer.Target.Structured, m_Asset.m_SplatCount, 40);

int splatCountNextPot = Mathf.NextPowerOfTwo(m_Asset.m_SplatCount);
m_GpuSortDistances = new GraphicsBuffer(GraphicsBuffer.Target.Structured, splatCountNextPot, 4) { name = "GaussianSplatSortDistances" };
m_GpuSortKeys = new GraphicsBuffer(GraphicsBuffer.Target.Structured, splatCountNextPot, 4) { name = "GaussianSplatSortIndices" };
Expand Down Expand Up @@ -175,6 +185,8 @@ void OnPreCullCamera(Camera cam)
displayMat.SetBuffer("_SplatChunks", m_GpuChunks);
displayMat.SetInteger("_SplatChunkCount", m_GpuChunks.count);

displayMat.SetBuffer("_SplatViewData", m_GpuView);

displayMat.SetBuffer("_OrderBuffer", m_GpuSortKeys);
displayMat.SetFloat("_SplatScale", m_SplatScale);
displayMat.SetFloat("_SplatSize", m_PointDisplaySize);
Expand All @@ -190,6 +202,8 @@ void OnPreCullCamera(Camera cam)
SortPoints(cam, matrix);
++m_FrameCounter;

CalcViewData(cam, matrix);

int vertexCount = 6;
int instanceCount = m_Asset.m_SplatCount;
MeshTopology topology = MeshTopology.Triangles;
Expand All @@ -209,12 +223,15 @@ void OnPreCullCamera(Camera cam)
{
m_RenderCommandBuffer.GetTemporaryRT(rtNameID, -1, -1, 0, FilterMode.Point,
GraphicsFormat.R16G16B16A16_SFloat);
m_RenderCommandBuffer.BeginSample(s_ProfDraw);
m_RenderCommandBuffer.SetRenderTarget(rtNameID, BuiltinRenderTextureType.CurrentActive);
m_RenderCommandBuffer.ClearRenderTarget(RTClearFlags.Color, new Color(0, 0, 0, 0), 0, 0);
m_RenderCommandBuffer.DrawProcedural(matrix, displayMat, 0, topology, vertexCount,
instanceCount);
m_RenderCommandBuffer.DrawProcedural(matrix, displayMat, 0, topology, vertexCount, instanceCount);
m_RenderCommandBuffer.EndSample(s_ProfDraw);
m_RenderCommandBuffer.BeginSample(s_ProfCompose);
m_RenderCommandBuffer.SetRenderTarget(BuiltinRenderTextureType.CameraTarget);
m_RenderCommandBuffer.DrawProcedural(Matrix4x4.identity, m_MatComposite, 0, MeshTopology.Triangles, 6, 1);
m_RenderCommandBuffer.EndSample(s_ProfCompose);
m_RenderCommandBuffer.ReleaseTemporaryRT(rtNameID);
}

Expand Down Expand Up @@ -256,6 +273,19 @@ static string TextureTypeToPropertyName(GaussianSplatAsset.TexType type)
};
}

void SetAssetTexturesOnCS(ComputeShader cs, int kernelIndex)
{
uint texFlags = 0;
for (var t = GaussianSplatAsset.TexType.Pos; t < GaussianSplatAsset.TexType.TypeCount; ++t)
{
var tex = m_Asset.GetTex(t);
if (tex.graphicsFormat == GraphicsFormat.R32_SFloat) // so that a shader knows it needs to interpret R32F as packed integer
texFlags |= (1u << (int) t);
cs.SetTexture(kernelIndex, TextureTypeToPropertyName(t), tex);
}
cs.SetInt("_TexFlagBits", (int)texFlags);
}

void SetAssetTexturesOnMaterial(Material displayMat)
{
uint texFlags = 0;
Expand Down Expand Up @@ -286,11 +316,13 @@ void DisposeResourcesForAsset()
}

m_GpuChunks?.Dispose();
m_GpuView?.Dispose();
m_GpuSortDistances?.Dispose();
m_GpuSortKeys?.Dispose();
m_SorterFfxArgs.resources.Dispose();

m_GpuChunks = null;
m_GpuView = null;
m_GpuSortDistances = null;
m_GpuSortKeys = null;
}
Expand All @@ -311,11 +343,52 @@ public void OnDisable()
DestroyImmediate(m_MatDebugData);
}

void SortPoints(Camera cam, Matrix4x4 matrix)
void CalcViewData(Camera cam, Matrix4x4 matrix)
{
if (cam.cameraType == CameraType.Preview || !m_RenderInSceneView && cam.cameraType == CameraType.SceneView)
return;

using var prof = s_ProfView.Auto();

var tr = transform;

Matrix4x4 matView = cam.worldToCameraMatrix;
Matrix4x4 matProj = GL.GetGPUProjectionMatrix(cam.projectionMatrix, true);
Matrix4x4 matO2W = tr.localToWorldMatrix;
Matrix4x4 matW2O = tr.worldToLocalMatrix;
int screenW = cam.pixelWidth, screenH = cam.pixelHeight;
Vector4 screenPar = new Vector4(screenW, screenH, 0, 0);
Vector4 camPos = cam.transform.position;

// calculate view dependent data for each splat
const int kernelIdx = 2;
SetAssetTexturesOnCS(m_CSSplatUtilities, kernelIdx);

m_CSSplatUtilities.SetInt("_SplatCount", m_GpuView.count);
m_CSSplatUtilities.SetBuffer(kernelIdx, "_SplatViewData", m_GpuView);
m_CSSplatUtilities.SetBuffer(kernelIdx, "_OrderBuffer", m_GpuSortKeys);
m_CSSplatUtilities.SetBuffer(kernelIdx, "_SplatChunks", m_GpuChunks);

m_CSSplatUtilities.SetMatrix("_MatrixVP", matProj * matView);
m_CSSplatUtilities.SetMatrix("_MatrixV", matView);
m_CSSplatUtilities.SetMatrix("_MatrixP", matProj);
m_CSSplatUtilities.SetMatrix("_MatrixObjectToWorld", matO2W);
m_CSSplatUtilities.SetMatrix("_MatrixWorldToObject", matW2O);

m_CSSplatUtilities.SetVector("_VecScreenParams", screenPar);
m_CSSplatUtilities.SetVector("_VecWorldSpaceCameraPos", camPos);
m_CSSplatUtilities.SetFloat("_SplatScale", m_SplatScale);
m_CSSplatUtilities.SetInt("_SHOrder", m_SHOrder);

m_CSSplatUtilities.GetKernelThreadGroupSizes(kernelIdx, out uint gsX, out uint gsY, out uint gsZ);
m_CSSplatUtilities.Dispatch(kernelIdx, (m_GpuView.count + (int)gsX - 1)/(int)gsX, 1, 1);
}

void SortPoints(Camera cam, Matrix4x4 matrix)
{
if (cam.cameraType == CameraType.Preview || !m_RenderInSceneView && cam.cameraType == CameraType.SceneView)
return;

bool useFfx = m_PreferFfxSort && m_SorterFfx.Valid;
Matrix4x4 worldToCamMatrix = cam.worldToCameraMatrix;
if (useFfx)
Expand All @@ -326,24 +399,27 @@ void SortPoints(Camera cam, Matrix4x4 matrix)
}

// calculate distance to the camera for each splat
int kernelIdx = 1;
m_RenderCommandBuffer.BeginSample(s_ProfSort);
var texPos = m_Asset.GetTex(GaussianSplatAsset.TexType.Pos);
m_CSSplatUtilities.SetTexture(1, "_TexPos", texPos);
m_CSSplatUtilities.SetInt("_TexFlagBits", texPos.graphicsFormat == GraphicsFormat.R32_SFloat ? 1 : 0);
m_CSSplatUtilities.SetBuffer(1, "_SplatSortDistances", m_GpuSortDistances);
m_CSSplatUtilities.SetBuffer(1, "_SplatSortKeys", m_GpuSortKeys);
m_CSSplatUtilities.SetBuffer(1, "_SplatChunks", m_GpuChunks);
m_CSSplatUtilities.SetMatrix("_LocalToWorldMatrix", matrix);
m_CSSplatUtilities.SetMatrix("_WorldToCameraMatrix", worldToCamMatrix);
m_CSSplatUtilities.SetInt("_SplatCount", m_Asset.m_SplatCount);
m_CSSplatUtilities.SetInt("_SplatCountPOT", m_GpuSortDistances.count);
m_CSSplatUtilities.GetKernelThreadGroupSizes(1, out uint gsX, out uint gsY, out uint gsZ);
m_CSSplatUtilities.Dispatch(1, (m_GpuSortDistances.count + (int)gsX - 1)/(int)gsX, 1, 1);
m_RenderCommandBuffer.SetComputeTextureParam(m_CSSplatUtilities, kernelIdx, "_TexPos", texPos);
m_RenderCommandBuffer.SetComputeIntParam(m_CSSplatUtilities, "_TexFlagBits", texPos.graphicsFormat == GraphicsFormat.R32_SFloat ? 1 : 0);
m_RenderCommandBuffer.SetComputeBufferParam(m_CSSplatUtilities, kernelIdx, "_SplatSortDistances", m_GpuSortDistances);
m_RenderCommandBuffer.SetComputeBufferParam(m_CSSplatUtilities, kernelIdx, "_SplatSortKeys", m_GpuSortKeys);
m_RenderCommandBuffer.SetComputeBufferParam(m_CSSplatUtilities, kernelIdx, "_SplatChunks", m_GpuChunks);
m_RenderCommandBuffer.SetComputeMatrixParam(m_CSSplatUtilities, "_LocalToWorldMatrix", matrix);
m_RenderCommandBuffer.SetComputeMatrixParam(m_CSSplatUtilities, "_WorldToCameraMatrix", worldToCamMatrix);
m_RenderCommandBuffer.SetComputeIntParam(m_CSSplatUtilities, "_SplatCount", m_Asset.m_SplatCount);
m_RenderCommandBuffer.SetComputeIntParam(m_CSSplatUtilities, "_SplatCountPOT", m_GpuSortDistances.count);
m_CSSplatUtilities.GetKernelThreadGroupSizes(kernelIdx, out uint gsX, out _, out _);
m_RenderCommandBuffer.DispatchCompute(m_CSSplatUtilities, kernelIdx, (m_GpuSortDistances.count + (int)gsX - 1)/(int)gsX, 1, 1);

// sort the splats
if (useFfx)
m_SorterFfx.Dispatch(m_RenderCommandBuffer, m_SorterFfxArgs);
else
m_SorterIsland.Dispatch(m_RenderCommandBuffer, m_SorterIslandArgs);
m_RenderCommandBuffer.EndSample(s_ProfSort);
}

public void Update()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ v2f vert (uint vtxID : SV_VertexID, uint instID : SV_InstanceID)
boxSize *= _SplatScale;

float3x3 splatRotScaleMat = CalcMatrixFromRotationScale(boxRot, boxSize);
splatRotScaleMat = mul((float3x3)unity_ObjectToWorld, splatRotScaleMat);

centerWorldPos = splat.pos * float3(1,1,-1);
centerWorldPos = splat.pos;
centerWorldPos = mul(unity_ObjectToWorld, float4(centerWorldPos,1)).xyz;

o.col.rgb = saturate(splat.sh.col);
o.col.a = saturate(splat.opacity);
Expand All @@ -84,6 +86,7 @@ v2f vert (uint vtxID : SV_VertexID, uint instID : SV_InstanceID)
localPos = localPos * 0.5 + 0.5;
SplatChunkInfo chunk = _SplatChunks[instID];
localPos = lerp(chunk.boundsMin.pos, chunk.boundsMax.pos, localPos);
localPos = mul(unity_ObjectToWorld, float4(localPos,1)).xyz;

o.col.rgb = palette((float)instID / (float)_SplatChunkCount, half3(0.5,0.5,0.5), half3(0.5,0.5,0.5), half3(1,1,1), half3(0.0, 0.33, 0.67));
o.col.a = 0.1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ v2f vert (uint vtxID : SV_VertexID, uint instID : SV_InstanceID)

SplatData splat = LoadSplatData(splatIndex);

float3 centerWorldPos = splat.pos * float3(1,1,-1);
float3 centerWorldPos = splat.pos;
centerWorldPos = mul(unity_ObjectToWorld, float4(centerWorldPos,1)).xyz;

float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));

Expand Down
7 changes: 7 additions & 0 deletions Assets/GaussianSplatting/Shaders/GaussianSplatting.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -365,4 +365,11 @@ SplatData LoadSplatDataRaw(uint2 coord2)
return s;
}

struct SplatViewData
{
float4 pos;
float4 conicRadius;
uint2 color; // 4xFP16
};

#endif // GAUSSIAN_SPLATTING_HLSL
55 changes: 13 additions & 42 deletions Assets/GaussianSplatting/Shaders/RenderGaussianSplats.shader
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ CGPROGRAM
#pragma require compute
#pragma use_dxc metal vulkan

#include "UnityCG.cginc"
#include "GaussianSplatting.hlsl"

StructuredBuffer<uint> _OrderBuffer;
Expand All @@ -29,67 +28,39 @@ struct v2f
float4 vertex : SV_POSITION;
};

float _SplatScale;
uint _SHOrder;
StructuredBuffer<SplatViewData> _SplatViewData;

v2f vert (uint vtxID : SV_VertexID, uint instID : SV_InstanceID)
{
v2f o;
instID = _OrderBuffer[instID];
SplatData splat = LoadSplatData(instID);

float4 boxRot = splat.rot;
float3 boxSize = splat.scale;
boxSize *= _SplatScale;

float3x3 splatRotScaleMat = CalcMatrixFromRotationScale(boxRot, boxSize);
splatRotScaleMat = mul((float3x3)unity_ObjectToWorld, splatRotScaleMat);

float3 centerWorldPos = splat.pos;
centerWorldPos = mul(unity_ObjectToWorld, float4(centerWorldPos,1)).xyz;

float3 worldViewDir = _WorldSpaceCameraPos.xyz - centerWorldPos;
float3 objViewDir = mul((float3x3)unity_WorldToObject, worldViewDir);
objViewDir = normalize(objViewDir);

o.col.rgb = ShadeSH(splat.sh, objViewDir, _SHOrder);
o.col.a = splat.opacity;

float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));

SplatViewData view = _SplatViewData[instID];
o.col.r = f16tof32(view.color.x >> 16);
o.col.g = f16tof32(view.color.x);
o.col.b = f16tof32(view.color.y >> 16);
o.col.a = f16tof32(view.color.y);
o.conic = view.conicRadius.xyz;

float4 centerClipPos = view.pos;
bool behindCam = centerClipPos.w <= 0;
o.centerScreenPos = (centerClipPos.xy / centerClipPos.w * float2(0.5, 0.5*_ProjectionParams.x) + 0.5) * _ScreenParams.xy;

float3 cov3d0, cov3d1;
CalcCovariance3D(splatRotScaleMat, cov3d0, cov3d1);
float3 cov2d = CalcCovariance2D(centerWorldPos, cov3d0, cov3d1, UNITY_MATRIX_V, UNITY_MATRIX_P, _ScreenParams);

// conic
float det = cov2d.x * cov2d.z - cov2d.y * cov2d.y;
float3 conic = float3(cov2d.z, -cov2d.y, cov2d.x) * rcp(det);
o.conic = conic;

// make the quad in screenspace the required size to cover the extents
// of the 2D splat.
//@TODO: should be possible to orient the quad to cover an elongated
// splat tighter

// two bits per vertex index to result in 0,1,2,1,3,2 from lowest:
// 0b1011'0110'0100
uint quadIndices = 0xB64;
uint idx = quadIndices >> (vtxID * 2);
float2 quadPos = float2(idx&1, (idx>>1)&1) * 2.0 - 1.0;

float mid = 0.5f * (cov2d.x + cov2d.z);
float lambda1 = mid + sqrt(max(0.1f, mid * mid - det));
float lambda2 = mid - sqrt(max(0.1f, mid * mid - det));
float radius = ceil(3.f * sqrt(max(lambda1, lambda2)));
float radius = view.conicRadius.w;

float2 deltaScreenPos = quadPos * radius * 2 / _ScreenParams.xy;
o.vertex = centerClipPos;
o.vertex.xy += deltaScreenPos * centerClipPos.w;

if (behindCam)
o.vertex = 0.0 / 0.0;
o.vertex = 0.0 / 0.0;

return o;
}

Expand Down
Loading