In [239]:
using BenchmarkTools
struct PermuteMultiply{Tv, Ti<:Integer} <: AbstractMatrix{Tv}
    perm::Vector{Ti}   # new orders
    vals::Vector{Tv}  # multiplied values.

    function PermuteMultiply{Tv, Ti}(perm::Vector{Ti}, vals::Vector{Tv}) where {Tv, Ti<:Integer}
        if length(perm) != length(vals)
            throw(DimensionMismatch("permutation ($(length(perm))) and multiply ($(length(vals))) length mismatch."))
        end
        new{Tv, Ti}(perm, vals)
    end
end

function invperm(order)
    v = similar(order)
    @inbounds @simd for i=1:length(order)
        v[order[i]] = i
    end
    v
end

function PermuteMultiply(perm::Vector, vals::Vector)
    Tv = eltype(vals)
    Ti = eltype(perm)
    PermuteMultiply{Tv,Ti}(perm, vals)
end

import Base: size, show, eltype, getindex, full
size(M::PermuteMultiply) = (length(M.perm), length(M.perm))
function size(A::PermuteMultiply, d::Integer)
    if d < 1
        throw(ArgumentError("dimension must be ≥ 1, got $d"))
    elseif d<=2
        return length(A.perm)
    else
        return 1
    end
end
getindex(M::PermuteMultiply, i::Integer, j::Integer) = M.perm[i] == j ? M.vals[i] : 0

LoadError: [91merror in method definition: function Base.invperm must be explicitly imported to be extended[39m

In [240]:
@benchmark full(p1)
function Matrix{T}(M::PermuteMultiply) where T
    n = size(M, 1)
    Mf = zeros(T, n, n)
    @inbounds for i=1:n
        Mf[i, M.perm[i]] = M.vals[i]
    end
    return Mf
end
Matrix(M::PermuteMultiply{T}) where {T} = Matrix{T}(M)
Array(M::PermuteMultiply) = Matrix(M)
full(M::PermuteMultiply) = Matrix(M)

full (generic function with 25 methods)

In [241]:
import Base: show
function show(io::IO, M::PermuteMultiply)
    println("PermuteMultiply")
    for item in zip(M.perm, M.vals)
        i, p = item
        println("- ($i) * $p")
    end
end

println(p1)

PermuteMultiply
- (2) * 0.1 + 0.0im
- (1) * 0.2 + 0.0im
- (4) * 0.0 + 0.4im
- (3) * 0.5 + 0.0im



In [435]:
import Base: randn
randn(T::Type{Complex{F}}, n::Int...) where F = randn(F, n...) + im*randn(F, n...)
pmrand(T::Type, n::Int) = PermuteMultiply(randperm(n), randn(T, n))
pmrand(n::Int) = PermuteMultiply(randperm(n), randn(n))

pmrand (generic function with 2 methods)

In [437]:
p1 = PermuteMultiply([2,1,4,3],[0.1, 0.2, 0.4im, 0.5])
p2 = PermuteMultiply([2,1,4,3],[0.1, 0.2im, 0.4, 0.5])
p3 = pmrand(Complex128, 4)
sp = sprand(4,4,0.3)
ds = randn(Complex128, 4,4)
println(p1)
println(p3)
println(ds)
fp1, fp2, fp3 = full(p1), full(p2), full(p3)
sp1, sp2, sp3 = sparse(p1), sparse(p2), sparse(p3)
println(fp1 == p1 == sp1)
println(size(p1)==size(fp1) && size(p1,1) == size(fp1,1))
println(!isapprox(p1, p2))

PermuteMultiply
- (2) * 0.1 + 0.0im
- (1) * 0.2 + 0.0im
- (4) * 0.0 + 0.4im
- (3) * 0.5 + 0.0im

PermuteMultiply
- (1) * 0.274625711759987 + 0.8443762778750868im
- (3) * -1.0505635217492233 + 2.0240720084886794im
- (2) * 0.5371287098449212 - 0.5238685022371369im
- (4) * -0.17687353184013457 + 0.3817937452006296im

Complex{Float64}[0.494975-0.424953im -0.100426+1.21454im 0.866+0.0614908im -0.527156-0.383367im; 1.25135+0.321035im -2.0168-0.393942im -0.360625+0.431917im 0.102813+0.772808im; 2.05566+0.532734im 0.900861+1.30656im 1.73243+0.123618im 1.00657+0.523671im; 0.390475+0.305632im -0.926132+0.753056im 1.74789+0.191208im 0.0461557+0.758886im]
true
true
true


In [244]:
println(eltype(p1))

Complex{Float64}


In [245]:
function Matrix{T}(M::PermuteMultiply) where T
    n = size(M, 1)
    Mf = zeros(T, n, n)
    @inbounds for i=1:n
        Mf[i, M.perm[i]] = M.vals[i]
    end
    return Mf
end
Matrix(M::PermuteMultiply{T}) where {T} = Matrix{T}(M)
Array(M::PermuteMultiply) = Matrix(M)

Array (generic function with 1 method)

In [246]:
import Base: sparse, kron
function sparse(M::PermuteMultiply{T}) where {T}
    n = size(M, 1)
    sparse(collect(1:n), M.perm, M.vals, n, n)
end

sparse (generic function with 23 methods)

In [247]:
import Base: *, /, ==, copy, conj, real, imag
for func in (:conj, :real, :imag)
    @eval ($func)(M::PermuteMultiply) = PermuteMultiply(M.perm, ($func)(M.vals))
end
copy(M::PermuteMultiply) = PermuteMultiply(copy(M.perm), copy(M.vals))


import Base: transpose
function transpose(M::PermuteMultiply)
    new_perm = invperm(M.perm)
    return PermuteMultiply(new_perm, M.vals[new_perm])
end


adjoint(S::PermuteMultiply{<:Real}) = transpose(S)
adjoint(S::PermuteMultiply{<:Complex}) = conj(transpose(S))

adjoint (generic function with 2 methods)

In [248]:
println(conj(p1)==conj(fp1))
println(adjoint(p1)==transpose(conj(fp1)))
println(real(p1)==real(fp1))
println(imag(p1)==imag(fp1))

true
true
true
true


In [249]:
println(p1*transpose(p1)==fp1*transpose(fp1))
println(p1*2==fp1*2)
println(p1/2==fp1/2)

true
true
true


In [250]:
*(A::PermuteMultiply, B::Number) = PermuteMultiply(A.perm, A.vals*B)
*(B::Number, A::PermuteMultiply) = A*B
/(A::PermuteMultiply, B::Number) = PermuteMultiply(A.perm, A.vals/B)
==(A::PermuteMultiply, B::PermuteMultiply) = (A.perm==B.perm) && (A.vals==B.vals)

== (generic function with 133 methods)

In [251]:
import Base: nnz, nonzeros, inv
nnz(M::PermuteMultiply) = length(M.vals)
nonzeros(M::PermuteMultiply) = M.vals

function inv(M::PermuteMultiply)
    new_perm = invperm(M.perm)
    return PermuteMultiply(new_perm, 1.0 ./ M.vals[new_perm])
end

inv (generic function with 28 methods)

In [252]:
println(nnz(p1)==4)
println(inv(p1) * p1==eye(4))

true
true


In [418]:
function (*)(A::PermuteMultiply{Ta}, X::AbstractVector{Tx}) where {Ta, Tx}
    nX = length(X)
    nX == size(A, 2) || throw(DimensionMismatch())
    v = similar(X, promote_type(Ta, Tx))
    @simd for i = 1:nX
        @inbounds v[i] = A.vals[i]*X[A.perm[i]]
    end
    v
end

function (*)(X::RowVector{Tx}, A::PermuteMultiply{Ta}) where {Tx, Ta}
    nX = length(X)
    nX == size(A, 1) || throw(DimensionMismatch())
    v = similar(X, promote_type(Tx, Ta))
    @simd for i = 1:nX
        @inbounds v[A.perm[i]] = A.vals[i]*X[i]
    end
    v
end

function (*)(D::Diagonal{Td}, A::PermuteMultiply{Ta}) where {Td, Ta}
    T = Base.promote_op(*, Td, Ta)
    PermuteMultiply(A.perm, A.vals .* D.diag)
end

function (*)(A::PermuteMultiply{Ta}, D::Diagonal{Td}) where {Td, Ta}
    T = Base.promote_op(*, Td, Ta)
    PermuteMultiply(A.perm, A.vals .* view(D.diag, A.perm))
end

function (*)(A::PermuteMultiply, B::PermuteMultiply)
    size(A, 1) == size(B, 1) || throw(DimensionMismatch())
    PermuteMultiply(B.perm[A.perm], A.vals.*view(B.vals, A.perm))
end

* (generic function with 194 methods)

In [419]:
Ds = randn(1000,1000)
Pm = pmrand(1000)
v = randn(1000)
Dv = Diagonal(v)
sPm = sparse(Pm);

In [420]:
println(Pm*v == sparse(Pm)*v)
println(Pm*Dv == sparse(Pm)*Dv)
println(v'*Pm == v'*sparse(Pm))
println(Dv'*Pm == Dv'*sparse(Pm))

true
true
true
true


In [434]:
#@benchmark $(v')*$Pm
#@benchmark $Pm*$v
#@benchmark $Dv*$Pm
#@benchmark $Pm*$Dv
#@benchmark $Pm*$Pm

In [408]:
# to matrix
function (*)(A::PermuteMultiply, X::AbstractMatrix)
    size(X, 1) == size(A, 2) || throw(DimensionMismatch())
    return @views A.vals .* X[A.perm, :]   # this may be inefficient for sparse CSC matrix.
end

function (*)(X::AbstractMatrix, A::PermuteMultiply)
    mX, nX = size(X)
    nX == size(A, 1) || throw(DimensionMismatch())
    return @views (A.vals' .* X)[:, invperm(A.perm)]
end

* (generic function with 193 methods)

In [409]:
println(Pm*Ds == sparse(Pm)*Ds)
println(Ds*Pm == Ds*sparse(Pm))

true
true


In [404]:
#@benchmark $sPm*$Ds
#@benchmark $Pm*$Ds
#@benchmark $Ds*$sPm
#@benchmark $Ds*$Pm
@benchmark $v'*$Pm

BenchmarkTools.Trial: 
  memory estimate:  30.60 MiB
  allocs estimate:  2005004
  --------------
  minimum time:     40.862 ms (4.32% GC)
  median time:      42.362 ms (4.53% GC)
  mean time:        42.693 ms (5.95% GC)
  maximum time:     51.611 ms (4.82% GC)
  --------------
  samples:          118
  evals/sample:     1

In [141]:
function (*)(A::PermuteMultiply, X::SparseMatrixCSC)
    nA = size(A, 1)
    mX, nX = size(X)
    mX == nA || throw(DimensionMismatch())
    perm = invperm(A.perm)
    nzval = similar(X.nzval)
    rowval = similar(X.rowval)
    @inbounds for j = 1:nA
        @inbounds @simd for k = X.colptr[j]:X.colptr[j+1]-1
            r = perm[X.rowval[k]]
            nzval[k] = X.nzval[k]*A.vals[r]
            rowval[k] = r
        end
    end
    SparseMatrixCSC(mX, nX, X.colptr, rowval, nzval)
end

function (*)(X::SparseMatrixCSC, A::PermuteMultiply)
    nA = size(A, 1)
    mX, nX = size(X)
    nX == nA || throw(DimensionMismatch())
    perm = invperm(A.perm)
    nzval = similar(X.nzval)
    colptr = similar(X.colptr)
    rowval = similar(X.rowval)
    colptr[1] = 1
    z = 1
    @inbounds for j = 1:nA
        pk = perm[j]
        va = A.vals[pk]
        @inbounds @simd for k = X.colptr[pk]:X.colptr[pk+1]-1
            nzval[z] = X.nzval[k]*va
            rowval[z] = X.rowval[k]
            z+=1
        end
        colptr[j+1] = z
    end
    SparseMatrixCSC(mX, nX, colptr, rowval, nzval)
end

* (generic function with 193 methods)

In [135]:
println(Sp*Pm == Sp*sPm)
println((Pm*Sp)*Sp==(sPm*Sp)*Sp)
println(Pm*Sp*Sp == sPm*Sp*Sp)

true
true
true


In [145]:
#@benchmark $Sp*$Pm
#@benchmark $Sp*$sPm
#@benchmark $Pm*$Sp
#@benchmark $sPm*$Sp
#@benchmark invperm($Pm.perm)

BenchmarkTools.Trial: 
  memory estimate:  164.53 KiB
  allocs estimate:  7
  --------------
  minimum time:     64.573 μs (0.00% GC)
  median time:      67.606 μs (0.00% GC)
  mean time:        75.255 μs (7.67% GC)
  maximum time:     1.193 ms (84.53% GC)
  --------------
  samples:          10000
  evals/sample:     1

In [12]:
println(full(p3*p2) == full(p3)*full(p2))
println(full(sparse(p3)*p2) == full(p3)*full(p2))
println(full(p3)*p2 == full(p3)*full(p2))

v = [0.5, 0.3, 0.2, 1.0]
println(p3*v == full(p3)*v)

println(v'*p3==v'*full(p3))

Dv = diagm(v)
println(Dv*p3 == Dv*full(p3) == full(Dv)*full(p3))
println(p3*Dv == full(p3)*Dv == full(p3)*full(Dv))

true
true
true
true
true
true
true


In [18]:
p3==full(p3)==sparse(p3)

true

In [46]:
function kron(A::PermuteMultiply{Ta}, B::PermuteMultiply{Tb}) where {Ta, Tb}
    nA = size(A, 1)
    nB = size(B, 1)
    Tc = promote_type(Ta, Tb)
    vals = kron(A.vals, B.vals)
    perm = Vector{Int}(nB*nA)
    permA = (A.perm.-1)*nB
    @inbounds for i = 1:nA
        #perm[(i-1)*nB+1:i*nB] = permA[i] +B.perm
        start = (i-1)*nB
        permAi = permA[i]
        @inbounds @simd for j = 1:nB
            perm[start+j] = permAi +B.perm[j]
        end
    end
    PermuteMultiply(perm, vals)
end

kron (generic function with 14 methods)

In [53]:
Dv = Diagonal(randn(1<<8))
Pm = PermuteMultiply(randperm(1<<8), randn(1<<8));

In [54]:
sparse(kron(Pm, Pm)) == kron(sparse(Pm), sparse(Pm))

true

In [107]:
sPm = sparse(Pm);
#@benchmark kron($Pm, $Pm)

In [23]:
kron(A::PermuteMultiply, B::Diagonal) = kron(A, PermuteMultiply(B))
kron(A::Diagonal, B::PermuteMultiply) = kron(PermuteMultiply(A), B)
PermuteMultiply(dmat::Diagonal) = PermuteMultiply(collect(1:size(dmat, 1)), dmat.diag)

PermuteMultiply

In [24]:
println(kron(sparse(Pm), sparse(Dv)) == sparse(kron(Pm, Dv)))
println(sparse(kron(Pm, Dv)) == kron(sparse(Pm), sparse(Dv)))
println(sparse(kron(Dv, Pm)) == kron(sparse(Dv), sparse(Pm)))

true
true
true


In [25]:
kron(A::PermuteMultiply, B::SparseMatrixCSC) = kron(sparse(A), B)
kron(A::SparseMatrixCSC, B::PermuteMultiply) = kron(A, sparse(B))

kron (generic function with 18 methods)

In [66]:
function kron(A::PermuteMultiply{Ta}, B::SparseMatrixCSC{Tb}) where {Ta, Tb}
    nA = size(A, 1)
    mB, nB = size(B)
    nV = nnz(B)
    perm = invperm(A.perm)
    #nzval = kron(A.vals[perm], B.nzval)
    nzval = Vector{promote_type(Ta, Tb)}(nA*nV)
    rowval = Vector{Int}(nA*nV)
    colptr = Vector{Int}(nA*nB+1)
    colptr[1] = 1
    @inbounds @simd for i in 1:nA
        start_row = (i-1)*nV
        start_ri = (perm[i]-1)*mB
        v0 = A.vals[perm[i]]
        @inbounds @simd for j = 1:nV
            nzval[start_row+j] = B.nzval[j]*v0
            rowval[start_row+j] = B.rowval[j] + start_ri
        end
        start_col = (i-1)*nB+1
        start_ci = (i-1)*nV
        @inbounds @simd for j = 1:nB
            colptr[start_col+j] = B.colptr[j+1] + start_ci
        end
        #=
        rowval[(i-1)*nV+1:i*nV] = B.rowval+(perm[i]-1)*mB
        colptr[(i-1)*nB+1:i*nB+1] = B.colptr+(i-1)*nV
        =#
    end
    SparseMatrixCSC(mB*nA, nB*nA, colptr, rowval, nzval)
end

kron (generic function with 18 methods)

In [67]:
kron(p1, sp) == kron(sparse(p1), sp)

true

In [68]:
Sp = sprand(1000, 1000, 0.01);
sp = sprand(4, 4, 0.3);

In [108]:
#@benchmark kron($Pm, $Sp)

In [128]:
function kron(A::SparseMatrixCSC{T}, B::PermuteMultiply{Tb}) where {T, Tb}
    nB = size(B, 1)
    mA, nA = size(A)
    nV = nnz(A)
    perm = invperm(B.perm)
    rowval = Vector{Int}(nB*nV)
    colptr = Vector{Int}(nA*nB+1)
    nzval = Vector{promote_type(T, Tb)}(nB*nV)
    z=1
    colptr[z] = 1
    @inbounds for i in 1:nA
        rstart = A.colptr[i]
        rend = A.colptr[i+1]-1
        @inbounds for k in 1:nB
            irow = perm[k]
            bval = B.vals[irow]
            irow_nB = irow - nB
            @inbounds @simd for r in rstart:rend
                rowval[z] = A.rowval[r]*nB+irow_nB
                nzval[z] = A.nzval[r]*bval
                z+=1
            end
            colptr[(i-1)*nB+k+1] = z
        end
    end
    SparseMatrixCSC(mA*nB, nA*nB, colptr, rowval, nzval)
end

kron (generic function with 20 methods)

In [129]:
kron(sp, p1) == kron(sp, sparse(p1))

true

In [109]:
#@benchmark kron(Sp, Pm)
sPm = sparse(Pm);
#@benchmark kron(Sp, sPm)

In [163]:
function kron(A::StridedMatrix{Tv}, B::PermuteMultiply{Tb}) where {Tv, Tb}
    mA, nA = size(A)
    nB = size(B, 1)
    perm = invperm(B.perm)
    nzval = Vector{promote_type(Tv, Tb)}(mA*nA*nB)
    rowval = Vector{Int}(mA*nA*nB)
    colptr = collect(1:mA:nA*nB*mA+1)
    z = 1
    @inbounds for j = 1:nA
        @inbounds for j2 = 1:nB
            p2 = perm[j2]
            val2 = B.vals[p2]
            ir = p2
            @inbounds @simd for i = 1:mA
                nzval[z] = A[i, j]*val2  # merge
                rowval[z] = ir
                z += 1
                ir += nB
            end
        end
    end
    SparseMatrixCSC(mA*nB, nA*nB, colptr, rowval, nzval)
end

function kron(A::PermuteMultiply{Ta}, B::StridedMatrix{Tb}) where {Tb, Ta}
    mB, nB = size(B)
    nA = size(A, 1)
    perm = invperm(A.perm)
    nzval = Vector{promote_type(Ta, Tb)}(mB*nA*nB)
    rowval = Vector{Int}(mB*nA*nB)
    colptr = collect(1:mB:nA*nB*mB+1)
    z = 1
    @inbounds for j = 1:nA
        colbase = (j-1)*nB
        p1 = perm[j]
        val2 = A.vals[p1]
        ir = (p1-1)*mB
        @inbounds for j2 = 1:nB
            @inbounds @simd for i2 = 1:mB
                nzval[z] = B[i2, j2]*val2  # merge
                rowval[z] = ir+i2
                z += 1
            end
        end
    end
    SparseMatrixCSC(nA*mB, nA*nB, colptr, rowval, nzval)
end

kron (generic function with 20 methods)

In [164]:
Dm = randn(100,100);
println(kron(Dm, p1) == kron(Dm, sparse(p1)))
println(kron(p1, Dm) == kron(sparse(p1), Dm))

true
true


In [111]:
sp1 = sparse(p1);
#@benchmark kron(Dm, sp1)
#@benchmark size(kron(Dm, p1))
#@benchmark size(kron(p1, Dm))
#@benchmark size(kron(sp1, Dm))

In [168]:
issubtype(typeof(p1), StridedMatrix)

false

In [444]:
findn(sp)

([2, 1, 4, 1, 4, 3], [1, 2, 2, 3, 3, 4])

In [455]:
function PermuteMultiply(ds::AbstractMatrix)
    i,j,v = findnz(ds)
    j == collect(1:size(ds, 2)) || throw(ArgumentError())
    order = invperm(i)
    PermuteMultiply(order, v[order])
end

PermuteMultiply

In [456]:
PermuteMultiply(full(p1)) == p1

true

In [460]:
import Base: convert
convert(::Type{PermuteMultiply{T}}, B::PermuteMultiply) where T = PermuteMultiply(B.perm, T.(B.vals))

convert (generic function with 722 methods)

In [462]:
convert(PermuteMultiply{Complex32}, p1)

4×4 PermuteMultiply{Complex{Float16},Int64}:
        0       0.099976+0.0im     0          0        
 0.19995+0.0im          0          0          0        
        0               0          0       0.0+0.3999im
        0               0       0.5+0.0im     0        