In [1]:
] activate .

In [2]:
using BenchmarkTools, StaticArrays, LinearAlgebra

Generic code

In [3]:
inner(v, A, w) = dot(v, A*w)

function innersum(A, vs)
    t = zero(eltype(A)) # generic!
    for v in vs
        t += inner(v, A, v)
    end
    return t
end

innersum (generic function with 1 method)

In [4]:
A = rand(3,3)
vs = [rand(3) for _ = 1:10]

innersum(A, vs)

14.684932824817619

Some other type

In [5]:
using StaticArrays

A = @SMatrix rand(3,3)

vs = [@SVector rand(3) for _ in 1:10]

innersum(A, vs)

16.11338843933083

Own type

In [6]:
# One-hot vector
# v = (0, ..., 0, 1, 0, .., 0)

struct OneHotVector <: AbstractVector{Bool}
    len::Int
    ind::Int
end

In [7]:
v = OneHotVector(5,2)

MethodError: MethodError: no method matching size(::OneHotVector)
Closest candidates are:
  size(::AbstractArray{T,N}, !Matched::Any) where {T, N} at abstractarray.jl:38
  size(!Matched::BitArray{1}) at bitarray.jl:77
  size(!Matched::BitArray{1}, !Matched::Any) at bitarray.jl:81
  ...

What functions do we need to implement for our type to actually be a vector?

https://docs.julialang.org/en/latest/manual/interfaces/#man-interface-array-1

In [9]:
Base.size(v::OneHotVector) = (v.len,)

Base.getindex(v::OneHotVector, i::Integer) = i == v.ind

In [10]:
v = OneHotVector(5,2)

5-element OneHotVector:
 false
  true
 false
 false
 false

In [11]:
norm(v)

1.0

In [12]:
dump(v)

OneHotVector
  len: Int64 5
  ind: Int64 2


In [13]:
vs = [OneHotVector(3, rand(1:3)) for _ in 1:10]

10-element Array{OneHotVector,1}:
 [false, false, true]
 [false, false, true]
 [false, false, true]
 [false, true, false]
 [true, false, false]
 [false, false, true]
 [true, false, false]
 [true, false, false]
 [false, true, false]
 [false, false, true]

In [14]:
innersum(A, vs) # just works!

6.2632725803179286

# Efficieny

In [15]:
A = rand(1000,1000);
v = OneHotVector(1000,2);

In [16]:
@which A*v

In [18]:
@btime $A*$v;

  283.800 μs (1 allocation: 7.94 KiB)


In [19]:
import Base: *
*(A::AbstractMatrix, v::OneHotVector) where T = A[:, v.ind]

* (generic function with 370 methods)

In [20]:
@which A*v

In [21]:
@btime $A * $v;

  544.845 ns (2 allocations: 7.97 KiB)
