Skip to content

Commit 946b285

Browse files
committed
Provisional gfor() iterator, calling gforGet() in binary operators and math ops
Use in rainfall sample (but need array proxies to finish) Stub for index objects Sprinkling of notes about how to add array proxies, tentative idea being to integrate them as arrays with different GetHandle() and SetHandle() behavior, and various operations disabled
1 parent 8f2979e commit 946b285

File tree

7 files changed

+113
-11
lines changed

7 files changed

+113
-11
lines changed

scripts/lib/af_lib.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ for _, v in ipairs{
1717
"funcs.statistics",
1818
"funcs.util",
1919
"funcs.vector",
20+
"funcs.gfor",
2021
"graphics.window",
2122
"methods.constructors",
2223
"methods.device",

scripts/lib/funcs/gfor.lua

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
--- gfor() mechanism.
2+
3+
-- Standard library imports --
4+
local assert = assert
5+
6+
-- Modules --
7+
local array = require("lib.impl.array")
8+
9+
-- Imports --
10+
local GetLib = array.GetLib
11+
12+
-- Exports --
13+
local M = {}
14+
15+
--- See also: https://github.com/arrayfire/arrayfire/blob/devel/include/af/gfor.h
16+
-- https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/gfor.cpp
17+
18+
-- --
19+
local Status = false
20+
21+
--
22+
local function AuxGfor (seq)
23+
Status = not Status
24+
25+
if Status then
26+
return seq
27+
end
28+
end
29+
30+
--
31+
function M.Add (into)
32+
for k, v in pairs{
33+
--
34+
batchFunc = function(lhs, rhs, func)
35+
assert(not Status, "batchFunc can not be used inside GFOR") -- TODO: AF_ERR_ARG
36+
37+
Status = true
38+
39+
local res = func(lhs, rhs)
40+
41+
Status = false
42+
43+
return res
44+
end,
45+
46+
--
47+
gfor = function(...)
48+
local lib = GetLib()
49+
50+
return AuxGfor, lib.seq(lib.seq(...), true)
51+
end,
52+
53+
--
54+
gforGet = function()
55+
return Status
56+
end,
57+
58+
--
59+
gforSet = function(val)
60+
Status = not not val
61+
end
62+
} do
63+
into[k] = v
64+
end
65+
end
66+
67+
-- Export the module.
68+
return M

scripts/lib/funcs/mathematics.lua

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ local array = require("lib.impl.array")
66

77
-- Imports --
88
local CallWrap = array.CallWrap
9+
local GetLib = array.GetLib
910
local IsArray = array.IsArray
1011
local TwoArrays = array.TwoArrays
1112

@@ -17,7 +18,7 @@ local M = {}
1718
--
1819
local function Binary (name)
1920
return function(a, b)
20-
return TwoArrays(name, a, b--[[TODO: IsArray(a) and IsArray(b) and gfor_get]])
21+
return TwoArrays(name, a, b, IsArray(a) and IsArray(b) and GetLib().gforGet())
2122
end
2223
end
2324

scripts/lib/impl/array.lua

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ end
142142
-- @bool remove
143143
-- @treturn ?|af_array|nil X
144144
function M.GetHandle (arr, remove)
145+
-- TODO: If proxy, add reference?
145146
local ha = arr.m_handle
146147

147148
if remove then
@@ -191,6 +192,7 @@ end
191192
-- @tparam LuaArray arr
192193
-- @tparam ?|af_array|nil handle
193194
function M.SetHandle (arr, handle)
195+
-- TODO: disable for proxies
194196
local cur = arr.m_handle
195197

196198
if cur ~= nil then
@@ -298,6 +300,7 @@ for _, v in ipairs{
298300
"lib.impl.ephemeral",
299301
"lib.impl.operators",
300302
"lib.impl.seq",
303+
"lib.impl.index", -- depends on seq
301304
"lib.methods.methods"
302305
} do
303306
require(v).Add(M, ArrayMethodsAndMetatable)

scripts/lib/impl/index.lua

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
--- Core index module.
2+
3+
-- Modules --
4+
local af = require("arrayfire")
5+
6+
-- Forward declarations --
7+
8+
-- Exports --
9+
local M = {}
10+
11+
--
12+
function M.Add (array_module)
13+
-- Import these here since the array module is not yet registered.
14+
15+
16+
end
17+
18+
-- Export the module.
19+
return M

scripts/lib/impl/operators.lua

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
local af = require("arrayfire")
55

66
-- Forward declarations --
7+
local GetLib
78
local TwoArrays
89

910
-- Exports --
@@ -17,9 +18,10 @@ local function Binary (name, cmp)
1718
name = "af_" .. name
1819

1920
return function(a, b)
21+
-- TODO: disable for proxies?
2022
Result = nil
2123

22-
local arr = TwoArrays(name, a, b, true) -- TODO: gforGet()
24+
local arr = TwoArrays(name, a, b, GetLib().gforGet())
2325

2426
if cmp then
2527
Result = arr
@@ -32,6 +34,7 @@ end
3234
--
3335
function M.Add (array_module, meta)
3436
-- Import these here since the array module is not yet registered.
37+
GetLib = array_module.GetLib
3538
TwoArrays = array_module.TwoArrays
3639

3740
--
@@ -55,6 +58,15 @@ function M.Add (array_module, meta)
5558
__le = Binary("le", true),
5659
__mod = Binary("mod"),
5760
__mul = Binary("mul"),
61+
--[[
62+
__newindex = function(arr, k, v)
63+
-- TODO: disable for non-proxies?
64+
65+
if k == "_" then
66+
-- lvalue assign of v
67+
end
68+
end
69+
]]
5870
__pow = Binary("pow"),
5971
__sub = Binary("sub"),
6072
__unm = function(a)

scripts/tests/getting_started/rainfall.lua

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,20 @@ AF.main(function()
3636
local site = AF.array(n,site_)
3737
local measurement = AF.array(n,measurement_)
3838
local rainfall = AF.constant(0, sites)
39-
--[[
40-
gfor (seq s, sites) {
41-
rainfall(s) = sum(measurement * (site == s));
42-
}
43-
]]
39+
40+
for s in AF.gfor(sites) do
41+
-- rainfall(s) = AF.sum(measurement * COMP(site == AF.array(s)))
42+
end
4443
print("total rainfall at each site:")
4544
AF.print("rainfall", rainfall)
4645
local is_between = AF["and"](Comp(WC(1) <= day), Comp(day <= WC(5))) -- days 1 and 5
4746
local rain_between = AF.sum("f32", measurement * is_between)
4847
AF.printf("rain between days: %g", rain_between)
4948
AF.printf("number of days with rain: %g", AF.sum("f32", Comp(AF.diff1(day) > WC(0))) + 1)
5049
local per_day = AF.constant(0, days)
51-
--[[
52-
gfor (seq d, days)
53-
per_day(d) = sum(measurement * (day == d))
54-
]]
50+
for d in AF.gfor(days) do
51+
-- per_day(d) = AF.sum(measurement * COMP(day == AF.array(d)))
52+
end
5553
print("total rainfall each day:")
5654
AF.print("per_day", per_day)
5755
AF.printf("number of days over five: %g", AF.sum("f32", Comp(per_day > WC(5))))

0 commit comments

Comments
 (0)