This notebook guides you on how to use inbuilt functions using simple examples. 

### Installation instructions
Use Python 3.5+
evo, autograd and all that

### 0. Basics

In [2]:
import jax.numpy as jnp
"""
Basically, use jnp instead of using "np", our favourite numpy library. 
All functions work as it is (at least that are required for this project).
Be careful though:
JAX works on python functions that are "functionally pure": 
For the sake of our project, that just means using array datatype everywhere 
(or 'jnp.array()' in particular) instead of using other datatype, say lists for
storing arrays or matrices. Whenever you face some datatype issue with jax, 
first try to convert it to jax numpy array using `jnp.array()`.

In my experience, jnp's errors didn't seem very readable as compared to np.
So use "np" first for most of the code and after everything looks okay, 
and the only remaining thing left is Jacobian calculation checking, 
replace all np's with jnp's.

"""
from jax import jacfwd

## 1. Understanding Jacobian from `jax` library

In [3]:
# Define some simple function.
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)
# Note that here, I want a derivative of a "vector" output function (inputs*a + b is a vector) wrt a input 
# "vector" a at a0: Derivative of vector wrt another vector is a matrix: The Jacobian
def simpleJ(a, b, inputs): #inputs is a matrix, a & b are vectors
    return sigmoid(jnp.dot(inputs, a) + b)

inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])

b = jnp.array([0.2, 0.1, 0.3, 0.2])
a0 = jnp.array([0.1,0.7,0.7])

# Isolate the function: variables to be differentiated from the constant parameters
f = lambda a: simpleJ(a, b, inputs) # Now f is just a function of variable to be differentiated

J = jacfwd(f)
# Till now I have only calculated the derivative, it still needs to be evaluated at a0.
J(a0)



DeviceArray([[ 0.07388726,  0.1591418 ,  0.10940998],
             [ 0.20861849, -0.2560318 ,  0.03555997],
             [ 0.12171669,  0.01404423, -0.3042917 ],
             [ 0.17407253, -0.58573055,  0.3269741 ]], dtype=float32)

## 2. Using EVO

## 3. Sample Example for running g2o binary

After you build ... add it to bashrc..

## 4. Misc Helper Functions

In [None]:
def draw(X, Y, THETA):
    ax = plt.subplot(111)
    ax.plot(X, Y, 'ro')
    plt.plot(X, Y, 'c-')

    for i in range(len(THETA)):
        x2 = 0.25*math.cos(THETA[i]) + X[i]
        y2 = 0.25*math.sin(THETA[i]) + Y[i]
        plt.plot([X[i], x2], [Y[i], y2], 'g->')

    plt.show()