# Immitation Minimal Viable Program

This notebook tests an imitation system using other demonstrations as pseudo live views.


## 1. Load Existing Demonstrations

In [None]:
from pathlib import Path

import numpy as np
from PIL import Image, ImageOps
from matplotlib import pyplot as plt
from ipywidgets import widgets, interact, Layout
from flow_control.demo.simple_loader_nick import SimpleLoader

root_dir = Path("/home/argusm/CLUSTER/robot_recordings/hand_recordings/view_1/")
num_runs = len(list(root_dir.iterdir()))
loaders = [SimpleLoader(root_dir, run=r) for r in range(num_runs)]

### 1.1 Visualze Demonstration

In [None]:
%matplotlib notebook
fig, ax = plt.subplots(1,figsize=(8, 6))
fig.suptitle("Demonstration Frames")
ax.set_axis_off()
image_h = ax.imshow(loaders[0].get_image(0))

def update(demo_index, frame_index):
    demo_len = loaders[demo_index].get_len()
    if frame_index >= demo_len:
        print(f"invalid frame index: {frame_index}, demo length: {demo_len}")
        frame_index = demo_len -1
    image = loaders[demo_index].get_image(frame_index)
    image_h.set_data(image)
    fig.canvas.draw_idle()
    
slider_w = widgets.IntSlider(min=0, max=num_runs-1, step=1, value=0,
                             layout=Layout(width='70%'))
slider_i = widgets.IntSlider(min=0, max=200-1, step=1, value=0,
                             layout=Layout(width='70%'))

interact(update, demo_index=slider_w, frame_index=slider_i)

In [None]:
# Annotate data: clip demonstrations to frames where manipulation happens
motion_bounds = {0:(25, 52),
                 1:(30, 120),
                 2:(38, 100),
                 3:(36, 71)}

# 2. Find Object Trajectories (2D)

## 2.1 Find Object Segmentation Mask (using motion between demos)

This finds the object segmentation for the initial frame using the motion of the object between different demos. This works if the camera position is static. Instead of this we could also use:
1. Manual annotation
2. Hands23 (Nick is planning to do this)
3. Segment-Anything

In [None]:
# Load Servoing Module
from flow_control.flow.module_raft import FlowModule
from flow_control.flow.flow_plot import FlowPlot

flow_module = FlowModule(size=(640,480))

In [None]:
live_index = 0
demo_index = (live_index + 1) % len(loaders)

live_rgb = loaders[live_index].get_image(0)
demo_rgb = loaders[demo_index].get_image(0)

# compute flow, magnitude and then threshold
flow = flow_module.step(live_rgb, demo_rgb)
fn = np.linalg.norm(flow, axis=2)
flow_threshold = 5
object_segmentation = fn > flow_threshold

In [None]:
# plot flow image
fp = FlowPlot()
flow_image = fp.compute_image(flow)
fig, ax = plt.subplots(1,3, figsize=(24, 6))
[x.set_axis_off() for x in ax]
ax[0].imshow(live_rgb)
ax[1].imshow(demo_rgb)
ax[2].imshow(flow_image)
#ax[2].scatter(centroid_x, centroid_y, marker='x', color='red')
plt.show()

## 2.2 Find Object Keypoints

Currently just use the center of object segmentation.

Find the keypoints for all demonstrations by using optical flow to warp between demonstrations.

In [None]:
centroid = np.mean(np.argwhere(object_segmentation), axis=0)
centroid_x, centroid_y = int(centroid[1]), int(centroid[0])
print(f"centroid for demo {live_index} is ({centroid_x}, {centroid_y})")

In [None]:
# plot centroid
ax = plt.imshow(fn > flow_threshold)
print(flow[centroid_y,centroid_x])
print(centroid_x, centroid_y)
plt.scatter(centroid_x, centroid_y, marker='x', color='red')
plt.show()

In [None]:
assert np.all(centroid == (411, 181))
warped_centroids = [centroid]

plot_warped_centroids = False
for other in (1,2,3):
    live_rgb = loaders[0].get_image(0)
    demo_rgb = loaders[other].get_image(0)
    flow = flow_module.step(live_rgb, demo_rgb)
    
    change = flow[centroid[1], centroid[0]]
    centroid_x_n, centroid_y_n = int(round(centroid_x+change[0])), int(round(centroid_y+change[1]))
    warped_centroids.append((centroid_x_n, centroid_y_n))
    
    if plot_warped_centroids:
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))
        [x.set_axis_off() for x in ax]
        ax[0].imshow(live_rgb)
        ax[1].imshow(demo_rgb)
        ax[0].scatter(centroid_x, centroid_y, marker='x', color='red')
        ax[1].scatter(centroid_x_n, centroid_y_n, marker='x', color='red')
        plt.show()
    
print("warped centroids", warped_centroids)
# should be roughly warped_centroids = [(411, 181), (387, 203), (395, 198), (381, 196)]

## 2.3 Track Keypoint Trajectories

by following the optical flow between demonstration frames.

In [None]:
def get_trajectory(demo_index, centroid, frames):
    centroid_list = [centroid]
    centroid_x, centroid_y = centroid
    for i in range(len(frames)-1):
        frame_a, frame_b = frames[i], frames[i+1]

        image_a = loaders[demo_index].get_image(frame_a)
        image_b = loaders[demo_index].get_image(frame_b)
        flow = flow_module.step(image_a, image_b)
        flow_image = fp.compute_image(flow)

        change = flow[centroid_y, centroid_x]
        centroid_x_n, centroid_y_n = int(round(centroid_x+change[0])), int(round(centroid_y+change[1]))
        centroid_x, centroid_y = centroid_x_n, centroid_y_n
        centroid_list.append((centroid_x,centroid_y))
        
    return np.array(centroid_list)

trajectories = {}
for demo_index in range(len(loaders)):
    centroid = warped_centroids[demo_index]
    frames = np.linspace(motion_bounds[demo_index][0], motion_bounds[demo_index][1], 10).astype(int)
    trajectory = get_trajectory(demo_index, centroid, frames)
    trajectories[demo_index] = trajectory
    print("done with demo_index", demo_index)

In [None]:
# plot live and demo trajectories with smoothing
%matplotlib inline
from scipy import interpolate

def smooth_line(arr, samples=100):
    x, y = zip(*arr)
    #create spline function
    f, u = interpolate.splprep([x, y], s=0)
    #create interpolated lists of points
    xint, yint = interpolate.splev(np.linspace(0, 1, samples), f)
    return np.stack((xint, yint),axis=1)

demo_index = 0
trajectory = trajectories[demo_index]

fig, ax = plt.subplots(1)
ax = [ax,]
[x.set_axis_off() for x in ax]
image_a = loaders[demo_index].get_image(motion_bounds[demo_index][0])
ax[0].imshow(image_a)
#ax[0].plot(trajectory[:, 0], trajectory[:, 1], marker='.', color='lime')
st = smooth_line(trajectory)
ax[0].plot(st[:,0], st[:,1], color='lime')
other_trajectories = []
for i in range(len(loaders)):
    if i == demo_index:
        continue
    trj = trajectories[i].copy()
    start_other = trj[0]
    start_curr= trajectory[0]
    trj += start_curr - start_other
    #ax[0].plot(trj[:, 0], trj[:, 1], marker='.', color='blue')
    trj = smooth_line(trj)
    ax[0].plot(trj[:, 0], trj[:, 1], marker='', color='blue')
    
other_trajectories = np.array(other_trajectories)
plt.show()

In [None]:
# plot the live trajectory
demo_index = 0
trajectory = trajectories[demo_index]
frames = 12, 37, 52
fig_frames = [loaders[0].get_image(f) for f in frames]
fig_frames = np.mean(fig_frames, axis=0).round().astype(np.uint8)
fig_frames = Image.fromarray(fig_frames)
print(fig_frames.size)
fig_frames_small = ImageOps.contain(fig_frames, (640,480))
rel_size = np.array(fig_frames.size) / np.array(fig_frames_small.size)
plt.imshow(fig_frames_small)
plt.plot(trajectory[:, 0]/rel_size[0], trajectory[:, 1]/rel_size[1],marker='.', c='lime')
plt.axis('off')
plt.show()

## 2.4 Model Demo Trajectories (2D)

Create a 2D model of the demo distributions by fitting gaussians for each point.
This probably does resampling according to the percentage of trajectory completed or something.

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import scipy.optimize as opt
import numpy as np
from scipy import interpolate
import scipy.stats as st

samples = 8  # number of gauss curves

# start plot
fig, ax = plt.subplots(1)
ax.set_axis_off()
image_a = loaders[demo_index].get_image(motion_bounds[demo_index][0])
ax.imshow(image_a)
st = smooth_line(trajectory)
ax.plot(st[:,0], st[:,1], color='lime')

sample_list = []
for i in trajectories:
    trj = trajectories[i].copy()
    start_other = trj[0]
    start_curr = trajectories[0][0]
    trj += start_curr - start_other
    res = smooth_line(trj, samples=samples)
    sample_list.append(res)
    #plt.plot(trj[:,0], trj[:,1],".-")  # plot raw points
    #plt.scatter(xint, yint, c=np.linspace(0,1,len(xint)))  # plot smoothed points
    
gauss_n = 50
sample_list = np.array(sample_list)
for s in range(samples):
    x = sample_list[:, s, 0]
    y = sample_list[:, s, 1]
    xmin, xmax = max(x.min()-gauss_n, 0), x.max()+gauss_n
    ymin, ymax = max(y.min()-gauss_n, 0), y.max()+gauss_n
    mean = np.mean(sample_list[:, s, :], axis=0)
    cov = np.cov(sample_list[:, s, :], rowvar=0)
    cov*= 3  # looks nicer
    x, y = np.mgrid[xmin:xmax:25j, ymin:ymax:25j]
    rv = multivariate_normal(mean, cov)
    data = np.dstack((x, y))
    z = rv.pdf(data)
    plt.contour(x, y, z, levels=3, cmap='jet')

plt.show()

## 2.5 Find 3D Trajectories

...maybe start by just using depth observation.


# 3. Find Object Masks

Here we try to find the segmentation of a live image based on a demo image segmentation.

We only have to do this for the inital images, as its for grasp computation.

Here we use optical flow, several alternative stratagies are possible here:
1. SIFT features
2. Segment-Anything + Dino features

## 3.1 Same Perspective

In [None]:
demo_index = 0  # this needs to match the variable demo_segmentation
live_index = 1

def plot_mask_edge(mask, image):
    # note: does in-place modification of image variable
    edge = np.gradient(mask.astype(float))
    edge = (np.abs(edge[0]) + np.abs(edge[1])) > 0
    image[edge] = (0, 255, 0)
    
demo_rgb = loaders[demo_index].get_image(0)
demo_seg = object_segmentation.copy()
demo_seg = demo_seg[:, :, np.newaxis]
print(demo_seg.shape)

live_rgb = loaders[live_index].get_image(0)
live_seg = flow_module.warp_mask(demo_seg, demo_rgb, live_rgb)

In [None]:
# clean the image up a bit
from skimage.morphology import dilation, erosion
live_seg_2 = dilation(erosion(live_seg))

In [None]:
fig, ax = plt.subplots(1,2, figsize=(12, 6))
[x.set_axis_off() for x in ax]
demo_rgb_plot = np.copy(demo_rgb)
live_rgb_plot = np.copy(live_rgb)

plot_mask_edge(demo_seg[:, :, 0], demo_rgb_plot)
plot_mask_edge(live_seg_2[:, :, 0], live_rgb_plot)

ax[0].imshow(demo_rgb_plot)
ax[1].imshow(live_rgb_plot)
#ax[1].imshow(live_seg_2)
plt.show()

## 3.2 Different Perspective

Perspective `view_2` recording `0` seems to have poor depth values. Use recording `1` instead.

In [None]:
root2_dir = Path("/home/argusm/CLUSTER/robot_recordings/hand_recordings/view_2/")
live_index = 1

num_runs2 = len(list(root2_dir.iterdir()))
loaders2 = [SimpleLoader(root2_dir, run=r) for r in range(num_runs2)]
demo_seg = object_segmentation.copy()
demo_seg = demo_seg[:, :, np.newaxis]

live_rgb = loaders2[live_index].get_image(0)
live_seg = flow_module.warp_mask(demo_seg, demo_rgb, live_rgb)

In [None]:
live_seg_2 = dilation(erosion(live_seg))

In [None]:
fig, ax = plt.subplots(1,2, figsize=(12, 6))
[x.set_axis_off() for x in ax]
demo_rgb_plot = np.copy(demo_rgb)
live_rgb_plot = np.copy(live_rgb)

plot_mask_edge(demo_seg[:, :, 0], demo_rgb_plot)
plot_mask_edge(live_seg_2[:, :, 0], live_rgb_plot)

ax[0].imshow(demo_rgb_plot)
ax[1].imshow(live_rgb_plot)
#ax[1].imshow(live_seg_2)
plt.show()

## 3.3 Clean-Up Segmentation (incomplete)

1. Use Pointcloud distances (nearest neighbors to the warped mask)
2. Use Segment-Anything to clean up segmentation.

In [None]:
import open3d as o3d

def get_depth(demo_dir, run, frame_index) -> np.ndarray:
        """returns the image for a given frame
        Returns:
            depth: numpy array (w,h) in range (0, ~12m)
        """
        depth_path = demo_dir / "{0}/images.np.npz".format(run)
        depths = np.load(depth_path)["depths"]
        return depths[frame_index]

rgb = o3d.geometry.Image(live_rgb)
live_depth = get_depth(root2_dir, 0,0)
depth = o3d.geometry.Image(live_depth)
rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(rgb, depth,
                                      depth_scale=1.0, depth_trunc=1.0,
                                      convert_rgb_to_intensity=False)

width = np.asarray(depth).shape[1]
height = np.asarray(depth).shape[0]
fx=700.819
fy=700.819
cx=665.465
cy=371.953
K_o3d = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy)
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, K_o3d)

In [None]:
fig, ax = plt.subplots(1,2, figsize=(12, 6))
[x.set_axis_off() for x in ax]

demo_index = 1
demo_frame = 0
live_rgb = loaders2[demo_index].get_image(demo_frame)
live_depth = get_depth(root2_dir, demo_index, demo_frame)
ax[0].imshow(live_rgb)
ax[1].imshow(live_depth)
plt.show()

In [None]:
o3d.visualization.draw_geometries([pcd])

# 4. Generate Grasps

Send the segmented pointcloud from the live view to a grasp generation system.

## 4.1 Filter Grasps

Find a grasp close to the hand pose, in case we generate too many candidates.