In [None]:
using MakieCore
import MakieCore: convert_arguments 

# This part would be the extension package (calls like lines, scatter, etc.)
using CairoMakie

# To match previous signature
using OrderedCollections

# Some nice themes that we can decide to use or not
using MakiePublication

import NamedTrajectories as NT

- convert_arguments cannot plot multiple lines? seemingly not without Makie (see nb 03). Stick with PointBased (MakieCore) or Series (Makie)?

Examples to follow
- https://github.com/JuliaGaussianProcesses/AbstractGPsMakie.jl/blob/main/src/recipes/gpsample.jl
- https://github.com/MakieOrg/Makie.jl/issues/837#issuecomment-846052033

In [None]:
function convert_arguments(
    P::MakieCore.PointBased, 
    traj::NT.NamedTrajectory{<:Any},
    component::Int,
)
    times = NT.get_times(traj)
    positions = [(times[j], traj.data[component, j]) for j in eachindex(times)]
    return convert_arguments(P, positions)
end

# define the default plot type (preserve `plot` calls)
MakieCore.plottype(::NT.NamedTrajectory, ::Int) = MakieCore.Lines


In [None]:
traj = rand(NT.NamedTrajectory, 10);

In [None]:
f = Figure()
ax = Axis(f[1, 1])
[lines!(ax, traj, i, color=range(0, 1, length=traj.T), colormap=:Reds) for i in traj.components[:u]]
f

Notes
- Unsure why theme is not being set by the recipe. Use attributes to set it manually.

Little notes
- Makie defaults to Makie.Axis as the preferred axis.
- Can you set relational features of the figure (e.g. colgaps) in the recipe?
- How to deal with multiple symbols? (See nb 03 for Series). Recipes are only allowed one axis. SpecApi is experimental.

In [None]:
@recipe(PlotComponents, traj, name) do scene
    Attributes(
        linestyle=theme(scene, :linestyle),
        linewidth=theme(scene, :linewidth),
        marker=theme(scene, :marker),
        markersize=theme(scene, :markersize),
        merge=false
    )
end

# Adds the ability to recall plot labels for a legend
Makie.get_plots(P::PlotComponents) = P.plots

# Dealing with a default plot of an existing component
function Makie.plot!(
    P::PlotComponents{<:Tuple{<:NT.NamedTrajectory, Symbol}};
    kwargs...
)
    lift(P[:traj], P[:name]) do traj, name
        times = NT.get_times(traj)
        for (i, comp) in enumerate(traj.components[name])
            scatterlines!(
                P, times, traj.data[comp, :], 
                label="$name$i",
                linewidth = P[:linewidth], 
                linestyle = P[:linestyle],
                marker = P[:marker],
                markersize = P[:markersize],
            )
        end
    end
    return P
end

# Dealing with a transformation to a new component
function Makie.plot!(
    P::PlotComponents{<:Tuple{<:NT.NamedTrajectory, Symbol, Symbol, <:AbstractVector{<:Function}}};
    kwargs...
)
    lift(P[:traj], P[:name], P[3], P[4]) do traj, input_name, output_name, transformations
        times = NT.get_times(traj)

        @assert length(transformations) == length(times)

        # Descriptive error message
        output_data = try
            stack([f(input) for (f, input) in zip(transformations, eachcol(traj[input_name]))])
        catch
            throw(ArgumentError("Transformation $(input_name) -> $(output_name) failed."))
        end

        for (j, output_row) in enumerate(eachrow(output_data))
            scatterlines!(
                P, times, output_row, 
                label="$output_name$j",
                linewidth = P[:linewidth], 
                linestyle = P[:linestyle],
                marker = P[:marker],
                markersize = P[:markersize],
            )
        end
    end
    return P
end

function Makie.plot!(
    P::PlotComponents{<:Tuple{<:NT.NamedTrajectory, Symbol, Symbol, <:Function}};
    kwargs...
)
    lift(P[:traj], P[:name], P[3], P[4]) do traj, input_name, output_name, transformation
        times = NT.get_times(traj)

        # Descriptive error message
        output_data = try
            stack(transformation.(eachcol(traj[input_name])))
        catch
            throw(ArgumentError("Transformation $(input_name) -> $(output_name) failed."))
        end

        for (j, output_row) in enumerate(eachrow(output_data))
            scatterlines!(
                P, times, output_row, 
                label="$output_name$j",
                linewidth = P[:linewidth], 
                linestyle = P[:linestyle],
                marker = P[:marker],
                markersize = P[:markersize],
            )
        end
    end
    return P
end

- I'm not convinced that with_theme behaves correctly (doens't seem to reset if called in a block, but works if a function is passed)

In [None]:
function test_plot()
    f = Figure()

    ax = Axis(f[1, 1])
    plotcomponents!(ax, traj, :x, :y, x -> x .^ 2)
    Legend(f[1, 2], ax)

    ax = Axis(f[2, 1])
    plotcomponents!(
        ax, traj, :x, :y, [x -> x for _ in 1:traj.T]
    )
    Legend(f[2, 2], ax)

    p = plotcomponents(f[3, 1], traj, :u,
        linewidth=5, linestyle=:dash, marker=:+
    )
    # plots, labels = Makie.get_labeled_plots(ax, merge=false, unique=false)
    if merge
        Legend(f[3, 2], [p.plots], ["u"])
    else
        Legend(f[3, 2], ax)
    end

    f
end

function test_plot(theme::MakieCore.Theme)
    with_theme(theme, rowgap=0.0) do
        test_plot()
    end
end

In [None]:
Makie.get_plots(p.plot)

In [None]:
f = Figure()
p = plotcomponents(f[3, 1], traj, :u, linewidth=5, linestyle=:dash, marker=:+)
Legend(f[3, 2], Makie.get_plots(p.plot), ["u", "u"], merge=true, unique=false)
f

In [None]:
test_plot(theme_minimal())

In [None]:
test_plot(theme_web())


In [None]:
function plot_trajectory(
    traj::NT.NamedTrajectory,
    comps::Union{AbstractVector{Symbol}, Tuple{Vararg{Symbol}}}=traj.names;
)
    fig = Figure()

    for (i, comp) in enumerate(comps)
        ax = Axis(
            fig[i, 1],
            xticklabelsvisible = i == length(comps),
            xtickalign=1,
        )

        plotcomponents!(ax, traj, comp)
        Legend(fig[i, 2], ax)
    end

    for i in 1:length(comps)-1
        rowgap!(fig.layout, i, 0.0)
    end

    initial = length(comps)

    for (i, comp) in enumerate(comps)
        ax = Axis(
            fig[initial + i, 1],
            xticklabelsvisible = i == length(comps),
            xtickalign=1,
        )

        plotcomponents!(ax, traj, comp)
        Legend(fig[initial + i, 2], ax)
    end

    for i in 1:length(comps)-1
        rowgap!(fig.layout, initial + i, 0.0)
    end

    fig
end

function plot_trajectory(
    theme::MakieCore.Theme,
    args...;
    kwargs...
)
    with_theme(theme) do
        plot_trajectory(args...; kwargs...)
    end
end

In [None]:
# plot_trajectory(theme_minimal(), traj, [:x, :u])
plot_trajectory(traj, [:x, :u])


In [None]:
function old_plot(
    traj::NamedTrajectory,
    comps::Union{Symbol, Vector{Symbol}, Tuple{Vararg{Symbol}}} = traj.names;

    # ---------------------------------------------------------------------------
    # component specification keyword arguments
    # ---------------------------------------------------------------------------

    ignored_labels::Union{Symbol, Vector{Symbol}, Tuple{Vararg{Symbol}}}=(),
    ignore_timestep::Bool=true,
    merge_components::Bool=false,
    use_latex::Bool=true,

    # ---------------------------------------------------------------------------
    # transformation keyword arguments
    # ---------------------------------------------------------------------------

    # transformations
    transformations::OrderedDict{Symbol, <:Union{Function, Vector}} =
        OrderedDict{Symbol, Union{Function, Vector{Function}}}(),

    # labels for transformed components
    transformation_labels::Union{
        Nothing,
        OrderedDict{
            Symbol,
            <:Union{
                Nothing,
                String,
                Vector{<:Union{Nothing, <:AbstractString}}
            }
        }
    } = OrderedDict{Symbol, Union{Nothing, Vector{Nothing}}}(name => f isa Vector ? fill(nothing, length(f)) : nothing for (name, f) ∈ transformations),

    # whether or not to include labels for transformed components
    include_transformation_labels::Union{
        Bool,
        Vector{<:Union{Bool, <:AbstractVector{Bool}}}
    } = false,

    # titles for transformations
    transformation_titles::Union{
        Nothing,
        OrderedDict{Symbol, <:Union{<:AbstractString, <:AbstractVector{<:AbstractString}}}
    } = nothing,

    # ---------------------------------------------------------------------------
    # style keyword arguments
    # ---------------------------------------------------------------------------

    fig_size::Tuple{Int, Int}=(1200, 800),
    titlesize::Int=25,
    markersize=5,
    series_color::Symbol=:glasbey_bw_minc_20_n256,
    xlims::Union{Nothing, Tuple{Real, Real}}=nothing,
    ylims::Union{Nothing, NamedTuple, Tuple{Real, Real}}=nothing,

    # ---------------------------------------------------------------------------
    # CairoMakie.series! keyword arguments
    # ---------------------------------------------------------------------------
    kwargs...
)
    # set up parse function
    parse = use_latex ? latexstring : string

    # convert single symbol to vector: comps
    if comps isa Symbol
        comps = [comps]
    end

    # # if transformations is only a bool, convert to vector of bools
    # if include_transformation_labels isa Bool
    #     include_transformation_labels = fill(
    #         include_transformation_labels,
    #         length(transformations)
    #     )
    # end

    # @assert length(include_transformation_labels) == length(transformations)

    # include_transformation_labels = Any[include_transformation_labels...]

    # for (i, (b, f)) ∈ enumerate(zip(
    #     include_transformation_labels,
    #     values(transformations)
    # ))
    #     if f isa Function
    #         @assert b isa Bool
    #     else
    #         if b isa Bool
    #             include_transformation_labels[i] = fill(b, length(f))
    #         else
    #             @assert length(b) == length(f)
    #         end
    #     end
    # end

    # convert single symbol to iterable: ignored labels
    # if ignored_labels isa Symbol
    #     ignored_labels = Symbol[ignored_labels]
    # elseif ignored_labels isa Tuple
    #     ignored_labels = Symbol[ignored_labels...]
    # end

    @assert all([key ∈ keys(traj.components) for key ∈ comps])
    @assert all([key ∈ keys(traj.components) for key ∈ keys(transformations)])

    # ts = get_times(traj)

    # create figure
    fig = Figure(size=fig_size)

    # initialize axis count
    ax_count = 0

    # plot transformed components
    for ((name, f), include_transformation_labels_k) ∈ zip(
        transformations,
        include_transformation_labels
    )
        if f isa Vector

            @assert all([fⱼ isa Function for fⱼ in f])

            for (j, fⱼ) in enumerate(f)

                # data matrix for name component of trajectory
                data = traj[name]

                # apply transformation fⱼ to each column of data
                transformed_data = mapslices(fⱼ, data; dims=1)

                # ylims
                if ylims isa NamedTuple
                    if haskey(ylims, name)
                        ylims_name = ylims[name]
                    else
                        ylims_name = (nothing, nothing)
                    end
                else
                    ylims_name = ylims
                end

                if isnothing(transformation_titles)
                    title = parse(name, "(t)", "\\text{ transformation } $j")
                else
                    title = transformation_titles[name][j]
                end

                # create axis for transformed data
                ax = Axis(
                    fig[ax_count + 1, 1];
                    title=title,
                    titlesize=titlesize,
                    xlabel=L"t",
                    limits =(xlims, ylims_name)
                )

                # plot transformed data
                if include_transformation_labels_k[j]

                    if isnothing(transformation_labels[name][j])
                        labels = string.(1:size(transformed_data, 2))
                    else
                        labels = [
                            parse(transformation_labels[name][j], "_{$i}")
                                for i = 1:size(transformed_data, 2)
                        ]
                    end

                    series!(
                        ax,
                        ts,
                        transformed_data;
                        color=series_color,
                        markersize=markersize,
                        labels=labels
                    )
                    # create legend
                    Legend(fig[ax_count + 1, 2], ax)
                else
                    series!(
                        ax,
                        ts,
                        transformed_data;
                        color=series_color,
                        markersize=markersize
                    )
                end

                # increment axis count
                ax_count += 1
            end
        else

            # data matrix for name componenent of trajectory
            data = traj[name]

            # apply transformation f to each column of data
            transformed_data = mapslices(f, data; dims=1)

            if isnothing(transformation_titles)
                title = parse(name, "(t)", "\\text{ transformation}")
            else
                if isnothing(transformation_titles[name])
                    title = parse(name, "(t)", "\\text{ transformation}")
                else
                    title = transformation_titles[name]
                end
            end

            # ylims
            if ylims isa NamedTuple
                if haskey(ylims, name)
                    ylims_name = ylims[name]
                else
                    ylims_name = (nothing, nothing)
                end
            else
                ylims_name = ylims
            end

            if isnothing(transformation_titles)
                title = parse(name, "(t)", "\\text{ transformation}")
            else
                title = transformation_titles[name]
            end

            # create axis for transformed data
            ax = Axis(
                fig[ax_count + 1, :];
                title=title,
                titlesize=titlesize,
                xlabel=L"t",
                limits=(xlims, ylims_name)
            )

            # plot transformed data
            if include_transformation_labels_k

                if isnothing(transformation_labels[name])
                    labels = [
                        parse(name, "_{$i}")
                            for i = 1:size(transformed_data, 2)
                    ]
                else
                    labels = [
                        parse(transformation_labels[name], "_{$i}")
                            for i = 1:size(transformed_data, 2)
                    ]
                end

                series!(
                    ax,
                    ts,
                    transformed_data;
                    color=series_color,
                    markersize=markersize,
                    labels=labels
                )
                # create legend
                Legend(fig[ax_count + 1, 2], ax)
            else
                series!(
                    ax,
                    ts,
                    transformed_data;
                    color=series_color,
                    markersize=markersize,
                    kwargs...
                )
            end

            # increment axis count
            ax_count += 1
        end
    end
    
    # create shared axis for data
    if merge_components
        # ylims
        if ylims isa NamedTuple
            throw(ArgumentError("plot_ylims must be a tuple if merge_components is true"))
        end
        ax = Axis(
            fig[ax_count + 1, 1];
            title=parse("Trajectory"),
            titlesize=titlesize,
            xlabel=L"t",
            limits=(xlims, ylims)
        )
    end

    # plot normal components
    for (i, name) in enumerate(comps)

        if traj.timestep isa Symbol && name == traj.timestep && ignore_timestep
            continue
        end

        # data matrix for name componenent of trajectory
        data = traj[name]
        
        # create axis for data
        if !merge_components
            # ylims
            if ylims isa NamedTuple
                if haskey(ylims, name)
                    ylims_name = ylims[name]
                else
                    ylims_name = (nothing, nothing)
                end
            else
                ylims_name = ylims
            end
            
            ax = Axis(
                fig[ax_count + 1, 1];
                title=parse(name, "(t)"),
                titlesize=titlesize,
                xlabel=L"t",
                limits =(xlims, ylims_name)
            )
        end

        # create labels if name is not in ignored_labels
        if name ∈ ignored_labels
            labels = nothing
        else
            labels = [parse(name, "_{$i}") for i = 1:size(data, 1)]
        end

        # plot data
        series!(
            ax,
            ts,
            data;
            color=series_color,
            markersize=markersize,
            labels=labels,
            kwargs...
        )

        # create legend
        if name ∉ ignored_labels
            Legend(fig[ax_count + 1, 2], ax)
        end

        # increment axis count
        if !merge_components
            ax_count += 1
        end
    end

    return fig
end