Permalink
Find file
661a57e Dec 7, 2015
@stepelu @ViralBShah @jiahao @jakebolewski
197 lines (172 sloc) 4.62 KB
if jit.arch ~= 'x64' then
print('WARNING: please use BIT=64 for optimal OpenBLAS performance')
end
local ffi = require 'ffi'
local bit = require 'bit'
local time = require 'time'
local alg = require 'sci.alg'
local prng = require 'sci.prng'
local stat = require 'sci.stat'
local dist = require 'sci.dist'
local complex = require 'sci.complex'
local min, sqrt, random, abs = math.min, math.sqrt, math.random, math.abs
local cabs = complex.abs
local rshift = bit.rshift
local format = string.format
local nowutc = time.nowutc
local rng = prng.std()
local vec, mat, join = alg.vec, alg.mat, alg.join
local sum, trace = alg.sum, alg.trace
local var, mean = stat.var, stat.mean
--------------------------------------------------------------------------------
local function elapsed(f)
local t0 = nowutc()
local val1, val2 = f()
local t1 = nowutc()
return (t1 - t0):tomilliseconds(), val1, val2
end
local function timeit(f, name, check)
local t, k, s = 1/0, 0, nowutc()
while true do
k = k + 1
local tx, val1, val2 = elapsed(f)
t = min(t, tx)
if check then
check(val1, val2)
end
if k > 5 and (nowutc() - s):toseconds() >= 2 then break end
end
io.write(format('lua,%s,%g\n', name, t))
end
--------------------------------------------------------------------------------
local function fib(n)
if n < 2 then
return n
else
return fib(n-1) + fib(n-2)
end
end
timeit(function() return fib(20) end, 'fib', function(x) assert(x == 6765) end)
local function parseint()
local lmt = 2^32 - 1
local n, m
for i = 1, 1000 do
n = random(lmt) -- Between 0 and 2^32 - 1, i.e. uint32_t.
local s = format('0x%x', tonumber(n))
m = tonumber(s)
end
assert(n == m) -- Done here to be even with Julia benchmark.
return n, m
end
timeit(parseint, 'parse_int')
local function mandel(z)
local c = z
local maxiter = 80
for n = 1, maxiter do
if cabs(z) > 2 then
return n-1
end
z = z*z + c
end
return maxiter
end
local function mandelperf()
local a = mat(26, 21)
for r=1,26 do -- Lua's for i=l,u,c doesn't match Julia's for i=l:c:u.
for c=1,21 do
local re, im = (r - 21)*0.1, (c - 11)*0.1
a[{r, c}] = mandel(re + im*1i)
end
end
return a
end
timeit(mandelperf, 'mandel', function(a) assert(sum(a) == 14791) end)
local function qsort(a, lo, hi)
local i, j = lo, hi
while i < hi do
local pivot = a[rshift(lo+hi, 1)]
while i <= j do
while a[i] < pivot do i = i+1 end
while a[j] > pivot do j = j-1 end
if i <= j then
a[i], a[j] = a[j], a[i]
i, j = i+1, j-1
end
end
if lo < j then qsort(a, lo, j) end
lo, j = i, hi
end
return a
end
local function sortperf()
local n = 5000
local v = ffi.new('double[?]', n+1)
for i=1,n do
v[i] = rng:sample()
end
return qsort(v, 1, n)
end
timeit(sortperf, 'quicksort', function(x)
for i=2,5000 do
assert(x[i-1] <= x[i])
end
end
)
local function pisum()
local s
for j = 1, 500 do
s = 0
for k = 1, 10000 do
s = s + 1 / (k*k)
end
end
return s
end
timeit(pisum, 'pi_sum', function(x)
assert(abs(x - 1.644834071848065) < 1e-12)
end)
local function rand(r, c)
local x = mat(r, c)
for i=1,#x do
x[i] = rng:sample()
end
return x
end
local function randn(r, c)
local x = mat(r, c)
for i=1,#x do
x[i] = dist.normal(0, 1):sample(rng)
end
return x
end
local function randmatstat(t)
local n = 5
local v, w = vec(t), vec(t)
for i=1,t do
local a, b, c, d = randn(n, n), randn(n, n), randn(n, n), randn(n, n)
local P = join(a..b..c..d)
local Q = join(a..b, c..d)
v[i] = trace((P[]`**P[])^^4)
w[i] = trace((Q[]`**Q[])^^4)
end
return sqrt(var(v))/mean(v), sqrt(var(w))/mean(w)
end
timeit(function() return randmatstat(1000) end, 'rand_mat_stat',
function(s1, s2)
assert( 0.5 < s1 and s1 < 1.0 and 0.5 < s2 and s2 < 1.0 )
end)
local function randmatmult(n)
local a, b = rand(n, n), rand(n, n)
return a[]**b[]
end
timeit(function() return randmatmult(1000) end, 'rand_mat_mul')
if jit.os ~= 'Windows' then
local function printfd(n)
local f = io.open('/dev/null','w')
for i = 1, n do
f:write(format('%d %d\n', i, i+1))
end
f:close()
end
timeit(function() return printfd(100000) end, 'printfd')
end