Skip to content

Commit

Permalink
loop over types
Browse files Browse the repository at this point in the history
  • Loading branch information
lpawela committed Aug 25, 2019
1 parent a95c835 commit 235400f
Showing 1 changed file with 9 additions and 20 deletions.
29 changes: 9 additions & 20 deletions curandommatrices/src/circular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ end
function _qr_fix!(z::CuMatrix)
q, r = CuArrays.qr!(z)
ph = diag(r)
len = min(length(ph), 1024)
len = min(length(ph), 1024) #hack, warpsize() segfaults
@cuda threads=len blocks=16 cplx_phase!(ph)
q = CuMatrix(q)
idim = size(r, 1)
for i=1:idim
for i = 1:idim
q[:, i] .*= ph[i]
end
q[:, 1:idim]
Expand All @@ -33,29 +33,18 @@ function curand(c::COE)
transpose(u)*u
end

function curand(c::CUE)
z = curand(c.g)
_qr_fix!(z)
end

function curand(c::CSE)
z = curand(c.g)
u = _qr_fix!(z)
ur = cat([CuMatrix{Float32}([0 -1; 1 0]) for _=1:c.d÷2]..., dims=[1,2])
ur*u*ur'*transpose(u)
end

function curand(c::CircularRealEnsemble)
z = curand(c.g)
_qr_fix!(z)
end

function curand(c::CircularQuaternionEnsemble)
z = curand(c.g)
_qr_fix!(z)
end

function curand(c::HaarIsometry)
z = curand(c.g)
_qr_fix!(z)
for T in (CUE, CircularRealEnsemble, CircularQuaternionEnsemble, HaarIsometry)
@eval begin
function curand(c::$T)
z = curand(c.g)
_qr_fix!(z)
end
end
end

0 comments on commit 235400f

Please sign in to comment.