Skip to content

Commit

Permalink
Add init/read timing for Julia
Browse files Browse the repository at this point in the history
  • Loading branch information
tom91136 committed Oct 7, 2023
1 parent e7774c1 commit 3cb01e7
Showing 1 changed file with 50 additions and 19 deletions.
69 changes: 50 additions & 19 deletions src/julia/JuliaStream.jl/src/Stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ end

@enum Benchmark All Triad Nstream


function run_init_arrays!(data::StreamData{T,C}, context, init::Tuple{T,T,T})::Float64 where {T,C}
return @elapsed init_arrays!(data, context, init)
end

function run_read_data(data::StreamData{T,C}, context)::Tuple{Float64,VectorData{T}} where {T,C}
elapsed = @elapsed begin
result = read_data(data, context)
end
return (elapsed, result)
end

function run_all!(data::StreamData{T,C}, context, times::Int)::Tuple{Timings,T} where {T,C}
timings = Timings(times)
lastSum::T = 0
Expand All @@ -39,11 +51,7 @@ function run_triad!(data::StreamData{T,C}, context, times::Int)::Float64 where {
end
end

function run_nstream!(
data::StreamData{T,C},
context,
times::Int,
)::Vector{Float64} where {T,C}
function run_nstream!(data::StreamData{T,C}, context, times::Int)::Vector{Float64} where {T,C}
timings::Vector{Float64} = zeros(times)
for i = 1:times
@inbounds timings[i] = @elapsed nstream!(data, context)
Expand Down Expand Up @@ -93,9 +101,7 @@ function check_solutions(
error = abs((dot - gold_sum) / gold_sum)
failed = error > 1.0e-8
if failed
println(
"Validation failed on sum. Error $error \nSum was $dot but should be $gold_sum",
)
println("Validation failed on sum. Error $error \nSum was $dot but should be $gold_sum")
end
!failed
end : true
Expand Down Expand Up @@ -166,7 +172,7 @@ function main()
parse_options(config)

if config.list
for (i, (_,repr, impl)) in enumerate(devices())
for (i, (_, repr, impl)) in enumerate(devices())
println("[$i] ($impl) $repr")
end
exit(0)
Expand All @@ -175,9 +181,7 @@ function main()
ds = devices()
# TODO implement substring device match
if config.device < 1 || config.device > length(ds)
error(
"Device $(config.device) out of range (1..$(length(ds))), NOTE: Julia is 1-indexed",
)
error("Device $(config.device) out of range (1..$(length(ds))), NOTE: Julia is 1-indexed")
else
device = ds[config.device]
end
Expand Down Expand Up @@ -257,16 +261,42 @@ function main()
end
end

function show_init(init::Float64, read::Float64)
setup = [("Init", init, 3 * array_bytes), ("Read", read, 3 * array_bytes)]
if config.csv
tabulate(
map(
x -> [
("phase", x[1]),
("n_elements", config.arraysize),
("sizeof", x[3]),
("max_m$(config.mibibytes ? "i" : "")bytes_per_sec", mega_scale * total_bytes / x[2]),
("runtime", x[2]),
],
setup,
)...,
)
else
for (name, elapsed, total_bytes) in setup
println(
"$name: $(round(elapsed; digits=5)) s (=$(round(( mega_scale * total_bytes) / elapsed; digits = 5)) M$(config.mibibytes ? "i" : "")Bytes/sec)",
)
end
end
end

init::Tuple{type,type,type} = DefaultInit
scalar::type = DefaultScalar

GC.enable(false)

(data, context) = make_stream(config.arraysize, scalar, device, config.csv)
init_arrays!(data, context, init)
tInit = run_init_arrays!(data, context, init)
if benchmark == All
(timings, sum) = run_all!(data, context, config.numtimes)
valid = check_solutions(read_data(data, context), config.numtimes, init, benchmark, sum)
(tRead, result) = run_read_data(data, context)
show_init(tInit, tRead)
valid = check_solutions(result, config.numtimes, init, benchmark, sum)
tabulate(
mk_row(timings.copy, "Copy", 2 * array_bytes),
mk_row(timings.mul, "Mul", 2 * array_bytes),
Expand All @@ -276,21 +306,22 @@ function main()
)
elseif benchmark == Nstream
timings = run_nstream!(data, context, config.numtimes)
valid =
check_solutions(read_data(data, context), config.numtimes, init, benchmark, nothing)
(tRead, result) = run_read_data(data, context)
show_init(tInit, tRead)
valid = check_solutions(result, config.numtimes, init, benchmark, nothing)
tabulate(mk_row(timings, "Nstream", 4 * array_bytes))
elseif benchmark == Triad
elapsed = run_triad!(data, context, config.numtimes)
valid =
check_solutions(read_data(data, context), config.numtimes, init, benchmark, nothing)
(tRead, result) = run_read_data(data, context)
show_init(tInit, tRead)
valid = check_solutions(result, config.numtimes, init, benchmark, nothing)
total_bytes = 3 * array_bytes * config.numtimes
bandwidth = mega_scale * (total_bytes / elapsed)
println("Runtime (seconds): $(round(elapsed; digits=5))")
println("Bandwidth ($giga_suffix/s): $(round(bandwidth; digits=3)) ")
else
error("Bad benchmark $(benchmark)")
end

GC.enable(true)

if !valid
Expand Down

0 comments on commit 3cb01e7

Please sign in to comment.