# Interactive plot to study the influence of individual principal components on leaf shape

In [37]:
import visualise
import leaf_encoding
%matplotlib notebook
from ipywidgets import *
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [50]:
# Get data and PCA encoding
train_ds, test_ds, train_labels, test_labels, pca, transformed = leaf_encoding.get_encoding(train_split=0, location=True, rotation=False, scale=False, as_features=True)
ranges = leaf_encoding.new_random_leaf_from_distribution(train_ds, train_labels, pca, draw=False)

In [51]:
fig, axs = plt.subplots(1, 2)

starting_weights = leaf_encoding.compress(train_ds[0], pca.components_[:50], pca)
outlines = leaf_encoding.reshape_coordinates_and_additional_features(train_ds)[0]

outline1 = axs[0].scatter(outlines[0][:,0],outlines[0][:,1], s=1)
outline2 = axs[1].scatter(outlines[0][:,0],outlines[0][:,2], s=1)
plt.ylim([-1, 1])

def update(comp1, comp2, comp3):
    updated_weights = starting_weights[:]
    updated_weights[0] = comp1
    updated_weights[1] = comp2
    updated_weights[2] = comp3
    reprojection = leaf_encoding.decompress(updated_weights, pca.components_[:50], pca)
    unrolled = leaf_encoding.reshape_coordinates_and_additional_features(reprojection)[0]
    
    outline1.set_offsets(unrolled[:,:2])
    outline1.set_array(unrolled[:,2])
    outline2.set_offsets(unrolled[:,1:])
    fig.canvas.draw_idle()

interact(update, \
         comp1 = widgets.FloatSlider(value=starting_weights[0], min=ranges[0][0], max=ranges[1][0]), \
         comp2 = widgets.FloatSlider(value=starting_weights[1], min=ranges[0][1], max=ranges[1][1]), \
         comp3 = widgets.FloatSlider(value=starting_weights[2], min=ranges[0][2], max=ranges[1][2]));




<IPython.core.display.Javascript object>

interactive(children=(FloatSlider(value=-34.65872566319614, description='comp1', max=140.36952313353908, min=-…

In [53]:
def set_axes_equal(ax):
    '''Make axes of 3D plot have equal scale so that spheres appear as spheres,
    cubes as cubes, etc..  This is one possible solution to Matplotlib's
    ax.set_aspect('equal') and ax.axis('equal') not working for 3D.

    Input
      ax: a matplotlib axis, e.g., as output from plt.gca().
    '''

    x_limits = ax.get_xlim3d()
    y_limits = ax.get_ylim3d()
    z_limits = ax.get_zlim3d()

    x_range = abs(x_limits[1] - x_limits[0])
    x_middle = np.mean(x_limits)
    y_range = abs(y_limits[1] - y_limits[0])
    y_middle = np.mean(y_limits)
    z_range = abs(z_limits[1] - z_limits[0])
    z_middle = np.mean(z_limits)

    # The plot bounding box is a sphere in the sense of the infinity
    # norm, hence I call half the max range the plot radius.
    plot_radius = 0.5*max([x_range, y_range, z_range])

    ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
    ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
    ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])

In [58]:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')


starting_weights = leaf_encoding.compress(train_ds[0], pca.components_[:50], pca)
outlines, add_f = leaf_encoding.reshape_coordinates_and_additional_features(train_ds)

scatterplot = ax.scatter(xs=outlines[0][:,0], ys=outlines[0][:,1], zs=outlines[0][:,2], s=1)
set_axes_equal(ax)

def update(comp1, comp2, comp3, comp4, comp5, comp6, comp7, comp8, comp9, comp10):
    updated_weights = starting_weights[:]
    updated_weights[0] = comp1
    updated_weights[1] = comp2
    updated_weights[2] = comp3
    updated_weights[3] = comp4
    updated_weights[4] = comp5
    updated_weights[5] = comp6
    updated_weights[6] = comp7
    updated_weights[7] = comp8
    updated_weights[8] = comp9
    updated_weights[9] = comp10
    
    reprojection = leaf_encoding.decompress(updated_weights, pca.components_[:50], pca)
    unrolled, add_f = leaf_encoding.reshape_coordinates_and_additional_features(reprojection)
    scatterplot._offsets3d = (unrolled[:,0],unrolled[:,1],unrolled[:,2])
    fig.canvas.draw_idle()

interact(update, \
         comp1 = widgets.FloatSlider(value=starting_weights[0], min=ranges[0][0], max=ranges[1][0]), \
         comp2 = widgets.FloatSlider(value=starting_weights[1], min=ranges[0][1], max=ranges[1][1]), \
         comp3 = widgets.FloatSlider(value=starting_weights[2], min=ranges[0][2], max=ranges[1][2]), \
         comp4 = widgets.FloatSlider(value=starting_weights[3], min=ranges[0][3], max=ranges[1][3]), \
         comp5 = widgets.FloatSlider(value=starting_weights[4], min=ranges[0][4], max=ranges[1][4]), \
         comp6 = widgets.FloatSlider(value=starting_weights[5], min=ranges[0][5], max=ranges[1][5]), \
         comp7 = widgets.FloatSlider(value=starting_weights[6], min=ranges[0][6], max=ranges[1][6]), \
         comp8 = widgets.FloatSlider(value=starting_weights[7], min=ranges[0][7], max=ranges[1][7]), \
         comp9 = widgets.FloatSlider(value=starting_weights[8], min=ranges[0][8], max=ranges[1][8]), \
         comp10 = widgets.FloatSlider(value=starting_weights[9], min=ranges[0][9], max=ranges[1][9]), \
        )



<IPython.core.display.Javascript object>

interactive(children=(FloatSlider(value=-34.65872566319614, description='comp1', max=140.36952313353908, min=-…

<function __main__.update(comp1, comp2, comp3, comp4, comp5, comp6, comp7, comp8, comp9, comp10)>