In [1]:
from memory.unsafe import DTypePointer
from memory import memset_zero

In [2]:
struct Field[DT: DType, VW: Int]:
    var data: DTypePointer[DT]
    var gx: Int
    var gy: Int
    var gz: Int
    var dsize: Int
    
    fn __init__(inout self, gx: Int, gy: Int, gz: Int):
        self.dsize = (gx+2)*(gy+2)*(gz+2)
        self.data = DTypePointer[DT].alloc(self.dsize)
        #memset_zero(self.data, self.dsize)
        self.gx = gx
        self.gy = gy
        self.gz = gz

    fn __copyinit__(inout self, other: Self):
        self.gx = other.gx
        self.gy = other.gy
        self.gz = other.gz
        self.data = DTypePointer[DT].alloc(other.dsize)
        self.dsize = other.dsize
        for i in range(0,self.dsize,VW):
            self.data.simd_store[VW](i, other.data.load(i))

    fn __add__(self, rhs: Field[DT,VW]) -> Field[DT,VW]:
        # test if the fields have the same sizes
        let result = Field[DT,VW](self.gx, self.gy, self.gz)
        # adding the BC (i.e. the whole data array) is not needed but is simpler
        for i in range(0,self.dsize,VW):
            result.data.simd_store[VW](i, self.data.simd_load[VW](i) + rhs.data.simd_load[VW](i))
        return result

    fn __del__(owned self):
        self.data.free()
        print("done del")
        
    fn zero(inout self):
        memset_zero[DT](self.data, self.dsize)

    @always_inline
    fn __getitem__(self, x: Int, y: Int, z: Int) -> SIMD[DT,1]:
        return self.data.load(z * (self.gx+2) * (self.gy+2) + y * (self.gx+2) + x)

    @always_inline
    fn load[nelts:Int=VW](self, x: Int, y: Int, z: Int) -> SIMD[DT,nelts]:
        return self.data.simd_load[nelts](z * (self.gx+2) * (self.gy+2) + y * (self.gx+2) + x)

    @always_inline
    fn __setitem__(self, x: Int, y: Int, z: Int, val: SIMD[DT,1]):
        return self.data.store(z * (self.gx+2) * (self.gy+2) + y * (self.gx+2) + x, val)

    @always_inline
    fn store[nelts:Int=VW](self, x: Int, y: Int, z: Int, val: SIMD[DT, nelts]):
        self.data.simd_store(z * (self.gx+2) * (self.gy+2) + y * (self.gx+2) + x, val)

In [7]:

n=8
var a = Field[DType.float32,4](9,1,1)
var b = Field[DType.float32,4](9,1,1)
var c = Field[DType.float32,4](9,1,1)
for i in range(n):
    a[i,0,0] = 2
    b[i,0,0] = 3
    print(i, a[i,0,0], b[i,0,0])

c = a + b
for i in range(n):
    print(i, c[i,0,0])

0 2.0 3.0
1 2.0 3.0
2 2.0 3.0
3 2.0 3.0
4 2.0 3.0
5 2.0 3.0
6 2.0 3.0
7 2.0 3.0
done del
0 5.0
1 5.0
2 5.0
3 5.0
4 5.0
5 5.0
6 5.0
7 5.0
