Skip to content
This repository has been archived by the owner on Nov 8, 2022. It is now read-only.

Commit

Permalink
CGen support 2D vcat()
Browse files Browse the repository at this point in the history
  • Loading branch information
Ehsan Totoni committed Feb 9, 2017
1 parent 9bac479 commit 193081e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
19 changes: 14 additions & 5 deletions src/cgen-pattern-match.jl
Original file line number Diff line number Diff line change
Expand Up @@ -711,23 +711,32 @@ function from_assignment_match_vcat(lhs, rhs::Expr, linfo)
args = getCallArguments(rhs)
for a in args
atyp = getType(a, linfo)
@assert atyp<:Array && ndims(atyp)==1 "CGen only supports vcat of 1D arrays"
@assert atyp<:Array && (ndims(atyp)==1 || ndims(atyp)==2) "CGen only supports vcat of 1D and 2D arrays"
end
num_dims = ndims(getType(args[1], linfo))
typ = eltype(getType(args[1], linfo))
ctyp = toCtype(typ)
clhs = from_expr(lhs,linfo)
# get total size of array: size(a1)+size(a2)+...
csize = "("* mapfoldl(a->from_arraysize(a,1,linfo),(a,b)->"$a+$b",args) *")"
c_num_cols = "1"
if num_dims==2
c_num_cols = from_arraysize(args[1],2,linfo)
csize *= ", $c_num_cols"
end
s *= "{\n"
s *= "$clhs = j2c_array<$ctyp>::new_j2c_array_1d(NULL, $csize);\n"
s *= " int64_t __cgen_curr_ind = 0;\n"
s *= "$clhs = j2c_array<$ctyp>::new_j2c_array_$(num_dims)d(NULL, $csize);\n"
s *= "int64_t __cgen_curr_ind = 0;\n"
s *= "for(int64_t j=0; j<$c_num_cols; j++){\n"
for arr in args
carr = from_expr(arr, linfo)
s *= "for(int64_t i=0; i<$(from_arraysize(arr,1,linfo)); i++){\n"
s *= " $clhs.data[__cgen_curr_ind++] = $carr.data[i];\n"
col_size = from_arraysize(arr,1,linfo)
s *= "for(int64_t i=0; i<$col_size; i++){\n"
s *= " $clhs.data[__cgen_curr_ind++] = $carr.data[j*$col_size+i];\n"
s *= "}\n"
end
s *= "}\n"
s *= "}\n"
end
return s
end
Expand Down
8 changes: 7 additions & 1 deletion test/vcat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,24 @@ using ParallelAccelerator
#using CompilerTools
#CompilerTools.OptFramework.set_debug_level(4)

@acc function cat1(a::Array{Float64,1},b::Array{Float64,1})
@acc function cat1(a,b)
C = vcat(a,b)
return C
end

function test1()
return cat1([1.,2.,3.],[4.,5.,6.])
end

function test2()
return cat1([1. 2. 3.; 6. 7. 8.],[4. 5. 6.])
end

end

using Base.Test
println("Testing vcat...")
@test_approx_eq VCatTest.test1() [1.,2.,3.,4.,5.,6.]
@test_approx_eq VCatTest.test2() [1. 2. 3.; 6. 7. 8.; 4. 5. 6.]
println("Done testing vcat.")

0 comments on commit 193081e

Please sign in to comment.