In [1]:
using BenchmarkTools
struct PermuteMultiply{Tv, Ti<:Integer} <: AbstractMatrix{Tv}
    perm::Vector{Ti}   # new orders
    vals::Vector{Tv}  # multiplied values.

    function PermuteMultiply{Tv, Ti}(perm::Vector{Ti}, vals::Vector{Tv}) where {Tv, Ti<:Integer}
        if length(perm) != length(vals)
            throw(DimensionMismatch("permutation ($(length(perm))) and multiply ($(length(vals))) length mismatch."))
        end
        new{Tv, Ti}(perm, vals)
    end
end
function PermuteMultiply(perm::Vector, vals::Vector)
    Tv = eltype(vals)
    Ti = eltype(perm)
    PermuteMultiply{Tv,Ti}(perm, vals)
end

import Base: size, show, eltype, getindex, full
size(M::PermuteMultiply) = (length(M.perm), length(M.perm))
function size(A::PermuteMultiply, d::Integer)
    if d < 1
        throw(ArgumentError("dimension must be ≥ 1, got $d"))
    elseif d<=2
        return length(A.perm)
    else
        return 1
    end
end
getindex(M::PermuteMultiply, i::Integer, j::Integer) = M.perm[i] == j ? M.vals[i] : 0

getindex (generic function with 191 methods)

In [2]:
p1 = PermuteMultiply([2,1,4,3],[0.1, 0.2, 0.4im, 0.5])
p2 = PermuteMultiply([2,1,4,3],[0.1, 0.2im, 0.4, 0.5])
p3 = PermuteMultiply([4,1,2,3],[0.5, 0.4im, 0.3, 0.2])
println(p1)
println(p2)
println(size(p1)," " ,size(p1,1))
println(isapprox(p1, p2))

Complex{Float64}[0 0.1+0.0im 0 0; 0.2+0.0im 0 0 0; 0 0 0 0.0+0.4im; 0 0 0.5+0.0im 0]
Complex{Float64}[0 0.1+0.0im 0 0; 0.0+0.2im 0 0 0; 0 0 0 0.4+0.0im; 0 0 0.5+0.0im 0]
(4, 4) 4
false


In [3]:
println(eltype(p1))

Complex{Float64}


In [4]:
@benchmark full(p1)
function Matrix{T}(M::PermuteMultiply) where T
    n = size(M, 1)
    Mf = zeros(T, n, n)
    @inbounds for i=1:n
        Mf[i, M.perm[i]] = M.vals[i]
    end
    return Mf
end
Matrix(M::PermuteMultiply{T}) where {T} = Matrix{T}(M)
Array(M::PermuteMultiply) = Matrix(M)
full(M::PermuteMultiply) = Matrix(M)

print(full(p1))

Complex{Float64}[0.0+0.0im 0.1+0.0im 0.0+0.0im 0.0+0.0im; 0.2+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.4im; 0.0+0.0im 0.0+0.0im 0.5+0.0im 0.0+0.0im]

In [5]:
function Matrix{T}(M::PermuteMultiply) where T
    n = size(M, 1)
    Mf = zeros(T, n, n)
    @inbounds for i=1:n
        Mf[i, M.perm[i]] = M.vals[i]
    end
    return Mf
end
Matrix(M::PermuteMultiply{T}) where {T} = Matrix{T}(M)
Array(M::PermuteMultiply) = Matrix(M)

print(full(p1))
print(sparse(p1))

Complex{Float64}[0.0+0.0im 0.1+0.0im 0.0+0.0im 0.0+0.0im; 0.2+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.4im; 0.0+0.0im 0.0+0.0im 0.5+0.0im 0.0+0.0im]
  [2, 1]  =  0.2+0.0im
  [1, 2]  =  0.1+0.0im
  [4, 3]  =  0.5+0.0im
  [3, 4]  =  0.0+0.4im

In [6]:
import Base: sparse, kron
function sparse(M::PermuteMultiply{T}) where {T}
    n = size(M, 1)
    sparse(collect(1:n), M.perm, M.vals, n, n)
end

sparse (generic function with 23 methods)

In [7]:
print(sparse(p2))


  [2, 1]  =  0.0+0.2im
  [1, 2]  =  0.1+0.0im
  [4, 3]  =  0.5+0.0im
  [3, 4]  =  0.4+0.0im

In [8]:
import Base: show
function show(io::IO, M::PermuteMultiply)
    println("PermuteMultiply")
    for item in zip(M.perm, M.vals)
        i, p = item
        println("- ($i) * $p")
    end
end

println(p1)

PermuteMultiply
- (2) * 0.1 + 0.0im
- (1) * 0.2 + 0.0im
- (4) * 0.0 + 0.4im
- (3) * 0.5 + 0.0im



In [9]:
import Base: *, /, ==, copy, conj, real, imag
for func in (:conj, :real, :imag)
    @eval ($func)(M::PermuteMultiply) = PermuteMultiply(M.perm, ($func)(M.vals))
end
copy(M::PermuteMultiply) = PermuteMultiply(copy(M.perm), copy(M.vals))


import Base: transpose
function transpose(M::PermuteMultiply)
    new_perm = sortperm(M.perm)
    return PermuteMultiply(new_perm, M.vals[new_perm])
end


adjoint(S::PermuteMultiply{<:Real}) = transpose(S)
adjoint(S::PermuteMultiply{<:Complex}) = conj(transpose(S))

adjoint (generic function with 2 methods)

In [10]:
println(conj(p1)-conj(full(p1)))
println(adjoint(p1)-transpose(conj(full(p1))))
println(real(p1)-real(full(p1)))
println(imag(p1)-imag(full(p1)))

Complex{Float64}[0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im]
Complex{Float64}[0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im]
[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]


In [11]:
println(p1*transpose(p1))
println(p1*2)
println(p1/2)
println(p1==p2)

Complex{Float64}[0.01+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.04+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im -0.16+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.25+0.0im]
Complex{Float64}[0.0+0.0im 0.2+0.0im 0.0+0.0im 0.0+0.0im; 0.4+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.8im; 0.0+0.0im 0.0+0.0im 1.0+0.0im 0.0+0.0im]
Complex{Float64}[0.0+0.0im 0.05+0.0im 0.0+0.0im 0.0+0.0im; 0.1+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.2im; 0.0+0.0im 0.0+0.0im 0.25+0.0im 0.0+0.0im]
false


In [12]:
*(A::PermuteMultiply, B::Number) = PermuteMultiply(A.perm, A.vals*B)
*(B::Number, A::PermuteMultiply) = A*B
/(A::PermuteMultiply, B::Number) = PermuteMultiply(A.perm, A.vals/B)
==(A::PermuteMultiply, B::PermuteMultiply) = (A.perm==B.perm) && (A.vals==B.vals)

== (generic function with 133 methods)

In [13]:
println(p1*transpose(p1))
println(p1*2)
println(p1/2)
println(p1==p2)

Complex{Float64}[0.01+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.04+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im -0.16+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.25+0.0im]
PermuteMultiply
- (2) * 0.2 + 0.0im
- (1) * 0.4 + 0.0im
- (4) * 0.0 + 0.8im
- (3) * 1.0 + 0.0im

PermuteMultiply
- (2) * 0.05 + 0.0im
- (1) * 0.1 + 0.0im
- (4) * 0.0 + 0.2im
- (3) * 0.25 + 0.0im

false


In [14]:
import Base: nnz, nonzeros, inv
nnz(M::PermuteMultiply) = length(M.vals)
nonzeros(M::PermuteMultiply) = M.vals

function inv(M::PermuteMultiply)
    new_perm = sortperm(M.perm)
    return PermuteMultiply(new_perm, 1.0 ./ M.vals[new_perm])
end

inv (generic function with 28 methods)

In [15]:
println(nnz(p1))
println(inv(p1) * p1)

4
Complex{Float64}[1.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 1.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 1.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 1.0+0.0im]


In [16]:
function (*)(A::PermuteMultiply, X::AbstractVector)
    length(X) == size(A, 2) || throw(DimensionMismatch())
    return A.vals .* X[A.perm]
end

function (*)(X::AbstractVector, A::PermuteMultiply)
    length(X) == size(A, 1) || throw(DimensionMismatch())
    return (A.vals .* X)[sortperm(A.perm)]
end

function (*)(A::PermuteMultiply, X::AbstractMatrix)
    size(X, 1) == size(A, 2) || throw(DimensionMismatch())
    return A.vals .* X[A.perm, :]   # this may be inefficient for sparse CSC matrix.
end

function (*)(X::AbstractMatrix, A::PermuteMultiply)
    size(X, 2) == size(A, 1) || throw(DimensionMismatch())
    return (A.vals' .* X)[:, sortperm(A.perm)] # how can we lazy evaluate and cache this sort order?
end
function (*)(D::Diagonal, A::PermuteMultiply)
    T = Base.promote_op(*, eltype(D), eltype(A))
    B = copy(A)
    B.vals *= D.diag
    return B
end

function (*)(A::PermuteMultiply, D::Diagonal)
    T = Base.promote_op(*, eltype(D), eltype(A))
    B = copy(A)
    B.vals *= D.diag[B.perm]
    return B
end

function (*)(A::PermuteMultiply, B::PermuteMultiply)
    size(A, 1) == size(B, 1) || throw(DimensionMismatch())
    PermuteMultiply(B.perm[A.perm], A.vals.*B.vals[A.perm])
end

* (generic function with 191 methods)

In [17]:
println(full(p3*p2))
println(full(sparse(p3)*p2))
println(full(p3)*p2)

v = [0.5, 0.3, 0.2, 1.0]
println(p3*v)
println(full(p3)*v)

println(v'*p3)
println(v'*full(p3))

Dv = diagm(v)
println(Dv*p3)
println(Dv*full(p3))
println(p3*Dv)
println(full(p3)*Dv)

Complex{Float64}[0.0+0.0im 0.0+0.0im 0.25+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.04im 0.0+0.0im 0.0+0.0im; 0.0+0.06im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.08+0.0im]
Complex{Float64}[0.0+0.0im 0.0+0.0im 0.25+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.04im 0.0+0.0im 0.0+0.0im; 0.0-0.06im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.08+0.0im]
Complex{Float64}[0.0+0.0im 0.0+0.0im 0.25+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.04im 0.0+0.0im 0.0+0.0im; 0.0-0.06im 0.0+0.0im 0.0+0.0im 0.0+0.0im; 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.08+0.0im]
Complex{Float64}[0.5+0.0im, 0.0+0.2im, 0.09+0.0im, 0.04+0.0im]
Complex{Float64}[0.5+0.0im, 0.0+0.2im, 0.09+0.0im, 0.04+0.0im]
Complex{Float64}[0.0+0.12im 0.06+0.0im 0.2+0.0im 0.25-0.0im]
Complex{Float64}[0.0+0.12im 0.06-0.0im 0.2-0.0im 0.25-0.0im]
Complex{Float64}[0.0-0.0im 0.0-0.0im 0.0-0.0im 0.25-0.0im; 0.0-0.12im 0.0-0.0im 0.0-0.0im 0.0-0.0im; 0.0-0.0im 0.06-0.0im 0.0-0.0im 0.0-0.0im; 0.0-0.0im 0.0-0.0im 0.2-0.0im 0.0-0.0im]
Complex{Float

In [18]:
p3==full(p3)==sparse(p3)

true

In [19]:
function kron(A::PermuteMultiply{Ta}, B::PermuteMultiply{Tb}) where {Ta, Tb}
    nA = size(A, 1)
    nB = size(B, 1)
    Tc = promote_type(Ta, Tb)
    vals = kron(A.vals, B.vals)
    perm = Vector{Int}(nB*nA)
    permA = (A.perm.-1)*nB
    @inbounds for i = 1:nA
        #perm[(i-1)*nB+1:i*nB] = permA[i] +B.perm
        start = (i-1)*nB
        permAi = permA[i]
        @simd for j = 1:nB
            perm[start+j] = permAi +B.perm[j]
        end
    end
    PermuteMultiply(perm, vals)
end

kron (generic function with 14 methods)

In [20]:
Dv = Diagonal(randn(1<<8))
Pm = PermuteMultiply(randperm(1<<8), randn(1<<8));

In [21]:
sparse(kron(Pm, Pm)) == kron(sparse(Pm), sparse(Pm))

true

In [22]:
sPm = sparse(Pm)
@benchmark kron($Pm, $Pm)

BenchmarkTools.Trial: 
  memory estimate:  1.00 MiB
  allocs estimate:  13
  --------------
  minimum time:     153.579 μs (0.00% GC)
  median time:      228.585 μs (0.00% GC)
  mean time:        269.033 μs (11.00% GC)
  maximum time:     1.754 ms (71.18% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [23]:
kron(A::PermuteMultiply, B::Diagonal) = kron(A, PermuteMultiply(B))
kron(A::Diagonal, B::PermuteMultiply) = kron(PermuteMultiply(A), B)
PermuteMultiply(dmat::Diagonal) = PermuteMultiply(collect(1:size(dmat, 1)), dmat.diag)

PermuteMultiply

In [24]:
println(kron(sparse(Pm), sparse(Dv)) == sparse(kron(Pm, Dv)))
println(sparse(kron(Pm, Dv)) == kron(sparse(Pm), sparse(Dv)))
println(sparse(kron(Dv, Pm)) == kron(sparse(Dv), sparse(Pm)))

true
true
true


In [25]:
kron(A::PermuteMultiply, B::SparseMatrixCSC) = kron(sparse(A), B)
kron(A::SparseMatrixCSC, B::PermuteMultiply) = kron(A, sparse(B))

kron (generic function with 18 methods)

In [26]:
function kron(A::PermuteMultiply{Ta}, B::SparseMatrixCSC{Tb}) where {Ta, Tb}
    nA = size(A, 1)
    mB, nB = size(B)
    nV = nnz(B)
    perm = sortperm(A.perm)
    #nzval = kron(A.vals[perm], B.nzval)
    nzval = Vector{promote_type(Ta, Tb)}(nA*nV)
    rowval = Vector{Int}(nA*nV)
    colptr = Vector{Int}(nA*nB+1)
    colptr[1] = 1
    @inbounds @simd for i in 1:nA
        start_row = (i-1)*nV
        start_ri = (perm[i]-1)*mB
        v0 = A.vals[perm[i]]
        for j = 1:nV
            nzval[start_row+j] = B.nzval[j]*v0
            rowval[start_row+j] = B.rowval[j] + start_ri
        end
        start_col = (i-1)*nB+1
        start_ci = (i-1)*nV
        for j = 1:nB
            colptr[start_col+j] = B.colptr[j+1] + start_ci
        end
        #rowval[(i-1)*nV+1:i*nV] = B.rowval+(perm[i]-1)*mB
        #colptr[(i-1)*nB+1:i*nB+1] = B.colptr+(i-1)*nV
    end
    SparseMatrixCSC(mB*nA, nB*nA, colptr, rowval, nzval)
end

kron (generic function with 18 methods)

In [27]:
kron(p1, sp) == kron(sparse(p1), sp)

LoadError: [91mUndefVarError: sp not defined[39m

In [28]:
Sp = sprand(1000, 1000, 0.01);
sp = sprand(4, 4, 0.3);

In [29]:
@benchmark kron(Pm, Sp)

BenchmarkTools.Trial: 
  memory estimate:  41.46 MiB
  allocs estimate:  9
  --------------
  minimum time:     7.008 ms (0.00% GC)
  median time:      7.625 ms (6.77% GC)
  mean time:        9.420 ms (10.76% GC)
  maximum time:     81.228 ms (85.35% GC)
  --------------
  samples:          530
  evals/sample:     1

In [43]:
function kron(A::SparseMatrixCSC{T}, B::PermuteMultiply{Tb}) where {T, Tb}
    nB = size(B, 1)
    mA, nA = size(A)
    nV = nnz(A)
    perm = sortperm(B.perm)
    rowval = Vector{Int}(nB*nV)
    colptr = Vector{Int}(nA*nB+1)
    nzval = Vector{promote_type(T, Tb)}(nB*nV)
    z=1
    colptr[z] = 1
    @inbounds for i in 1:nA
        rstart = A.colptr[i]
        rend = A.colptr[i+1]-1
        @simd for k in 1:nB
            irow = perm[k]
            for r in rstart:rend
                rowval[z] = (A.rowval[r]-1)*nB+irow
                nzval[z] = A.nzval[r]*B.vals[irow]
                z+=1
            end
            colptr[(i-1)*nB+k+1] = z
        end
    end
    SparseMatrixCSC(mA*nB, nA*nB, colptr, rowval, nzval)
end

kron (generic function with 18 methods)

In [41]:
kron(sp, p1) == kron(sp, sparse(p1))

true

In [44]:
@benchmark kron(Sp, Pm)

BenchmarkTools.Trial: 
  memory estimate:  41.46 MiB
  allocs estimate:  9
  --------------
  minimum time:     7.924 ms (0.00% GC)
  median time:      15.129 ms (3.72% GC)
  mean time:        13.032 ms (11.79% GC)
  maximum time:     86.256 ms (82.21% GC)
  --------------
  samples:          384
  evals/sample:     1

In [33]:
sPm = sparse(Pm)
@benchmark kron(Sp, sPm)

BenchmarkTools.Trial: 
  memory estimate:  41.45 MiB
  allocs estimate:  7
  --------------
  minimum time:     46.951 ms (1.13% GC)
  median time:      54.182 ms (1.24% GC)
  mean time:        57.958 ms (8.62% GC)
  maximum time:     199.432 ms (64.27% GC)
  --------------
  samples:          87
  evals/sample:     1