# Profundizando en redes neuronales con Jax

* En este notebook vamos a enfocarnos en hacer una clase más completa que permita entrenar una red neuronal de tamaño más arbitrario.
* Seguiremos utilizando el enfoque de hacerlo como en Scikit-learn.

In [None]:
# Importamos libre
from jax import numpy as jnp
from jax import grad
from jax import jacobian
from jax import random
from jax import vmap
from jax import value_and_grad
from jax import jit
from jax import nn
from jax import lax
import sys
import numpy as np

In [None]:
class JaxMLPClassifier():
  def __init__(self, units, epochs=10, lr=0.01, binary=True, n_classes=2, activation='relu', seed=0, batch_size=-1):
    self.key = random.key(seed)
    self.epochs = epochs
    self.lr = lr
    self.W = dict()
    self.batch_size = batch_size
    if activation == 'relu':
      self.activation = nn.relu
    elif activation == 'sigmoid':
      self.activation = nn.sigmoid
    else:
      self.activation = nn.tanh

    if binary:
      self.units = units + [1]
    else:
      self.units = units + [n_classes]

  def forward(self, W, X):
    n_layers = len(self.units)
    output = X
    for i in range(n_layers-1):
      #W = W[i]
      z = jnp.dot(output, jnp.transpose(W[i]))
      output = self.activation(z)
    z = jnp.dot(output, jnp.transpose(W[n_layers-1]))
    output = nn.sigmoid(z)
    return output

  def predict(self, X):
    n_layers = len(self.units)
    output = X
    for i in range(n_layers-1):
      #W = W[i]
      z = jnp.dot(output, jnp.transpose(self.W[i]))
      output = self.activation(z)
    z = jnp.dot(output, jnp.transpose(self.W[n_layers-1]))
    output = nn.sigmoid(z)
    return output

  def loss(self, W, X, y):
    yp = self.forward(W, X)
    #print("yp: ", jnp.min(yp))
    #print("y: ", y)
    #l = jnp.sum(jnp.pow(yp - y, 2))/y.shape[0]
    l = jnp.log(yp) * y + jnp.log(1-yp) * (1 - y) # entropia cruzada
    #print("min losses:", jnp.max(l))
    #print("max losses:", jnp.min(l))
    ls = -jnp.sum(l)/y.shape[0]
    return ls


  def fit(self, X, y):
    n_features = X.shape[1]
    n_samples = X.shape[0]
    n_layers = len(self.units) # +1 para tomar en cuenta la capa de salida

    # Creamos las matrices de pesos.
    keys = random.split(self.key, n_layers)
    W = dict()
    for i in range(n_layers):
      n_units = self.units[i]
      W[i] = random.normal(keys[i], (n_units, n_features))
      n_features = self.units[i]

    # calculamos el gradiente y calculamos en cada batch
    #val_grad_loss = jit(value_and_grad(self.loss, 0))
    val_grad_loss = value_and_grad(self.loss, 0)

    # Entrenamos la red
    X_p = lax.slice(X, (0,0), (self.batch_size,X.shape[1]))
    y_p = lax.slice(y, (0,), (self.batch_size,))
    loss_val, grad_val = val_grad_loss(W, X_p, y_p)
    print(f"Initial loss: {loss_val}")
    for ep in range(self.epochs):
      if self.batch_size == -1:
        # Calcular gradiente
        loss_val, grad_w = val_grad_loss(W, X, y)
        print(f"Epoch: {ep} - Loss: {loss_val}")
        #print("Grad: ", grad_w)

        # Actualizar pesos
        for i in range(n_layers):
          W[i] = W[i] - self.lr*grad_w[i]
      else:
        n_batches = int(jnp.floor(y.shape[0]/self.batch_size))
        #print(n_batches)
        avg_loss = 0
        for n in range(n_batches-1):
          i0 = n*self.batch_size
          i1 = (n+1)*self.batch_size
          #print(i0)
          #print(i1)
          X_p = lax.slice(X, (i0,0), (i1, X.shape[1]))
          y_p = lax.slice(y, (i0,), (i1,))
          loss_val, grad_w = val_grad_loss(W, X_p, y_p)
          avg_loss += loss_val
          for i in range(n_layers):
            W[i] = W[i] - self.lr*grad_w[i]
        print(f"Epoch: {ep} - Loss: {avg_loss/n_batches}")
    # Guardamos pesos
    self.W = W

In [None]:
# Probamos un entrenamiento
X = jnp.array([[1.0, 3.5],[-0.5, 0.6], [-0.5, 1.3]])
#y = jnp.array([[1],[0],[0]])
y = jnp.array([[1.0],[0.0],[0.0]])


In [None]:
# Creamos y entrenamos modelo
model = JaxMLPClassifier([8,8], lr=0.01, epochs=1, activation='relu')
model.fit(X,y)

Initial loss: 2.5114800930023193
Epoch: 0 - Loss: 2.5114800930023193
Epoch: 1 - Loss: 2.1464388370513916
Epoch: 2 - Loss: 1.8001383543014526
Epoch: 3 - Loss: 1.4716453552246094
Epoch: 4 - Loss: 1.162377119064331
Epoch: 5 - Loss: 0.8781678676605225
Epoch: 6 - Loss: 0.6312857270240784
Epoch: 7 - Loss: 0.4375635087490082
Epoch: 8 - Loss: 0.3042278587818146
Epoch: 9 - Loss: 0.2213926464319229
Epoch: 10 - Loss: 0.17141039669513702
Epoch: 11 - Loss: 0.14035014808177948
Epoch: 12 - Loss: 0.1200738325715065
Epoch: 13 - Loss: 0.10616356879472733
Epoch: 14 - Loss: 0.09619425982236862
Epoch: 15 - Loss: 0.08877945691347122
Epoch: 16 - Loss: 0.08308900147676468
Epoch: 17 - Loss: 0.07860378175973892
Epoch: 18 - Loss: 0.07498599588871002
Epoch: 19 - Loss: 0.07200868427753448
Epoch: 20 - Loss: 0.06951439380645752
Epoch: 21 - Loss: 0.06739137321710587
Epoch: 22 - Loss: 0.06555844843387604
Epoch: 23 - Loss: 0.06395541876554489
Epoch: 24 - Loss: 0.0625370591878891
Epoch: 25 - Loss: 0.06126868724822998
Ep

In [None]:
yp = model.predict(X)
yp

Array([[0.956443  ],
       [0.0394823 ],
       [0.05127172]], dtype=float32)

In [None]:
np.log(1.5212246393114803e-11)

-24.908920328707698

## Comparación Scikit-learn vs jax

* Vamos a tomar un dataset existente y comparar en términos de error y tiempo de entrenamiento.


In [None]:
from sklearn.datasets import load_breast_cancer
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

In [None]:
# Cargamos el dataset
X, y = load_breast_cancer(return_X_y=True)

In [None]:
# normalizamos los datos
scaler = StandardScaler()
scaler.fit(X)
X_norm = scaler.transform(X)
#X_norm

In [None]:
# Convertimos a formato de Jax
X_jax = jnp.array(X_norm)
y_jax = jnp.array(y)

In [None]:
jax_model = JaxMLPClassifier([4,4], lr=0.005, epochs=100, activation='tanh', seed=12, batch_size=64)
jax_model.fit(X_jax,y_jax)

Initial loss: 82.90943908691406
Epoch: 0 - Loss: 57.80400466918945
Epoch: 1 - Loss: 43.465232849121094
Epoch: 2 - Loss: 40.94175338745117
Epoch: 3 - Loss: 39.86909484863281
Epoch: 4 - Loss: 39.35636901855469
Epoch: 5 - Loss: 39.08998107910156
Epoch: 6 - Loss: 38.936973571777344
Epoch: 7 - Loss: 38.83965301513672
Epoch: 8 - Loss: 38.77171325683594
Epoch: 9 - Loss: 38.72043228149414
Epoch: 10 - Loss: 38.679256439208984
Epoch: 11 - Loss: 38.64460754394531
Epoch: 12 - Loss: 38.61449432373047
Epoch: 13 - Loss: 38.587677001953125
Epoch: 14 - Loss: 38.56344223022461
Epoch: 15 - Loss: 38.54129409790039
Epoch: 16 - Loss: 38.520835876464844
Epoch: 17 - Loss: 38.501705169677734
Epoch: 18 - Loss: 38.48351287841797
Epoch: 19 - Loss: 38.46590805053711
Epoch: 20 - Loss: 38.4486083984375
Epoch: 21 - Loss: 38.43164825439453
Epoch: 22 - Loss: 38.41552734375
Epoch: 23 - Loss: 38.40092468261719
Epoch: 24 - Loss: 38.38808822631836
Epoch: 25 - Loss: 38.37671661376953
Epoch: 26 - Loss: 38.36634063720703
Epoc

In [None]:
yp_jax = jax_model.predict(X_jax)
yp = np.array(yp_jax.tolist())
preds = np.where(yp>0.5, 1, 0)
accuracy_score(y, preds)

0.6274165202108963

## Siguientes pasos:

* Parece que para problemas reales mi implementación ya está batallando.
* Podría ser por el batch size y el método de aprendizaje. Parece que el gradiente descendiente no es muy bueno para manejar estos casos. Tal vez haya que implementar algo más sofisticado.
* Vemos que también es muy sensible a los parámetros que elijamos.
* Aquí es donde cobra relevanvia métodos como Adam y RMSProp.