Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 117 additions & 3 deletions docs/examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ begin
viewdir = normalize(ax.scene.camera.view_direction[])
end

hitpoints, centroid = RayCaster.get_centroid(bvh, viewdir)


begin
@time "hitpoints" hitpoints, centroid = RayCaster.get_centroid(bvh, viewdir)
@time "illum" illum = RayCaster.get_illumination(bvh, viewdir)
Expand All @@ -33,10 +36,11 @@ begin
f, ax, pl = mesh(world_mesh, color=:blue)
per_face_vf = FaceView((viewfacts), [GLTriangleFace(i) for i in 1:N])
viewfact_mesh = GeometryBasics.mesh(world_mesh, color=per_face_vf)
pl = Makie.mesh(f[1, 2],
pl = Makie.mesh(
f[1, 2],
viewfact_mesh, colormap=[:black, :red], axis=(; show_axis=false),
shading=false, highclip=:red, lowclip=:black, colorscale=sqrt,)

shading=false, highclip=:red, lowclip=:black, colorscale=sqrt,
)
# Centroid
cax, pl = Makie.mesh(f[2, 1], world_mesh, color=(:blue, 0.5), axis=(; show_axis=false), transparency=true)

Expand All @@ -58,3 +62,113 @@ begin

f
end


using KernelAbstractions, Atomix

function random_scatter_kernel!(bvh, triangle, u, v, normal)
point = RayCaster.random_triangle_point(triangle)
o = point .+ (normal .* 0.01f0) # Offset so it doesn't self intersect
dir = RayCaster.random_hemisphere_uniform(normal, u, v)
ray = RayCaster.Ray(; o=o, d=dir)
hit, prim, _ = RayCaster.intersect!(bvh, ray)
return hit, prim
end

import GeometryBasics as GB

@kernel function viewfact_ka_kernel!(result, bvh, primitives, rays_per_triangle)
idx = @index(Global)
prim_idx = ((UInt32(idx) - UInt32(1)) ÷ rays_per_triangle) + UInt32(1)
if prim_idx <= length(primitives)
triangle, u, v, normal = primitives[prim_idx]
hit, prim = random_scatter_kernel!(bvh, triangle, u, v, normal)
if hit && prim.material_idx !== triangle.material_idx
# weigh by angle?
Atomix.@atomic result[triangle.material_idx, prim.material_idx] += 1
end
end
end

function view_factors!(result, bvh, prim_info, rays_per_triangle=10000)

backend = get_backend(result)
workgroup = 256
total_rays = length(bvh.primitives) * rays_per_triangle
per_workgroup = total_rays ÷ workgroup
final_rays = per_workgroup * workgroup
per_triangle = final_rays ÷ length(bvh.primitives)

kernel = viewfact_ka_kernel!(backend, 256)
kernel(result, bvh, prim_info, UInt32(per_triangle); ndrange = final_rays)
return result
end

result = zeros(UInt32, length(bvh.primitives), length(bvh.primitives))
using AMDGPU
prim_info = map(bvh.primitives) do triangle
n = GB.orthogonal_vector(Vec3f, GB.Triangle(triangle.vertices...))
normal = normalize(Vec3f(n))
u, v = RayCaster.get_orthogonal_basis(normal)
return triangle, u, v, normal
end
bvh_gpu = RayCaster.to_gpu(ROCArray, bvh)
result_gpu = ROCArray(result)
prim_info_gpu = ROCArray(prim_info)
@time begin
view_factors!(result_gpu, bvh_gpu, prim_info_gpu, 10000)
KernelAbstractions.synchronize(get_backend(result_gpu))
end;



@kernel function viewfact_ka_kernel2!(result, bvh, primitives, rays_per_triangle)
idx = @index(Global)
prim_idx = ((UInt32(idx) - UInt32(1)) ÷ rays_per_triangle) + UInt32(1)
if prim_idx <= length(primitives)
triangle, u, v, normal = primitives[prim_idx]
hit, prim = random_scatter_kernel!(bvh, triangle, u, v, normal)
if hit && prim.material_idx !== triangle.material_idx
# weigh by angle?
@inbounds result[idx] = UInt32(1)
end
end
end


function view_factors2!(result, bvh, prim_info, per_triangle)
backend = get_backend(result)
kernel = viewfact_ka_kernel2!(backend, 256)
kernel(result, bvh, prim_info, UInt32(per_triangle); ndrange = length(result))
return result
end


using AMDGPU
workgroup = 256
rays_per_triangle = 10000
total_rays = length(bvh.primitives) * rays_per_triangle
per_workgroup = total_rays ÷ workgroup
final_rays = per_workgroup * workgroup
per_triangle = final_rays ÷ length(bvh.primitives)
result = zeros(UInt32, final_rays)

final_rays / 10^6

prim_info = map(bvh.primitives) do triangle
n = GB.orthogonal_vector(Vec3f, GB.Triangle(triangle.vertices...))
normal = normalize(Vec3f(n))
u, v = RayCaster.get_orthogonal_basis(normal)
return triangle, u, v, normal
end

bvh_gpu = RayCaster.to_gpu(ROCArray, bvh)
result_gpu = ROCArray(result)
prim_info_gpu = ROCArray(prim_info)
@time begin
view_factors2!(result_gpu, bvh_gpu, prim_info_gpu, per_triangle)
KernelAbstractions.synchronize(get_backend(result_gpu))
end;

@time view_factors2!(result, bvh, prim_info, per_triangle)
@code_warntype random_scatter_kernel!(bvh, prim_info[1]...)
187 changes: 119 additions & 68 deletions src/bvh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,100 +239,151 @@ end
length(bvh.nodes) > Int32(0) ? bvh.nodes[1].bounds : Bounds3()
end

@inline function intersect!(bvh::BVHAccel{P}, ray::AbstractRay) where {P}
hit = false
interaction = SurfaceInteraction()
"""
_traverse_bvh(bvh::BVHAccel{P}, ray::AbstractRay, hit_callback::F) where {P, F<:Function}

Internal function that traverses the BVH to find ray-primitive intersections.
Uses a callback pattern to handle different intersection behaviors.

Arguments:
- `bvh`: The BVH acceleration structure
- `ray`: The ray to test for intersections
- `hit_callback`: Function called when primitive is tested. Signature:
hit_callback(primitive, ray) -> (continue_traversal::Bool, ray::AbstractRay, results::Any)

Returns:
- The final result from the hit_callback
"""
@inline function traverse_bvh(hit_callback::F, bvh::BVHAccel{P}, ray::AbstractRay) where {P, F<:Function}
# Early return if BVH is empty
if length(bvh.nodes) == 0
return false, ray, nothing
end

# Prepare ray for traversal
ray = check_direction(ray)
inv_dir = 1f0 ./ ray.d
dir_is_neg = is_dir_negative(ray.d)

to_visit_offset, current_node_i = Int32(1), Int32(1)
# Initialize traversal stack
to_visit_offset = Int32(1)
current_node_idx = Int32(1)
nodes_to_visit = zeros(MVector{64,Int32})
primitives = bvh.primitives
@_inbounds primitive = primitives[1]
nodes = bvh.nodes

# State variables to hold callback results
continue_search = true
prim1 = primitives[1]
result = hit_callback(prim1, ray, nothing)

# Traverse BVH
@_inbounds while true
ln = nodes[current_node_i]
if intersect_p(ln.bounds, ray, inv_dir, dir_is_neg)
if !ln.is_interior && ln.n_primitives > Int32(0)
# Intersect ray with primitives in node.
for i in Int32(0):ln.n_primitives - Int32(1)
offset = ln.offset % Int32
tmp_primitive = primitives[offset+i]
tmp_hit, ray, tmp_interaction = intersect_p!(
tmp_primitive, ray,
)
if tmp_hit
hit = tmp_hit
interaction = tmp_interaction
primitive = tmp_primitive
current_node = nodes[current_node_idx]

# Test ray against current node's bounding box
if intersect_p(current_node.bounds, ray, inv_dir, dir_is_neg)
if !current_node.is_interior && current_node.n_primitives > Int32(0)
# Leaf node - test all primitives
offset = current_node.offset % Int32

for i in Int32(0):(current_node.n_primitives - Int32(1))
primitive = primitives[offset + i]

# Call the callback for this primitive
continue_search, ray, result = hit_callback(primitive, ray, result)

# Early exit if callback requests it
if !continue_search
return false, ray, result
end
end
to_visit_offset == Int32(1) && break

# Done with leaf, pop next node from stack
if to_visit_offset == Int32(1)
break
end
to_visit_offset -= Int32(1)
current_node_i = nodes_to_visit[to_visit_offset]
current_node_idx = nodes_to_visit[to_visit_offset]
else
if dir_is_neg[ln.split_axis] == Int32(2)
nodes_to_visit[to_visit_offset] = current_node_i + Int32(1)
current_node_i = ln.offset % Int32
# Interior node - push children to stack
if dir_is_neg[current_node.split_axis] == Int32(2)
nodes_to_visit[to_visit_offset] = current_node_idx + Int32(1)
current_node_idx = current_node.offset % Int32
else
nodes_to_visit[to_visit_offset] = ln.offset % Int32
current_node_i += Int32(1)
nodes_to_visit[to_visit_offset] = current_node.offset % Int32
current_node_idx += Int32(1)
end
to_visit_offset += Int32(1)
end
else
to_visit_offset == 1 && break
# Miss - pop next node from stack
if to_visit_offset == Int32(1)
break
end
to_visit_offset -= Int32(1)
current_node_i = nodes_to_visit[to_visit_offset]
current_node_idx = nodes_to_visit[to_visit_offset]
end
end
return hit, primitive, interaction

# Return final state
return continue_search, ray, result
end

@inline function intersect_p(bvh::BVHAccel, ray::AbstractRay)
# Initialization
closest_hit_callback(primitive, ray, ::Nothing) = (false, primitive, Point3f(0.0))

length(bvh.nodes) == Int32(0) && return false
function closest_hit_callback(primitive, ray, prev_result::Tuple{Bool, P, Point3f}) where {P}
# Test intersection and update if closer
tmp_hit, ray, tmp_bary = intersect_p!(primitive, ray)
# Always continue search to find closest
return true, ray, ifelse(tmp_hit, (true, primitive, tmp_bary), prev_result)
end

ray = check_direction(ray)
inv_dir = 1f0 ./ ray.d
dir_is_neg = is_dir_negative(ray.d)
"""
intersect!(bvh::BVHAccel{P}, ray::AbstractRay) where {P}

to_visit_offset, current_node_i = Int32(1), Int32(1)
nodes_to_visit = zeros(MVector{64,Int32})
primitives = bvh.primitives
@_inbounds while true
ln = bvh.nodes[current_node_i]
if intersect_p(ln.bounds, ray, inv_dir, dir_is_neg)
if !ln.is_interior && ln.n_primitives > Int32(0)
for i in Int32(0):ln.n_primitives-Int32(1)
offset = ln.offset % Int32
intersect_p(
primitives[offset + i], ray,
) && return true
end
to_visit_offset == 1 && break
to_visit_offset -= Int32(1)
current_node_i = nodes_to_visit[to_visit_offset]
else
if dir_is_neg[ln.split_axis] == Int32(2)
# @setindex 64 nodes_to_visit[to_visit_offset] = Int32(current_node_i + 1)
nodes_to_visit[to_visit_offset] = current_node_i + Int32(1)
current_node_i = ln.offset % Int32
else
# @setindex 64 nodes_to_visit[to_visit_offset] = Int32(ln.offset)
nodes_to_visit[to_visit_offset] = ln.offset % Int32
current_node_i += Int32(1)
end
to_visit_offset += Int32(1)
end
else
to_visit_offset == Int32(1) && break
to_visit_offset -= Int32(1)
current_node_i = Int32(nodes_to_visit[to_visit_offset])
end
Find the closest intersection between a ray and the primitives stored in a BVH.

Returns:
- `hit_found`: Boolean indicating if an intersection was found
- `hit_primitive`: The primitive that was hit (if any)
- `barycentric_coords`: Barycentric coordinates of the hit point
"""
@inline function intersect!(bvh::BVHAccel{P}, ray::AbstractRay) where {P}
# Traverse BVH with closest-hit callback
_, _, result = traverse_bvh(closest_hit_callback, bvh, ray)
return result::Tuple{Bool, Triangle, Point3f}
end


any_hit_callback(primitive, current_ray, result::Nothing) = ()

# Define any-hit callback
function any_hit_callback(primitive, current_ray, ::Tuple{})
# Test for intersection
if intersect_p(primitive, current_ray)
# Stop traversal on first hit
return false, current_ray, true
end
false
# Continue search if no hit
return true, current_ray, false
end

"""
intersect_p(bvh::BVHAccel, ray::AbstractRay)

Test if a ray intersects any primitive in the BVH (without finding the closest hit).

Returns:
- `hit_found`: Boolean indicating if any intersection was found
"""
@inline function intersect_p(bvh::BVHAccel, ray::AbstractRay)
# Traverse BVH with any-hit callback
continue_search, _, result = traverse_bvh(any_hit_callback, bvh, ray)
# If traversal completed without finding a hit, return false
# Otherwise return the hit result (true)
return !continue_search ? result : false
end

function calculate_ray_grid_bounds(bounds::GeometryBasics.Rect, ray_direction::Vec3f)
Expand Down
3 changes: 1 addition & 2 deletions src/kernel-abstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,5 @@ end
function to_gpu(ArrayType, bvh::RayCaster.BVHAccel; preserve=[])
primitives = to_gpu(ArrayType, bvh.primitives; preserve=preserve)
nodes = to_gpu(ArrayType, bvh.nodes; preserve=preserve)
materials = to_gpu(ArrayType, to_gpu.((ArrayType,), bvh.materials; preserve=preserve); preserve=preserve)
return RayCaster.BVHAccel(primitives, materials, bvh.max_node_primitives, nodes)
return RayCaster.BVHAccel(primitives, bvh.max_node_primitives, nodes)
end
7 changes: 4 additions & 3 deletions src/kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ function hits_from_grid(bvh, viewdir; grid_size=32)
Threads.@threads for idx in CartesianIndices(ray_origins)
o = ray_origins[idx]
ray = RayCaster.Ray(; o=o, d=ray_direction)
hit, prim, si = RayCaster.intersect!(bvh, ray)
@inbounds result[idx] = RayHit(hit, si.core.p, prim.material_idx)
hit, prim, bary = RayCaster.intersect!(bvh, ray)
hitpoint = sum_mul(bary, prim.vertices)
@inbounds result[idx] = RayHit(hit, hitpoint, prim.material_idx)
end
return result
end
Expand All @@ -34,7 +35,7 @@ function view_factors!(result, bvh, rays_per_triangle=10000)
point_on_triangle = random_triangle_point(triangle)
o = point_on_triangle .+ (normal .* 0.01f0) # Offset so it doesn't self intersect
ray = Ray(; o=o, d=random_hemisphere_uniform(normal, u, v))
hit, prim, si = intersect!(bvh, ray)
hit, prim, _ = intersect!(bvh, ray)
if hit && prim.material_idx != triangle.material_idx
# weigh by angle?
result[triangle.material_idx, prim.material_idx] += 1
Expand Down
Loading