# Agregando Adam


* En el notebook pasado vimos como se batalla para ajustar los hiperparametros de una red neuronal para problemas más complejos.
* En este notebook nos enfocaremos en agregar el método de optimización de pesos, Adam.

In [1]:
# Importamos librerias
from jax import numpy as jnp
from jax import grad
from jax import jacobian
from jax import random
from jax import vmap
from jax import value_and_grad
from jax import jit
from jax import nn
from jax import lax
import sys
import numpy as np

In [30]:
class JaxMLPClassifier():
  def __init__(self, units, epochs=10, lr=0.01, mr=0.9, binary=True, n_classes=2, activation='relu', seed=0, batch_size=-1):
    self.key = random.key(seed)
    self.epochs = epochs
    self.lr = lr
    self.mr = mr # Momentum rate
    self.W = dict()
    self.batch_size = batch_size
    if activation == 'relu':
      self.activation = nn.relu
    elif activation == 'sigmoid':
      self.activation = nn.sigmoid
    else:
      self.activation = nn.tanh

    if binary:
      self.units = units + [1]
    else:
      self.units = units + [n_classes]

  def forward(self, W, X):
    n_layers = len(self.units)
    output = X
    for i in range(n_layers-1):
      #W = W[i]
      z = jnp.dot(output, jnp.transpose(W[i]))
      output = self.activation(z)
    z = jnp.dot(output, jnp.transpose(W[n_layers-1]))
    output = nn.sigmoid(z)
    return output

  def predict(self, X):
    n_layers = len(self.units)
    output = X
    for i in range(n_layers-1):
      #W = W[i]
      z = jnp.dot(output, jnp.transpose(self.W[i]))
      output = self.activation(z)
    z = jnp.dot(output, jnp.transpose(self.W[n_layers-1]))
    output = nn.sigmoid(z)
    return output

  def loss(self, W, X, y):
    yp = self.forward(W, X)
    #print("yp: ", jnp.min(yp))
    #print("y: ", y)
    #l = jnp.sum(jnp.pow(yp - y, 2))/y.shape[0]
    l = jnp.log(yp) * y + jnp.log(1-yp) * (1 - y) # entropia cruzada
    #print("min losses:", jnp.max(l))
    #print("max losses:", jnp.min(l))
    ls = -jnp.sum(l)/y.shape[0]
    return ls

  def basic_update_step(self, W, grad_w):
    n_layers = len(self.units)
    for i in range(n_layers):
          W[i] = W[i] - self.lr*grad_w[i]
    return W

  def momentum_update_step(self, W, grad_w, momentum_w):
    n_layers = len(self.units)
    for i in range(n_layers):
      momentum_w[i] = self.mr*momentum_w[i] + self.lr*W[i]
      W[i] = W[i] - momentum_w[i]
    return W, momentum_w

  def rmsprop_update_step(self, W, grad_w, momentum_w):
    n_layers = len(self.units)
    for i in range(n_layers):
      momentum_w[i] = (1-self.mr)*W[i]*W[i] + self.mr*momentum_w[i]*momentum_w[i]
      W[i] = W[i] - self.lr / jnp.sqrt(momentum_w[i] + 0.001)*W[i]
    return W, momentum_w

  def fit(self, X, y):
    n_features = X.shape[1]
    n_samples = X.shape[0]
    n_layers = len(self.units) # +1 para tomar en cuenta la capa de salida

    # Creamos las matrices de pesos.
    keys = random.split(self.key, n_layers)
    W = dict()
    momentum = dict()
    for i in range(n_layers):
      n_units = self.units[i]
      W[i] = random.normal(keys[i], (n_units, n_features))
      momentum[i] = jnp.zeros_like(W[i])
      n_features = self.units[i]

    # calculamos el gradiente y calculamos en cada batch
    #val_grad_loss = jit(value_and_grad(self.loss, 0))
    val_grad_loss = value_and_grad(self.loss, 0)

    # Entrenamos la red
    X_p = lax.slice(X, (0,0), (self.batch_size,X.shape[1]))
    y_p = lax.slice(y, (0,), (self.batch_size,))
    loss_val, grad_val = val_grad_loss(W, X_p, y_p)
    print(f"Initial loss: {loss_val}")
    for ep in range(self.epochs):
      if self.batch_size == -1:
        # Calcular gradiente
        loss_val, grad_w = val_grad_loss(W, X, y)
        print(f"Epoch: {ep} - Loss: {loss_val}")
        #print("Grad: ", grad_w)

        # Actualizar pesos
        #W = self.basic_update_step(W, grad_w)
        #W, momentum = self.momentum_update_step(W, grad_w, momentum)
        W, momentum = self.rmsprop_update_step(W, grad_w, momentum)
        #for i in range(n_layers):
          #W[i] = W[i] - self.lr*grad_w[i]
      else:
        n_batches = int(jnp.floor(y.shape[0]/self.batch_size))
        #print(n_batches)
        avg_loss = 0
        for n in range(n_batches-1):
          i0 = n*self.batch_size
          i1 = (n+1)*self.batch_size
          #print(i0)
          #print(i1)
          X_p = lax.slice(X, (i0,0), (i1, X.shape[1]))
          y_p = lax.slice(y, (i0,), (i1,))
          loss_val, grad_w = val_grad_loss(W, X_p, y_p)
          avg_loss += loss_val
          W, momentum = self.rmsprop_update_step(W, grad_w, momentum)
          #W, momentum = self.momentum_update_step(W, grad_w, momentum)
          #W = self.basic_update_step(W, grad_w)
          #for i in range(n_layers):
            #W[i] = W[i] - self.lr*grad_w[i]

        print(f"Epoch: {ep} - Loss: {avg_loss/n_batches}")
    # Guardamos pesos
    self.W = W

In [3]:
from sklearn.datasets import load_breast_cancer
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

In [31]:
# Cargamos el dataset
X, y = load_breast_cancer(return_X_y=True)
scaler = StandardScaler()
scaler.fit(X)
X_norm = scaler.transform(X)
X_jax = jnp.array(X_norm)
y_jax = jnp.array(y)

In [38]:
jax_model = JaxMLPClassifier([4,8,4], lr=0.001, mr=0.95, epochs=15, activation='tanh', seed=1, batch_size=32)
jax_model.fit(X_jax,y_jax)

Initial loss: 22.615942001342773
Epoch: 0 - Loss: 25.5786190032959
Epoch: 1 - Loss: 24.96795082092285
Epoch: 2 - Loss: 24.306941986083984
Epoch: 3 - Loss: 23.592178344726562
Epoch: 4 - Loss: 22.898921966552734
Epoch: 5 - Loss: 22.32330894470215
Epoch: 6 - Loss: 21.851078033447266
Epoch: 7 - Loss: 21.459163665771484
Epoch: 8 - Loss: 21.157258987426758
Epoch: 9 - Loss: 20.94194984436035
Epoch: 10 - Loss: 20.80376434326172
Epoch: 11 - Loss: 20.734617233276367
Epoch: 12 - Loss: 20.72124671936035
Epoch: 13 - Loss: 20.74492645263672
Epoch: 14 - Loss: 20.784385681152344


In [39]:
yp_jax = jax_model.predict(X_jax)
yp = np.array(yp_jax.tolist())
preds = np.where(yp>0.5, 1, 0)
accuracy_score(y, preds)

0.7434094903339191

## Siguientes pasos:

* Haciendo pruebas con RMSProp y Momentum vimos que la red converge más rápido al valor de error mínimo que estaba sacando.
* Faltaría ver que tal se comporta Adam.