In [1]:
# activate package, basic imports + load data
include("C:\\Users\\Federico\\Documents\\GitHub\\pysical_locomotion\\analysis\\behavior\\analysis_fixtures.jl")
include("C:\\Users\\Federico\\Documents\\GitHub\\pysical_locomotion\\analysis\\behavior\\analysis_visuals.jl")
using Dierckx

import jcontrol: State, movingaverage
import jcontrol: Δ
import OrderedCollections: OrderedDict
using StatsBase
using JSON: JSON
using DataFrames: DataFrame
import JSONTables: objecttable, jsontable
using Glob

int = (Int ∘ round)

[1m[4m[38;2;66;165;245m@Info[22m[24m[39m [38;2;237;180;111m(Main):[39m  [38;2;138;190;255mLoaded 995 trials[39m 
  [1m[2m[38;2;66;165;245m╰────────────────────────────────────────────────[22m[22m[39m 
                      [2mMon, 06 Jun 2022[22m [1m[2m[4m10:41:01[22m[22m[24m 
[1m[4m[38;2;66;165;245m@Info[22m[24m[39m [38;2;237;180;111m(Main):[39m  [38;2;138;190;255mAfter tortuosity analysis, discarded 8.643% of trials | 909[0m 
  [2m[38;2;66;165;245m│[22m[39m                [38;2;138;190;255m[38;2;138;190;255mtrials left[39m[0m[39m 
  [1m[2m[38;2;66;165;245m╰────────────────────────────────────────────────[22m[22m[39m 
                      [2mMon, 06 Jun 2022[22m [1m[2m[4m10:41:04[22m[22m[24m 


[32m[1m  Activating[22m[39m project at `C:\Users\Federico\Documents\GitHub\pysical_locomotion\analysis\behavior`


Int64 ∘ round

# Generate RNN dataset

Take each trial, get kinematics in curvilinear coordinates, cut into chunks of equal length.
Also at each frame sample the curvature of the track that lays ahead of the mosue

utility functions


In [2]:
"""
Upsample frame rate of tracking data
"""
function upsample_framerate(x, Δt)
    t_original = 0:1/60:(length(x)/60 + 1)

    t_original = t_original[1:length(x)]

    t_upsample = 0:Δt:(t_original[end])

    spl = Spline1D(t_original, x; k=1) # k: order of interpolant; can be between 1-5
    return spl(t_upsample)
end



""" get index of equally spaced way points along the track from a starting sval """
function get_track_waypoints(track, sval)
    ds = track.S .- sval
    idxs = findall(
        (diff(mod.(ds, Δs)) .<= 0) .* (ds .>= 0)[2:end] .* (ds .<= S)[2:end],
    )
    return [argmin(ds.^2), idxs...]
end


# get the average ψ at each track position
function get_average_ψ_along_track(ctrials)
    s_vals = 1:260
    ψ_vals = OrderedDict(s=>[] for s in s_vals)
    for trial in ctrials
        for s in s_vals
            idx = findfirst(trial.s .>= s)
            isnothing(idx) && continue
            push!(ψ_vals[s], trial.ψ[idx])
        end
    end

    Ψ = OrderedDict(s=>mean(ψ_vals[s]) for s in s_vals)
end

""" fix chunks initial ψ value  """
function reset_psi(trial, start, Ψ)
    idx = (Int ∘ round)(trial.s[start])
    return Ψ[idx]
end

reset_psi

types

In [3]:
"""
Represents a single trial but in curvilinear coordinates space and with
upsampled framerate.
"""
mutable struct CurvilinearCoordsTrials
    x::Vector
    y::Vector
    θ::Vector
    s::Vector
    n::Vector
    ψ::Vector
    V::Vector  # velocity vector: V = √(v² + u²)
    ω::Vector
    V̇::Vector  # acceleration
    ω̇::Vector  # angular acceleration
    t::Vector  # time
end


struct DatasetEntry
    n::Vector{Float64}
    ψ::Vector{Float64}
    s::Vector{Float64}
    V::Vector{Float64}
    ω::Vector{Float64}
    V̇::Vector{Float64}
    ω̇::Vector{Float64}
    k::Union{Nothing, Matrix{Float64}}   # stores the curvature at each waypoint and each frame
end

### Prepare tracking data
Load tracking data, turn into track's curvilinear coordinates and upsample

In [4]:
""" convert a vector of states to a CurvilinearCoordsTrials """
function CurvilinearCoordsTrials(states::Vector{State}, frames_idxs, Δt, speeds_smoothing_window)
    v = movingaverage(
        upsample_framerate(sqrt.(getfield.(states, :v).^2 + getfield.(states, :u).^2), Δt),
        speeds_smoothing_window
    )
    ω = movingaverage(
        upsample_framerate(getfield.(states, :ω), Δt),
        speeds_smoothing_window
    )

    v̇ = Δ(v) ./ Δt
    ω̇ = Δ(ω) ./ Δt
    time = frames_idxs ./ 60

    ψ = upsample_framerate(getfield.(states, :ψ) .- getfield(states[1], :ψ), Δt)

    return CurvilinearCoordsTrials(
        upsample_framerate(getfield.(states, :x), Δt),
        upsample_framerate(getfield.(states, :y), Δt),    
        upsample_framerate(getfield.(states, :θ), Δt),    
        upsample_framerate(getfield.(states, :s), Δt),
        upsample_framerate(getfield.(states, :n), Δt),
        clamp.(ψ, -outliers_limits[:ψ], outliers_limits[:ψ]),
        v,
        clamp.(ω, -outliers_limits[:ω], outliers_limits[:ω]),
        clamp.(v̇, -outliers_limits[:v̇], outliers_limits[:v̇]),
        clamp.(ω̇, -outliers_limits[:ω̇], outliers_limits[:ω̇]),
        upsample_framerate(time, Δt),
    )
end

CurvilinearCoordsTrials

### create curvilinear trials

In [5]:
# get trials in curvilinear coordinates space
function get_ctrials(trials, Δt, speeds_smoothing_window)
    trial_states(trial) = (
                map(
                    frame->State(trial, frame, FULLTRACK; v=trial.v[frame], smoothing_window=5), 
                    1:length(trial.x)
                ), 
                1:length(trial.x)
    )

    ctrials = map(trial -> CurvilinearCoordsTrials(trial_states(trial)..., Δt, speeds_smoothing_window), trials[1:SELECT_N_TRIALS])
end


get_ctrials (generic function with 1 method)

In [6]:


"""
Splits each trials into chunks of equal length
"""
function get_chunks(ctrials, n_samples_chunk, prediction_Δt_shift_samples)
    @info "Extracting chunks from trials" ctrials

    Ψ = get_average_ψ_along_track(ctrials)
    chunks = []
    for trial in ctrials
        timesteps = range(1, length(trial.x)-n_samples_chunk, step=n_samples_chunk)
        if length(timesteps) == 0
            continue
        end

        # starts = sort!(int.(rand(timesteps, N)))
        starts = sort!(collect(timesteps))
        stops = int.(starts .+ n_samples_chunk)



        for (start, stop) in zip(starts, stops)
            stop >= length(trial.n) && continue

            # if mouse slows down during chunk - ignore
            any(trial.V[start:stop] .<= 25) && continue

            stop+prediction_Δt_shift_samples > length(trial.x) && continue
            
            # get speed and angular velocity in the future
            future_v = trial.V[start+prediction_Δt_shift_samples : stop+prediction_Δt_shift_samples]
            future_ω = trial.ω[start+prediction_Δt_shift_samples : stop+prediction_Δt_shift_samples]

            push!(chunks, 
                CurvilinearCoordsTrials(
                    trial.x[start:stop],
                    trial.y[start:stop],
                    trial.θ[start:stop],
                    trial.s[start:stop],
                    trial.n[start:stop],
                    trial.ψ[start:stop] .- (trial.ψ[start] -  reset_psi(trial, start, Ψ)),
                    trial.V[start:stop],
                    trial.ω[start:stop],
                    future_v,
                    future_ω,
                    trial.t[start:stop],
                )
            )
        end
    end

    return chunks
end



get_chunks

### Step 4
for each chunk get the track's curvature ahead at each frame

In [7]:
function create_dbase(ctrials, S, n_samples_chunk, prediction_Δt_shift_samples)

    chunks = get_chunks(ctrials, n_samples_chunk, prediction_Δt_shift_samples)
    # @info "Got $(length(chunks)) chunks out of $(length(ctrials)) trials"
    
    if S > 0
        waypoints = collect(0:Δs:S)
        waypoints_idxs = 1:length(waypoints)
    end

    entries = []
    for chunk in chunks
        # get curvature
        if S > 0
            chunk_curv = zeros(length(waypoints_idxs), length(chunk.x))

            
            for frame in 1:length(chunk.x)
                # get relevant S vector from track
                s = chunk.s[frame]
                idxs = get_track_waypoints(FULLTRACK, s)

                # @assert length(idxs) == length(waypoints_idxs)  "$(length(idxs)) $(length(waypoints_idxs)) $ts $frame $s"
                length(idxs) != length(waypoints_idxs) && continue
                chunk_curv[:, frame] = FULLTRACK.curvature[idxs]
            end
        else
            chunk_curv = nothing
        end

        push!(entries, DatasetEntry(
            chunk.n,
            chunk.ψ,chunk.s,
            chunk.V,
            chunk.ω,
            chunk.V̇,
            chunk.ω̇,
            chunk_curv,
        )
        )

    end
    # @info "Got $(length(entries)) entries from $(length(chunks)) chunks"
    return entries
end


create_dbase (generic function with 1 method)

## Visual inspection

In [8]:
# plt = plot()

# for entry in entries[1:500]
#     plot!(entry.s, entry.ψ, label=nothing, alph=.33, color="black", alpha=.2)

#     # plot!(entry.s, entry.ω ./ 10, label=nothing, alph=.33, color="red", alpha=.2)
# end
# # plot!(FULLTRACK.S, Δ(FULLTRACK.θ) * 1000, lw=5, color="green")
# plt



### Save data
save data to file in a format which can be loaded back in python for RNN training

In [9]:
function save_dbase(savepath, entries, S, prediction_Δt_shift)
    for (i, entry) in enumerate(entries)
        if S > 0
            ks = collect(Symbol("k_", i)=>entry.k[i, :] for i in 1:size(entry.k,1 ))
            dict = OrderedDict{Symbol, Union{Vector{Float64}, Matrix{Float64}}}(
                :n=>entry.n,
                :ψ=>entry.ψ,
                :s=>entry.s,
                :V=>entry.V,
                :ω=>entry.ω,
                :V̇=>entry.V̇,
                :ω̇=>entry.ω̇,
                ks...
            )
        else
            dict = OrderedDict{Symbol, Union{Vector{Float64}, Matrix{Float64}}}(
                :n=>entry.n,
                :ψ=>entry.ψ,
                :s=>entry.s,
                :V=>entry.V,
                :ω=>entry.ω,
                :V̇=>entry.V̇,
                :ω̇=>entry.ω̇,
            )
        end
        df = DataFrame(dict)

        open(joinpath(savepath, "$(i).json"), "w") do f
            write(f, objecttable(df))
        end
    end

    savepath
end

save_dbase (generic function with 1 method)

## Params

In [10]:


# PARAMS
Δt = 0.005  # time step duration of upsampled trial
# N = 5    # number of chunks per trial
T = 1.0  # duration of each trial "chunk" for training, in seconds
Δs = 5  # distance between sampling's of the track's curvature
S = 50   # lookahead distance for track's curvature.

prediction_Δt_shift = .75  # prediction targets are velocities at this offset in the future (in seconds)
prediction_Δt_shift_samples = int(prediction_Δt_shift / Δt)


speeds_smoothing_window = 21
SELECT_N_TRIALS = length(trials)


outliers_limits = Dict(
    :ψ => 1.5,
    :ω => 12,
    :v̇ => 1000,
    :ω̇ => 100,
)


n_samples_chunk = int(T/Δt)
@info "n samples per trial" n_samples_chunk

[1m[4m[38;2;66;165;245m@Info[22m[24m[39m [38;2;237;180;111m(Main):[39m  [38;2;138;190;255mn samples per trial[39m 
  [2m[38;2;66;165;245m│[22m[39m 
  [2m[38;2;66;165;245m│[22m[39m [2m[38;2;206;147;216m(Int64)[22m[39m  [2m[38;2;66;165;245m▶[22m[39m  [1m[38;2;255;153;89mn_samples_chunk[22m[39m [1m[31m=[22m[39m [38;2;66;165;245m200[39m 
  [1m[2m[38;2;66;165;245m╰────────────────────────────────────────────────[22m[22m[39m 
                      [2mMon, 06 Jun 2022[22m [1m[2m[4m10:41:17[22m[22m[24m 


In [11]:
ctrials = get_ctrials(trials, Δt, speeds_smoothing_window)

for S in (0, 5, 10, 25, 50)
    for prediction_Δt_shift in (0, .25, .5, .75, 1)
        
        savepath = "D:\\Dropbox (UCL)\\Rotation_vte\\Locomotion\\analysis\\RNN\\datasets\\$(S)cm_pred_$(prediction_Δt_shift)"
        if isdir(savepath)
            @warn "skipping " savepath
            continue
        else
            mkpath(savepath)  # crate if it doesn't exist
        # rm(savepath, recursive=true)  # delete content
        # mkpath(savepath)  # crate if it doesn't exist
        end



        entries = create_dbase(ctrials, S, n_samples_chunk, prediction_Δt_shift_samples)
        save_dbase(savepath, entries, S, prediction_Δt_shift)
    end 
end

[1m[4m[38;2;255;167;38m@Warn[22m[24m[39m [38;2;237;180;111m(Main.[4m[38;2;255;238;88mtop-level scope[24m[39m[38;2;237;180;111m):[39m  [38;2;138;190;255mskipping [39m 
  [2m[38;2;255;167;38m│[22m[39m 
  [2m[38;2;255;167;38m│[22m[39m [2m[38;2;206;147;216m(String)[22m[39m  [2m[38;2;255;167;38m▶[22m[39m  [1m[38;2;255;153;89msavepath[22m[39m [1m[31m=[22m[39m [38;2;165;214;167mD:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\RNN\datasets\0cm_pred_0[39m 
  [1m[2m[38;2;255;167;38m╰────────────────────────────────────────────────[22m[22m[39m 
                      [2mMon, 06 Jun 2022[22m [1m[2m[4m10:54:27[22m[22m[24m 
[1m[4m[38;2;255;167;38m@Warn[22m[24m[39m [38;2;237;180;111m(Main.[4m[38;2;255;238;88mtop-level scope[24m[39m[38;2;237;180;111m):[39m  [38;2;138;190;255mskipping [39m 
  [2m[38;2;255;167;38m│[22m[39m 
  [2m[38;2;255;167;38m│[22m[39m [2m[38;2;206;147;216m(String)[22m[39m  [2m[38;2;255;167;38m▶[22m

 
  [2m[38;2;255;167;38m│[22m[39m 
  [2m[38;2;255;167;38m│[22m[39m [2m[38;2;206;147;216m(String)[22m[39m  [2m[38;2;255;167;38m▶[22m[39m  [1m[38;2;255;153;89msavepath[22m[39m [1m[31m=[22m[39m [38;2;165;214;167mD:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\RNN\datasets\10cm_pred_0.75[39m 
  [1m[2m[38;2;255;167;38m╰────────────────────────────────────────────────[22m[22m[39m 
                      [2mMon, 06 Jun 2022[22m [1m[2m[4m10:54:28[22m[22m[24m 
[1m[4m[38;2;255;167;38m@Warn[22m[24m[39m [38;2;237;180;111m(Main.[4m[38;2;255;238;88mtop-level scope[24m[39m[38;2;237;180;111m):[39m  [38;2;138;190;255mskipping [39m 
  [2m[38;2;255;167;38m│[22m[39m 
  [2m[38;2;255;167;38m│[22m[39m [2m[38;2;206;147;216m(String)[22m[39m  [2m[38;2;255;167;38m▶[22m[39m  [1m[38;2;255;153;89msavepath[22m[39m [1m[31m=[22m[39m [38;2;165;214;167mD:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\RNN\datasets\10cm_pred_1[39m 
  [1m[