# TODO List

1. Run the `conditional/record_script.py` script for num_episodes=100 cases, also try setting `sample_params=False` and see if this still produces randomized positions, if so use this.
2. In the `shape_sorting_task.py` file adjust `shape_center` and `shape_d` variables so objects are nearly always visible. The inital view frame in this notebook can be used for this.
3. For each of the episodes apply the conditional servoing selection function to show the closest demonstration. Plot these so that it is possible to scroll through the closest.


In [None]:
import os
import numpy as np
import logging
import copy
import json
from glob import glob
from tqdm import tqdm

from flow_control.recombination.record_multi import get_configurations
from flow_control.servoing.playback_env_servo import PlaybackEnvServo

root_dir = "/tmp/flow_experiments3"

demo_cfgs = get_configurations(prefix="demo")
recordings = []
for _, _, _, demo_seed, demo_dir in demo_cfgs:
    recordings.append(demo_dir)    
print("number of recordings:", len(recordings))
print(recordings)

In [None]:
playbacks = [PlaybackEnvServo(rec) for rec in recordings[:]]

In [None]:
get_ipython().run_line_magic('matplotlib', 'notebook')
from ipywidgets import widgets, interact, Layout
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1)
fig.suptitle("Initial Frame")
ax.set_axis_off()
image_h = ax.imshow(playbacks[0].cam.get_image()[0])

def update(w):
    image = playbacks[w].cam.get_image()[0]
    image_h.set_data(image)
    fig.canvas.draw_idle()
    
    fg_mask = playbacks[w].get_fg_mask()
    if fg_mask is not None:
        print("percent fg:", np.mean(fg_mask)*100)

    
    
slider_w = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
                             layout=Layout(width='70%'))
interact(update, w=slider_w)


In [None]:
print("ep\t len\t rew")
for i, pb in enumerate(playbacks):
    print(i, "\t", len(pb), "\t", pb[-1].data["rew"])

In [None]:
files = [out[-1] for out in get_configurations(root_dir)]
print(files)

In [None]:
from pathlib import Path
import json
import numpy as np

def hacky_get_reward(save_dir, demo_seed):
    save_dir2 = f"{save_dir}_{demo_seed:03d}"
    frame_names = sorted(glob(f"{save_dir2}/frame_*.npz"))
    return np.load(frame_names[-1])["rew"].item()
    #rewards = [np.load(frame)["rew"].item() for frame in frame_names]
        
    
row_rewards = []

demo_cfgs = get_configurations(prefix="demo")
for _, _, _, demo_seed, demo_dir in demo_cfgs:
    run_cfgs =  get_configurations(prefix="run")
    rewards = []
    for _, _, _, seed, save_dir in run_cfgs:
        reward = hacky_get_reward(save_dir, demo_seed)
        rewards.append(reward)
    row_rewards.append(rewards)
    print(len(rewards))

        
print("done!")
tmp = np.array(row_rewards)

print(tmp)
print(np.mean(tmp))
print(np.diag(tmp))
print(np.mean(tmp, axis=1))
fg_pixels = [np.mean(pb.get_fg_mask())*100 for pb in playbacks]
for i, fg in enumerate(fg_pixels):
    print(i, round(fg,2))
X,Y,c = [], [], []
for x in range(tmp.shape[0]):
    for y in range(tmp.shape[1]):
        X.append(fg_pixels[x])
        Y.append(fg_pixels[y])
        c.append(tmp[x,y])

fig, ax = plt.subplots()
ax.scatter(X,Y,c=c)
ax.set_xlabel("demo %FG pixels")
plt.show()
"""
flatten_list = [element for sublist in row_rewards for element in sublist]   
print("XXX", np.mean(flatten_list))
"""

In [None]:
np.mean([0.2, 0.4, 0.,  0.3, 0.2, 0.2, 0.3, 0.2, 0.3, 0.4])