Permalink
| if jit.arch ~= 'x64' then | |
| print('WARNING: please use BIT=64 for optimal OpenBLAS performance') | |
| end | |
| local ffi = require 'ffi' | |
| local bit = require 'bit' | |
| local time = require 'time' | |
| local alg = require 'sci.alg' | |
| local prng = require 'sci.prng' | |
| local stat = require 'sci.stat' | |
| local dist = require 'sci.dist' | |
| local complex = require 'sci.complex' | |
| local min, sqrt, random, abs = math.min, math.sqrt, math.random, math.abs | |
| local cabs = complex.abs | |
| local rshift = bit.rshift | |
| local format = string.format | |
| local nowutc = time.nowutc | |
| local rng = prng.std() | |
| local vec, mat, join = alg.vec, alg.mat, alg.join | |
| local sum, trace = alg.sum, alg.trace | |
| local var, mean = stat.var, stat.mean | |
| -------------------------------------------------------------------------------- | |
| local function elapsed(f) | |
| local t0 = nowutc() | |
| local val1, val2 = f() | |
| local t1 = nowutc() | |
| return (t1 - t0):tomilliseconds(), val1, val2 | |
| end | |
| local function timeit(f, name, check) | |
| local t, k, s = 1/0, 0, nowutc() | |
| while true do | |
| k = k + 1 | |
| local tx, val1, val2 = elapsed(f) | |
| t = min(t, tx) | |
| if check then | |
| check(val1, val2) | |
| end | |
| if k > 5 and (nowutc() - s):toseconds() >= 2 then break end | |
| end | |
| io.write(format('lua,%s,%g\n', name, t)) | |
| end | |
| -------------------------------------------------------------------------------- | |
| local function fib(n) | |
| if n < 2 then | |
| return n | |
| else | |
| return fib(n-1) + fib(n-2) | |
| end | |
| end | |
| timeit(function() return fib(20) end, 'fib', function(x) assert(x == 6765) end) | |
| local function parseint() | |
| local lmt = 2^32 - 1 | |
| local n, m | |
| for i = 1, 1000 do | |
| n = random(lmt) -- Between 0 and 2^32 - 1, i.e. uint32_t. | |
| local s = format('0x%x', tonumber(n)) | |
| m = tonumber(s) | |
| end | |
| assert(n == m) -- Done here to be even with Julia benchmark. | |
| return n, m | |
| end | |
| timeit(parseint, 'parse_int') | |
| local function mandel(z) | |
| local c = z | |
| local maxiter = 80 | |
| for n = 1, maxiter do | |
| if cabs(z) > 2 then | |
| return n-1 | |
| end | |
| z = z*z + c | |
| end | |
| return maxiter | |
| end | |
| local function mandelperf() | |
| local a = mat(26, 21) | |
| for r=1,26 do -- Lua's for i=l,u,c doesn't match Julia's for i=l:c:u. | |
| for c=1,21 do | |
| local re, im = (r - 21)*0.1, (c - 11)*0.1 | |
| a[{r, c}] = mandel(re + im*1i) | |
| end | |
| end | |
| return a | |
| end | |
| timeit(mandelperf, 'mandel', function(a) assert(sum(a) == 14791) end) | |
| local function qsort(a, lo, hi) | |
| local i, j = lo, hi | |
| while i < hi do | |
| local pivot = a[rshift(lo+hi, 1)] | |
| while i <= j do | |
| while a[i] < pivot do i = i+1 end | |
| while a[j] > pivot do j = j-1 end | |
| if i <= j then | |
| a[i], a[j] = a[j], a[i] | |
| i, j = i+1, j-1 | |
| end | |
| end | |
| if lo < j then qsort(a, lo, j) end | |
| lo, j = i, hi | |
| end | |
| return a | |
| end | |
| local function sortperf() | |
| local n = 5000 | |
| local v = ffi.new('double[?]', n+1) | |
| for i=1,n do | |
| v[i] = rng:sample() | |
| end | |
| return qsort(v, 1, n) | |
| end | |
| timeit(sortperf, 'quicksort', function(x) | |
| for i=2,5000 do | |
| assert(x[i-1] <= x[i]) | |
| end | |
| end | |
| ) | |
| local function pisum() | |
| local s | |
| for j = 1, 500 do | |
| s = 0 | |
| for k = 1, 10000 do | |
| s = s + 1 / (k*k) | |
| end | |
| end | |
| return s | |
| end | |
| timeit(pisum, 'pi_sum', function(x) | |
| assert(abs(x - 1.644834071848065) < 1e-12) | |
| end) | |
| local function rand(r, c) | |
| local x = mat(r, c) | |
| for i=1,#x do | |
| x[i] = rng:sample() | |
| end | |
| return x | |
| end | |
| local function randn(r, c) | |
| local x = mat(r, c) | |
| for i=1,#x do | |
| x[i] = dist.normal(0, 1):sample(rng) | |
| end | |
| return x | |
| end | |
| local function randmatstat(t) | |
| local n = 5 | |
| local v, w = vec(t), vec(t) | |
| for i=1,t do | |
| local a, b, c, d = randn(n, n), randn(n, n), randn(n, n), randn(n, n) | |
| local P = join(a..b..c..d) | |
| local Q = join(a..b, c..d) | |
| v[i] = trace((P[]`**P[])^^4) | |
| w[i] = trace((Q[]`**Q[])^^4) | |
| end | |
| return sqrt(var(v))/mean(v), sqrt(var(w))/mean(w) | |
| end | |
| timeit(function() return randmatstat(1000) end, 'rand_mat_stat', | |
| function(s1, s2) | |
| assert( 0.5 < s1 and s1 < 1.0 and 0.5 < s2 and s2 < 1.0 ) | |
| end) | |
| local function randmatmult(n) | |
| local a, b = rand(n, n), rand(n, n) | |
| return a[]**b[] | |
| end | |
| timeit(function() return randmatmult(1000) end, 'rand_mat_mul') | |
| if jit.os ~= 'Windows' then | |
| local function printfd(n) | |
| local f = io.open('/dev/null','w') | |
| for i = 1, n do | |
| f:write(format('%d %d\n', i, i+1)) | |
| end | |
| f:close() | |
| end | |
| timeit(function() return printfd(100000) end, 'printfd') | |
| end |