to do:
- `add_section` method
- enhance `getindex` so noninstantiated bases can be pulled on-the-fly
- rm `bases_to_dict` and `HALBasesDict` if performance is worse

In [3]:
using Pkg, Revise
Pkg.activate(".")

[32m[1m  Activating[22m[39m environment at `~/Desktop/SelectivelyAdaptiveLasso/Project.toml`


In [5]:
import SelectivelyAdaptiveLasso as SAL

┌ Info: Precompiling SelectivelyAdaptiveLasso [52f33b2b-83b2-4cfe-8aa4-4763af86ef4f]
└ @ Base loading.jl:1342


In [6]:
import Statistics: mean
import Distributions as Dist
import LinearAlgebra: diagm

n, p = 1000, 50

X_dist = Dist.reshape(
    Dist.MvNormal(zeros(p), 0.1ones(p,p) + 0.9diagm(ones(p))), 
    (1,p)
)

X = vcat(rand(X_dist, n)...);
# Y = rand(n)
Y = 5X[:,1] + X[:,2].^2 - X[:,3].*X[:,4]
Y = Y .- mean(Y)
Y = Y / max(abs.(Y)...)

bases = SAL.Bases(X);
mean((Y .- mean(Y)).^2)

0.060447726144800214

In [11]:
# @time β, cycle_loss, iter_loss = SAL.lasso_coordinate_descent(hcat(X,ones(n)), Y, λ=0.02)
# length(iter_loss), length(cycle_loss), cycle_loss[end], length(β[β.>1e-14])

└ @ Revise /Users/aschuler/.julia/packages/Revise/3RMhb/src/packagedef.jl:657


In [7]:
@time β, cycle_loss = SAL.coordinate_descent(bases, Y, λ=0.001)
length(cycle_loss), cycle_loss[end], length(β)

  2.153010 seconds (7.21 M allocations: 996.537 MiB, 10.78% gc time, 23.15% compilation time)


(102, 0.0027546638846404643, 600)

In [75]:
function section_search(
        bases::Bases,
        Y::Vector{Float64},
    )::Tuple{BitArray{2}, Vector{Int}}
    #=
    Find the section that will be most useful to linearly predict Y.
    
    Greedily searches through sections, starting at one-way. Given a k-way section, this function creates all 
    of the interactions between that section and all one-way sections. It then searches through all of the 
    bases within the p newly-created candidate sections and finds the basis with the maximum univariate 
    regression coefficient on the outcome. If this is greater than that of the previous iteration, it 
    replaces the k-waysection with the (k+1)-way section corresponding to that found basis. 
    
    This "top-down" approach makes sense heuristically because we expect that low-order interactions are 
    what's important for most real-world data-generating processes. There is also the fact that realizations
    of higher-order sections must have greater and greater sparsity since these are products of h∈{0,1}ⁿ.
    
    I also tried choosing the section using the sum of the univariate coefficients across all bases within 
    each candidate section. This is slightly slower because of the sum and doesn't appear to increase 
    "performance", but perhaps worth revisting later. I also tried a recursive implementation but looping 
    is faster.
    =# 
    
    section = []
    max_β = 0.0
    bases = Bases.one_way
    
    while true
        print("$max_β\n")
        new_max_β, idx = findmax(  
            sum(abs.(
                sum(Y .* bases, dims=1) ./
                sum(bases, dims=1)
            ), dims=2)
        )
        new_sectional_component = idx[3]

        if new_max_β ≤ max_β
            return bases[:,:,new_sectional_component], section
        else
            bases .= bases[:,:,new_sectional_component] .* Bases.one_way
            max_β = new_max_β
            push!(section, new_sectional_component)
        end
    end
end

select_next_bases (generic function with 2 methods)

In [None]:
mutable struct SAL
    # an instantiated selectively adaptive lasso
    
    # parameters
    knots::Matrix{Float64} # (m x p) matrix, each row is a knot
    β::Dict{Tuple{Int, Section}, Float64} # maps (knot, section) => coefficient
    
    # hyperparameters
    λ::Number = 1, # desired regularization strength (minimum if also using early stopping)
    B::Int = 500, # number of sections to adaptively select and use (max if also using early stopping)
    m::Union{Int,Nothing} = nothing, # number of data points from training data to use as knots
end

function construct_bases_dict(
        X::Matrix{Float64},
        knots::Matrix{Float64},
        sections::Set{Set{Int}}
    )
    m_knots, p = size(knots)
    bases = one_way_bases(X, knots)

    bases_dict = Dict{Tuple{Int,Int}, BitVector}()
    
    for section in sections, knot in 1:m_knots
        bases_dict[(knot, section)] = prod(bases[:, knot, [section...]], dims=3)
    end
    
    return bases_dict
end

function predict(sal, X; bases_dict)
end

In [None]:
function fit(
    # model object and hyperparameters
        sal::SAL
    # training data
        X::Matrix{Float64}, 
        Y::Matrix{Flloat64};
    # validation data for early stopping
        Xᵥ::Union{Nothing, Matrix{Float64}} = nothing, 
        Yᵥ::Union{Nothing, Matrix{Float64}} = nothing,
    )::SAL
    #= 
    Fits the selectively adaptive lasso
    =#
    
    # λ_max = n makes sense if y ∈ [-1,1] (center and then scale)
    
end

accept train/val data, max lambda

    create initial bases for train/val (use train knots for val)
    
    pick first section using train
    add these bases to the fitting set, construct them for the validation set (use same dict structure)

    loop:
        fit lasso on training set
        calculate val mse and add to record (train mse, val mse, section, lambda)
        decide: 
            1. new section (repeat above block)
            2. decrease lambda (unless min or 2nd time in a row)

In [65]:
@time next_bases, section = bases_tree_loop(Y, bases);
section

0.0
392.256850549578
505.06678751022395
  1.414116 seconds (240 allocations: 1.121 GiB, 3.95% gc time)


2-element Vector{Int64}:
 2
 3

In [38]:
@time β = coordinate_descent(bases_dict, Y, λ=1)

  9.732577 seconds (2.43 M allocations: 18.409 GiB, 36.88% gc time)


DefaultDict{Tuple{Int64, Int64}, Float64, Int64} with 6634 entries:
  (423, 30) => 0.00602341
  (547, 32) => 0.0
  (411, 1)  => 0.0
  (990, 9)  => -9.26322e-5
  (141, 20) => -0.0829121
  (332, 33) => -0.000118015
  (210, 30) => -0.000636054
  (229, 33) => 0.0126294
  (820, 33) => 0.0
  (981, 33) => 0.00241622
  (605, 42) => -0.00156275
  (779, 8)  => 0.0
  (743, 1)  => 0.000993445
  (442, 22) => 0.0
  (245, 28) => 0.00925831
  (52, 32)  => -0.0104064
  (506, 39) => 0.00117442
  (369, 32) => 0.00062154
  (147, 18) => 0.0
  (752, 10) => 0.0
  (824, 22) => 0.0
  (91, 30)  => 0.0269337
  (100, 10) => 0.0
  (357, 48) => 0.0
  (724, 23) => 0.0
  ⋮         => ⋮

In [104]:
3345.29*2 + (3345.3+5057.58) + (3347.38+3559.22) + 3695.27*2 + 3842.43*2 + 3761.67*2 + 3761.67*2 + 2068.33

54190.47

In [14]:
β

DefaultDict{Tuple{Int64, Int64}, Float64, Int64} with 14087 entries:
  (423, 30) => 0.00458607
  (141, 20) => -0.00115403
  (266, 14) => 0.0
  (229, 33) => -0.0176731
  (719, 11) => -0.000859735
  (40, 28)  => -0.000991317
  (779, 8)  => -0.000860537
  (107, 3)  => 0.0
  (447, 15) => 0.000345992
  (442, 22) => -0.00356439
  (91, 30)  => 0.000592763
  (100, 10) => 0.00426158
  (403, 48) => 0.00129259
  (866, 30) => -0.00016538
  (978, 9)  => 0.0
  (411, 6)  => -0.00503935
  (725, 40) => 0.00087678
  (333, 5)  => 0.000260004
  (278, 2)  => 0.00291205
  (309, 9)  => 0.00275902
  (153, 29) => 0.000497848
  (264, 21) => 0.0
  (625, 13) => 0.000532021
  (826, 40) => -0.00286204
  (757, 48) => 0.00251887
  ⋮         => ⋮

In [153]:
mse(Y,Ŷ)/length(Y)

0.06984108046984915

In [146]:
Stat.mean(Y)

2.0539125955565396e-17

In [147]:
Stat.mean(Ŷ)

0.001115331680277078

In [98]:
max(Ŷ...)

1.6765763880704447

In [27]:
βvec = [beta for beta in values(β)]
max(βvec...)

0.2038664658255553