In [6]:
using Pkg; Pkg.activate("/Users/work/Code/Repos/Temp/CausalityTools.jl")
using Revise, CausalityTools

[32m[1m  Activating[22m[39m project at `~/Code/Repos/Temp/CausalityTools.jl`


In [281]:
using Neighborhood: Euclidean, Chebyshev, KDTree, Theiler, NeighborNumber
using Neighborhood: bulksearch
using Distances: evaluate
using DelayEmbeddings.StateSpaceSets: SubStateSpaceSet
using LinearAlgebra: det, norm
using StateSpaceSets: StateSpaceSet
using StaticArrays: MVector, MMatrix, SVector, SMatrix

import Entropies: entropy

"""
    Gao2017 <: EntropyEstimator
    Gao2017(k = 1, w = 1, base = 2)

A resubstitution estimator from Gao et al. (2017). Can be used both for entropy
estimation and

[^Gao2017]: Gao, W., Oh, S., & Viswanath, P. (2017, June). Density functional estimators
    with k-nearest neighbor bandwidths. In 2017 IEEE International Symposium on Information
    Theory (ISIT) (pp. 1351-1355). IEEE.
"""
Base.@kwdef struct Gao2017{B, M} #<: CausalityTools.InformationEstimator
    k::Int = 1
    w::Int = 0
    base::B = 2
    metric::M = Euclidean()
end

function Î(q, est::Gao2017, x::AbstractStateSpaceSet{D}) where D
    (; k, w, base, metric) = est
    N = length(x)
    tree = KDTree(x, metric)
    Bk,d,α,K = bias(est)
    idxs, ds = bulksearch(tree, x, NeighborNumber(k), Theiler(w))

end

# TODO: implement
multiplicative_bias(est::Gao2017) = 1.0

Base.@kwdef struct LocalLikelihood{M} <: ProbabilitiesEstimator
    k::Int = 5
    w::Int = 0
    metric::M = Euclidean()
end

function point_densities(est::LocalLikelihood, x::AbstractStateSpaceSet{D}) where D
    (; k, w, metric) = est
    N = length(x)
    # Modified heuristic from Gao et al. (2017): it is sufficient to consider the 
    # `K = max(floor(Int, log(N), k)` nearest neighbors neighbors of `x[i]` when 
    # estimating the local density. A global point-search is pointless and expensive.
    kmax = max(floor(Int, log(N)), k)
    tree = KDTree(x, Euclidean())
    
    # The bandwidth `bw[i]` for the point `x[i]` is the distance to the `k`-th nearest
    # neighbor of `x[i]`.
    idxs, ds = bulksearch(tree, x, NeighborNumber(kmax), Theiler(w))
    bws = [d[k] for d in ds]
    densities = zeros(N)

    S₁ = zeros(MVector{D, Float64})
    S₂ = zeros(MMatrix{D, D, Float64})

    for i = 1:N
        xᵢ = x[i]
        bwᵢ = bws[i]
        neighborsᵢ = @views x[idxs[i]]
        densities[i] = point_density!(S₁, S₂, est, xᵢ, bwᵢ, neighborsᵢ)
    end
    return densities
end

"""
    point_density!(S₁, S₂, est::LocalLikelihood, xᵢ, bwᵢ, 
        neighborsᵢ::AbstractStateSpaceSet{D}) where D

Estimate the density around point `xᵢ` using a local likehood estimator, which is 
a generalization of kernel density estimation. This is done by fitting a local gaussian 
distribution around `xᵢ` from its local neighborhood (represented the points `neighborsᵢ`).
The bandwidth  `bwᵢ` is given by the distance from `xᵢ` to its `k`-th nearest neighbor. 

`S₁` is a pre-allocated length-`D` vector which holds the means, and `S₂` is a pre-allocated
`D`-by-`D` matrix which holds the covariances. Both `S₁` and `S₂` are zeroed every time
`point_density!` is called.
"""
function point_density!(S₁, S₂, est::LocalLikelihood, xᵢ, bwᵢ, 
        neighborsᵢ::AbstractStateSpaceSet{D}) where D
    N = length(neighborsᵢ)
    S₀ = 0.0
    S₁ .= 0.0
    S₂ .= 0.0 
    
    bwᵢ_sq = bwᵢ^2
    twice_bwᵢ_sq = 2*bwᵢ_sq
    for (k, nⱼ) in enumerate(neighborsᵢ)
        dᵢ = evaluate(est.metric, nⱼ, xᵢ)
        eᵢ = exp(-dᵢ / twice_bwᵢ_sq)
        Δⱼ = (nⱼ - xᵢ)
        S₀ += eᵢ
        S₁ += eᵢ * (Δⱼ / bwᵢ)
        S₂ += eᵢ * (Δⱼ * transpose(Δⱼ)) / bwᵢ_sq
    end
    # Weighted sample mean and sample variance
    μ = S₁ / S₀
    Σ = S₂ / S₀ - S₁*transpose(S₁) / S₀^2
    
    detΣ = det(Σ)
    # if Σ is singular, we can't take its inverse either, so just return 0.0
    # density straight away.
    if det(Σ) ≈ 0
        return 0.0
    end
    
    num = S₀ * exp((-1/(2*S₀^2))*transpose(μ)*inv(Σ)*μ) 
    den = N*(2π)^(D/2)*(bwᵢ^D) * det(Σ)^(1/2)
    return num/den
end

function probabilities_and_outcomes(est::LocalLikelihood, x)
    return Probabilities(point_densities(est, x)), x
end
probabilities(est::LocalLikelihood, x) = Probabilities(point_densities(est, x))
outcomes(est::LocalLikelihood, x) = x
total_outcomes(x, est::LocalLikelihood) = length(x)

# TODO: implement. not sure how, though. Gao (2017) is not very clear...
bias(q, est::LocalLikelihood, x) = 1.0 

bias (generic function with 1 method)

In [283]:
using CairoMakie

#x = StateSpaceSet(rand(1.0:1.0:5.0, 30, 2) .+ randn(30, 2)*0.0001 )
x = StateSpaceSet(rand(10000, 5))
est = LocalLikelihood()
#x = StateSpaceSet([0.79, 0.5, 0.45, 0.46, 0.5, 0.46, 0.03, 0.11, 0.02, 0.2, 0.03, 0.5, 0.61])
@time information(Shannon(), est, x)


MethodError: MethodError: no method matching bias()
Closest candidates are:
  bias(!Matched::Any, !Matched::LocalLikelihood, !Matched::Any) at ~/Code/Repos/Temp/CausalityTools.jl/src/methods/infomeasures/mutualinfo/estimators/nearest_neighbors/gao2017.ipynb:129