In [51]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import jax.numpy as jnp

In [8]:
np.__version__

'1.26.4'

# Introduction to Numpy

NumPy is the fundamental package for scientific computing in Python. It is a Python library that provides a multidimensional array object, various derived objects (such as masked arrays and matrices), and an assortment of routines for fast operations on arrays, including mathematical, logical, shape manipulation, sorting, selecting, I/O, discrete Fourier transforms, basic linear algebra, basic statistical operations, random simulation and much more.

## Ndarrays: the bread and butter of Numpy

A numpy ndarray is an 'N-dimensional Array'. Arrays are data structures for fast retrieval of data. Numpy's ndarrays are similar. 1-d array can be viewed as a single vector, while a 2d array can be viewed as a table. A 3d array can be seem as a stack of tables.

In [13]:
# CREATING 1DARRAYS
PyL = [1,2,30,400,-5]           # This is a normal python list
NpA1 = np.array(PyL)            # This converts a normal list to a ndarray
NpA2 = np.arange(0,10,2)        # This creates a sequence representing an AP
NpA3 = np.linspace(0,10,5)
NpAz = np.zeros(10)             # This creates an array of zeros
NpAo = np.ones(10)              # This creates an array of ones
NpAe = np.empty(10)             # This creates an array of empty slots

In [20]:
NpAe

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [21]:
# CREATING NDARRAYS
Py2dL = [[1,2,3],[4,5,6]]
NpA1 = np.array(Py2dL)
NpA2 = np.arange(0,12,2).reshape(2,3)
NpA3 = np.zeros_like(NpA2)
NpA3

array([[0, 0, 0],
       [0, 0, 0]])

In [23]:
NpA2

array([[ 0,  2,  4],
       [ 6,  8, 10]])

# Basic Operations
Element wise, this is where it differs from regular lists.

In [29]:
a = np.arange(0,12,2)
b = np.arange(24,6,-3)
print(3*a)
print(2*b)
print(3*a + 2*b)

[ 0  6 12 18 24 30]
[48 42 36 30 24 18]
[48 48 48 48 48 48]


In [31]:
a!= 3

array([ True,  True,  True,  True,  True,  True])

In [35]:
A = np.arange(0,12).reshape(4,3)
B = np.arange(45,33,-1).reshape(3, 4)
print(A)
A += B.T
print(B.T)
print(A)
print(A.shape)

[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]
[[45 41 37]
 [44 40 36]
 [43 39 35]
 [42 38 34]]
[[45 42 39]
 [47 44 41]
 [49 46 43]
 [51 48 45]]
(4, 3)


In [36]:
A.max()

51

In [38]:
A.cumsum(axis=1)

array([[ 45,  87, 126],
       [ 47,  91, 132],
       [ 49,  95, 138],
       [ 51,  99, 144]])

# Universal Operators

In [39]:
np.sin(A)

array([[ 0.85090352, -0.91652155,  0.96379539],
       [ 0.12357312,  0.01770193, -0.15862267],
       [-0.95375265,  0.90178835, -0.83177474],
       [ 0.67022918, -0.76825466,  0.85090352]])

In [40]:
np.where(A < np.mean(A), A, np.zeros_like(A))

array([[ 0, 42, 39],
       [ 0, 44, 41],
       [ 0,  0, 43],
       [ 0,  0,  0]])

In [41]:
Bool_A = np.array([
    [True, False, True],
    [True, False, True],
    [True, False, True],
    [True, False, True],
])

In [42]:
np.where(Bool_A, A, np.zeros_like(A))

array([[45,  0, 39],
       [47,  0, 41],
       [49,  0, 43],
       [51,  0, 45]])

In [43]:
A[1:3, 1:3] = 0

In [44]:
A

array([[45, 42, 39],
       [47,  0,  0],
       [49,  0,  0],
       [51, 48, 45]])

In [45]:
for x in A.flat:
    print(x)

45
42
39
47
0
0
49
0
0
51
48
45


In [48]:
np.all(A > 0)

False

# Better Version of Numpy: JAX

In [52]:
# JAX is a GPU version of Numpy
from jax import random as jrand
from tqdm import tqdm
import numpy as np

jkey = jrand.PRNGKey(1701)
jx = jrand.normal(jkey, (1_000_000,))
nx = np.random.normal((1_000_000,))

def JP(x, coeffs):
    return jnp.sum(jnp.array([c * (x ** k) for k, c in enumerate(coeffs, start=1)]))

def NP(x, coeffs):
    return np.sum([c * (x ** k) for k, c in enumerate(coeffs, start=1)])


In [53]:
from jax import jit

def f(x):
    return -4*x**3 + 9*x**2 + 6*x -3

x = np.random.randn(20000, 20000)

func = jit(f)

y = jnp.array(x)
_ = func(y)

%timeit f(x)
%timeit func(y)

31.1 s ± 121 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
6.32 ms ± 856 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Gradients using JAX

In [54]:
def sigmoid(x):
    return jnp.sum(1 / (1 + jnp.exp(-x)))

from jax import grad

derivative_fn = grad(sigmoid)
x = jnp.arange(3).astype(jnp.float32)

In [55]:
derivative_fn(x)

Array([0.25      , 0.19661194, 0.10499357], dtype=float32)

In [58]:
h = 10e-3
x_plus_h = x[:]
x_plus_h = x_plus_h.at[0].set(x[0] + h)
x_plus_h

Array([0.01, 1.  , 2.  ], dtype=float32)

In [60]:
(sigmoid(x_plus_h) - sigmoid(x))/h

Array(0.24998188, dtype=float32)

In [61]:
jnp.eye(len(x))

Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

In [62]:
def Sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

from jax import jacobian

Sigmoid_Jacobian = jacobian(Sigmoid)
Sigmoid_Jacobian(x)

Array([[0.25      , 0.        , 0.        ],
       [0.        , 0.19661194, 0.        ],
       [0.        , 0.        , 0.10499357]], dtype=float32)