In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
tf.enable_eager_execution()

import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
import skimage

from IPython import display
%matplotlib inline

tfk = tf.keras
tfkl = tf.keras.layers

# Load Data

In [None]:
from softlearning.models.state_estimation import (
    get_dumped_pkl_data
)
# images_path = '/home/justinvyu/dev/softlearning-vice/goal_classifier/free_screw_state_estimator_data_invisible_claw/more_data.pkl'
# images_path = '/root/nfs/kun1/users/justinvyu/data/fixed_data.pkl'
# images_path = '/root/nfs/kun1/users/justinvyu/data/invisible_claw_antialiased_data.pkl'
images_path = '/root/nfs/kun1/users/justinvyu/data/fixed_data_with_states.pkl'

images, states = get_dumped_pkl_data(images_path)

In [None]:
images.shape, states.shape

In [None]:
with gzip.open('/root/nfs/kun1/users/justinvyu/data/fixed_data.pkl', 'rb') as f:
    fixed_data = pickle.load(f)

In [None]:
fixed_data_with_states = {
    'pixels': fixed_data,
    'states': states
}
with gzip.open('/root/nfs/kun1/users/justinvyu/data/fixed_data_with_states.pkl', 'wb') as f:
    pickle.dump(fixed_data_with_states, f)

In [None]:
images, states = fixed_data_with_states['pixels'], fixed_data_with_states['states']

# Train state estimator on top of the VAE outputs

In [None]:
from softlearning.preprocessors.utils import get_vae_preprocessor
# encoder_weights_fn = '/root/softlearning/softlearning/models/vae_weights/invisible_claw_encoder_weights_4_final.h5'
# decoder_weights_fn = '/root/softlearning/softlearning/models/vae_weights/invisible_claw_decoder_weights_4_final.h5'

encoder_weights_fn = '/root/softlearning/softlearning/models/vae_weights/invisible_claw_encoder_weights.h5'
decoder_weights_fn = '/root/softlearning/softlearning/models/vae_weights/invisible_claw_decoder_weights.h5'

vae_preprocessor_params = {
    'type': 'VAEPreprocessor',
    'kwargs': {
        'encoder_path': encoder_weights_fn,
        'decoder_path': decoder_weights_fn,
        'trainable': True,
        'image_shape': (32, 32, 3),
        'latent_dim': 16,
        'include_decoder': False,
    }
}

from softlearning.models.state_estimation import state_estimator_model
# state_estimator = state_estimator_model(
#     input_shape=(32, 32, 3),
#     preprocessor_params=vae_preprocessor_params)
state_estimator = state_estimator_model(
    input_shape=(32, 32, 3))

In [None]:
state_estimator.summary()

In [None]:
from softlearning.preprocessors.utils import get_vae_preprocessor
vae = get_vae_preprocessor(**vae_preprocessor_params['kwargs'])

In [None]:
vae.encoder.summary(), vae.decoder.summary()

In [None]:
split_data = np.split(images, 50)

In [None]:
reconstructions = []
for batch in split_data:
    reconstructions.append(vae(batch))

In [None]:
reconstruct = np.concatenate(reconstructions, axis=0)

In [None]:
import skimage
reconstruct_int = skimage.util.img_as_ubyte(reconstruct)

In [None]:
dump = {
    'pixels': reconstruct_int,
    'states': states
}
import gzip
import pickle
with gzip.open('/root/nfs/kun1/users/justinvyu/data/reconstructions_with_state', 'wb') as f:
    pickle.dump(dump, f)

In [None]:
# Compile
state_estimator.compile(optimizer='adam', loss='mse')

In [None]:
N_EPOCHS = 15

history = state_estimator.fit(
    x=images,
    y=states,
    batch_size=256,
    epochs=N_EPOCHS,
    validation_split=0.1
)

# Show estimation errors

In [None]:
# Get samples to calculate metrics on
random_indices = np.random.choice(images.shape[0], size=10000)
test_images = images[random_indices]
test_labels = states[random_indices]
preds = state_estimator.predict(test_images)

In [None]:
pos_errors = []
angle_errors = []

degrees = lambda x: x * 180 / np.pi
def angle_distance(deg1, deg2):
    phi = np.abs(deg1 - deg2) % 360
    distance = 360 - phi if phi > 180 else phi
    return distance

for i, (test_img, label, pred) in enumerate(zip(test_images, test_labels, preds)):
    pos_error_xy = np.abs(label[:2] - pred[:2])
    pos_error = np.linalg.norm(pos_error_xy)
    pos_error = 15 * pos_error # free box is 30 cm, 15 on each side (-1 -> 1 --> -15 -> 15)
    
    true_angle = np.arctan2(label[3], label[2])
    true_angle = degrees(true_angle)
    pred_angle = np.arctan2(pred[3], pred[2])
    pred_angle = degrees(pred_angle)
    
    angle_error = angle_distance(true_angle, pred_angle)

    pos_errors.append(pos_error)
    angle_errors.append(angle_error)

mean_pos_error = np.mean(pos_errors)
mean_angle_error = np.mean(angle_errors)
print('MEAN POS ERROR (CM):', mean_pos_error)
print('MEAN ANGLE ERROR (degrees):', mean_angle_error)

In [None]:
def display_top_errors(errors, label_str=""):
    errors = np.array(errors)
    ind = np.argpartition(errors, -20)[-20:]
    ind = ind[np.argsort(errors[ind])]
    ind = np.flip(ind) # Order descending
    print(ind)
    top_errors = errors[ind]
    
    top_error_imgs, top_error_labels, top_error_preds = test_images[ind], test_labels[ind], preds[ind]
    for i, (error, img, label, pred) in enumerate(zip(top_errors,
                                                      top_error_imgs,
                                                      top_error_labels,
                                                      top_error_preds)):
        print('\n========== IMAGE #', i, '=========')
        plt.axis('off')
        plt.imshow(img)
        print('{} ERROR: {}\n\ntrue: {}\npred: {}'.format(label_str, error, label, pred))
        plt.show()

In [None]:
display_top_errors(pos_errors, label_str="POS (cm)")
display_top_errors(angle_errors, label_str="ANGLE (degrees)")

In [None]:
def plot_histograms(pos_errors, angle_errors):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title('Position Errors (cm)')
    plt.hist(pos_errors, bins=30)
    plt.subplot(1, 2, 2)
    plt.title('Angle Errors (deg)')
    plt.hist(angle_errors, bins=30)
    plt.show()
    
plot_histograms(pos_errors, angle_errors)

In [None]:
def plot_pos_support():
    plt.figure(figsize=(5, 5))
    plt.scatter(test_labels[:, 0], test_labels[:, 1], alpha=0.1, s=5)
    plt.show()
    
plot_pos_support()

def plot_angle_support():
    plt.figure(figsize=(5,5))
    angles = np.arctan2(test_labels[:, 3], test_labels[:, 2])
    plt.hist(angles, bins=50)
    plt.show()
    
plot_angle_support()

In [None]:
def get_noise(size, loc=0, scale=0.02):
    return np.random.normal(loc=loc, scale=scale, size=size)

labels_x, labels_y = test_labels[:, 0], test_labels[:, 1]

preds_x, preds_y = preds[:, 0], preds[:, 1]
dxs, dys = preds_x - labels_x, preds_y - labels_y
plt.figure(figsize=(20, 20))
plt.title('State estimator errors (xy)')

plt.scatter(labels_x, labels_y, c='blue', s=2, label='labels (blue)')
plt.scatter(preds_x, preds_y, c='green', s=2, label='preds (green)')
plt.legend()
plt.quiver(labels_x, labels_y, dxs, dys, angles='xy', scale_units='xy', scale=1, width=0.001, alpha=0.6)

plt.figure(figsize=(10, 10))
plt.title('Position errors')
plt.xlabel('label')
plt.ylabel('predicted')

plt.scatter(labels_x, preds_x, s=0.5, alpha=0.2, label='x')
plt.scatter(labels_y, preds_y, s=0.5, alpha=0.2, label='y')
plt.legend()

In [None]:
labels_z_cos, labels_z_sin = test_labels[:, 2], test_labels[:, 3]
preds_z_cos, preds_z_sin = preds[:, 2], preds[:, 3]
dzs_cos, dzs_sin = preds_z_cos - labels_z_cos, preds_z_sin - labels_z_sin

labels_angle, preds_angle = (
    np.arctan2(labels_z_sin, labels_z_cos),
    np.arctan2(preds_z_sin, preds_z_cos)
)
plt.figure(figsize=(10, 10))
plt.title('Angle errors')
plt.xlabel('label angle (radians)')
plt.ylabel('predicted angle (radians)')
plt.scatter(labels_angle, preds_angle, s=0.4, alpha=0.25)

In [None]:
vae = state_estimator.get_layer('vae_preprocessor')
vae(test_images[0][None])