Skip to content

Commit

Permalink
consistent names for constants
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Jul 27, 2014
1 parent 5573440 commit 3ca9922
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
10 changes: 7 additions & 3 deletions src/TensorOperations.jl
Expand Up @@ -14,10 +14,14 @@ end

# Constants that define base case in recursive algorithms
#---------------------------------------------------------
const OBASELENGTH=16 # square root of total size of all open dimensions
const CBASELENGTH=24 # total size of all contraction dimensions
const PERMUTEBASELENGTH=1024
# for tensorcopy, tensoradd and tensortrace
const TBASELENGTH=512
# note: total number elements involved = 2*512 = 1024

# for tensorcontract
const OBASELENGTH=16 # total size of all open dimensions in one of the two contraction partners
const CBASELENGTH=24 # total size of all contraction dimensions
# note: total number elements involved = 16*24*2+16*16 = 1024

# Tensor Operations
#-------------------
Expand Down
6 changes: 3 additions & 3 deletions src/tensoradd.jl
Expand Up @@ -54,13 +54,13 @@ end
Clinear = C
end

if length(C)<=4*PERMUTEBASELENGTH
if length(C)<=4*TBASELENGTH
@stridedloops(N, i, dims, indA, startA, stridesA, indC, startC, stridesC, @inbounds Clinear[indC]=beta*Clinear[indC]+alpha*Alinear[indA])
else
@nexprs N d->(minstrides_{d} = min(stridesA_{d},stridesC_{d}))

# build recursive stack
depth=iceil(log2(length(C)/PERMUTEBASELENGTH))+2 # 2 levels safety margin
depth=iceil(log2(length(C)/TBASELENGTH))+2 # 2 levels safety margin
level=1 # level of recursion
stackpos=zeros(Int,depth) # record pos of algorithm at the different recursion level
stackpos[level]=0
Expand Down Expand Up @@ -88,7 +88,7 @@ end
bstartA=stackbstartA[level]
bstartC=stackbstartC[level]

if blength<=PERMUTEBASELENGTH || level==depth # base case
if blength<=TBASELENGTH || level==depth # base case
@stridedloops(N, i, bdims, indA, bstartA, stridesA, indC, bstartC, stridesC, @inbounds Clinear[indC]=beta*Clinear[indC]+alpha*Alinear[indA])
level-=1
elseif pos==0
Expand Down
6 changes: 3 additions & 3 deletions src/tensorcopy.jl
Expand Up @@ -53,13 +53,13 @@ const PERMUTEGENERATE=[1,2,3,4,5,6,7,8]
Clinear = C
end

if length(C)<=4*PERMUTEBASELENGTH
if length(C)<=4*TBASELENGTH
@stridedloops(N, i, dims, indA, startA, stridesA, indC, startC, stridesC, @inbounds Clinear[indC]=Alinear[indA])
else
@nexprs N d->(minstrides_{d} = min(stridesA_{d},stridesC_{d}))

# build recursive stack
depth=iceil(log2(length(C)/PERMUTEBASELENGTH))+2 # 2 levels safety margin
depth=iceil(log2(length(C)/TBASELENGTH))+2 # 2 levels safety margin
level=1 # level of recursion
stackpos=zeros(Int,depth) # record position of algorithm at the different recursion level
stackpos[level]=0
Expand Down Expand Up @@ -87,7 +87,7 @@ const PERMUTEGENERATE=[1,2,3,4,5,6,7,8]
bstartA=stackbstartA[level]
bstartC=stackbstartC[level]

if blength<=PERMUTEBASELENGTH || level==depth # base case
if blength<=TBASELENGTH || level==depth # base case
@stridedloops(N, i, bdims, indA, bstartA, stridesA, indC, bstartC, stridesC, @inbounds Clinear[indC]=Alinear[indA])
level-=1
elseif pos==0
Expand Down
6 changes: 3 additions & 3 deletions src/tensortrace.jl
Expand Up @@ -100,13 +100,13 @@ const TRACEGENERATE={(2,0),(3,1),(4,2),(4,0),(5,3),(5,1),(6,4),(6,2),(6,0)}
Clinear = C
end

if olength*(clength+1)<=8*PERMUTEBASELENGTH
if olength*(clength+1)<=8*TBASELENGTH
@gentracekernel(div(NA-NC,2),NC,order,alpha,Alinear,beta,Clinear,startA,startC,odims,cdims,ostridesA,cstridesA,ostridesC)
else
@nexprs NC d->(minostrides_{d} = min(ostridesA_{d},ostridesC_{d}))

# build recursive stack
depth=iceil(log2(olength*(clength+1)/2/PERMUTEBASELENGTH))+2 # 2 levels safety margin
depth=iceil(log2(olength*(clength+1)/2/TBASELENGTH))+2 # 2 levels safety margin
level=1 # level of recursion
stackpos=zeros(Int,depth) # record position of algorithm at the different recursion level
stackpos[level]=0
Expand Down Expand Up @@ -146,7 +146,7 @@ const TRACEGENERATE={(2,0),(3,1),(4,2),(4,0),(5,3),(5,1),(6,4),(6,2),(6,0)}
bstartC=stackbstartC[level]
gamma=stackgamma[level]

if oblength*(cblength+1)<=2*PERMUTEBASELENGTH || level==depth # base case
if oblength*(cblength+1)<=2*TBASELENGTH || level==depth # base case
@gentracekernel(div(NA-NC,2),NC,order,alpha,Alinear,gamma,Clinear,bstartA,bstartC,obdims,cbdims,ostridesA,cstridesA,ostridesC)
level-=1
elseif pos==0
Expand Down

0 comments on commit 3ca9922

Please sign in to comment.