Skip to content

Commit

Permalink
Merge branch 'yaobin_dldnMR/restirReflection' into 'main'
Browse files Browse the repository at this point in the history
[REMIX-2894] Adjust ReSTIR GI Search Radius.
Increasing search radius can reduce sample coherency and make ReSTIR GI's output more friendly to denoisers.

* Increase search radius when camera is quickly approaching an object.
* Increase search radius when reflection reprojection is applied.

See merge request lightspeedrtx/dxvk-remix-nv!762
  • Loading branch information
Yaobin Ouyang committed Mar 27, 2024
2 parents a435dfc + 9662692 commit 3afe623
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/dxvk/rtx_render/rtx_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ namespace dxvk {
constants.permutationSamplingSize = restirGI.permutationSamplingSize();
constants.enableReSTIRGITemporalBiasCorrection = restirGI.useTemporalBiasCorrection();
constants.enableReSTIRGIDiscardEnlargedPixels = restirGI.useDiscardEnlargedPixels();
constants.reSTIRGIHistoryDiscardStrength = restirGI.historyDiscardStrength();
constants.enableReSTIRGITemporalJacobian = restirGI.useTemporalJacobian();
constants.reSTIRGIFireflyThreshold = restirGI.fireflyThreshold();
constants.reSTIRGIRoughnessClamp = restirGI.roughnessClamp();
Expand Down
1 change: 1 addition & 0 deletions src/dxvk/rtx_render/rtx_restir_gi_rayquery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ namespace dxvk {
ImGui::DragInt("Temporal History Length (frame)", &temporalFixedHistoryLengthObject(), 1.f, 1, 500, "%d", ImGuiSliderFlags_AlwaysClamp);
ImGui::DragInt("Permutation Sampling Size", &permutationSamplingSizeObject(), 0.1f, 1, 8, "%d", ImGuiSliderFlags_AlwaysClamp);
ImGui::Checkbox("Discard Enlarged Pixels", &useDiscardEnlargedPixelsObject());
ImGui::DragFloat("History Discard Strength", &historyDiscardStrengthObject(), 0.01f, 0.f, 50.f, "%.1f");
ImGui::DragFloat("Firefly Threshold", &fireflyThresholdObject(), 0.01f, 1.f, 5000.f, "%.1f");
ImGui::DragFloat("Roughness Clamp", &roughnessClampObject(), 0.001f, 0.f, 1.f, "%.3f");

Expand Down
1 change: 1 addition & 0 deletions src/dxvk/rtx_render/rtx_restir_gi_rayquery.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ namespace dxvk {
RTX_OPTION("rtx.restirGI", float, sampleStealingJitter, 0.0, "Jitter samples by k pixels to avoid aliasing.");
RTX_OPTION("rtx.restirGI", bool, stealBoundaryPixelSamplesWhenOutsideOfScreen , true, "Steals ReSTIR GI samples even a hit point is outside the screen. This will further improve highly specular samples at the cost of some bias.");
RTX_OPTION("rtx.restirGI", bool, useDiscardEnlargedPixels, true, "Discards enlarged samples when the camera is moving towards an object.");
RTX_OPTION("rtx.restirGI", float, historyDiscardStrength, 0.0, "The sensitivity of discarding history. Higher values discard more history.");
RTX_OPTION("rtx.restirGI", bool, useTemporalJacobian, true, "Calculates Jacobian determinant in temporal reprojection.");
RW_RTX_OPTION("rtx.restirGI", bool, useReflectionReprojection, true, "Uses reflection reprojection for reflective objects to achieve stable result when the camera is moving.");
RTX_OPTION("rtx.restirGI", float, reflectionMinParallax, 3.0, "When the parallax between normal and reflection reprojection is greater than this threshold, randomly choose one reprojected position and reuse the sample on it. Otherwise, get a sample between the two positions.");
Expand Down
1 change: 1 addition & 0 deletions src/dxvk/shaders/rtx/pass/raytrace_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ struct RaytraceArgs {
uint32_t permutationSamplingSize;
uint enableReSTIRGITemporalBiasCorrection;
uint enableReSTIRGIDiscardEnlargedPixels;
float reSTIRGIHistoryDiscardStrength;
uint enableReSTIRGITemporalJacobian;
float reSTIRGIFireflyThreshold;
float reSTIRGIRoughnessClamp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@
#include "rtx/algorithm/resolve.slangh"
#include "rtx/algorithm/rtxdi/rtxdi.slangh"

float getTemporalSearchRadius(RAB_Surface surface, float3 virtualMotionVector, float reprojectionDistance)
{
if (cb.teleportationPortalIndex != 0 || surface.portalSpace == PORTAL_SPACE_PORTAL_COMBINED)
return 0;

// Increasing search radius can reduce sample coherency and make ReSTIR GI's output more friendly to denoisers.
// If the camera is quickly approaching a surface, increase the search radius to avoid artifacts
const float maxRadius = 30;
const float3 cameraPositon = cameraGetWorldPosition(cb.camera);
const float3 prevVirtualWorldPosition = surface.virtualWorldPosition + virtualMotionVector;
float prevDistance = length(cameraGetPreviousWorldPosition(cb.camera).xyz - prevVirtualWorldPosition);
float currDistance = length(cameraPositon - surface.virtualWorldPosition);
float radius = (prevDistance / currDistance - 1) * cb.reSTIRGIHistoryDiscardStrength * maxRadius;

// If reflection reprojection is applied, also increase the search radius to avoid artifacts
if (reprojectionDistance > 2)
{
radius += 20;
}

// Ignore small search radius to avoid low frequency artifacts
return radius < 10 ? 0 : min(radius, maxRadius);
}

void validateSample(ReSTIRGI_Reservoir inputReservoir, inout ReSTIRGI_Reservoir resultReservoir, inout bool isGBufferSimilar)
{
// Calculate screen space position of the sample point
Expand Down Expand Up @@ -177,14 +201,14 @@ void main(int2 thread_id : SV_DispatchThreadID)
vec4 prevNDC = mul(camera.prevWorldToProjection, vec4(prevVirtualWorldPosition, 1.0f));
prevNDC.xyz /= prevNDC.w;
float2 prevPixelCenter = cameraNDCToScreenUV(camera, prevNDC.xy) * vec2(camera.resolution);
float expectedPrevHitDistance = length(prevVirtualWorldPosition - cameraGetPreviousWorldPosition(camera).xyz);
float2 expectedPrevHitDistance = length(prevVirtualWorldPosition - cameraGetPreviousWorldPosition(camera).xyz);
float viewDirectionDotTriangleNormal = abs(dot(surface.minimalRayInteraction.viewDirection, surface.minimalSurfaceInteraction.triangleNormal));
float2 depthThreshold = 0.01 / max(viewDirectionDotTriangleNormal, 0.01);
float normalThreshold = lerp(0.995, 0.5, surface.opaqueSurfaceMaterialInteraction.isotropicRoughness);

// Calculate reflection sample weight
float reflectionWeight = 0;
if (cb.enableReSTIRGIReflectionReprojection > 0)
if (cb.enableReSTIRGIReflectionReprojection > 0 && surface.portalSpace == PORTAL_SPACE_NONE)
{
float roughWeight = 1.0 - surface.opaqueSurfaceMaterialInteraction.isotropicRoughness;
const float16_t normalDotOutputDirection = dot(surface.opaqueSurfaceMaterialInteraction.shadingNormal, surface.minimalRayInteraction.viewDirection);
Expand All @@ -200,11 +224,13 @@ void main(int2 thread_id : SV_DispatchThreadID)
float2 currentPixelCenter = thread_id + 0.5;
bool discardEnlargedPixels = cb.enableReSTIRGIDiscardEnlargedPixels;
float reflectionReprojectionWeight = 0;
float reprojectionDistance = 0;
if (cb.enableReSTIRGIReflectionReprojection > 0 && reflectionWeight > 0.05 && length(prevPixelCenter - currentPixelCenter) > 1)
{
float3 worldPos = surface.minimalSurfaceInteraction.position;
float3 viewVector = -surface.minimalRayInteraction.viewDirection;
float3 reflectionVector = reflect(viewVector, surface.opaqueSurfaceMaterialInteraction.shadingNormal);
f16vec3 surfaceNormal = surface.opaqueSurfaceMaterialInteraction.shadingNormal;
float3 reflectionVector = reflect(viewVector, surfaceNormal);

// Calculate reflection hit T
uint8_t backupPortalID = RTXDI_INVALID_PORTAL_INDEX;
Expand All @@ -224,7 +250,7 @@ void main(int2 thread_id : SV_DispatchThreadID)

// Intepolate reprojection position, or randomly choose one type of reprojection
const float interpolateDistance = cb.restirGIReflectionMinParallax;
float reprojectionDistance = length(prevReflectionCenter - prevPixelCenter);
reprojectionDistance = length(prevReflectionCenter - prevPixelCenter);
reflectionReprojectionWeight = reflectionWeight;
prevBackupPixelCenter = prevPixelCenter; // Set original reprojection position as a backup position
if (reprojectionDistance < interpolateDistance)
Expand All @@ -238,13 +264,27 @@ void main(int2 thread_id : SV_DispatchThreadID)
// To get significant parallax, the camera must be close to reflective surface.
// Losen depth check and don't discard enlarged pixels to get stable reprojected result.
prevPixelCenter = prevReflectionCenter;
depthThreshold.x = 0.25;
float3 prevCameraPosition = cameraGetPreviousWorldPosition(cb.camera);
float3 prevCameraVector = normalize(reflectPosition - prevCameraPosition);
float prevReflectionHitDistance = dot(worldPos - prevCameraPosition, surfaceNormal) / dot(prevCameraVector, surfaceNormal);
if (prevReflectionHitDistance > 0)
{
expectedPrevHitDistance.x = prevReflectionHitDistance;
}
discardEnlargedPixels = false;
}
}

float initialSearchRadius = getTemporalSearchRadius(surface, virtualMotionVector, reprojectionDistance);
int temporalHistoryLength = cb.temporalHistoryLength;
if (initialSearchRadius > 0)
{
depthThreshold.x = 0.28;
normalThreshold.x = 0.8;
}

ReSTIRGI_TemporalResamplingParameters tparams = {};
tparams.temporalHistoryLength = cb.temporalHistoryLength;
tparams.temporalHistoryLength = temporalHistoryLength;
tparams.enableJacobian = cb.enableReSTIRGITemporalJacobian;
tparams.enableBiasCorrection = cb.enableReSTIRGITemporalBiasCorrection;
tparams.prevPixelCenter = prevPixelCenter;
Expand All @@ -260,7 +300,7 @@ void main(int2 thread_id : SV_DispatchThreadID)
tparams.teleportationPortalIndex = cb.teleportationPortalIndex;
tparams.discardEnlargedPixels = discardEnlargedPixels;
tparams.sourceBufferIndex = ReSTIRGI_GetTemporalInputPage();
tparams.initialSearchRadius = 0;
tparams.initialSearchRadius = initialSearchRadius;

bool isGBufferSimilar;
bool isInitialSample;
Expand Down
2 changes: 1 addition & 1 deletion submodules/rtxdi

0 comments on commit 3afe623

Please sign in to comment.