In [1]:
import Base: getindex, size, println
function println(x)
    show(STDOUT, "text/plain", x)
    print("\n")
end

println (generic function with 4 methods)

In [2]:
struct Identity{Tv} <: AbstractMatrix{Tv}
    n::Int
end
Identity(n::Int) = Identity{Int}(n)

size(A::Identity, i::Int) = A.n
size(A::Identity) = (A.n, A.n)
getindex(A::Identity{T}, i::Integer, j::Integer) where T = T(i==j)

getindex (generic function with 185 methods)

In [3]:
id = Identity{Float64}(4)

4×4 Identity{Float64}:
 1.0  0.0  0.0  0.0
 0.0  1.0  0.0  0.0
 0.0  0.0  1.0  0.0
 0.0  0.0  0.0  1.0

In [4]:
import Base: sparse, full
sparse(A::Identity{T}) where T = speye(T, A.n)
full(A::Identity{T}) where T = eye(T, A.n)

full (generic function with 25 methods)

In [5]:
import Base: transpose, conj, copy, real, imag, ctranspose
for func in (:conj, :real, :imag, :ctranspose, :transpose, :copy)
    @eval ($func)(M::Identity{T}) where T = Identity{T}(M.n)
end

In [6]:
import Base: *, /, ==
*(A::Identity{T}, B::Number) where T = Diagonal(fill(promote_type(T, eltype(B))(B), A.n))
*(B::Number, A::Identity{T}) where T = Diagonal(fill(promote_type(T, eltype(B))(B), A.n))
/(A::Identity{T}, B::Number) where T = Diagonal(fill(promote_type(T, eltype(B))(1/B), A.n))
==(A::Identity, B::Identity) = A.n == B.n

== (generic function with 127 methods)

In [7]:
println(id*2im)
println(2*id)
println(id/3im)
println(id==id)

4×4 Diagonal{Complex{Float64}}:
 0.0+2.0im      ⋅          ⋅          ⋅    
     ⋅      0.0+2.0im      ⋅          ⋅    
     ⋅          ⋅      0.0+2.0im      ⋅    
     ⋅          ⋅          ⋅      0.0+2.0im
4×4 Diagonal{Float64}:
 2.0   ⋅    ⋅    ⋅ 
  ⋅   2.0   ⋅    ⋅ 
  ⋅    ⋅   2.0   ⋅ 
  ⋅    ⋅    ⋅   2.0
4×4 Diagonal{Complex{Float64}}:
 0.0-0.333333im      ⋅               ⋅               ⋅         
     ⋅           0.0-0.333333im      ⋅               ⋅         
     ⋅               ⋅           0.0-0.333333im      ⋅         
     ⋅               ⋅               ⋅           0.0-0.333333im
true


In [8]:
import Base: nnz, nonzeros, inv
nnz(M::Identity) = M.n
nonzeros(M::Identity{T}) where T = ones(T, M.n)
inv(M::Identity) = M

inv (generic function with 28 methods)

In [9]:
println(nnz(id))
println(nonzeros(id))
println(inv(id))

4
4-element Array{Float64,1}:
 1.0
 1.0
 1.0
 1.0
4×4 Identity{Float64}:
 1.0  0.0  0.0  0.0
 0.0  1.0  0.0  0.0
 0.0  0.0  1.0  0.0
 0.0  0.0  0.0  1.0


In [26]:
####### multiply ###########
Mats = Union{SparseMatrixCSC, StridedVecOrMat}
function (*)(A::Identity, B::Mats)
    size(A, 2) == size(B, 1) || throw(DimensionMismatch())
    B
end
function (*)(A::Mats, B::Identity)
    size(A, ndims(A)) == size(B, 1) || throw(DimensionMismatch())
    A
end
#### why can not take a union here?
function (*)(A::Identity, B::Diagonal)
    size(A, 2) == size(B, 1) || throw(DimensionMismatch())
    B
end
function (*)(A::Diagonal, B::Identity)
    size(A, ndims(A)) == size(B, 1) || throw(DimensionMismatch())
    A
end
function (*)(A::Identity, B::Identity)
    size(A, 2) == size(B, 1) || throw(DimensionMismatch())
    B
end

* (generic function with 191 methods)

In [11]:
typeintersect(Identity, AbstractArray)

Identity

In [12]:
mat = rand(5,4)
id*mat'

4×5 Array{Float64,2}:
 0.0200548  0.280313  0.119009  0.7276    0.707122
 0.412546   0.994307  0.981615  0.966783  0.679972
 0.732867   0.450539  0.276721  0.571208  0.211329
 0.376325   0.993609  0.662932  0.830309  0.234954

In [13]:
function setdiag!(A::AbstractMatrix, B::Number)
    n = size(A, 1)
    A[1:n+1:end] = B
    A
end
function adddiag!(A::AbstractMatrix, B::Number)
    n = size(A, 1)
    A[1:n+1:end] += B
    A
end

adddiag! (generic function with 1 method)

In [14]:
A = rand(4,4)
offsetdiag!(A, 3)
A

LoadError: [91mUndefVarError: offsetdiag! not defined[39m

In [15]:
import Base: kron

kron(A::Identity{Ta}, B::Identity{Tb}) where {Ta, Tb}= Identity{promote_type(Ta, Tb)}(A.n*B.n)

function kron(A::AbstractMatrix{Tv}, B::Identity) where Tv
    mA, nA = size(A)
    nB = B.n
    C = zeros(Tv, mA*nB, nA*nB)
    @inbounds for j = 1:nA
        for i = 1:mA
            val = A[i,j]
            @inbounds @simd for k = 1:nB
                C[(i-1)*nB+k, (j-1)*nB+k] = val
            end
        end
    end
    C
end
function kron(A::Identity, B::AbstractMatrix{Tv}) where Tv
    mB, nB = size(B)
    nA = A.n
    C = zeros(Tv, mB*nA, nB*nA)
    @inbounds @simd for i = 1:nA
        C[(i-1)*mB+1:i*mB, (i-1)*nB+1:i*nB] = B
    end
    C
end

kron (generic function with 16 methods)

In [16]:
using BenchmarkTools
@benchmark kron(A, id) == kron(A, full(id))

BenchmarkTools.Trial: 
  memory estimate:  4.45 KiB
  allocs estimate:  3
  --------------
  minimum time:     1.611 μs (0.00% GC)
  median time:      1.750 μs (0.00% GC)
  mean time:        2.069 μs (12.29% GC)
  maximum time:     184.557 μs (93.45% GC)
  --------------
  samples:          10000
  evals/sample:     10

In [17]:
@benchmark kron(id, A) == kron(full(id), A)

BenchmarkTools.Trial: 
  memory estimate:  4.70 KiB
  allocs estimate:  11
  --------------
  minimum time:     1.785 μs (0.00% GC)
  median time:      1.911 μs (0.00% GC)
  mean time:        2.221 μs (12.19% GC)
  maximum time:     177.861 μs (93.35% GC)
  --------------
  samples:          10000
  evals/sample:     10

In [129]:
function kron(A::Identity, B::SparseMatrixCSC)
    nA = A.n
    mB, nB = size(B)
    nV = nnz(B)
    nzval = repeat(B.nzval, outer=nA)
    rowval = vcat([B.rowval+(i-1)*mB for i in 1:nA]...)
    colptr = vcat([B.colptr[(i==1?1:2):end]+(i-1)*nV for i in 1:nA]...)
    SparseMatrixCSC(mB*nA, nB*nA, colptr, rowval, nzval)
end
function kron(A::SparseMatrixCSC{T}, B::Identity) where T
    nB = B.n
    mA, nA = size(A)
    nV = nnz(A)
    #rowval = reshape(hcat([A.rowval+(i-1)*nB for i in 1:nB]...), nB*nV)
    #rowval = vcat([(r-1)*nB+1:r*nB for r in A.rowval]...)
    rowval = Vector{Int}(nB*nV)
    nzval = Vector{T}(nB*nV)
    z=0
    @inbounds for i in 1:nA
        rstart = A.colptr[i]
        rend = A.colptr[i+1]-1
        row = A.rowval[rstart:rend]
        nzv = A.nzval[rstart:rend]
        nrow = length(row)
        @inbounds @simd for k in 1:nB
            rowval[z+1:z+nrow] = (row-1)*nB+k
            nzval[z+1:z+nrow] = nzv
            z+=nrow
        end
    end
    nl = diff(A.colptr)
    colptr = prepend!(cumsum(repeat(nl, inner=nB))+1, 1)
    #println(nzval)
    #println(rowval)
    #println(colptr)
    SparseMatrixCSC(mA*nB, nA*nB, colptr, rowval, nzval)
end

kron (generic function with 24 methods)

In [130]:
sa = sprand(4,4,0.2)
println(sa.colptr)
println(sa.rowval)

5-element Array{Int64,1}:
 1
 1
 1
 2
 2
1-element Array{Int64,1}:
 2


In [131]:
println(kron(id, sa) == kron(id, full(sa)))

true


In [132]:
println(full(kron(sa, id))[1:4,:])

4×16 Array{Float64,2}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0


In [133]:
kron(full(sa), id) == full(kron(sa, id))

true

In [134]:
p1 = Identity(4)
p2 = sprand(Complex128, 4,4, 0.5)
p3 = rand(Complex128, 4,4)
v = [0.5, 0.3im, 0.2, 1.0]
Dv = Diagonal(v)

4×4 Diagonal{Complex{Float64}}:
 0.5+0.0im      ⋅          ⋅          ⋅    
     ⋅      0.0+0.3im      ⋅          ⋅    
     ⋅          ⋅      0.2+0.0im      ⋅    
     ⋅          ⋅          ⋅      1.0+0.0im

In [135]:
for target in [p1, p2, p3, Dv]
    lres = p1*target
    rres = (target')*p1
    println(typeof(lres), typeof(rres))
    println(lres == target)
    println(rres == target')
    println(typeof(target), typeof(lres), typeof(rres))
    println(typeof(lres) == typeof(target) == typeof(rres))
end

Identity{Int64}Identity{Int64}
true
true
Identity{Int64}Identity{Int64}Identity{Int64}
true
SparseMatrixCSC{Complex{Float64},Int64}SparseMatrixCSC{Complex{Float64},Int64}
true
false
SparseMatrixCSC{Complex{Float64},Int64}SparseMatrixCSC{Complex{Float64},Int64}SparseMatrixCSC{Complex{Float64},Int64}
true
Array{Complex{Float64},2}Array{Complex{Float64},2}
true
false
Array{Complex{Float64},2}Array{Complex{Float64},2}Array{Complex{Float64},2}
true
Diagonal{Complex{Float64}}Diagonal{Complex{Float64}}
true
false
Diagonal{Complex{Float64}}Diagonal{Complex{Float64}}Diagonal{Complex{Float64}}
true


In [136]:
# see the definition of diagonal

In [137]:
kron(A::Identity, B::Diagonal{Tb}) where Tb = Diagonal{Tb}(repeat(B.diag, outer=A.n))
kron(B::Diagonal{Tb}, A::Identity) where Tb = Diagonal{Tb}(repeat(B.diag, inner=A.n))

kron (generic function with 24 methods)

In [138]:
d = Diagonal([1,2,3])
kron(d, p1)

12×12 Diagonal{Int64}:
 1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  2  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  2  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  2  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  2  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  3  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  3  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  3  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  3

In [139]:
####### diagonal kron ########
function kron(A::StridedMatrix{Tv}, B::Diagonal{Tb}) where {Tv, Tb}
    mA, nA = size(A)
    nB = size(B, 1)
    C = zeros(promote_type(Tv, Tb), mA*nB, nA*nB)
    @inbounds for j = 1:nA
        for i = 1:mA
            val = A[i,j]
            @inbounds @simd for k = 1:nB
                C[(i-1)*nB+k, (j-1)*nB+k] = val*B.diag[k]  # merge
            end
        end
    end
    C
end

function kron(A::Diagonal{Ta}, B::StridedMatrix{Tv}) where {Tv, Ta}
    mB, nB = size(B)
    nA = size(A, 1)
    C = zeros(promote_type(Tv, Ta), mB*nA, nB*nA)
    @inbounds @simd for i = 1:nA
        C[(i-1)*mB+1:i*mB, (i-1)*nB+1:i*nB] = B*A.diag[i]
    end
    C
end

kron (generic function with 24 methods)

In [140]:
ds = rand(3,3)
println(kron(d, ds)==kron(full(d), ds))
println(kron(ds, d)==kron(ds, full(d)))

true
true


In [141]:
function kron(A::Diagonal, B::SparseMatrixCSC)
    nA = size(A, 1)
    mB, nB = size(B)
    nV = nnz(B)
    nzval = vcat([B.nzval*A.diag[i] for i in 1:nA]...)
    rowval = vcat([B.rowval+(i-1)*mB for i in 1:nA]...)
    colptr = vcat([B.colptr[(i==1?1:2):end]+(i-1)*nV for i in 1:nA]...)
    SparseMatrixCSC(mB*nA, nB*nA, colptr, rowval, nzval)
end

function kron(A::SparseMatrixCSC{T}, B::Diagonal{Tb}) where {T, Tb}
    nB = size(B, 1)
    mA, nA = size(A)
    nV = nnz(A)
    rowval = Vector{Int}(nB*nV)
    nzval = Vector{promote_type(T, Tb)}(nB*nV)
    z=0
    @inbounds for i in 1:nA
        rstart = A.colptr[i]
        rend = A.colptr[i+1]-1
        row = A.rowval[rstart:rend]
        nzv = A.nzval[rstart:rend]
        nrow = length(row)
        @inbounds @simd for k in 1:nB
            rowval[z+1:z+nrow] = (row-1)*nB+k
            nzval[z+1:z+nrow] = nzv*B.diag[k]
            z+=nrow
        end
    end
    nl = diff(A.colptr)
    colptr = prepend!(cumsum(repeat(nl, inner=nB))+1, 1)
    SparseMatrixCSC(mA*nB, nA*nB, colptr, rowval, nzval)
end

kron (generic function with 24 methods)

In [142]:
kron(d, p2) == kron(full(d), p2)

true

In [143]:
full(kron(p2, d)) == kron(full(p2), full(d))

true

In [144]:
kron(p2, d) == kron(full(p2), full(d))

true