In [19]:
using Pkg; Pkg.activate("../../../../")
using Revise
using CairoMakie
using CausalityTools
using Distributions: MvNormal
using Statistics
using LinearAlgebra
using StateSpaceSets


[32m[1m  Activating[22m[39m project at `~/Code/Repos/Temp/CausalityTools.jl`


In [20]:
# Velmejka-Palus example


#Example from their paper.
using Distributions: Normal
using Statistics: cor
using LinearAlgebra: eigvals, diagm
function sunspot_model(n::Int; Ttr = 100, ϵ, base = 2,
        α₁ = 1.90694, α₂ = -0.98751,
        β₁ = 0.78512, β₂ = -0.40662)
    Na1 = Normal(0, 1)
    Na2 = Normal(0, 1)
    a1 = rand(Na1, n + Ttr)
    a2 = rand(Na2, n + Ttr)
    z1 = zeros(n + Ttr)
    z2 = zeros(n + Ttr)
    z1[1] = rand(Na1); z1[2] = rand(Na1)
    z2[1] = rand(Na1); z2[2] = rand(Na1)
    for i = 3:n+Ttr
        z1[i] = α₁*z1[i-1] + α₂*z1[i-2] + a1[i] - β₁*a1[i-1] - β₂*a1[i-2]
        z2[i] = α₁*(ϵ*z1[i-1]+(1-ϵ)*z2[i-1]) +
            α₂*z2[i-2] + a2[i] - β₁*a2[i-1] - β₂*a2[i-2]
    end
    Z = [z1[2:end] z2[2:end] circshift(z2, -1)[2:end]]
   # condmutualinfo = -0.5*sum(log.(eigvals(cor(Z)))) / log(base, ℯ)
    return z1[4:end-Ttr], z2[4:end-Ttr]
end

sunspot_model (generic function with 1 method)

In [398]:
CairoMakie.activate!(type = "svg")

x, y = sunspot_model(10000, ϵ = 0.1)
f = Figure()
ax = Axis(f[1, 1], xlabel = "Time", ylabel = "Value")
lines!(ax, x, label = "x")
lines!(ax, y, label = "y")
f

In [408]:
x, y = sunspot_model(5000, ϵ = 0.3)
condmutualinfo(Shannon(; base = ℯ), VejmelkaPalus(k = 5), 
    x[1:end-1], y[1:end-1], y[2:end])

0.11368398853393624

In [393]:
D = 3
μ = zeros(D)
#Σ = [1.496 0.453 0.88; 0.453 0.221 0.166; 0.88 0.166 0.639]
x = rand(D, D); Σ = x * transpose(x)
N = MvNormal(μ, Σ)
ix, iy, iz = 1, 2, 3
iyz = [iy; iz]
ixz = [ix; iz]

iX_YZ = 0.5*log(det(Σ[ix, ix]) * det(Σ[iyz, iyz]) / det(Σ))
iX_Z = 0.5*log(det(Σ[ix, ix]) * det(Σ[iz, iz]) / det(Σ[ixz, ixz]))
true_condmutualinfo = iX_YZ - iX_Z

n = 1000
k = max(Int(sqrt(n) ÷ 3), 10)
nreps = 10

estimators = [
    Kraskov(; k), 
    KozachenkoLeonenko(),
    Zhu(; k), 
    ZhuSingh(; k),
    GaoNaive(; k),
    GaoNaiveCorrected(; k),
    Lord(; k = k),
    KSG1(; k), 
    KSG2(; k),
    Gao2018(; k),
]

estimator_names = [
    String(typeof(est).name.name) for est in estimators
]

estimates = zeros(nreps, length(estimators))
for (i, est) in enumerate(estimators)
    for j = 1:nreps
        data = Dataset([rand(N) for i = 1:n])
        X, Y, Z = data[:, ix], data[:, iy], data[:, iz]
        estimates[j, i] = (condmutualinfo(est, X, Y, Z; base = ℯ) - true_condmutualinfo) / true_condmutualinfo
    end
end

In [394]:
fig = Figure()
ax = Axis(fig[1, 1])
for (i, est) in enumerate(estimators)
    boxplot!(ax, 
        [i],
        estimates[:, i],
    )
end
hlines!(ax, [true_condmutualinfo], label = "CMI (true)")
fig