Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples_tests
60 changes: 31 additions & 29 deletions include/nbl/builtin/hlsl/rwmc/resolve.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <nbl/builtin/hlsl/colorspace/encodeCIEXYZ.hlsl>
#include <nbl/builtin/hlsl/rwmc/ResolveParameters.hlsl>
#include <nbl/builtin/hlsl/concepts/accessors/loadable_image.hlsl>
#include <nbl/builtin/hlsl/colorspace.hlsl>
#include <nbl/builtin/hlsl/vector_utils/vector_traits.hlsl>

namespace nbl
{
Expand All @@ -19,23 +21,21 @@ namespace rwmc
// not the greatest syntax but works
#define NBL_CONCEPT_PARAM_0 (a,T)
#define NBL_CONCEPT_PARAM_1 (scalar,VectorScalarType)
#define NBL_CONCEPT_PARAM_2 (vec,vector<VectorScalarType, Dims>)
// start concept
NBL_CONCEPT_BEGIN(2)
// need to be defined AFTER the concept begins
#define a NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
#define scalar NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
#define vec NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
NBL_CONCEPT_END(
((NBL_CONCEPT_REQ_EXPR)((a.calcLuma(vec))))
((NBL_CONCEPT_REQ_EXPR)((a.calcLuma(vector<VectorScalarType, 3>(scalar, scalar, scalar)))))
);
#undef a
#undef vec
#undef scalar
#include <nbl/builtin/hlsl/concepts/__end.hlsl>

/* ResolveAccessor is required to:
* - satisfy `LoadableImage` concept requirements
* - implement function called `calcLuma` which calculates luma from a pixel value
* - implement function called `calcLuma` which calculates luma from a 3 component pixel value
*/

template<typename T, typename VectorScalarType, int32_t Dims>
Expand All @@ -50,9 +50,9 @@ struct ResolveAccessorAdaptor

RWTexture2DArray<float32_t4> cascade;

float32_t calcLuma(in float32_t3 col)
float32_t calcLuma(NBL_REF_ARG(float32_t3) col)
{
return hlsl::dot<float32_t3>(hlsl::transpose(colorspace::scRGBtoXYZ)[1], col);
return hlsl::dot<float32_t3>(colorspace::scRGB::ToXYZ()[1], col);
}

template<typename OutputScalarType, int32_t Dimension>
Expand All @@ -69,10 +69,11 @@ struct ResolveAccessorAdaptor
}
};

template<typename CascadeAccessor, typename OutputColorType> //NBL_PRIMARY_REQUIRES(ResolveAccessor<CascadeAccessor, typename CascadeAccessor::output_scalar_type, CascadeAccessor::image_dimension>)
template<typename CascadeAccessor, typename OutputColorTypeVec NBL_PRIMARY_REQUIRES(concepts::Vector<OutputColorTypeVec> && ResolveAccessor<CascadeAccessor, typename CascadeAccessor::output_scalar_type, CascadeAccessor::image_dimension>)
struct Resolver
{
using output_type = OutputColorType;
using output_type = OutputColorTypeVec;
using scalar_t = typename vector_traits<output_type>::scalar_type;

struct CascadeSample
{
Expand All @@ -91,13 +92,15 @@ struct Resolver

output_type operator()(NBL_REF_ARG(CascadeAccessor) acc, const int16_t2 coord)
{
float reciprocalBaseI = 1.f;
using scalar_t = typename vector_traits<output_type>::scalar_type;

scalar_t reciprocalBaseI = 1.f;
CascadeSample curr = __sampleCascade(acc, coord, 0u, reciprocalBaseI);

float32_t3 accumulation = float32_t3(0.0f, 0.0f, 0.0f);
float Emin = params.initialEmin;
output_type accumulation = output_type(0.0f, 0.0f, 0.0f);
scalar_t Emin = params.initialEmin;

float prevNormalizedCenterLuma, prevNormalizedNeighbourhoodAverageLuma;
scalar_t prevNormalizedCenterLuma, prevNormalizedNeighbourhoodAverageLuma;
for (int16_t i = 0u; i <= params.lastCascadeIndex; i++)
{
const bool notFirstCascade = i != 0;
Expand All @@ -110,13 +113,13 @@ struct Resolver
next = __sampleCascade(acc, coord, int16_t(i + 1), reciprocalBaseI);
}

float reliability = 1.f;
scalar_t reliability = 1.f;
// sample counting-based reliability estimation
if (params.reciprocalKappa <= 1.f)
{
float localReliability = curr.normalizedCenterLuma;
scalar_t localReliability = curr.normalizedCenterLuma;
// reliability in 3x3 pixel block (see robustness)
float globalReliability = curr.normalizedNeighbourhoodAverageLuma;
scalar_t globalReliability = curr.normalizedNeighbourhoodAverageLuma;
if (notFirstCascade)
{
localReliability += prevNormalizedCenterLuma;
Expand All @@ -130,11 +133,11 @@ struct Resolver
// check if above minimum sampling threshold (avg 9 sample occurences in 3x3 neighbourhood), then use per-pixel reliability (NOTE: tertiary op is in reverse)
reliability = globalReliability < params.reciprocalN ? globalReliability : localReliability;
{
const float accumLuma = acc.calcLuma(accumulation);
const scalar_t accumLuma = acc.calcLuma(accumulation);
if (accumLuma > Emin)
Emin = accumLuma;

const float colorReliability = Emin * reciprocalBaseI * params.colorReliabilityFactor;
const scalar_t colorReliability = Emin * reciprocalBaseI * params.colorReliabilityFactor;

reliability += colorReliability;
reliability *= params.NOverKappa;
Expand All @@ -156,19 +159,18 @@ struct Resolver

// pseudo private stuff:

CascadeSample __sampleCascade(NBL_REF_ARG(CascadeAccessor) acc, int16_t2 coord, uint16_t cascadeIndex, float reciprocalBaseI)
CascadeSample __sampleCascade(NBL_REF_ARG(CascadeAccessor) acc, int16_t2 coord, uint16_t cascadeIndex, scalar_t reciprocalBaseI)
{
typename CascadeAccessor::output_type tmp;
output_type neighbourhood[9];
neighbourhood[0] = acc.template get<float, 2>(coord + int16_t2(-1, -1), cascadeIndex);
neighbourhood[1] = acc.template get<float, 2>(coord + int16_t2(0, -1), cascadeIndex);
neighbourhood[2] = acc.template get<float, 2>(coord + int16_t2(1, -1), cascadeIndex);
neighbourhood[3] = acc.template get<float, 2>(coord + int16_t2(-1, 0), cascadeIndex);
neighbourhood[4] = acc.template get<float, 2>(coord + int16_t2(0, 0), cascadeIndex);
neighbourhood[5] = acc.template get<float, 2>(coord + int16_t2(1, 0), cascadeIndex);
neighbourhood[6] = acc.template get<float, 2>(coord + int16_t2(-1, 1), cascadeIndex);
neighbourhood[7] = acc.template get<float, 2>(coord + int16_t2(0, 1), cascadeIndex);
neighbourhood[8] = acc.template get<float, 2>(coord + int16_t2(1, 1), cascadeIndex);
neighbourhood[0] = acc.template get<scalar_t, 2>(coord + int16_t2(-1, -1), cascadeIndex).xyz;
neighbourhood[1] = acc.template get<scalar_t, 2>(coord + int16_t2(0, -1), cascadeIndex).xyz;
neighbourhood[2] = acc.template get<scalar_t, 2>(coord + int16_t2(1, -1), cascadeIndex).xyz;
neighbourhood[3] = acc.template get<scalar_t, 2>(coord + int16_t2(-1, 0), cascadeIndex).xyz;
neighbourhood[4] = acc.template get<scalar_t, 2>(coord + int16_t2(0, 0), cascadeIndex).xyz;
neighbourhood[5] = acc.template get<scalar_t, 2>(coord + int16_t2(1, 0), cascadeIndex).xyz;
neighbourhood[6] = acc.template get<scalar_t, 2>(coord + int16_t2(-1, 1), cascadeIndex).xyz;
neighbourhood[7] = acc.template get<scalar_t, 2>(coord + int16_t2(0, 1), cascadeIndex).xyz;
neighbourhood[8] = acc.template get<scalar_t, 2>(coord + int16_t2(1, 1), cascadeIndex).xyz;

// numerical robustness
float32_t3 excl_hood_sum = ((neighbourhood[0] + neighbourhood[1]) + (neighbourhood[2] + neighbourhood[3])) +
Expand Down
Loading