Skip to content

Commit

Permalink
Replace function pointers with branching
Browse files Browse the repository at this point in the history
  • Loading branch information
KYovchevski committed Jul 15, 2022
1 parent e3c2e89 commit d364550
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 70 deletions.
4 changes: 2 additions & 2 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ fn compile_bindings() {
Config::new()
.file("src/ispc/kernels/lanczos3.ispc")
.opt_level(2)

.woff()
.target_isas(vec![
TargetISA::SSE2i32x4,
Expand All @@ -21,8 +22,7 @@ fn compile_bindings() {
.bindgen_options(BindgenOptions {
allowlist_functions: vec![
"resample".into(),
"resample_with_cache_3".into(),
"resample_with_cache_4".into(),
"resample_with_cache".into(),
"calculate_weights".into(),
"calculate_weight_variables".into(),
],
Expand Down
15 changes: 2 additions & 13 deletions src/ispc/downsample_ispc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,12 @@ extern "C" {
);
}
extern "C" {
pub fn resample_with_cache_3(
src_width: u32,
src_height: u32,
target_width: u32,
target_height: u32,
cache: *const Cache,
scratch_space: *mut u8,
src_data: *const u8,
out_data: *mut u8,
);
}
extern "C" {
pub fn resample_with_cache_4(
pub fn resample_with_cache(
src_width: u32,
src_height: u32,
target_width: u32,
target_height: u32,
num_channels: u8,
cache: *const Cache,
scratch_space: *mut u8,
src_data: *const u8,
Expand Down
55 changes: 23 additions & 32 deletions src/ispc/kernels/lanczos3.ispc
Original file line number Diff line number Diff line change
Expand Up @@ -222,22 +222,9 @@ void clean_and_write_4_channels(varying float<4> color, varying uint64 write_add
}

/// scratch_space must be at least src_height * target_width pixels big
void resample_with_cache(uniform uint src_width, uniform uint src_height, uniform uint target_width, uniform uint target_height, uniform uint8 num_channels,
export void resample_with_cache(uniform uint src_width, uniform uint src_height, uniform uint target_width, uniform uint target_height, uniform uint8 num_channels,
uniform const Cache * uniform cache, uniform uint8 scratch_space[], uniform const uint8 src_data[], uniform uint8 out_data[]) {

// Quick way to swap between sampling for 3 channels and sampling for 4 channels.
// TODO[#Kamen]: This is slow because of the function pointers. Branching yields the same results, but is more difficult to maintain.
// Ideally, we should split this function into two versions depending channel count, and branch only once in Rust than twice per sample.
varying uint8<4> (*read_fn)(varying uint64, const uniform uint8*);
void (*write_fn)(varying float<4>, varying uint64, uniform uint8*);

if (num_channels == 3) {
read_fn = sample_3_channels;
write_fn = clean_and_write_3_channels;
} else {
read_fn = sample_4_channels;
write_fn = clean_and_write_4_channels;
}
// TODO[#Kamen]: Ideally, we should split this function into two versions depending channel count, and branch only once in Rust than twice per sample.

uniform WeightCollection * uniform vertical_weight_collection = &cache->vertical_weights;
uniform WeightCollection * uniform horizontal_weight_collection = &cache->horizontal_weights;
Expand All @@ -255,12 +242,20 @@ void resample_with_cache(uniform uint src_width, uniform uint src_height, unifor
uint src_x = src_width_start + i;
uint64 src_read_address = (y * src_width + src_x) * num_channels;

color += read_fn(src_read_address, src_data) * weight;
if (num_channels == 3) {
color += sample_3_channels(src_read_address, src_data) * weight;
} else {
color += sample_4_channels(src_read_address, src_data) * weight;
}
}

uint64 scratch_write_address = (y * target_width + x) * num_channels;

write_fn(color, scratch_write_address, scratch_space);
if (num_channels == 3) {
clean_and_write_3_channels(color, scratch_write_address, scratch_space);
} else {
clean_and_write_4_channels(color, scratch_write_address, scratch_space);
}
}

// Accumulate the scratch space data along the height
Expand All @@ -280,23 +275,19 @@ void resample_with_cache(uniform uint src_width, uniform uint src_height, unifor
uniform uint8<3>* varying scratch_pixel_ptr = (uniform uint8<3>* varying)(scratch_space + scratch_read_address);
uint8<3> scratch_color = *scratch_pixel_ptr;

color += read_fn(scratch_read_address, scratch_space) * weight;
if (num_channels == 3) {
color += sample_3_channels(scratch_read_address, scratch_space) * weight;
} else {
color += sample_4_channels(scratch_read_address, scratch_space) * weight;
}
}

uint64 out_write_address = (y * target_height + x) * num_channels;
write_fn(color, out_write_address, out_data);
}
}

export void resample_with_cache_3(uniform uint src_width, uniform uint src_height, uniform uint target_width, uniform uint target_height,
uniform const Cache * uniform cache, uniform uint8 scratch_space[], uniform const uint8 src_data[], uniform uint8 out_data[]) {

resample_with_cache(src_width, src_height, target_width, target_height, 3, cache, scratch_space, src_data, out_data);
}


export void resample_with_cache_4(uniform uint src_width, uniform uint src_height, uniform uint target_width, uniform uint target_height,
uniform const Cache * uniform cache, uniform uint8 scratch_space[], uniform const uint8 src_data[], uniform uint8 out_data[]) {

resample_with_cache(src_width, src_height, target_width, target_height, 4, cache, scratch_space, src_data, out_data);
if (num_channels == 3) {
clean_and_write_3_channels(color, out_write_address, out_data);
} else {
clean_and_write_4_channels(color, out_write_address, out_data);
}
}
}
48 changes: 25 additions & 23 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,29 +178,31 @@ pub fn downsample(src: &Image, target_width: u32, target_height: u32) -> Vec<u8>
horizontal_weights: height_weights.ispc_representation(),
};
unsafe {
if src.format.num_channels() == 3 {
ispc::downsample_ispc::resample_with_cache_3(
src.width,
src.height,
target_width,
target_height,
&weight_cache as *const _,
scratch_space.as_mut_ptr(),
src.pixels.as_ptr(),
output.as_mut_ptr(),
);
} else {
ispc::downsample_ispc::resample_with_cache_4(
src.width,
src.height,
target_width,
target_height,
&weight_cache as *const _,
scratch_space.as_mut_ptr(),
src.pixels.as_ptr(),
output.as_mut_ptr(),
);
}
ispc::downsample_ispc::resample_with_cache(
src.width,
src.height,
target_width,
target_height,
src.format.num_channels(),
&weight_cache as *const _,
scratch_space.as_mut_ptr(),
src.pixels.as_ptr(),
output.as_mut_ptr(),
);
// if src.format.num_channels() == 3 {

// } else {
// ispc::downsample_ispc::resample_with_cache_4(
// src.width,
// src.height,
// target_width,
// target_height,
// &weight_cache as *const _,
// scratch_space.as_mut_ptr(),
// src.pixels.as_ptr(),
// output.as_mut_ptr(),
// );
// }
}

output
Expand Down

0 comments on commit d364550

Please sign in to comment.