- Estimate probability densities at all the points $p(x_i)$ in the sample by finding distances to the $k$ nearest neighbors of each $x_i$.
- Average logarithm of these densities are found and adjusted by bias correction terms.

## Conditional TE for point processes.

Originally bivariate by Spinney.

$$
\bf \dot{T}_{Y \to X} = \lim_{t \to \infty} \dfrac{1}{\tau} \sum_{i=1}^{N_X} 
\log{
    \dfrac{\lambda_{x | {\bf x_{<t}}, {\bf y_{<t}} } [{\bf x_{<x_i}}, {\bf y_{<x_i}}]}{\lambda_{x | {\bf x_{<t}}} [\bf x_{<x_i>}]}
}
$$

- $X$ and $Y$ are series of *time points* $x_i$ and $y_j$ of the events $i$ and $j$ in the target and source, respectively.
- $N_X$ is the number of events in the target process.
- $\tau$ is the length in time of this process.
- $\lambda_{x | {\bf x_{<t}}, {\bf y_{<t}} } [{\bf x_{<x_i}}, {\bf y_{<x_i}}]$ is the instantaneous firing rate of the target, conditioned on the histories of the target $\bf x_{<x_i}$ and source $\bf y_{<x_i}$ at time points $x_i$ of the target process.
- $[\bf x_{<x_i>}]$ is the instantaneous firing rate of the target, conditioned on on its history alone.


## Continuous-time TE as sum of differential entropies


Assumptions:
- $X \in \mathbb{R}^{N_X}$ and $Y \in \mathbb{R}^{N_X}$ are two point processes, where each element represents the time of an event. 
- The set of extra conditioning processes are $\bf \mathcal{Z} = \{ Z_1, Z_2, \ldots, Z_{n_\mathcal{Z}} \}$.
- Define a *counting process* $\bf N_X(t)$ on $X$.
-  $N \in \mathbb{N}$ represents the 'state' of the process, incremented by one at the occurrence of an event.
- The instantaneous firing rate is $$ \lambda_X(t) = \lim_{\Delta t \to 0} \dfrac{p(\bf N_X(T + \Delta t) - \bf N_{X}(t) = 1)}{\Delta t}$$

- Assume an unknown probability distribution $\mu(\bf x)$ with $\bf x \in \mathbb{R}^d$.

Instantaneous firing rate:




$$
\bf \dot{T}_{Y \to X | \mathcal{Z}} = \bar{\lambda}_X \lim_{\Delta t \to 0} \mathbb{E}_{P_X} 
\left[
    \log{
        \dfrac{p_U(\bf N_X(x + \Delta t) - \bf N_X(x) = 1 | \bf x_{<x}, \bf y_{<x}, \bf \mathcal{Z}_{<x}) }{p_U(\bf N_X(x + \Delta t) - \bf{N}_X(x) = 1 | \bf x_{<x}, \bf \mathcal{z}_{<x})}
    }
\right]
$$


In [1]:
using Revise
using Pkg;
Pkg.activate("../")
#Pkg.free("DelayEmbeddings")
#Pkg.add(url = "https://github.com/JuliaDynamics/DelayEmbeddings.jl", 
#    rev = "hcat_multiple_datasets")
#Pkg.add(url = "https://github.com/JuliaDynamics/Entropies.jl")

using DelayEmbeddings, Entropies


[32m[1m  Activating[22m[39m project at `~/Code/Repos/Temp/TransferEntropy.jl`


In [2]:
using StaticArrays, DelayEmbeddings

"""
    EventIdentifier

An abstract type indicating some sort of event.
"""
abstract type EventIdentifier end

"""
    OneSpike <: EventIdentifier
    OneSpike()

Events are indicated by the presence of `1`s or `true`s.

## Example

```julia
# Events occur at indices 2, 4, and 6.
x = [0, 1, 0, 1, 0, 1]

# Events again occur at indices 2, 4, and 6.
x = [-2, 1, 5, 1, 2, 1]
```
"""
struct OneSpike <: EventIdentifier end

""" 
    first_index(x::AbstractVector{T}, m::Int) where T

Find the first time index for which a continuous-time embedding with history 
length `m` can be constructed.
"""
function first_index(x::AbstractVector{T}, event::EventIdentifier = OneSpike(); 
        m::Int = 2) where T

    spike = one(T)

    i::Int = 1
    ct::Int = 0
    while ct < m && i < length(y)
        if x[i] == spike
            ct += 1
        end
        i += 1
    end

    ct >= m || throw(ErrorException("Could not find m=$m spikes in `x`"))

    return i
end

function first_index(x::AbstractDataset{D, T}, event::EventIdentifier = OneSpike(); 
        m::Int = 2) where {D, T}
    
    spike = one(T)
    L = length(x)
    cts::MVector{D, Int} = zeros(MVector{D, Int})
    
    i::Int = 1
    while any(cts .< m) && i < L
        for d = 1:D
            if x[i][d] == spike
                cts[d] += 1
            end
        end
        i += 1
    end
    
    all(cts .>= m) || 
        throw(ErrorException("Couldn't find m=$m spikes in one or more columns of `x`"))
    return i
end

first_index(xs; kwargs...) = map(v -> first_index.(v; kwargs...), xs)

using Test
using DelayEmbeddings
z = [1, 0, 1, 0, 0, 1, 1, 0, 1, 0]
x = [0, 1, 0, 1, 0, 1, 0, 0, 0, 1]
y = [1, 1, 0, 0, 1, 0, 1, 0, 0, 1]
D = Dataset(x, y, z)
@test first_index(z, m = 2) == 4
@test first_index(z, m = 3) == 7
@test first_index(z, m = 4) == 8
@test first_index(D, m = 2) == 5
@test first_index(D, m = 3) == 7
@test_throws ErrorException first_index(D, m = 4)

[32m[1mTest Passed[22m[39m
      Thrown: ErrorException

In [31]:
"""
    spikeembed(i, iₜmin, x; spike = one(eltype(x)), m::Int = 2) where T <: Number

Given scalar-valued `x`, and history length `m`, make a *continuous-time* embedding,
by counting the duration between spikes relative to time index `iₜ`. 

`iₜmin` is the  minimum index for which to start checking, to ensure that a 
length-`m` embedding vector is actually possible to construct. This must have been 
checked beforehand by using [`first_index`](@ref), otherwise errors will be thrown.

See Fig 10 in Shorten et al. (2021) for details.

!!! note "Availability of length-`m` embedding vectors" 
    This function assumes that there are at least `m` spikes available before time index 
    `iₜ`. This function does not check for the existence of such spikes. Pre-check by using 
    [`first_index`](@ref). If no spikes exist to form an embedding, an error will be thrown.

[^Shorten2021]: Shorten, D. P., Spinney, R. E., & Lizier, J. T. (2021). Estimating transfer
    entropy in continuous time between neural spike trains or other event-based data. PLoS 
    computational biology, 17(4), e1008054.
"""
function spikeembed(target::AbstractVector{T}, event::EventIdentifier = OneSpike(); 
            m::Int = 2, # The same embedding for all timeseries processes for now.
            ) where T
    spike_identifier = one(T)

    # We can't know a priori how many elements this vector will contain,
    # because we're counting time *intervals* relative to some tᵢ, not 
    # *values at particular times relative to tᵢ*.
    emb = Vector{SVector{m, T}}(undef, 0)
    iₜmin = first_index(target; m)

    # Pre-allocate single vector to avoid excessive allocations, and convert
    # back to SVector inside the inner function. If the compiler is feeling fine, then 
    # this *could* incur no extra cost.
    p = zeros(MVector{m, T})

    L = length(target)
    for iₜ in iₜmin:L
        push!(emb, embed_at_iₜ_event!(p, iₜ, target, spike_identifier; m))
    end
    return emb
end

using BenchmarkTools

# Embeddings vectiors constructed at *all* times (independent of target spiking events).
function embed_at_alltimes(x::AbstractDataset{D, T}, event::EventIdentifier = OneSpike(); 
        m::Int = 2) where {D, T}
    spike_identifier = one(T)

    # Here, we don't care about the timing of spikes in the target time series, 
    # so construct embedding points for all possible time indices `iₜ`
    iₜmin = first_index(x; m)
    L = length(x)

    # After finding the minimum required index for `m`-length history vectors to
    # exist, we can also know the size of the embeddings.
    embeddings = [Vector{SVector{m, Int}}(undef, (L - iₜmin + 1)) for d = 1:D]

    # Pre-allocate single vector to avoid excessive allocations, and convert
    # back to SVector inside the inner function. If the compiler is feeling fine, then 
    # this *could* incur no extra cost.
    p = zeros(MVector{m, T})

    for d = 1:D
        ts = x[:, SVector{1, Int}(d)]
        for (i, iₜ) in enumerate(iₜmin:L)
            embeddings[d][i] = embed_at_iₜ_event!(p, iₜ, ts, spike_identifier; m)
        end
    end
    
    return Dataset.(embeddings)
end


# Embeddings constructed only *at* target spikes.
function embed_at_targetspikes(x::AbstractDataset{D, T}, event::EventIdentifier = OneSpike(); 
        m::Int = 2) where {D, T}
    spike_identifier = one(T)

    # Spikes in the *target* time series control which time indices for which
    # we construct embeddings.
    target = x[:, 1]
    inds_spikes = findall(target .== spike_identifier)

    # Ensure that all time series have enough spikes to be embedded.
    iₜmin = first_index(x; m)
    inds_spikes_valid = inds_spikes[inds_spikes .> iₜmin]
    L = length(inds_spikes_valid)
    
    # After finding the minimum required index for `m`-length history vectors to
    # exist, we can also know the size of the embeddings.
    embeddings = [Vector{SVector{m, Int}}(undef, L) for d = 1:D]

    # Pre-allocate single vector to avoid excessive allocations, and convert
    # back to SVector inside the inner function. If the compiler is feeling fine, then 
    # this *could* incur no extra cost.
    p = zeros(MVector{m, T})

    for d in 1:D
        ts = x[:, SVector{1, Int}(d)]
        for (i, iₜ) in enumerate(inds_spikes_valid)
            embeddings[d][i] = embed_at_iₜ_event!(p, iₜ, ts, spike_identifier; m)
        end
    end
    return Dataset.(embeddings)
end

function spike_embeddings(x::AbstractDataset{D, T}, event::EventIdentifier = OneSpike(); 
        m::Int = 2) where {D, T}
    Eₜ = embed_at_targetspikes(x, event; m)
    Eᵤ = embed_at_alltimes(x, event; m)
    if D <= 2
        vars_cond = 1:2
    else
        vars_cond = [1; 3:length(Eₜ)]
    end
    # Joint embedding vector sets J include all variables, while conditional embedding sets
    # vector sets C excludes the source variable, which is assumed to always be 
    # in position 2.
    Jₓ = hcat(Eₜ...)
    Cₓ = hcat(Eₜ[vars_cond]...)
    Jᵤ = hcat(Eᵤ...)
    Cᵤ = hcat(Eᵤ[vars_cond]...)

    return Jₓ, Cₓ, Jᵤ, Cᵤ
end

# Given a time index iₜ, find the `m` first intra-spike time intervals between
# t = t(iₜ) and t = t(1). Store the result in the pre-allocated MVector `v`.
# Spikes are identified using `spike_identifier` (usually `1`), but can
# be any identifier. Assumes `x` is a view of the [1:iₜ] first elements of x (Vector{StaticVector}).
function embed_at_iₜ_event!(v, iₜ, x, event_identifier; m = 2)
    n_found::Int = 0 # The number of embedding entries found.
    n_checked::Int = 0 # How many checked since last spike was found?
    k::Int = iₜ
    
    @inbounds while n_found < m
        if x[k - 1][] == event_identifier
            n_found += 1
            v[n_found] = n_checked
            n_checked = 0 # Restart counting from the current spike.
        else
            n_checked += 1
        end
        k -= 1
    end

    return SVector(v)
end

using Test
using DelayEmbeddings
zt = [1, 0, 1, 0, 0, 1, 1, 0, 1, 0]
xt = [0, 1, 0, 1, 0, 1, 0, 0, 0, 1]
yt = [1, 1, 0, 0, 1, 0, 1, 0, 0, 1]
Dt = Dataset(xt, yt, zt)

# If we encounter two successive spikes, then that is counted as a 0 time interval.
expected_m2 = [(0, 1), (1, 1), (2, 1), (0, 2), (0, 0), (1, 0), (0, 1)]
expected_m3 = [(0, 2, 1), (0, 0, 2), (1, 0, 2), (0, 1, 0)]
expected_m4 = [(0, 0, 2, 1), (1, 0, 2, 1), (0, 1, 0, 2)]
@test spikeembed(zt, m = 2) == SVector{2, Int}.(expected_m2)
@test spikeembed(zt, m = 3) == SVector{3, Int}.(expected_m3)
@test spikeembed(zt, m = 4) == SVector{4, Int}.(expected_m4)

# At all time_points
@test embed_at_alltimes(Dt; m = 2) == Dataset.([
    SVector{2, Int32}.([[0, 1], [1, 1], [0, 1], [1, 1], [2, 1], [3, 1]]),
    SVector{2, Int32}.([[2, 0], [0, 2], [1, 2], [0, 1], [1, 1], [2, 1]]),
    SVector{2, Int32}.([[1, 1], [2, 1], [0, 2], [0, 0], [1, 0], [0, 1]]),
])

# At events for specific time series
@test embed_at_targetspikes(Dt; m = 2) == Dataset.([
    SVector{2, Int32}.([[1, 1], [3, 1]]),
    SVector{2, Int32}.([[0, 2], [2, 1]]),
    SVector{2, Int32}.([[2, 1], [0, 1]]),
])

[32m[1mTest Passed[22m[39m

In [32]:
# x = rand([0, 1], 100)
# y = rand([0, 1], 100)
# z = rand([0, 1], 100)
# w = rand([0, 1], 100);
# D = Dataset(x, y, z, w);

450.708 μs (6506 allocations: 343.36 KiB)


In [33]:
# 117.083 μs (1887 allocations: 101.22 KiB)
# 138.125 μs (2007 allocations: 173.72 KiB)
# 160.166 μs (2587 allocations: 219.38 KiB)
# 139.208 μs (2199 allocations: 110.89 KiB)
# 127.417 μs (1911 allocations: 97.00 KiB)
# 24.708 μs (655 allocations: 24.06 KiB)

embed_at_targetspikes(D)
@btime embed_at_targetspikes($D)

  25.625 μs (655 allocations: 24.06 KiB)


4-element Vector{Dataset{2, Int64}}:
 2-dimensional Dataset{Int64} with 49 points
 2-dimensional Dataset{Int64} with 49 points
 2-dimensional Dataset{Int64} with 49 points
 2-dimensional Dataset{Int64} with 49 points

In [34]:
embed_at_alltimes(D)
@btime embed_at_alltimes($D)

  40.500 μs (1176 allocations: 35.94 KiB)


4-element Vector{Dataset{2, Int64}}:
 2-dimensional Dataset{Int64} with 93 points
 2-dimensional Dataset{Int64} with 93 points
 2-dimensional Dataset{Int64} with 93 points
 2-dimensional Dataset{Int64} with 93 points

In [35]:
# 430.875 μs (6414 allocations: 579.16 KiB)
# 437.625 μs (6942 allocations: 357.48 KiB)
# 73.209 μs (1886 allocations: 77.47 KiB)

spike_embeddings(D, m = 2)
@btime spike_embeddings($D; m = 2)

  71.708 μs (1886 allocations: 77.47 KiB)


(8-dimensional Dataset{Int64} with 49 points, 6-dimensional Dataset{Int64} with 49 points, 8-dimensional Dataset{Int64} with 93 points, 6-dimensional Dataset{Int64} with 93 points)

In [36]:
a,b,c,d = spike_embeddings(D)

(8-dimensional Dataset{Int64} with 49 points, 6-dimensional Dataset{Int64} with 49 points, 8-dimensional Dataset{Int64} with 93 points, 6-dimensional Dataset{Int64} with 93 points)

In [22]:
eᵤ = embed_at_alltimes(D; m = 2)


3-element Vector{Dataset{2, Int32}}:
 2-dimensional Dataset{Int32} with 6 points
 2-dimensional Dataset{Int32} with 6 points
 2-dimensional Dataset{Int32} with 6 points

In [23]:
eₓ = embed_at_targetspikes(D; m = 2)

3-element Vector{Dataset{2, Int32}}:
 2-dimensional Dataset{Int32} with 2 points
 2-dimensional Dataset{Int32} with 2 points
 2-dimensional Dataset{Int32} with 2 points

In [None]:
Jₓ, Cₓ, Jᵤ, Cᵤ = spike_embeddings(D)

(6-dimensional Dataset{Int32} with 2 points, 4-dimensional Dataset{Int32} with 2 points, 6-dimensional Dataset{Int32} with 6 points, 4-dimensional Dataset{Int32} with 6 points)

In [45]:

using TransferEntropy
function transferentropy_event(e::Entropy, 
        event::EventIdentifier, 
        est::ProbabilitiesEstimator,
        target::V, source::V, cond::Vararg{V};
        m = 2) where V <: AbstractVector
    
    D = Dataset(target, source, cond...)

    # Construct embeddings, given that the events are identified by `event`.
    Jₓ, Cₓ, Jᵤ, Cᵤ = spike_embeddings(D, event; m)
    return entropy(e, Cₓ, est) +
        entropy(e, Jₓ, est) +
        entropy(e, Jᵤ, est) -
        entropy(e, Cᵤ, est)
end

transferentropy_event(event::EventIdentifier, est::ProbabilitiesEstimator,
    target::V, source::V, cond::Vararg{V};
    m = 2, base = 2) where {V <: AbstractVector} = 
        transferentropy_event(Shannon(; base), event, est, target, source, cond; m)

est = CountOccurrences()


CountOccurrences()

In [49]:
# 93.125 μs (1921 allocations: 73.53 KiB)
# 77.541 μs (1869 allocations: 99.50 KiB)
# 68.875 μs (1833 allocations: 116.69 KiB)
using BenchmarkTools
transferentropy_event(Shannon(), OneSpike(), est, y, x, z, w, m = 4)
@btime te = transferentropy_event(Shannon(), OneSpike(), $est, $y, $x, $z, $w, m = 4)

  73.375 μs (1773 allocations: 139.50 KiB)


10.715104009236162

In [53]:
source = 1e3*rand(Int(1e3));
target = 1e3*rand(Int(1e3));
transferentropy_event(Shannon(), OneSpike(), CountOccurrences(), target, source, m = 2)

ErrorException: Couldn't find m=2 spikes in one or more columns of `x`

In [60]:
source = 1e3*rand(Int(1e3));
target = 1e3*rand(Int(1e3));
sort!(source);
sort!(target);

In [62]:
using Statistics
source


1000-element Vector{Float64}:
   0.9380651354707892
   1.762270777116215
   3.0441187211794896
   3.493147575191924
   4.527487081727566
   5.089478075968379
   5.699608150263935
   7.256668560836954
   8.298347408803973
   9.69650134824429
   ⋮
 990.7641153494266
 992.1778901771531
 993.2092476329424
 993.4613753140692
 994.122249662727
 994.408453970347
 994.6669998126573
 995.0014753162362
 997.5872382887875

In [63]:
function thin_target(source, target, target_rate)
    start_index = 1
    while target[start_index] < source[1]
         start_index += 1
    end
    target = target[start_index:end]
    new_target = Float64[]
    index_of_last_source = 1
    for event in target
        while index_of_last_source < length(source) && source[index_of_last_source + 1] < event
                 index_of_last_source += 1
        end
        distance_to_last_source = event - source[index_of_last_source]
        lambda = 0.5 + 5exp(-50(distance_to_last_source - 0.5)^2) - 5exp(-50(-0.5)^2)
        if rand() < lambda/target_rate
              push!(new_target, event)
        end
    end
    return new_target
end

thin_target (generic function with 1 method)

In [76]:
source = sort(1e4*rand(Int(1e4)));
target = sort(1e4*rand(Int(1e5)));
target_thinned = thin_target(source, target, 10)

spikeembed(target)


ErrorException: Could not find m=2 spikes in `x`