## Hints of MTM

look at trials' kinematics for hints that mice might be doing something like MTM.
Fit lin. reg.s on speed traces before approaching a curve for evidence of bang bang control. 
Look at deltas and deltat between start of slowing down and start of turning. 

In the end I do a bit of GLM to predict deceleration from position and velocity but it doesn't seem to work yet.

In [2]:
cd("/Users/federicoclaudi/Documents/Github/LocomotionControl/analysis/behavior")
import Pkg; 
Pkg.activate(".")

import Plots as plt
using Term
import Term: install_term_logger
import MyterialColors: salmon, green_dark
using Statistics
using AlgebraOfGraphics, CairoMakie
import AlgebraOfGraphics: density
import AlgebraOfGraphics as AOG
using Printf
using HypothesisTests
import DataFrames: DataFrame

using jcontrol
import jcontrol.comparisons: ComparisonPoints
import jcontrol.visuals as viz
import jcontrol: movingaverage

install_term_logger();

ArgumentError: ArgumentError: Package Term [22787eb5-b846-44ae-b979-8e399b8463ab] is required but does not seem to be installed:
 - Run `Pkg.instantiate()` to install all recorded dependencies.


Load tracking data

In [None]:
# load trials
trials = load_cached_trials(; keep_n = nothing);
@info "Loaded $(length(trials)) trials"


S = getfield.(trials, :s)
X = getfield.(trials, :x)
Y = getfield.(trials, :y)
U = getfield.(trials, :u)
Ω = getfield.(trials, :ω)

nothing

Define Curve type and curves positions + visualize

In [None]:
# Define curves S values

struct Curve
    s::Float64
    s0::Float64
    maxdur::Float64
    name::String
    turns0
    direction::Int   # 1 for left turns, -1 for right turns
end


curves = [
    Curve(35, 15, 1.5, "first", 20, -1), Curve(86, 42, 1.5, "second", 55, -1), Curve(137, 98, 1.5, "third", 110, 1), Curve(188, 144, 1.5,  "fourth", 155, 1)
]
cpoints = ComparisonPoints(FULLTRACK, getfield.(curves, :s))
cpoints2 = ComparisonPoints(FULLTRACK, getfield.(curves, :s0))

# show curves positions
plot = plt.plot()

plt.plot!.(X, Y; color="black", alpha=.1, label=nothing, aspect_ratio=:equal,
xlim=[-5, 45],
ylim=[-5, 65],);
viz.draw!.(cpoints.points; color="red")
viz.draw!.(cpoints2.points; color="green")

display(plot)

### Get slowing down/turning states
For each curve get for each trial where the mouse is when it starts turning vs when it starts slowing down, and what its state is.

In [None]:
function get_curve_indices(trial, s0, s1)
    start = findlast(trial.s .< s0)
    start = isnothing(start) ? 1 : start
    start += 2

    stop = findfirst(trial.s[start:end] .>= s1) + start
    stop = isnothing(stop) ? length(trial.s) : stop
    stop -= 2

    idxs = repeat([false], length(trial.s))
    idxs[start:stop] .= true
    idx = findfirst(idxs)
    return idxs, idx
end


In [None]:
# initialize empty data store
data_df = Dict(
    "curve_id" => [],
    "class" => [],  # slowing down/turning
    "x" => [],
    "y" => [],
    "s" => [],
    "u" => [],
    "ω" => [],
    "u̇" => [],
    "t0" => [],
    "tpeak" => [],
    "t1" => [],
    "trial_num" => [],

)


function append_data(data_df, trial, I, cidx, class, u̇, t0, tpeak, t1, trial_num)

    push!(data_df["curve_id"], cidx)
    push!(data_df["class"], class)
    push!(data_df["x"], trial.x[I-1])
    push!(data_df["y"], trial.y[I-1])
    push!(data_df["s"], trial.s[I-1])
    push!(data_df["u"], trial.u[I-1])
    push!(data_df["ω"], trial.ω[I-1])
    push!(data_df["u̇"], u̇)
    push!(data_df["t0"], t0)
    push!(data_df["tpeak"], tpeak)
    push!(data_df["t1"], t1)
    push!(data_df["trial_num"], trial_num)
end


### Get start of slowing down

In [None]:
# get when the mice start slowing down
plots = []

for cidx in 1:length(curves)
    curve = curves[cidx]
    s0 = cidx == 1 ? 1 : curves[cidx-1].s
    s1 = curve.s
    
    plot = plt.plot()

    peaks_s, peaks_values = [], []
    for tn in 1:length(trials)
        trial = trials[tn]
        idxs, idx = get_curve_indices(trial, s0, s1)

        u = movingaverage(trial.u, 3)[idxs]

        plt.plot!(trial.s[idxs], u, color="black", alpha=.1, label=nothing)

        duration = length(u) / 60
        (duration > curve.maxdur || duration < .3) && continue

        switchpoint = argmax(u)
        I = switchpoint + idx

        # get additional info
        info = Dict(
            "mean_u̇" => mean(diff(u[switchpoint:end]))
        )

        u̇ = mean(diff(filter(!isnan, u[switchpoint:end])))
        append_data(data_df, trial, I, cidx, "slowing", u̇, idx, I-1, idx+length(u), tn)
        push!(peaks_s, trial.s[I-1])
        push!(peaks_values, trial.u[I-1])
    end

    # scatter position of u̇ peaks
    plt.scatter!(
        peaks_s, peaks_values, color="red", label=nothing, alpha=1
    )
    plt.plot!([s0, s0], [-10, 10], lw=2, color="blue", label="s0", title="Curve $cidx")
    plt.plot!([curve.s, curve.s], [-10, 10], lw=2, color="green", label="curve")
    push!(plots, plot)
end


plt.plot(
    plots...
)


### Get start of turning

In [None]:
plots = []

for cidx in 1:length(curves)
    curve = curves[cidx]
    s0 = cidx == 1 ? 1 : curves[cidx-1].s
    s1 = curve.s
    
    plot = plt.plot()

    peaks_s, peaks_values = [], []

    for tn in 1:length(trials)
        trial = trials[tn]
        idxs, idx = get_curve_indices(trial, s0, s1)

        s = movingaverage(trial.s[idxs], 5)
        ω = movingaverage(trial.ω[idxs], 5) .* curve.direction
        plt.plot!(trial.s[idxs], ω, color="black", alpha=.1, label=nothing)

        duration = length(ω) / 60
        (duration > curve.maxdur || duration < .3) && continue

        # get when it starts turning
        t0 = findfirst(s .> curve.turns0)
        th = 1.5
        switchpoint = findfirst(ω[t0:end] .> th)
        
        while isnothing(switchpoint) && th > 0
            th -= .1
            switchpoint = findfirst(ω[t0:end] .> th)
        end
        switchpoint += t0
        switchpoint = isnothing(switchpoint) ? switchpoint = findlast(abs.(ω) .< 1.0) : switchpoint
        I = switchpoint + idx

        append_data(data_df, trial, I, cidx, "turning", 0.0, idx+t0, I-1, idx+length(ω), tn)
        push!(peaks_s, trial.s[I-1])
        push!(peaks_values, trial.ω[I-1] * curve.direction)
    end

    # scatter position of u̇ peaks
    plt.scatter!(
        peaks_s, peaks_values, color="red", label=nothing, alpha=1
    )
    plt.plot!([s0, s0], [-10, 10], lw=2, color="blue", label="s0")

    plt.plot!([curve.s, curve.s], [-10, 10], lw=2, color="green", label="curve", title="Curve $cidx", xlim=[s0, s1], ylim=[-12, 12])
    plt.plot!([curve.turns0, curve.turns0], [-10, 10], lw=2, color="red", label="turn s0")

    push!(plots, plot)
end


plt.plot(
    plots..., size=(1200, 1000)
)


In [None]:
# put data together

df = DataFrame(data_df)
df[!,:class] = convert.(String,df[!,:class])

for col in (:curve_id, :t0, :t1, :tpeak, :trial_num)
    df[!,col] = convert.(Int,df[!,col])
end

for col in (:x, :y, :s, :u, :ω, :u̇)
    df[!, col] = convert.(Float64, df[!, col])
end

filterdf(df, col, val) = df[df[:, col] .== val, :]
filterdf(df, col1, val1, col2, val2) = df[(df[:, col1] .== val1) .* (df[:, col2] .== val2), :]

df

## Accelerations analysis

We take the time before/after peak velocity for each trial approaching each curve. We fit a linear regression to the before/after speed trace and we look at distributions of slopes (compared to shuffled data). The idea is that, as per MTM predictions, mice go from a period of max acceleration to a period of max deceleration.

In [None]:
using GLM, Random


function slope(x, y)
    X = hcat(x, ones(length(x)))
    model = lm(X, y)
    return coef(model)[1] * 60
    # mean(diff(filter(!isnan, y))) * 60
end

function append_slopes_data(slope, class, curvidx)
    push!(slopes_data["slope"], slope)
    push!(slopes_data["class"], class)
    push!(slopes_data["curve_id"], curvidx)
end

# loop over curves
slopes_data = Dict(
    "slope"=>[],
    "class"=>[],
    "curve_id"=>[],
)



for curvidx in 1:4
    slowing = filterdf(df, :curve_id, curvidx, :class, "slowing")

    # loop over trials
    for i in 1:size(slowing, 1)
        tn = slowing[i,:trial_num]
        t0, tpeak, t1 = slowing[i,:t0], slowing[i,:tpeak]-2, slowing[i,:t1]

        # get slopes
        append_slopes_data(slope(S[tn][t0:tpeak],  U[tn][t0:tpeak]), "pre", curvidx)
        append_slopes_data(slope(S[tn][tpeak+1:t1],  U[tn][tpeak+1:t1]), "post", curvidx)

        # get slope of shuffled data
        U_shuffle = shuffle(U[tn][t0:t1])
        append_slopes_data(slope(S[tn][t0:t1],  U_shuffle), "shuffled", curvidx)

    end

end

slopes = DataFrame(slopes_data)
graph = data(slopes) * histogram(bins=range(-200, 200, 50)) * mapping(:slope, color=:class, layout = :curve_id => nonnumeric) * visual(alpha=.9)

draw(graph)


## Slowing/turning positions analysis

We look at where the mice are when they start slowing down before a turn compared to when they start turning to confirm that mice slow down first and turn only later.

In [None]:
graph = data(df) * mapping(layout = :curve_id => nonnumeric) *
    (
        mapping(color=:class) * (AOG.density() * visual(Contour))
    ) * mapping(:x, :y)
fg = draw(graph; figure = (resolution = (800, 800),), colorbar=(position=:top, size=25))

for ax in fg.figure.content[1:4]
    for i in 1:20
        lines!(ax, X[i], Y[i], color = (:black, 0.05),)
    end
    xlims!(ax, [0, 40])
    ylims!(ax, (0.0, 60.0))
    hidespines!(ax)
    hidedecorations!(ax)
    ax.aspect = DataAspect()
end

fg



we can also look at the Δs: the distance between where the mice are when they start slowing down compared to when they start turning. Values >0 means that the mice start slowing first and only turn later

In [None]:
# plot Δs
plots = []
for curve_idx in 1:4
    slowing = filterdf(df, "curve_id", curve_idx, "class", "slowing")
    turning = filterdf(df, "curve_id", curve_idx, "class", "turning")

    @assert size(slowing) == size(turning)

    Δs = turning.s .- slowing.s
    histo = plt.histogram(
        Δs, bins=25, color="black", label="Δs", title="Curve $curve_idx", 
    )

    H = 100
    plt.plot!([0, 0], [0, H], color="green", lw=5, label=nothing)
    plt.plot!([mean(Δs), mean(Δs)], [0, H], color="red", lw=5, label="mean Δs")
    push!(plots, histo)

end

plt.plot(plots...)

In [None]:
# plot Δt
plots = []
for curve_idx in 1:4
    slowing = filterdf(df, "curve_id", curve_idx, "class", "slowing")
    turning = filterdf(df, "curve_id", curve_idx, "class", "turning")

    @assert size(slowing) == size(turning)

    Δt = (turning.t0 .- slowing.t0) ./ 60
    histo = plt.histogram(
        Δt, bins=range(0, 1, 50), color="black", label="Δt", title="Curve $curve_idx", 
    )

    H = 100
    plt.plot!([0, 0], [0, H], color="green", lw=5, label=nothing)
    plt.plot!([mean(Δt), mean(Δt)], [0, H], color="red", lw=5, label="mean Δt")
    push!(plots, histo)

end

plt.plot(plots...)

We can also plot the speed at which the mice are going when they start turning compared to when they start slowing down

In [None]:
plots = []
for curve_idx in 1:4
    slowing = filterdf(df, "curve_id", curve_idx, "class", "slowing")
    turning = filterdf(df, "curve_id", curve_idx, "class", "turning")

    histo = plt.histogram(
        slowing.u, alpha=.5, color="black", label="slowing"
    )
    
    plt.histogram!(
        turning.u, alpha=.5, color="red", label="turning", title="Curve $curve_idx", xlim=[0, 100]
    )
    push!(plots, histo)
end

plt.plot(plots...)


Finally, we can look at how the ammount of deceleration relates to two other important variables: the speed of the mice when they start slowing down and how close they are to where they need to turn.

In [None]:
slowing = filterdf(df, "class", "slowing")

graph = data(slowing) * mapping(:u, :u̇,  color=:s, layout = :curve_id => nonnumeric) *
    (
        (visual(Scatter, colormap=:viridis) + linear() * visual(color="red", lw=3))
    )
fg = draw(graph; figure = (resolution = (800, 800),), colorbar=(position=:top, size=25))


### GLM analysis
To quantify the effect of various variables on the degree of decelaration, we fit a GLM.

In [None]:
using GLM, MLDataUtils, StatsBase
import DataFrames: mapcols, dropmissing

normalize and split database

In [None]:
# TODO: this shuld use a fit of linear regression to the speed trace instead of u̇? Are they the same?

# filter and cleanup
slowing = filterdf(df, "class", "slowing")[:, [:x, :y, :u, :s, :ω, :trial_num, :tpeak, :t1]]
# slowing = slowing[.!isnan.(slowing.u̇), :]


# get decleration slope for each trial
u̇ = []

for i in 1:size(slowing, 1)
    tn = slowing[i,:trial_num]
    tpeak, t1 = slowing[i,:tpeak]-2, slowing[i,:t1]

    # get slopes
    push!(u̇, slope(S[tn][tpeak:t1],  U[tn][tpeak:t1]))
end

# add/remove columns
slowing = slowing[:, [:x, :y, :u, :s, :ω]]
slowing[:, :u̇] = Vector{Float64}(u̇)


# normalize and split
slowing = mapcols(zscore, (slowing))
trainset, testset = splitobs(slowing)
trainset

In [None]:
trainset, testset = splitobs(slowing)

model = glm(
    @formula(u̇ ~ s + u), 
    slowing, 
    Normal(),
)


TODO: eval on test set.
TODO: make pltos out of it

In [None]:
u̇ = [slowing.u̇...,  trainset.u̇..., testset.u̇...]
û = [predict(model, slowing)..., predict(model, trainset)..., predict(model, testset)...]
label = vcat(
    repeat(["all"], size(slowing, 1)),
    repeat(["train"], size(trainset, 1)), 
    repeat(["test"], size(testset, 1))
)
fig = Figure()
ax = Axis(fig[1, 1], xlabel="u̇ actual", ylabel="u̇ pred")
lines!(ax, [-3, 3], [-3, 3], color=:black, linewidth=4)

graph = data((u̇ = u̇, û = û, label=label)) * 
            (visual(Scatter, alpha=.25) + linear() * visual(color="red")) * 
            mapping(:u̇, :û, color=:label)

draw!(ax, graph)
xlims!(ax, [-3, 3])
ylims!(ax, (-3.0, 3.0))
fig