# Redes Neuronales con JAX

* En este notebook vamos a crear una red neuronal multicapa con JAX.

In [None]:
from jax import numpy as jnp
from jax import grad
from jax import jacobian
from jax import random
from jax import vmap
from jax import value_and_grad
from jax import jit
import sys

## Implementando una red neuronal tipo básica

* Vamos a tratar de implementar la red directamente sin ningún tipo de abstracción.
* El objetivo es implementar las funciones necesarias para poder calcularles el gradiente y después de eso hacer la actualización de pesos.

In [None]:
def predict(W,x):
  # Layer 1
  W1 = W['1']
  z1 = jnp.dot(x, jnp.transpose(W1))
  y1 = jnp.tanh(z1)

  # Output layer
  W2 = W['2']
  z2 = jnp.dot(y1, jnp.transpose(W2))
  y2 = jnp.tanh(z2)
  return y2

def loss(W, x, y):
  # Foward pass
  yp = predict(W, x)

  # Loss
  loss = jnp.sum(jnp.pow(yp - y, 2))/y.shape[0]

  return loss

In [None]:
# Definir parametros
x = jnp.array([[1.0, 3.5],[-0.5, 0.6], [-0.5, 1.3]])
#y = jnp.array([[1],[0],[0]])
y = jnp.array([[1],[-1],[-1]])
n_neurons = 4
n_features = x.shape[1]
key = random.key(10)

key, key_w1, key_w2 = random.split(key, 3)
W1 = random.normal(key_w1, (n_neurons, n_features))
W2 = random.normal(key_w2, (1, n_neurons))
W = {'1': W1, '2': W2}

initial_loss = loss(W, x,y)
print("Loss inicial: ", initial_loss)

yp = predict(W,x)
print("Predicciones iniciales: ", yp)
#x

Loss inicial:  2.505769
Predicciones iniciales:  [[0.9933537 ]
 [0.91059923]
 [0.96643686]]


In [None]:
# Calculamos la loss y el gradiente al mismo tiempo
val_grad_loss = value_and_grad(loss, 0)

In [None]:
# aplicamos un ciclo de entrenamiento
epochs = 30
learning_rate = 0.1
for i in range(epochs):
  print(f"Epoca {i}")
  # Calcular gradiente
  loss_val, grad_w = val_grad_loss(W, x, y)
  print("Loss: ", loss_val)

  # Actualizar pesos
  W['1'] = W['1'] - learning_rate * grad_w['1']
  W['2'] = W['2'] - learning_rate * grad_w['2']

Epoca 0
Loss:  2.505769
Epoca 1
Loss:  2.3822467
Epoca 2
Loss:  1.9204383
Epoca 3
Loss:  1.3888327
Epoca 4
Loss:  0.34902894
Epoca 5
Loss:  0.14577721
Epoca 6
Loss:  0.100438915
Epoca 7
Loss:  0.08674523
Epoca 8
Loss:  0.035762027
Epoca 9
Loss:  0.029983804
Epoca 10
Loss:  0.026187697
Epoca 11
Loss:  0.023223206
Epoca 12
Loss:  0.020836145
Epoca 13
Loss:  0.018873692
Epoca 14
Loss:  0.017232819
Epoca 15
Loss:  0.015841307
Epoca 16
Loss:  0.014647047
Epoca 17
Loss:  0.0136114685
Epoca 18
Loss:  0.012705368
Epoca 19
Loss:  0.011906266
Epoca 20
Loss:  0.011196614
Epoca 21
Loss:  0.0105624525
Epoca 22
Loss:  0.009992548
Epoca 23
Loss:  0.009477801
Epoca 24
Loss:  0.009010741
Epoca 25
Loss:  0.0085851485
Epoca 26
Loss:  0.008195867
Epoca 27
Loss:  0.007838518
Epoca 28
Loss:  0.0075094122
Epoca 29
Loss:  0.007205428


In [None]:
# Calculamos las predicciones
yp = predict(W, x)
print("Predicciones: ", yp)

Predicciones:  [[ 0.93151844]
 [-0.9052495 ]
 [-0.9157144 ]]


## Implementar red neuronal tipo Scikit-learn

* Vamos a tratar de abstraer la construcción de una red neuronal para que podamos definir tamaños arbitrarios.
* La red se va a definir de manera similar a como se realiza en Scikit-learn.

In [None]:
# vamos a encapsular toda la lógica en una clase de Python.
class MLPClassifier():
  def __init__(self, units, epochs=10, lr=0.01, binary=True, n_classes=2,seed=0):
    self.key = random.key(seed)
    self.epochs = epochs
    self.lr = lr
    self.W = dict()
    if binary:
      self.units = units + [1]
    else:
      self.units = units + [n_classes]

  def forward(self, W, X):
    n_layers = len(self.units)
    output = X
    for i in range(n_layers):
      #W = W[i]
      z = jnp.dot(output, jnp.transpose(W[i]))
      output = jnp.tanh(z)
    return output

  def predict(self, X):
    n_layers = len(self.units)
    output = X
    for i in range(n_layers):
      #W = W[i]
      z = jnp.dot(output, jnp.transpose(self.W[i]))
      output = jnp.tanh(z)
    return output

  def loss(self, W, X, y):
    yp = self.forward(W, X)
    l = jnp.sum(jnp.pow(yp - y, 2))/y.shape[0]
    return l


  def fit(self, X, y):
    n_features = X.shape[1]
    n_samples = X.shape[0]
    n_layers = len(self.units) # +1 para tomar en cuenta la capa de salida

    # Creamos las matrices de pesos.
    keys = random.split(self.key, n_layers)
    W = dict()
    for i in range(n_layers):
      n_units = self.units[i]
      W[i] = random.normal(keys[i], (n_units, n_features))
      n_features = self.units[i]

    # calculamos el gradiente
    val_grad_loss = jit(value_and_grad(self.loss, 0))

    # Entrenamos la red
    loss_val, grad_val = val_grad_loss(W, X, y)
    print(f"Initial loss: {loss_val}")
    for ep in range(self.epochs):
      # Calcular gradiente
      loss_val, grad_w = val_grad_loss(W, X, y)
      print(f"Epoch: {ep} - Loss: {loss_val}")

      # Actualizar pesos
      for i in range(n_layers):
        W[i] = W[i] - self.lr*grad_w[i]
    # Guardamos pesos
    self.W = W


In [None]:
# Preparamos datos de entrada
X = jnp.array([[1.0, 3.5],[-0.5, 0.6], [-0.5, 1.3]])
#y = jnp.array([[1],[0],[0]])
y = jnp.array([[1],[0],[0]])


In [None]:
# Creamos y entrenamos modelo
model = MLPClassifier([16,16,8], lr=0.1, epochs=30)
model.fit(X,y)

Initial loss: 2.2084829807281494
Epoch: 0 - Loss: 2.2084829807281494
Epoch: 1 - Loss: 0.40662333369255066
Epoch: 2 - Loss: 0.15469112992286682
Epoch: 3 - Loss: 0.0693286806344986
Epoch: 4 - Loss: 0.04397426173090935
Epoch: 5 - Loss: 0.0320630706846714
Epoch: 6 - Loss: 0.025168560445308685
Epoch: 7 - Loss: 0.020678821951150894
Epoch: 8 - Loss: 0.017525091767311096
Epoch: 9 - Loss: 0.015190052799880505
Epoch: 10 - Loss: 0.013392775319516659
Epoch: 11 - Loss: 0.011967720463871956
Epoch: 12 - Loss: 0.010810835286974907
Epoch: 13 - Loss: 0.009853530675172806
Epoch: 14 - Loss: 0.009048725478351116
Epoch: 15 - Loss: 0.008362996391952038
Epoch: 16 - Loss: 0.0077720386907458305
Epoch: 17 - Loss: 0.00725766085088253
Epoch: 18 - Loss: 0.006806101184338331
Epoch: 19 - Loss: 0.0064066024497151375
Epoch: 20 - Loss: 0.0060507929883897305
Epoch: 21 - Loss: 0.005731978453695774
Epoch: 22 - Loss: 0.005444720387458801
Epoch: 23 - Loss: 0.005184633191674948
Epoch: 24 - Loss: 0.004948069341480732
Epoch: 25

In [None]:
yp = model.predict(X)
yp

Array([[ 0.9507567 ],
       [-0.9378166 ],
       [-0.92679554]], dtype=float32)

In [None]:
sys.getsizeof(model.W)
for i in range(len(model.W)):
  #print(f"Size {i}th layer: {model.W[i].shape}")
  print(f"Size {i}th layer: {sys.getsizeof(model.W[i])}")


Size 0th layer: 720
Size 1th layer: 720
Size 2th layer: 720
Size 3th layer: 720


## Siguientes pasos

* Agregarle bias.
* Agregarle otras funciones de activación
* Agregarle otras funciones de pérdida
* Probar para un problema multiclase real (por ejemplo, mnist)
* Comparar tiempo de entrenamiento vs scikit-learn