Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected type returned from multiplication of adjoint AxisArray #189

Open
BSnelling opened this issue Dec 8, 2020 · 1 comment
Open

Comments

@BSnelling
Copy link

Minimal example (using AxisArrays 0.4.3 and julia 1.5):

julia> v = rand(3);

julia> v_ax = AxisArray(v);

julia> typeof(v')
LinearAlgebra.Adjoint{Float64,Array{Float64,1}}

julia> typeof(v_ax')
AxisArray{Float64,2,LinearAlgebra.Adjoint{Float64,Array{Float64,1}},Tuple{Axis{:transpose,Base.OneTo{Int64}},Axis{:row,Base.OneTo{Int64}}}}

The expected return of the following operation would be a scalar.

julia> typeof(v' * rand(3, 3) * v)
Float64

julia> typeof(v_ax' * rand(3, 3) * v_ax)
Array{Float64,1}

But as shown the operation using AxisArray returns an Array. Looking into the methods involved in these operations:

julia> y = rand(3, 3) * v_ax;

julia> @which v_ax' * y
*(A::AbstractArray{T,2}, x::AbstractArray{S,1}) where {T, S} in LinearAlgebra at /Applications/Julia-1.5.app/Contents/Resources/julia/share/julia/stdlib/v1.5/LinearAlgebra/src/matmul.jl:49

The method called by the multiplication of v_ax' and y is:

function (*)(A::AbstractMatrix{T}, x::AbstractVector{S}) where {T,S}
    TS = promote_op(matprod, T, S)
    mul!(similar(x,TS,axes(A,1)),A,x)
end

which shows why the return is an Array. A new method might be needed here so that v_ax' * y returns the expected scalar.

@nickrobinson251
Copy link

nickrobinson251 commented Dec 8, 2020

I think this can be solved by adding a method to AxisArrays to delegate this to *(::Adjoint{Vector}, ::Vector), e.g.

function Base.:*(a::AxisArray{T,2,<:Adjoint{T,<:AbstractVector{T}}}, b::AbstractVector{T}) where {T}
    return *(parent(a), b)
end

we probably need all the same methods as NamedDims has for this
https://github.com/invenia/NamedDims.jl/blob/24b9839091ec4b6091e38c19c9f36e111ca4fbad/src/functions_math.jl#L38-L55

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants