require 'functional'
functional.add_to_env(getfenv())
function do_assert(condition, message, depth)
if not condition then
error(message, depth)
end
end
function assert_equals(expected, actual, depth)
depth = depth or 3
do_assert(expected == actual, "Values do not match. Expected: " .. tostring(expected)
.. " Actual: " .. tostring(actual), depth)
end
function assert_array_equals(expected, actual)
assert_equals("table", type(expected), 4)
assert_equals("table", type(actual), 4)
local expected_arr = {}
for _, v in ipairs(expected) do table.insert(expected_arr, v) end
local actual_arr = {}
for _, v in ipairs(actual) do table.insert(actual_arr, v) end
assert_table_equals(expected_arr, actual_arr)
end
function assert_table_equals(expected, actual)
assert_equals("table", type(expected), 4)
assert_equals("table", type(actual), 4)
for k, v in pairs(expected) do
do_assert(v == actual[k], "Values do not match. Key: " .. k .. " Expected: "
.. tostring(v) .. " Actual: " .. tostring(actual[k]))
end
for k, v in pairs(actual) do
do_assert(v == expected[k], "Values do not match. Key: " .. k .. " Expected: "
.. tostring(v) .. " Actual: " .. tostring(expected[k]))
end
end
function test_nipairs()
local a = { 1, 2, nil, 4, a = 25 }
local expected = { 1, 2, nil, 4 }
local visited = 0
for k, v in nipairs(a) do
assert_equals(expected[k], v)
visited = visited + 1
end
end
function test_map()
mapped = map(function(x) return x + 1 end, { 1, 2, 3, n = 4 })
assert_array_equals({ 2, 3, 4 }, mapped)
assert_equals(5, mapped.n)
end
function test_map_empty()
assert_array_equals({}, map(function(x) return x + 1 end, {}))
end
function test_reduce()
assert_equals(16, reduce(function(t, v) return t + v end, 10, { 1, 2, 3 }))
end
function test_reduce_empty()
assert_equals(5, reduce(function(t, v) return t + v end, 5, {}))
end
function test_reduce_with_nipairs()
local a = { 1, nil, 5 }
local addButResetForNil = function(t, v)
return v and t + v or 0
end
assert_equals(5, reduce(addButResetForNil, 0, wrap_iter(nipairs(a))))
end
function test_sort_by()
assert_array_equals({ 1, 2, 3}, sort_by(function(x) return x end, { 1, 3, 2 }))
end
function test_sort_by_empty()
assert_array_equals({}, sort_by(function(x) return x end, {}))
end
function test_partial()
local func = function(x, y, z) return { x, y, z } end
local partial = partial(func, PLACEHOLDER, 'b', PLACEHOLDER)
assert_equals("function", type(partial))
assert_equals("function", type(partial('a')))
assert_array_equals({ 'a', 'b', 'c' }, partial('a')('c'))
assert_array_equals({ 'a', 'b', 'c' }, partial('a', 'c'))
end
function test_partial_with_nil()
local func = function(x, y) return { x, y } end
local partial = partial(func, nil, PLACEHOLDER)
assert_equals("function", type(partial))
assert_table_equals({ nil, 'a' }, partial('a'))
end
function test_curry()
local func = function(x, y, z) return { x, y, z } end
local curried = curry(func, 3, 'a')
assert_equals("function", type(curried))
assert_equals("function", type(curried('b')))
assert_array_equals({ 'a', 'b', 'c' }, curried('b')('c'))
assert_array_equals({ 'a', 'b', 'c' }, curried('b', 'c'))
end
function test_curry_with_zero_arity()
assert_equals(5, curry(function() return 5 end, 0))
end
function test_filter_none()
assert_array_equals({}, filter(function() return false end, { 1, 2, 3, 4 }))
end
function test_filter_all()
assert_array_equals({ 1, 2, 3, 4 }, filter(function() return true end, { 1, 2, 3, 4 }))
end
function test_filter_even()
local even = function(x) return x % 2 == 0 end
assert_array_equals({ 2, 4 }, filter(even, { 1, 2, 3, 4 }))
end
function test_copy()
local array = { 'a', 'b', 'c' }
local copy = copy(array)
assert(array ~= copy, "array and it's copy should not be the same object")
assert_array_equals(array, copy)
end
function test_wrap_iter_example()
local array = { 'a', 'b', 'c', nil, 'e' }
local sequence = wrap_iter(nipairs(array))
local concatenate = function(t, v) return v and t .. v or t end
local result = reduce(concatenate, "", sequence)
assert_equals("abce", result)
end
-- Run the tests
for test_name, test in pairs(getfenv()) do
if string.match(test_name, "test_.*") then
print(test_name)
test()
end
end