In [None]:
""" parse all tasks for all rows in datasets.
save these parses for subsequent analyses.

DOES:
(1) Gets parses and saves
(2) Loads those parses, converts to MP, then saves
"""


In [1]:
# Load datasets
from pythonlib.dataset.dataset import Dataset
from pythonlib.drawmodel.parsing import *
import numpy as np
import pandas as pd

In [None]:
# OK version (good balance between fast and variety of parses)
params_parse = {
    "configs_per":10,
    "trials_per":50,
    "max_ntrials":75, 
    "max_nwalk":75,
    "max_nstroke":100
}
kparses = 50

# # FAST VERSION
# params_parse = {
#     "configs_per":5,
#     "trials_per":20,
#     "max_ntrials":50, 
#     "max_nwalk":50,
#     "max_nstroke":50
# }

# [NOTE: THIS IS THE VERSION I ORIGIANLLY USED]
# # VERY FAST
# params_parse = {
#     "configs_per":5,
#     "trials_per":10,
#     "max_ntrials":25, 
#     "max_nwalk":25,
#     "max_nstroke":25
# }
# kparses = 10

return_in_strokes_coords = True


In [2]:
# if animal == "Red":
#     path_list = [
#         "/data2/analyses/database/Red-lines5-formodeling-210329_005719",
# #         "/data2/analyses/database/Red-arc2-formodeling-210329_005550",
# #         "/data2/analyses/database/Red-shapes3-formodeling-210329_005200",
# #         "/data2/analyses/database/Red-figures89-formodeling-210329_005443"
#     ]
# elif animal=="Pancho":
#     path_list = [
#         "/data2/analyses/database/Pancho-lines5-formodeling-210329_014835",
# #         "/data2/analyses/database/Pancho-arc2-formodeling-210329_014648",
# #         "/data2/analyses/database/Pancho-shapes3-formodeling-210329_002448",
# #         "/data2/analyses/database/Pancho-figures89-formodeling-210329_000418"
#     ]
    
# append_list = None

animal_list = ["Pancho"]
expt_list = ["lines5"]

In [None]:
use_extra_junctions=True
score_ver = "travel"
score_norm = "negative"

for expt in expt_list:
    for animal in animal_list:

        # Load single dataset
        D = Dataset([])
        D.load_dataset_helper(animal, expt)
        
        # No need to preprocess data, since can just apply any preprocess afterwards to both task and beahviro

        # Get sketchpad edges
        maxes = []
        for k, v in D.Metadats.items():
            maxes.append(np.max(np.abs(v["sketchpad_edges"].flatten())))
        canvas_max_WH = np.max(maxes)

        # For each row, parse its task
        score_fn = lambda parses: score_function(parses, ver=score_ver, 
                                                 normalization=score_norm, use_torch=True)

        if False:
            # Just testing, pick a random trial
            import random
            ind = random.sample(range(len(D.Dat)), 1)[0]
            strokes = D.Dat["strokes_task"].values[ind]

            if False:
                from pythonlib.tools.stroketools import strokesInterpolate2
                strokes = strokesInterpolate2(strokes, N=["npts", 100])
            else:
                pass        
            parses, log_probs = get_parses_from_strokes(strokes, canvas_max_WH, 
                                              use_extra_junctions=use_extra_junctions, 
                                                plot=True, return_in_strokes_coords=return_in_strokes_coords, 
                                                k=kparses,configs_per = params_parse["configs_per"],
                                               trials_per = params_parse["trials_per"],
                                               max_ntrials = params_parse["max_ntrials"],
                                               max_nstroke = params_parse["max_nstroke"],
                                               max_nwalk = params_parse["max_nwalk"],
                                               )


        # save params
        params_parse["canvas_max_WH"] = canvas_max_WH
        params_parse["use_extra_junctions"] = use_extra_junctions
        params_parse["return_in_strokes_coords"] = return_in_strokes_coords
        params_parse["score_ver"] = score_ver
        params_parse["score_norm"] = score_norm

        # NEW VERSION - only do once for each unique task
        tasklist = sorted(list(set(D.Dat["unique_task_name"])))
        # Collect parses
        PARSES = []

        for i, task in enumerate(tasklist):
            # find the first row that has this task
            row = D.Dat[D.Dat["unique_task_name"]==task].iloc[0]
            assert row["unique_task_name"]==task

            print(i, "-",  task)
            strokes = row["strokes_task"]

            parses, log_probs = get_parses_from_strokes(strokes, canvas_max_WH, 
                                                      use_extra_junctions=use_extra_junctions, 
                                                        plot=False, return_in_strokes_coords=return_in_strokes_coords, 
                                                        k=kparses,configs_per = params_parse["configs_per"],
                                                       trials_per = params_parse["trials_per"],
                                                       max_ntrials = params_parse["max_ntrials"],
                                                       max_nstroke = params_parse["max_nstroke"],
                                                       max_nwalk = params_parse["max_nwalk"],
                                                       )
            assert len(parses)>0, "why?"

            PARSES.append(
                {"strokes_task":strokes,
                 "unique_task_name":task,
                 "parses":parses,
                 "parses_log_probs":log_probs}
            )

In [None]:
        # === save as dataframe
        import os
        import pickle
        
        paththis = f"{D.Metadats[0]['path']}/parses_good"
        os.makedirs(paththis, exist_ok=True)
#         DIR, FNAME = os.path.split(paththis)

        fname_parse = f"{paththis}/parses.pkl"
        print("Saving at:")
        print(fname_parse)

        PARSES  = pd.DataFrame(PARSES)
        PARSES.to_pickle(fname_parse)

        fname_params = f"{paththis}/parses_params.pkl"
        with open(fname_params, "wb") as f:
            pickle.dump(params_parse, f)

## CONVERT SAVED PARSES TO MOTOR PROGRAMS

In [3]:
##### Load dataset (use updated method, helper function)
expt = "lines5"
for animal in ["Pancho"]:
    D = Dataset([])
    D.load_dataset_helper(animal, expt)

    # Load pre-extracted parses
    D.parsesLoadAndExtract(parses_good=True)

    # Plot parsa and behavior for comparison
    if False:
        import random
        D.parsesChooseSingle()
        ind = random.randint(0, len(D.Dat))
        D.plotSingleTrial(ind, ["task", "beh", "parse"], sharex=True, sharey=True);

    if False:
        # Fit motor program to parse (i.e, to arbitrary strokes/splines)
        strokes_list = D.Dat["parses"].values[ind]
        print(len(strokes_list))

        from pythonlib.bpl.strokesToProgram import infer_MPs_from_strokes

        sketchpad_edges = D.Metadats[0]["sketchpad_edges"].T
        MPlist, score_all = infer_MPs_from_strokes(strokes_list, [0,1], {}, sketchpad_edges)

        del D.Dat["tmp3"]

        D.recenter(method="each_beh_center", apply_to="all")

    # Fit motor programs to all parses (and then save)
    D.bpl_extract_and_save_motorprograms_parses(parses_good=True)

Did not load data!!!
Searching using this string:
/data2/analyses/database//*Pancho*lines5*/*dat*.pkl
-- Splitting off dir from fname
Found this many paths:
1
---
/data2/analyses/database/Pancho-lines5-formodeling-210329_014835
Searching using this string:
/data2/analyses/database/BEH/*Pancho*lines5*/*dat*.pkl
-- Splitting off dir from fname
Found this many paths:
0
----------------
Currently loading: /data2/analyses/database/Pancho-lines5-formodeling-210329_014835
Loaded metadat:
{'sketchpad_edges': array([[-327.2, -327.2],
       [ 327.2,  429.6]])}
----
Resetting index
=== CLEANING UP self.Dat ===== 
ORIGINAL: online abort values
Series([], Name: online_abort, dtype: int64)
kept 6949 out of 6949
removed all cases with online abort not None
Deleted unused columns from self.Dat
num parses -- num cases
50    5623
20     315
48     218
44     195
47     142
49     121
45      94
42      73
41      49
35      41
40      24
46      16
43       8
39       6
11       4
38       3
33       2

Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving che

Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving che

Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving che

Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving che

Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving che

Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving che

Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving che

### SCRATCH

In [None]:
from gns.inference.parsing.top_k import search_parse
import torch
parses = [torch.rand(10,2) for _ in range(15)]
def score_fn(x):
    return torch.rand((len(x),))

parses, probs = search_parse(parses, score_fn, configs_per=10, trials_per=20)
print(len(parses))
print(probs)
print(parses[0])

### TESTING, seeing (1) effect of params on processing time and result (parse diversity and quality), and (2) overlaying after return to original coords, see if is accurate

In [None]:
k=20
parses, log_probs_k = get_parses_from_strokes(strokes, canvas_max_WH, 
                                              use_extra_junctions=use_extra_junctions, plot=True,
                                             return_in_strokes_coords=False, k=k)


In [None]:
summarizeParses(parses)

In [None]:
summarizeParses(parses)

In [None]:
# === PLOT RESULTS, OVERLAY WITH ORIGIANL
from pythonlib.drawmodel.strokePlots import plotDatStrokes
fig, ax = plt.subplots()

for strokes_out in parses[:5]:
    # strokes_out = parses[1]

    plotDatStrokes(strokes, ax=ax)

    for s in strokes_out:
        ax.plot(s[:,0], s[:,1], 'xr')

