In [1]:
# -------------------- #
# Dispatcher functions #
# -------------------- #

const BASE_TYPES = [:spli, :cheb, :lin]
const ABSR_MAP = Dict(
    :none => Direct(),
    :direct => Direct(),
    :tensor => Tensor(),
    :expanded => Expanded(),
)
get_bformat(b::T) where T<:BasisMatrix{Direct} = :direct
get_bformat(b::T) where T<:BasisMatrix{Expanded} = :expanded
get_bformat(b::T) where T<:BasisMatrix{Tensor} = :tensor

function to_dict(bm::BasisMatrix)
    B = Dict{Symbol, Any}()
    B[:order] = bm.order
    B[:format] = get_bformat(bm)
    B[:vals] = bm.vals
    B
end

function bm_from_dict(B::Dict)
    arr_type = eltype(B[:vals])
    bm = BasisMatrix{typeof(ABSR_MAP[B[:format]]),arr_type}(B[:order], B[:vals])
    bm
end

base_exists(s::Symbol) = s in BASE_TYPES

basedef(s::Symbol, args...) =
    s == :spli ? splidef(args...) :
    s == :cheb ? chebdef(args...) :
    s == :lin  ? lindef(args...)  :
    error("somehow you snuck through here you 👺")

basenode(s::Symbol, args...) =
    s == :spli ? splinode(args...) :
    s == :cheb ? chebnode(args...) :
    s == :lin  ? linnode(args...)  :
    error("somehow you snuck through here you 👺")

BasisMatrices.evalbase(s::Symbol, args...) =
    s == :spli ? splibase(args...) :
    s == :cheb ? chebbase(args...) :
    s == :lin  ? linbase(args...)  :
    error("somehow you snuck through here you 👺")

# Helper function
function squeeze_trail(x::AbstractArray)
    sz = size(x)
    squeezers = Int[]
    n = length(sz)
    for i=n:-1:1
        if sz[i] == 1
            push!(squeezers, i)
        else
            break
        end
    end
    squeeze(x, tuple(squeezers...))
end


# ---------------------------- #
# Generic translated functions #
# ---------------------------- #

# from fundef.m  -- DONE
function fundef(foo...)
    d = length(foo)  # 89
    n = zeros(Int, d)  # 93
    b = zeros(d)  # 94
    a = zeros(d)  # 95
    p = Array{Any}(undef, d)  # 96
    _params = Array{BasisMatrices.BasisParams}(undef, d)

    basetype = Array{Symbol}(undef, d)
    for j=1:d
        basetype[j] = foo[j][1]  # 99
        !(base_exists(basetype[j])) && error("Unknown basis $(foo[j][1])")
        n[j], a[j], b[j], p[j], _params[j] = basedef(basetype[j], foo[j][2:end]...)  # 124
    end

    # package output. Lines 143-150
    g = Dict{Symbol, Any}()
    g[:d] = d
    g[:n] = n
    g[:a] = a
    g[:b] = b
    g[:basetype] = basetype
    g[:params] = p
    g[:_basis_params] = _params
    g[:_basis] = Basis(_params...)
    g
end

# fundefn.m
function fundefn(basistype::Symbol, n, a, b, order=3)
    d = length(n)
    length(a) != d && error("a must be same dimension as n")
    length(b) != d && error("b must be same dimension as n")
    any(a .> b) && error("left endpoints must be less than right endpoints")
    any(n .< 2) && error("n(i) must be greater than 1")

    params = Array{Any}(undef, 1, d)
    if basistype == :cheb
        for i=1:d params[i] = Any[:cheb, n[i], a[i], b[i]] end
    elseif basistype == :spli
        for i=1:d params[i] = Any[:spli, [a[i], b[i]], n[i]-order+1, order] end
    elseif basistype == :lin
        for i=1:d params[i] = Any[:lin, [a[i], b[i]], n[i]] end
    end

    fundef(params...)
end

# funnode.m -- DONE
funnode(basis::Dict) = nodes(basis[:_basis])

# funbase.m -- DONE
function funbase(basis::Dict, x=funnode(basis)[1], order=fill(0, 1, basis[:d]))
    BasisMatrix(basis[:_basis], Expanded(), x, order).vals[1]
end

# funbasex.m -- DONE
function funbasex(basis::Dict{Symbol}, x=funnode(basis)[1], order=0,
                  bformat::Symbol=:none)
    to_dict(BasisMatrix(basis[:_basis], ABSR_MAP[bformat], x, order))
end

# funfitf.m -- DONE
funfitf(basis, f::Function, args...) = funfitf(basis[:_basis], f, args...)

# funfitxy.m -- DONE
function funfitxy(basis, x, y)
    c, bm = funfitxy(basis[:_basis], x, y)
    return c, to_dict(bm)
end

# funeval.m -- DONE
function funeval(c, basis::Dict, B, _order=0)
    isempty(c) && error("missing basis coefficients")
    order = BasisMatrices._check_order(basis[:d], _order)

    if isa(B, Dict)  # B is a basis structure
        bm = bm_from_dict(B)
        y = funeval(c, bm, order)
        return y, B
    else
        bm = BasisMatrix(basis[:_basis], B, order)
        y = funeval(c, bm, bm.order)
        return y, to_dict(bm)
    end
end

# fund.m
function fund(c, basis, x, hess_opt)
    # TODO: come back when I need this. I think I should probably do something
    #       like what optim does and write functions `f`, `fg!` and `fgh!` to
    #       replicate this for the type of basis instead of this function
    nothing
end

# funbconv.m  -- DONE
function funbconv(b::Dict, order=fill(0, 1, size(b[:order], 2)),
                  format::Symbol=:expanded)
    bm = bm_from_dict(b)
    new_bm = convert(typeof(ABSR_MAP[format]), bm, order)
    to_dict(new_bm)
end

LoadError: UndefVarError: Direct not defined

In [6]:
using Pkg
Pkg.add("QuantEcon")
import Pkg; Pkg.add("Combinatorics")
module BasisMatrices

# TODO: still need to write fund, minterp

#=
Note that each subtype of `BT<:BasisFamily` (with associated `PT<:BasisParam`)
will define the following constructor methods:
```julia
# basis constructors
Basis(::BT, args...)
Basis(::PT)
# node constructor
nodes(::PT)
```
=#

import Base: ==, *, \
using Base.Cartesian

using QuantEcon: gridmake, gridmake!, ckron, fix, fix!

using Combinatorics: with_replacement_combinations
using Base.Iterators: product

# types
export BasisFamily, Cheb, Lin, Spline, Basis, Smolyak,
       BasisParams, ChebParams, LinParams, SplineParams, SmolyakParams,
       AbstractBasisMatrixRep, Tensor, Expanded, Direct,
       BasisMatrix, Interpoland, SplineSparse, RowKron

# functions
export nodes, get_coefs, funfitxy, funfitf, funeval, evalbase,
       derivative_op, row_kron, evaluate, fit!, update_coefs!,
       complete_polynomial, complete_polynomial!, n_complete

#re-exports
export gridmake, gridmake!, ckron

# stdlib
using SparseArrays, LinearAlgebra, Statistics

abstract type BasisFamily end
abstract type BasisParams end
const IntSorV = Union{Int, AbstractVector{Int}}
const TensorX = Union{Tuple{Vararg{AbstractVector}},AbstractVector{<:AbstractVector}}

include("util.jl")
include("spline_sparse.jl")

# include the families

# BasisParams interface
SparseArrays.issparse(::Type{T}) where {T<:BasisParams} = false
Base.ndims(::BasisParams) = 1
for f in [:family, :family_name, :(SparseArrays.issparse), :(Base.eltype)]
    @eval $(f)(::T) where {T<:BasisParams} = $(f)(T)
end
include("cheb.jl")
include("lin.jl")
include("spline.jl")
include("complete.jl")
include("smolyak.jl")

evalbase(p::BasisParams, x::Number, args...) = evalbase(p, [x], args...)

# now some more interface methods that only make sense once we have defined
# the subtypes
basis_eltype(::TP, x) where {TP<:BasisParams} = promote_type(eltype(TP), eltype(x))
basis_eltype(::Type{TP}, x) where {TP<:BasisParams} = promote_type(eltype(TP), eltype(x))
"""
    basis_eltype(p::Union{BasisParams,Type{<:BasisParams}, x)
Return the eltype of the Basis matrix that would be obtained by calling
`evalbase(p, x)`
"""
basis_eltype

# give the type of the `vals` field based on the family type parameter of the
# corresponding basis. `Spline` and `Lin` use sparse, `Cheb` uses dense
# a hybrid must fall back to a generic AbstractMatrix{Float64}
# the default is just a dense matrix
# because there is only one dense version, we will start with the sparse
# case and overload for Cheb
bmat_type(::Type{TP}, x) where {TP<:BasisParams} = SparseMatrixCSC{basis_eltype(TP, x),Int}
bmat_type(::Type{T2}, ::Type{TP}, x) where {TP<:BasisParams,T2} = bmat_type(TP, x)
function bmat_type(::Type{T2}, ::Type{TP}, x) where {TP<:BasisParams,T2<:SplineSparse}
    SplineSparse{basis_eltype(TP, x),Int}
end

bmat_type(::Type{T}, x) where {T<:ChebParams} = Matrix{basis_eltype(T, x)}
function bmat_type(::Type{T2}, ::Type{TP}, x) where {TP<:ChebParams,T2<:SplineSparse}
    bmat_type(TP, x)
end

# version where there isn't an x passed
bmat_type(::Type{TP}) where {TP<:BasisParams} = bmat_type(TP, one(eltype(TP)))
function bmat_type(::Type{T2}, ::Type{TP}) where {TP<:BasisParams,T2}
    bmat_type(T2, TP, one(eltype(TP)))
end

# add methods to instances
bmat_type(::T) where {T<:BasisParams} = bmat_type(T)
bmat_type(ss::Type{T2}, ::TF) where {TF<:BasisParams,T2} = bmat_type(T2, TF)
bmat_type(ss::Type{T2}, ::TF, x) where {TF<:BasisParams,T2} = bmat_type(T2, TF, x)

# default method for evalbase with extra type hint is to just ignore the extra
# type hint
evalbase(::Type{T}, bp::BasisParams, x, order) where {T} = evalbase(bp, x, order)


# include other
include("basis.jl")
include("basis_structure.jl")
include("interp.jl")


# deprecations
@deprecate BasisStructure BasisMatrix

end # module


[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `C:\Users\Giorgia\.julia\environments\v1.7\Project.toml`
[32m[1m  No Changes[22m[39m to `C:\Users\Giorgia\.julia\environments\v1.7\Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `C:\Users\Giorgia\.julia\environments\v1.7\Project.toml`
[32m[1m  No Changes[22m[39m to `C:\Users\Giorgia\.julia\environments\v1.7\Manifest.toml`


LoadError: SystemError: opening file "C:\\Users\\Giorgia\\Dropbox\\ReplicationProject\\DairuchPaper\\Replication\\Files_replicated_Python\\util.jl": Permission denied

In [3]:
for (T, TP) in [(Cheb, ChebParams), (Lin, LinParams), (Spline, SplineParams)]
    @eval _param(::$(T)) = $TP
    @eval _param(::Type{$(T)}) = $TP
end

# params of same type are equal if all their fields are equal
==(p1::T, p2::T) where {T<:BasisParams} =
    all(map(nm->getfield(p1, nm) == getfield(p2, nm), fieldnames(T)))::Bool

# Bases of different dimension can't be equal
==(::T1, ::T2) where {T1<:BasisParams,T2<:BasisParams} = false

# ---------- #
# Basis Type #
# ---------- #

struct Basis{N,TP<:Tuple}
    params::TP     # params to construct basis
end

Base.min(b::Basis) = min.(b.params)
Base.max(b::Basis) = max.(b.params)
Base.ndims(::Basis{N}) where {N} = N
Base.ndims(::Type{Basis{N,TP}}) where {N,TP} = N

_get_TP(::Basis{N,TP}) where {N,TP} = TP
_get_TP(::Type{Basis{N,TP}}) where {N,TP} = TP

function Base.show(io::IO, b::Basis{N}) where N
    m = """
    $N dimensional Basis on the hypercube formed by $(min(b)) × $(max(b)).
    Basis families are $(join(string.(family_name.(b.params)), " × "))
    """
    print(io, m)
end

Basis(params::BasisParams...) = _Basis(params)
Basis(params::Tuple) = _Basis(params)

# hack to make method above type stable
@generated function _Basis(params)
    N = length(params.parameters)
    quote
        Basis{$N,$params}(params)
    end
end

# combining basis -- fundef-esque method
Basis(bs::Basis...) = _Basis2(bs)

# Another hack to make the above type stable
@generated function _Basis2(bs)
    N = sum(ndims, bs.parameters)

    # tup_of_tups will be a tuple where each element is a tuple of types
    # we want to concatenate them and end up with with a single tuple of types
    # to do that we first put them in a vector, then splat that vector into a
    # Tuple. Note that we don't use `tuple` because that will create a tuple
    # obejct, not the `Tuple` type.
    tup_of_tups = map(_get_TP, bs.parameters)
    basis_types = []
    for x in tup_of_tups
        push!(basis_types, x.parameters...)
    end
    TP = Tuple{basis_types...}
    quote
        new_params = []
        for x in bs
            push!(new_params, x.params...)
        end
        Basis{$N,$TP}(tuple(new_params...))
    end
end

# fundefn type method
Basis(bt::BasisFamily, n::Int, a, b) = Basis(_param(bt)(n, a, b))

Basis(::Type{T}, n::Int, a, b) where {T<:BasisFamily} = Basis(T(), n, a, b)

Basis(bt::T, n::Vector, a::Vector, b::Vector) where {T<:BasisFamily} =
    Basis(map(_param(T), n, a, b)...)

# special method for Spline that adds `k` argument
Basis(::Spline, n::Int, a, b, k) = Basis(SplineParams(n, a, b, k))
Basis(::Spline, n::Vector, a::Vector, b::Vector, k::Vector=ones(Int, length(n))) =
    Basis(map(SplineParams, n, a, b, k)...)::Basis{length(n)}

# ----------------- #
# Basis API methods #
# ----------------- #

# separating Basis -- just re construct it from the nth set of params
function Base.getindex(basis::Basis{N}, n::Int) where N
    n < 0 || n > N && error("n must be between 1 and $N")
    Basis(basis.params[n])::Basis{1}
end

_all_sparse(b::Basis{N,TP}) where {N,TP} = all(issparse, TP.parameters)

# other AbstractArray like methods for Basis
Base.length(b::Basis) = prod(length, b.params)
Base.size(b::Basis, i::Int) = length(b[i])  # uses method on previous line
Base.size(b::Basis{N}) where {N} = map(length, b.params)

# Bases of different dimension can't be equal
==(::Basis{N}, ::Basis{M}) where {N,M} = false

# basis are equal if all fields of the basis are equal
==(b1::Basis{N}, b2::Basis{N}) where {N} =
    all(map(nm->getfield(b1, nm) == getfield(b2, nm), fieldnames(Basis)))::Bool

function nodes(b::Basis{1})
    x = nodes(b.params[1])
    (x, (x,))
end

function nodes(b::Basis)  # funnode method
    xcoord = nodes.(b.params)
    x = gridmake(xcoord...)
    return x, xcoord
end

@generated function bmat_type(::Type{TO}, bm::Basis{N,TP}, x=1.0) where {N,TP,TO}
    if N == 1
        out = bmat_type(TO, TP.parameters[1], x)
    else
        out = bmat_type(TO, TP.parameters[1], x)
        for this_TP in TP.parameters[2:end]
            this_out = bmat_type(TO, this_TP, x)
            if this_out != out
                out = AbstractMatrix{promote_type(eltype(out), eltype(this_out))}
            end
        end
    end
    return :($out)
end

bmat_type(b::Basis) = bmat_type(Nothing, b)
bmat_type(b::Basis, x) = bmat_type(Nothing, b, x)


LoadError: UndefVarError: Cheb not defined