In [2]:
import Base: getindex, size, println
function println(x)
    show(STDOUT, "text/plain", x)
    print("\n")
end

println (generic function with 4 methods)

In [3]:
struct Identity{Tv} <: AbstractMatrix{Tv}
    n::Int
end
Identity(n::Int) = Identity{Int}(n)

size(A::Identity, i::Int) = A.n
size(A::Identity) = (A.n, A.n)
getindex(A::Identity{T}, i::Integer, j::Integer) where T = T(i==j)

getindex (generic function with 185 methods)

In [133]:
id = Identity{Float64}(4)

4×4 Identity{Float64}:
 1.0  0.0  0.0  0.0
 0.0  1.0  0.0  0.0
 0.0  0.0  1.0  0.0
 0.0  0.0  0.0  1.0

In [5]:
import Base: sparse, full
sparse(A::Identity{T}) where T = speye(T, A.n)
full(A::Identity{T}) where T = eye(T, A.n)

full (generic function with 25 methods)

In [6]:
import Base: transpose, conj, copy, real, imag, ctranspose, kron
for func in (:conj, :real, :imag, :ctranspose, :transpose, :copy)
    @eval ($func)(M::Identity{T}) where T = Identity{T}(M.n)
end

In [7]:
import Base: *, /, ==
*(A::Identity{T}, B::Number) where T = Diagonal(fill(promote_type(T, eltype(B))(B), A.n))
*(B::Number, A::Identity{T}) where T = Diagonal(fill(promote_type(T, eltype(B))(B), A.n))
/(A::Identity{T}, B::Number) where T = Diagonal(fill(promote_type(T, eltype(B))(1/B), A.n))
==(A::Identity, B::Identity) = A.n == B.n

== (generic function with 127 methods)

In [8]:
println(id*2im)
println(2*id)
println(id/3im)
println(id==id)

4×4 Diagonal{Complex{Float64}}:
 0.0+2.0im      ⋅          ⋅          ⋅    
     ⋅      0.0+2.0im      ⋅          ⋅    
     ⋅          ⋅      0.0+2.0im      ⋅    
     ⋅          ⋅          ⋅      0.0+2.0im
4×4 Diagonal{Float64}:
 2.0   ⋅    ⋅    ⋅ 
  ⋅   2.0   ⋅    ⋅ 
  ⋅    ⋅   2.0   ⋅ 
  ⋅    ⋅    ⋅   2.0
4×4 Diagonal{Complex{Float64}}:
 0.0-0.333333im      ⋅               ⋅               ⋅         
     ⋅           0.0-0.333333im      ⋅               ⋅         
     ⋅               ⋅           0.0-0.333333im      ⋅         
     ⋅               ⋅               ⋅           0.0-0.333333im
true


In [9]:
import Base: nnz, nonzeros, inv
nnz(M::Identity) = M.n
nonzeros(M::Identity{T}) where T = ones(T, M.n)
inv(M::Identity) = M

inv (generic function with 28 methods)

In [10]:
println(nnz(id))
println(nonzeros(id))
println(inv(id))

4
4-element Array{Float64,1}:
 1.0
 1.0
 1.0
 1.0
4×4 Identity{Float64}:
 1.0  0.0  0.0  0.0
 0.0  1.0  0.0  0.0
 0.0  0.0  1.0  0.0
 0.0  0.0  0.0  1.0


In [11]:
####### multiply ###########
Mats = Union{SparseMatrixCSC, StridedVecOrMat}
function (*)(A::Identity, B::Mats)
    size(A, 2) == size(B, 1) || throw(DimensionMismatch())
    B
end
function (*)(A::Mats, B::Identity)
    size(A, ndims(A)) == size(B, 1) || throw(DimensionMismatch())
    A
end
#### why can not take a union here?
function (*)(A::Identity, B::Diagonal)
    size(A, 2) == size(B, 1) || throw(DimensionMismatch())
    B
end
function (*)(A::Diagonal, B::Identity)
    size(A, ndims(A)) == size(B, 1) || throw(DimensionMismatch())
    A
end
function (*)(A::Identity, B::Identity)
    size(A, 2) == size(B, 1) || throw(DimensionMismatch())
    B
end

* (generic function with 189 methods)

In [12]:
typeintersect(Identity, AbstractArray)

Identity

In [13]:
mat = rand(5,4)
id*mat'

4×5 Array{Float64,2}:
 0.786424  0.945411   0.58803   0.637621  0.558157
 0.339167  0.400816   0.146227  0.647954  0.745609
 0.901472  0.955759   0.392278  0.385646  0.413835
 0.189123  0.0767857  0.24016   0.755861  0.277056

In [14]:
function setdiag!(A::AbstractMatrix, B::Number)
    n = size(A, 1)
    A[1:n+1:end] = B
    A
end
function adddiag!(A::AbstractMatrix, B::Number)
    n = size(A, 1)
    A[1:n+1:end] += B
    A
end

adddiag! (generic function with 1 method)

In [15]:
A = rand(4,4)
offsetdiag!(A, 3)
A

LoadError: [91mUndefVarError: offsetdiag! not defined[39m

In [16]:
import Base: kron

kron(A::Identity{Ta}, B::Identity{Tb}) where {Ta, Tb}= Identity{promote_type(Ta, Tb)}(A.n*B.n)

function kron(A::AbstractMatrix{Tv}, B::Identity) where Tv
    mA, nA = size(A)
    nB = B.n
    nzval = Vector{Tv}(nB*mA*nA)
    rowval = Vector{Int}(nB*mA*nA)
    colptr = collect(1:mA:nB*mA*nA+1)
    @inbounds for j = 1:nA
        source = A[:,j]
        startbase = (j-1)*nB*mA - mA
        @inbounds for j2 = 1:nB
            start = startbase + j2*mA
            row = j2-nB
            @inbounds @simd for i = 1:mA
                nzval[start+i] = source[i]
                rowval[start+i] = row+nB*i
            end
        end
    end
    SparseMatrixCSC(mA*nB, nA*nB, colptr, rowval, nzval)
end
function kron(A::Identity, B::AbstractMatrix{Tv}) where Tv
    nA = A.n
    mB, nB = size(B)
    rowval = Vector{Int}(nB*mB*nA)
    nzval = Vector{Tv}(nB*mB*nA)
    @inbounds for j in 1:nA
        r0 = (j-1)*mB
        @inbounds for j2 in 1:nB
            start = ((j-1)*nB+j2-1)*mB
            @inbounds @simd for i in 1:mB
                rowval[start+i] = r0+i
                nzval[start+i] = B[i,j2]
            end
        end
    end
    colptr = collect(1:mB:nB*mB*nA+1)
    SparseMatrixCSC(mB*nA, nA*nB, colptr, rowval, nzval)
end

kron (generic function with 16 methods)

In [17]:
include("../permmul.jl")
PermuteMultiply(A::Identity{T}) where T = PermuteMultiply(collect(1:A.n), ones(T, A.n))

PermuteMultiply

In [18]:
using BenchmarkTools
using ProfileView
id = Identity(1000)
target = randn(4,4)
@benchmark kron(id, target)

BenchmarkTools.Trial: 
  memory estimate:  281.63 KiB
  allocs estimate:  8
  --------------
  minimum time:     31.070 μs (0.00% GC)
  median time:      44.507 μs (0.00% GC)
  mean time:        50.842 μs (11.63% GC)
  maximum time:     703.356 μs (83.69% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [18]:
kron(A, id) == kron(A, sparse(id))

true

In [19]:
kron(id, A) == kron(sparse(id), A)

true

In [58]:
sid = PermuteMultiply(id)
@benchmark kron(sid, target)

BenchmarkTools.Trial: 
  memory estimate:  289.58 KiB
  allocs estimate:  10
  --------------
  minimum time:     59.136 μs (0.00% GC)
  median time:      61.897 μs (0.00% GC)
  mean time:        69.675 μs (9.58% GC)
  maximum time:     815.035 μs (83.09% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [21]:
@benchmark kron(target, id)

BenchmarkTools.Trial: 
  memory estimate:  282.13 KiB
  allocs estimate:  16
  --------------
  minimum time:     32.058 μs (0.00% GC)
  median time:      45.875 μs (0.00% GC)
  mean time:        64.422 μs (12.17% GC)
  maximum time:     1.994 ms (84.77% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [22]:
@benchmark kron(target, sid)

BenchmarkTools.Trial: 
  memory estimate:  289.58 KiB
  allocs estimate:  10
  --------------
  minimum time:     69.803 μs (0.00% GC)
  median time:      71.497 μs (0.00% GC)
  mean time:        89.201 μs (8.20% GC)
  maximum time:     1.801 ms (81.05% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [55]:
function kron(A::Identity, B::SparseMatrixCSC{T}) where T
    nA = A.n
    mB, nB = size(B)
    nV = nnz(B)
    nzval = Vector{T}(nA*nV)
    rowval = Vector{Int}(nA*nV)
    colptr = Vector{Int}(nB*nA+1)
    nzval = Vector{T}(nA*nV)
    colptr[1] = 1
    @inbounds for i = 1:nA
        r0 = (i-1)*mB
        start = nV*(i-1)
        @inbounds @simd for k = 1:nV
            rowval[start+k] = B.rowval[k]+r0
            nzval[start+k] = B.nzval[k]
        end
        colbase = (i-1)*nB
        @inbounds @simd for j=2:nB+1
            colptr[colbase+j] = B.colptr[j]+start
        end
    end
    SparseMatrixCSC(mB*nA, nB*nA, colptr, rowval, nzval)
end
function kron(A::SparseMatrixCSC{T}, B::Identity) where T
    nB = B.n
    mA, nA = size(A)
    nV = nnz(A)
    rowval = Vector{Int}(nB*nV)
    colptr = Vector{Int}(nA*nB+1)
    nzval = Vector{T}(nB*nV)
    z=1
    colptr[1] = 1
    @inbounds for i in 1:nA
        rstart = A.colptr[i]
        rend = A.colptr[i+1]-1
        colbase = (i-1)*nB+1
        @inbounds for k in 1:nB
            irow_nB = k - nB
            @inbounds @simd for r in rstart:rend
                rowval[z] = A.rowval[r]*nB+irow_nB
                nzval[z] = A.nzval[r]
                z+=1
            end
            colptr[colbase+k] = z
        end
    end
    SparseMatrixCSC(mA*nB, nA*nB, colptr, rowval, nzval)
end

kron (generic function with 30 methods)

In [20]:
sa = sprand(4,4,0.2)
Sa = sprand(16,16,0.2);

In [53]:
println(kron(id, sa) == kron(sparse(id), sa))

true


In [54]:
println(kron(sa, id) == kron(full(sa), full(id)))

true


In [None]:
kron(full(sa), id) == full(kron(sa, id))

In [None]:
@benchmark kron(Sa, id)

In [None]:
@benchmark kron(Sa, sid)

In [56]:
@benchmark kron(id, Sa)

BenchmarkTools.Trial: 
  memory estimate:  1.31 MiB
  allocs estimate:  9
  --------------
  minimum time:     127.899 μs (0.00% GC)
  median time:      149.976 μs (0.00% GC)
  mean time:        173.867 μs (13.62% GC)
  maximum time:     962.296 μs (76.84% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [59]:
@benchmark kron(sid, Sa)

BenchmarkTools.Trial: 
  memory estimate:  945.80 KiB
  allocs estimate:  9
  --------------
  minimum time:     142.348 μs (0.00% GC)
  median time:      165.237 μs (0.00% GC)
  mean time:        185.693 μs (10.77% GC)
  maximum time:     1.049 ms (74.70% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [138]:
p1 = Identity(4)
p2 = sprand(Complex128, 4,4, 0.5)
p3 = rand(Complex128, 4,4)
v = [0.5, 0.3im, 0.2, 1.0]
Dv = Diagonal(v)

4×4 Diagonal{Complex{Float64}}:
 0.5+0.0im      ⋅          ⋅          ⋅    
     ⋅      0.0+0.3im      ⋅          ⋅    
     ⋅          ⋅      0.2+0.0im      ⋅    
     ⋅          ⋅          ⋅      1.0+0.0im

In [139]:
for target in [p1, p2, p3, Dv]
    lres = p1*target
    rres = (target')*p1
    println(typeof(lres), typeof(rres))
    println(lres == target)
    println(rres == target')
    println(typeof(target), typeof(lres), typeof(rres))
    println(typeof(lres) == typeof(target) == typeof(rres))
end

Identity{Int64}Array{Int64,2}
true
true
Identity{Int64}Identity{Int64}Array{Int64,2}
false
SparseMatrixCSC{Complex{Float64},Int64}Array{Complex{Float64},2}
true
true
SparseMatrixCSC{Complex{Float64},Int64}SparseMatrixCSC{Complex{Float64},Int64}Array{Complex{Float64},2}
false
Array{Complex{Float64},2}Array{Complex{Float64},2}
true
true
Array{Complex{Float64},2}Array{Complex{Float64},2}Array{Complex{Float64},2}
true
Diagonal{Complex{Float64}}Array{Complex{Float64},2}
true
true
Diagonal{Complex{Float64}}Diagonal{Complex{Float64}}Array{Complex{Float64},2}
false


In [140]:
# see the definition of diagonal

In [141]:
kron(A::Identity, B::Diagonal{Tb}) where Tb = Diagonal{Tb}(repeat(B.diag, outer=A.n))
kron(B::Diagonal{Tb}, A::Identity) where Tb = Diagonal{Tb}(repeat(B.diag, inner=A.n))

kron (generic function with 34 methods)

In [142]:
d = Diagonal([1,2,3])
kron(d, p1)

12×12 Diagonal{Int64}:
 1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  2  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  2  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  2  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  2  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  3  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  3  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  3  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  3

In [143]:
####### diagonal kron ########
function kron(A::StridedMatrix{Tv}, B::Diagonal{Tb}) where {Tv, Tb}
    mA, nA = size(A)
    nB = size(B, 1)
    C = zeros(promote_type(Tv, Tb), mA*nB, nA*nB)
    @inbounds for j = 1:nA
        for i = 1:mA
            val = A[i,j]
            @inbounds @simd for k = 1:nB
                C[(i-1)*nB+k, (j-1)*nB+k] = val*B.diag[k]  # merge
            end
        end
    end
    C
end

function kron(A::Diagonal{Ta}, B::StridedMatrix{Tv}) where {Tv, Ta}
    mB, nB = size(B)
    nA = size(A, 1)
    C = zeros(promote_type(Tv, Ta), mB*nA, nB*nA)
    @inbounds @simd for i = 1:nA
        C[(i-1)*mB+1:i*mB, (i-1)*nB+1:i*nB] = B*A.diag[i]
    end
    C
end

kron (generic function with 34 methods)

In [None]:
ds = rand(3,3)
println(kron(d, ds)==kron(full(d), ds))
println(kron(ds, d)==kron(ds, full(d)))

In [None]:
function kron(A::Diagonal, B::SparseMatrixCSC)
    nA = size(A, 1)
    mB, nB = size(B)
    nV = nnz(B)
    nzval = vcat([B.nzval*A.diag[i] for i in 1:nA]...)
    rowval = vcat([B.rowval+(i-1)*mB for i in 1:nA]...)
    colptr = vcat([B.colptr[(i==1?1:2):end]+(i-1)*nV for i in 1:nA]...)
    SparseMatrixCSC(mB*nA, nB*nA, colptr, rowval, nzval)
end

function kron(A::SparseMatrixCSC{T}, B::Diagonal{Tb}) where {T, Tb}
    nB = size(B, 1)
    mA, nA = size(A)
    nV = nnz(A)
    rowval = Vector{Int}(nB*nV)
    nzval = Vector{promote_type(T, Tb)}(nB*nV)
    z=0
    @inbounds for i in 1:nA
        rstart = A.colptr[i]
        rend = A.colptr[i+1]-1
        row = A.rowval[rstart:rend]
        nzv = A.nzval[rstart:rend]
        nrow = length(row)
        @inbounds @simd for k in 1:nB
            rowval[z+1:z+nrow] = (row-1)*nB+k
            nzval[z+1:z+nrow] = nzv*B.diag[k]
            z+=nrow
        end
    end
    nl = diff(A.colptr)
    colptr = prepend!(cumsum(repeat(nl, inner=nB))+1, 1)
    SparseMatrixCSC(mA*nB, nA*nB, colptr, rowval, nzval)
end

In [None]:
kron(d, p2) == kron(full(d), p2)

In [None]:
full(kron(p2, d)) == kron(full(p2), full(d))

In [None]:
kron(p2, d) == kron(full(p2), full(d))

In [119]:
function kron(A::Identity, B::PermuteMultiply{T}) where T
    nA = size(A, 1)
    nB = size(B, 1)
    perm = Vector{Int}(nB*nA)
    vals = Vector{T}(nB*nA)
    @inbounds for i = 1:nA
        start = (i-1)*nB
        @inbounds @simd for j = 1:nB
            perm[start+j] = start +B.perm[j]
            vals[start+j] = B.vals[j]
        end
    end
    PermuteMultiply(perm, vals)
end

kron (generic function with 31 methods)

In [120]:
@benchmark kron(id, sid)

BenchmarkTools.Trial: 
  memory estimate:  15.26 MiB
  allocs estimate:  5
  --------------
  minimum time:     2.492 ms (0.00% GC)
  median time:      3.080 ms (16.55% GC)
  mean time:        3.018 ms (14.05% GC)
  maximum time:     78.043 ms (95.65% GC)
  --------------
  samples:          1654
  evals/sample:     1

In [115]:
sparse(kron(id, sid)) == kron(sparse(id), sparse(sid))

true

In [130]:
function kron(A::PermuteMultiply{T}, B::Identity) where T
    nA = size(A, 1)
    nB = size(B, 1)
    vals = Vector{T}(nB*nA)
    perm = Vector{Int}(nB*nA)
    @inbounds for i = 1:nA
        start = (i-1)*nB
        permAi = (A.perm[i]-1)*nB
        val = A.vals[i]
        @inbounds @simd for j = 1:nB
            perm[start+j] = permAi + j
            vals[start+j] = val
        end
    end
    PermuteMultiply(perm, vals)
end

kron (generic function with 32 methods)

In [131]:
sparse(kron(sid, id)) == kron(sparse(sid), sparse(id))

true

In [132]:
@benchmark kron(sid, id)

BenchmarkTools.Trial: 
  memory estimate:  15.26 MiB
  allocs estimate:  5
  --------------
  minimum time:     2.494 ms (0.00% GC)
  median time:      3.096 ms (17.81% GC)
  mean time:        2.998 ms (14.42% GC)
  maximum time:     72.131 ms (96.21% GC)
  --------------
  samples:          1665
  evals/sample:     1

In [134]:
kron(A::Identity, B::Diagonal{Tb}) where Tb = Diagonal{Tb}(repeat(B.diag, outer=A.n))
kron(B::Diagonal{Tb}, A::Identity) where Tb = Diagonal{Tb}(repeat(B.diag, inner=A.n))

kron (generic function with 34 methods)

In [145]:
id = Identity(100)
Dv = Diagonal(randn(100))
@benchmark kron(id, Dv)

BenchmarkTools.Trial: 
  memory estimate:  85.22 KiB
  allocs estimate:  222
  --------------
  minimum time:     41.256 μs (0.00% GC)
  median time:      42.734 μs (0.00% GC)
  mean time:        52.742 μs (10.70% GC)
  maximum time:     2.078 ms (93.89% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [169]:
function irepv(v::AbstractArray{Tv}, n::Int) where Tv
    nV = length(v)
    res = Vector{Tv}(nV*n)
    @inbounds for j = 1:nV
        vj = v[j]
        base = (j-1)*n
        @inbounds @simd for i = 1:n
            res[base+i] = vj
        end
    end
    res
end
function orepv(v::AbstractArray{Tv}, n::Int) where Tv
    nV = length(v)
    res = Vector{Tv}(nV*n)
    @inbounds for i = 1:n
        base = (i-1)*nV
        @inbounds @simd for j = 1:nV
            res[base+j] = v[j]
        end
    end
    res
end
kron(A::Identity, B::Diagonal) = Diagonal(orepv(B.diag, A.n))
kron(B::Diagonal, A::Identity) = Diagonal(irepv(B.diag, A.n))

kron (generic function with 34 methods)

In [170]:
println(sparse(kron(id, Dv)) == kron(sparse(id), sparse(Dv)))
println(sparse(kron(Dv, id)) == kron(sparse(Dv), sparse(id)))

true
true


In [164]:
@benchmark kron($id, $Dv)

BenchmarkTools.Trial: 
  memory estimate:  78.22 KiB
  allocs estimate:  3
  --------------
  minimum time:     6.868 μs (0.00% GC)
  median time:      12.466 μs (0.00% GC)
  mean time:        15.213 μs (18.21% GC)
  maximum time:     330.777 μs (90.46% GC)
  --------------
  samples:          10000
  evals/sample:     3

In [171]:
@benchmark kron($Dv, $id)

BenchmarkTools.Trial: 
  memory estimate:  78.22 KiB
  allocs estimate:  3
  --------------
  minimum time:     7.563 μs (0.00% GC)
  median time:      12.879 μs (0.00% GC)
  mean time:        20.268 μs (22.63% GC)
  maximum time:     969.589 μs (91.32% GC)
  --------------
  samples:          10000
  evals/sample:     3