In [1]:
# activate package, basic imports + load data
include("/Users/federicoclaudi/Documents/Github/LocomotionControl/analysis/behavior/analysis_fixtures.jl")
include("/Users/federicoclaudi/Documents/Github/LocomotionControl/analysis/behavior/analysis_visuals.jl")
using Dierckx

import jcontrol: State, movingaverage

[32m[1m  Activating[22m[39m project at `~/Documents/Github/LocomotionControl/analysis/behavior`


[1m[4m[38;2;66;165;245m@Info[22m[24m[39m [38;2;237;180;111m(Main):[39m  [38;2;138;190;255mLoaded 1021 trials[39m 
  [1m[2m[38;2;66;165;245m╰────────────────────────────────────────────────[22m[22m[39m 


                      [2mFri, 13 May 2022[22m [1m[2m[4m22:43:07[22m[22m[24m 
[1m[4m[38;2;66;165;245m@Info[22m[24m[39m [38;2;237;180;111m(Main):[39m  [38;2;138;190;255mAfter tortuosity analysis, discarded 9.109% of trials | 928[0m

 
  [2m[38;2;66;165;245m│[22m[39m                [38;2;138;190;255m[38;2;138;190;255mtrials left[39m[0m[39m 
  [1m[2m[38;2;66;165;245m╰────────────────────────────────────────────────[22m[22m[39m 
                      [2mFri, 13 May 2022[22m [1m[2m[4m22:43:09[22m[22m[24m 


## Generate dataset for RNN training
It takes the tracking data processed through the Julia behavior analysis code and creates a dataset for training an RNN on.

The frist step is to get the trial's trajectory in the track's curvilinear coordinates system and to upsample the data so that they are
expressed at an arbitrary framerate $\Delta t$.

Next, we sample `N` random time points from each trial and take the next `T` seconds of the trial as a training sample (to avoid training over an entire trial).
Then, at each frame in the selected chunk of trial we get the track's curvature at each $\Delta s$ point from the mouse's current position to `S` cm later.

In [2]:
# PARAMS
Δt = 0.005  # time step duration of upsampled trial
N = 5    # number of chunks per trial
T = .1  # duration of each trial "chunk" for training
Δs = 5  # distance between sampling's of the track's curvature
S = 50   # lookahead distance for track's curvature.

50

### Prepare tracking data
Load tracking data, turn into track's curvilinear coordinates and upsample

In [3]:

function upsample_framerate(x, Δt)
    t_original = 0:1/60:(length(x)/60 + 1)
    t_original = t_original[1:length(x)]
    t_upsample = 1:Δt:(t_original[end])
    spl = Spline1D(t_original, x; k=2) # k: order of interpolant; can be between 1-5
    return spl(t_upsample)
end

mutable struct CurvilinearCoordsTrials
    x::Vector
    y::Vector
    s::Vector
    n::Vector
    ψ::Vector
    V::Vector  # velocity vector: V = √(v² + u²)
    ω::Vector
    V̇::Vector  # acceleration
    ω̇::Vector  # angular acceleration
end

function CurvilinearCoordsTrials(states::Vector{State})
    v = upsample_framerate(sqrt.(getfield.(states, :v).^2 + getfield.(states, :u).^2), Δt)
    ω = upsample_framerate(getfield.(states, :ω), Δt)

    v̇ = movingaverage(diff(v), 21) / Δt
    ω̇ = movingaverage(diff(ω), 21) / Δt

    return CurvilinearCoordsTrials(
        upsample_framerate(getfield.(states, :x), Δt),
        upsample_framerate(getfield.(states, :y), Δt),    
        upsample_framerate(getfield.(states, :s), Δt),
        upsample_framerate(getfield.(states, :n), Δt),
        upsample_framerate(getfield.(states, :ψ) .- getfield(states[1], :ψ), Δt),
        upsample_framerate(v, Δt),
        upsample_framerate(ω, Δt),
        v̇,
        ω̇,
    )

end

CurvilinearCoordsTrials

In [4]:
trial_states(trial) = map(frame->State(trial, frame, FULLTRACK; v=trial.v[frame]), 1:length(trial.x))
ctrials = map(trial -> CurvilinearCoordsTrials(trial_states(trial)), trials)
nothing

### Step 2
Cut the trials into chunks.

In doing so, make sure not to take chunks that are too close of the trial to avoid having problems.

In [5]:
chunks = []
for trial in ctrials
    # timesteps = 1:findlast(trial.s .< (trial.s[end]-(S+15)))
    # timesteps = timesteps[1:end - (Int ∘ round)(T * 1000)]
    timesteps = 1:(length(trial.s) - (Int ∘ round)(T * 1000))
    if length(timesteps) == 0
        continue
    end

    starts = (Int ∘ round).(rand(timesteps, N))
    stops = (Int ∘ round).(starts .+ (T * 1000))


    for (start, stop) in zip(starts, stops)
        if stop <= start || stop >= length(trial.n)
            continue
        end
        push!(chunks, 
            CurvilinearCoordsTrials(
                trial.x[start:stop],
                trial.y[start:stop],
                trial.s[start:stop],
                trial.n[start:stop],
                trial.ψ[start:stop],
                trial.V[start:stop],
                trial.ω[start:stop],
                trial.V̇[start:stop],
                trial.ω̇[start:stop],
            )
        )
    end
end

@info "Got $(length(chunks)) chunks out of $(length(trials)) trials"
(length(trials) * 3 - length(chunks))/3

[1m[4m[38;2;66;165;245m@Info[22m[24m[39m [38;2;237;180;111m(Main):[39m  [38;2;138;190;255mGot 4636 chunks out of 928 trials[39m 
  [1m[2m[38;2;66;165;245m╰────────────────────────────────────────────────[22m[22m[39m 
                      [2mFri, 13 May 2022[22m [1m[2m[4m22:46:54[22m[22m[24m 


-617.3333333333334

### Step 3
for each chunk get the track's curvature ahead at each frame

In [6]:
struct DatasetEntry
    n::Vector{Float64}
    ψ::Vector{Float64}
    s::Vector{Float64}
    V::Vector{Float64}
    ω::Vector{Float64}
    V̇::Vector{Float64}
    ω̇::Vector{Float64}
    k::Matrix{Float64}   # stores the curvature at each waypoint and each frame
end

In [7]:
waypoints = collect(0:Δs:S)
waypoints_idxs = 1:length(waypoints)

entries = []
for chunk in chunks
    # get curvature
    chunk_curv = zeros(length(waypoints_idxs), length(chunk.s))

    

    for frame in 1:length(chunk.x)
        # get relevant S vector from track
        s̄ = FULLTRACK.S .- chunk.s[frame]
        start = argmin(s̄ .^ 2)
        stop = findlast(s̄ .<= S)
        # @assert stop > start
        # @assert FULLTRACK.S[stop] <= (chunk.s[frame]+S)

        svec = FULLTRACK.S[start:stop]
        for (I, sval) in zip(waypoints_idxs, waypoints)
            idx = findfirst((svec .- svec[1]) .> sval)
            idx = isnothing(idx) ? 1 : idx
            # push!(chunk_curv[I], svec[idx])
            chunk_curv[I, frame] = FULLTRACK.κ(svec[idx])
        end
    end

    push!(entries, DatasetEntry(
        chunk.n,
        chunk.ψ,
        chunk.s,
        chunk.V,
        chunk.ω,
        chunk.V̇,
        chunk.ω̇,
        chunk_curv,
    )
    )

end

entries[1].k

11×101 Matrix{Float64}:
  0.128197      0.117093      0.109053     …  -0.000229285  -0.000245414
  0.0468121     0.0453496     0.0442564        6.35898e-5    6.54237e-5
  0.00531312    0.00196971   -0.0013026       -1.9439e-5    -2.33865e-5
 -0.00392382   -0.00100927    0.00164133      -0.00022907   -0.000245198
  0.0106179     0.00874865    0.00694454      -0.00022907   -0.000245198
 -0.00310657   -0.00257541   -0.00206314   …  -0.00022907   -0.000245198
  0.000815265   0.000725415   0.000595535     -0.00022907   -0.000245198
 -0.000223693  -0.000192624  -0.000162369     -0.00022907   -0.000245198
  6.42378e-5    5.47944e-5    4.57873e-5      -0.00022907   -0.000245198
 -3.53179e-7    0.117179      0.109116        -0.00022907   -0.000245198
  0.128309      0.117179      0.109116     …  -0.00022907   -0.000245198

In [8]:
@info "Got $(length(entries)) entries from $(length(chunks)) chunks"

[1m[4m[38;2;66;165;245m@Info[22m[24m[39m [38;2;237;180;111m(Main):[39m  [38;2;138;190;255mGot 4636 entries from 4636 chunks[39m 
  [1m[2m[38;2;66;165;245m╰────────────────────────────────────────────────[22m[22m[39m 
                      [2mFri, 13 May 2022[22m [1m[2m[4m22:56:51[22m[22m[24m 


### Save data
save data to file in a format which can be loaded back in python for RNN training

In [9]:
using JSON: JSON
using DataFrames: DataFrame
import JSONTables: objecttable, jsontable
import OrderedCollections: OrderedDict

savepath = "/Users/federicoclaudi/Dropbox (UCL)/Rotation_vte/Locomotion/analysis/RNN/dataset"

for (i, entry) in enumerate(entries)
    ks = collect(Symbol("k_", i)=>entry.k[i, :] for i in 1:size(entry.k,1 ))
    dict = OrderedDict{Symbol, Union{Vector{Float64}, Matrix{Float64}}}(
        :n=>entry.n,
        :ψ=>entry.ψ,
        :s=>entry.s,
        :V=>entry.V,
        :ω=>entry.ω,
        :v̇=>entry.V̇,
        :ω̇=>entry.ω̇,
        ks...
    )
    df = DataFrame(dict)

    open(joinpath(savepath, "$(i).json"), "w") do f
        write(f, objecttable(df))
    end
end