In [1]:
! git clone https://github.com/abnercorrea/machine-learning.git

import sys
sys.path += ['/content/machine-learning/src']

Cloning into 'machine-learning'...
remote: Enumerating objects: 157, done.[K
remote: Counting objects: 100% (157/157), done.[K
remote: Compressing objects: 100% (100/100), done.[K
remote: Total 157 (delta 60), reused 136 (delta 42), pack-reused 0[K
Receiving objects: 100% (157/157), 37.21 KiB | 1.16 MiB/s, done.
Resolving deltas: 100% (60/60), done.


# Real data: hand-written numbers (5 & 6), with each of the 64 features representing the number of pixels in 4x4 square in 256x256 image.

In [36]:
import numpy as np

from abnercorrea.numpy.util.data_prep import read_train_data, read_test_data, norm, prepend_col, to_binary_classes


# read data
xtrp, ytrp = read_train_data(num_partitions=10)
xtr, ytr = np.concatenate(xtrp), np.concatenate(ytrp)
xte, yte = read_test_data()

# yi will be used as a scalar
ytr, yte = ytr[:, 0], yte[:, 0]

# y denote the vector of yi values (yi = 1 for class 1 and yi = 0 for class 2)
ytrb, classes = to_binary_classes(ytr)
yteb, _ = to_binary_classes(yte)

n_samples, n_features = xtr.shape

xtr.shape, xte.shape, ytrb.shape, yteb.shape

((1000, 64), (110, 64), (1000,), (110,))

# Synthetic data

In [37]:
n_samples = 1110
n_classes = 2
n_features = 64
n_informative = 48
n_redundant = 16
test_size = 0.1

In [87]:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split


def make_classification_train_test(n_samples, n_classes, n_features, n_informative, n_redundant, test_size, random_state):
    # define dataset
    X, y = make_classification(
        n_samples=n_samples, 
        n_features=n_features, 
        n_informative=n_informative, 
        n_redundant=n_redundant, 
        n_classes=n_classes, 
        random_state=random_state
    )

    # split train and test data
    xtr, xte, ytr, yte = train_test_split(X, y, test_size=test_size, random_state=random_state)

    return xtr, xte, ytr, yte

In [97]:
from datetime import datetime

random_state = int(datetime.utcnow().timestamp())

xtr, xte, ytrb, yteb = make_classification_train_test(
    n_samples=n_samples, 
    n_features=n_features, 
    n_informative=n_informative, 
    n_redundant=n_redundant, 
    n_classes=n_classes,
    test_size=test_size,
    random_state=random_state
)

# summarize the dataset
xtr.shape, xte.shape, ytrb.shape, yteb.shape

((999, 64), (111, 64), (999,), (111,))

# Data Pipeline: predictors are all standardized to have mean zero and unit norm.

In [98]:
xtr, xte = norm(xtr - xtr.mean()), norm(xte - xte.mean())

In [101]:
xtr.mean(), xte.mean()

(0.00015373424967282415, 0.0004237032574344316)

# Util

In [4]:
from jax import random


def create_sample_batch(x, y, mini_batch_size, prng_key):
    if mini_batch_size is None:
        return x, y

    n = y.size
    sample_indices = random.choice(prng_key, n, shape=(mini_batch_size,), replace=False).tolist()
    x_batch = x[sample_indices]
    y_batch = y[sample_indices]

    return x_batch, y_batch

In [106]:
import plotly.graph_objects as go


def plot_2d(x, y, title, axes_title, max=True):
    minmax = y.argmax() if max else y.argmin()
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x, y=y, mode='lines', name=' x '.join(axes_title)))
    fig.add_trace(go.Scatter(x=[x[minmax]], y=[y[minmax]], mode='markers', name='Max' if max else 'Min'))
    fig.update_layout(title=title, autosize=True, width=500, height=500,)
    fig.update_xaxes(title_text=axes_title[0])
    fig.update_yaxes(title_text=axes_title[1])
    fig.show()


def plot_epoch_loss(loss):
    epoch = list(range(1, len(loss) + 1))
    
    plot_2d(epoch, np.array(loss), 'Epoch x Loss', ['Epoch', 'Loss'], max=False)


# Optimizers

In [6]:
from absl import logging

from abc import abstractmethod
from functools import partial

import jax
import jax.numpy as jnp
from jax import random


class Optimizer():
    @abstractmethod
    def calculate_update(self, grads, epoch):
        pass

    @abstractmethod
    def optimize(self, params, x, y, loss_f):
        pass


class MiniBatchSGD(Optimizer):
    """
    In both gradient descent (GD) and stochastic gradient descent (SGD), you update a set of parameters in an iterative manner to minimize an error function.

    While in GD, you have to run through ALL the samples in your training set to do a single update for a parameter in a particular iteration,
    in SGD, on the other hand, you use ONLY ONE or SUBSET of training sample from your training set to do the update for a parameter in a particular iteration.
    If you use SUBSET, it is called Minibatch Stochastic gradient Descent.

    IMPORTANT: it's common practice to choose the mini batch size to be a power of 2. (usually between 64 and 512)

    https://datascience.stackexchange.com/a/36451
    https://web.archive.org/web/20180618211933/http://cs229.stanford.edu/notes/cs229-notes1.pdf

    # TODO: implement learning rate decay:
        - lr = lr * (1 / 1 + decay_rate * epoch)
        - lr = lr * 0.95 ** epoch  (exponential decay)
        - lr = lr * (k / epoch ** .5)
        - discrete staircase.
    """

    def __init__(self, mini_batch_size=256, epochs=10000, learning_rate=1e-8, eps=1e-3, overshoot_decrease_rate=.5, stagnation_batch_size=2, prng_key=None):
        assert mini_batch_size is None or mini_batch_size > 0, 'mini_batch_size can be either greater than 0 or None (for ordinary GD).'
        assert stagnation_batch_size > 1, 'stagnation_batch_size has to be greater than 1.'

        self.mini_batch_size = mini_batch_size
        self.epochs = epochs
        self.epoch = 0
        self.learning_rate = learning_rate
        self.eps = eps
        self.overshoot_decrease_rate = overshoot_decrease_rate
        self.stagnation_batch_size = stagnation_batch_size
        self.prng_key = prng_key or random.PRNGKey(0)
        self.loss_hist = []

    @partial(jax.jit, static_argnames=('self',))
    def calculate_update(self, grads, epoch):
        # Ordinary SGD uses gradient for updates
        return grads

    @partial(jax.jit, static_argnames=('self', 'loss_f',))
    def sgd_update(self, params, x, y, loss_f):
        loss, grads = jax.value_and_grad(loss_f)(params, x, y)

        logging.debug(f'Epoch: {self.epoch}\nLoss: {loss}\nGrads: {grads}\nParams: {params}')

        update = self.calculate_update(grads, self.epoch)

        params = jax.tree_multimap(lambda p, g: p - self.learning_rate * g, params, update)

        logging.debug(f'Params updated: {params}')

        return params, loss, grads

    def optimize(self, params, x, y, loss_f):
        """
        Vanishing gradients make it difficult to know which direction the parameters should move to improve the cost function … (from Deep Learning)

        :param params:
        :param x:
        :param y:
        :param loss_f:
        :return:
        """
        mini_batch_size, epochs, eps, learning_rate, overshoot_decrease_rate, stagnation_batch_size = self.mini_batch_size, self.epochs, self.eps, self.learning_rate, self.overshoot_decrease_rate, self.stagnation_batch_size
        loss_hist = self.loss_hist

        for epoch in range(1, epochs + 1):
            self.epoch = epoch

            logging.debug(f'epoch: {epoch}')

            # Creates mini batch
            x_batch, y_batch = create_sample_batch(x, y, mini_batch_size, self.prng_key)
            # Applies SGD update
            params, loss, grads = self.sgd_update(params, x_batch, y_batch, loss_f)
            loss_hist.append(loss.item())

            logging.debug(f'Loss hist: {loss_hist}')

            # TODO: vanished gradient? diverged?
            if jnp.all(jnp.isnan(grads[0])).item():
                logging.error(f'Gradient vanished! - epoch: {epoch}, loss: {loss}, prev_loss: {loss_hist[-2] if epoch > 1 else None}')
                break

            # TODO: research and improve learning rate decreasing
            if epoch > 1:
                overshoot = loss_hist[-1] > loss_hist[-2]
                if overshoot:
                    learning_rate *= overshoot_decrease_rate
                    logging.info(f'Overshoot! Lowering learning rate to {learning_rate} - Epoch: {epoch} - Loss: {loss} - Prev loss: {loss_hist[-2]}')

            # TODO: research and improve convergence checking
            if epoch >= stagnation_batch_size:
                loss_delta = jnp.abs(loss_hist[-1] - loss_hist[-stagnation_batch_size])
                logging.debug(f'Loss delta: {loss_delta}')
                if loss_delta <= eps:
                    logging.info(f'SGD converged in {epoch} epochs! - Initial loss: {loss_hist[0]}, final loss: {loss}')
                    break

        return params


class GradientDescent(MiniBatchSGD):
    def __init__(self, **kwargs):
        """
        In ordinary gradient descent, the batch size is equal to the entire training set.
        """
        assert 'mini_batch_size' not in kwargs
        super().__init__(mini_batch_size=None, **kwargs)


class SGD(MiniBatchSGD):
    def __init__(self, **kwargs):
        """
        In stochastic gradient descent, the batch size is equal to 1.
        """
        assert 'mini_batch_size' not in kwargs
        super().__init__(mini_batch_size=1, **kwargs)


class SGDWithMomentum(MiniBatchSGD):
    """
    Uses an exponentialy weighted average of the gradients to update parameters.
    """
    def __init__(self, momentum=0.9, **kwargs):
        super().__init__(**kwargs)

        self.momentum = momentum
        # Exponentialy weighted average of the gradients
        self.vd = None

    @partial(jax.jit, static_argnames=('self',))
    def calculate_update(self, grads, epoch):
        vd = self.vd or jax.tree_multimap(lambda g: jnp.zeros_like(g), grads)
        # exponentialy weighted average of the gradients
        self.vd = jax.tree_multimap(lambda vd, g: self.momentum * vd + (1 - self.momentum) * g, vd, grads)
        return self.vd


class RMSProp(MiniBatchSGD):
    """
    Uses the exponentialy weighted average of the square of the gradients to update parameters.
    """
    def __init__(self, momentum=0.9, rms_eps=1e-8, **kwargs):
        super().__init__(**kwargs)

        self.momentum = momentum
        self.rms_eps = rms_eps
        # Exponentialy weighted average of the square of the gradients
        self.sd = None


    @partial(jax.jit, static_argnames=('self',))
    def calculate_update(self, grads, epoch):
        sd = self.sd or jax.tree_multimap(lambda g: jnp.zeros_like(g), grads)
        
        # exponentialy weighted average of the square of the gradients
        sd = jax.tree_multimap(lambda sd, g: self.momentum * sd + (1 - self.momentum) * g ** 2, sd, grads)
        
        # RMSProp update
        update = jax.tree_multimap(lambda sd, g: g / (sd ** 0.5 + self.rms_eps), sd, grads)

        self.sd = sd

        return update


class Adam(MiniBatchSGD):
    """
    Adaptive moment estimation. (Adam = Momentum + RMSProp)

    https://arxiv.org/pdf/1412.6980.pdf
    """
    def __init__(self, momentum=0.9, rms_momentum=0.999, rms_eps=1e-8, **kwargs):
        super().__init__(**kwargs)

        self.momentum = momentum
        self.rms_momentum = rms_momentum
        self.rms_eps = rms_eps
        self.vd = None
        self.sd = None

    @partial(jax.jit, static_argnames=('self',))
    def calculate_update(self, grads, epoch):
        zeroes = jax.tree_multimap(lambda g: jnp.zeros_like(g), grads)

        # Momentum
        vd = jax.tree_multimap(lambda vd, g: self.momentum * vd + (1 - self.momentum) * g, self.vd or zeroes, grads)
        vd_correction = 1 / (1 - self.momentum ** epoch)
        self.vd = jax.tree_multimap(lambda vd: vd * vd_correction, vd)

        # RMSProp
        sd = jax.tree_multimap(lambda sd, g: self.rms_momentum * sd + (1 - self.rms_momentum) * g ** 2, self.sd or zeroes, grads)
        sd_correction = 1 / (1 - self.rms_momentum ** epoch)
        self.sd = jax.tree_multimap(lambda sd: sd * sd_correction, sd)

        # ADAM update
        update = jax.tree_multimap(lambda vd, sd: vd / (sd ** 0.5 + self.rms_eps), self.vd, self.sd)

        logging.debug(f'vd: {vd}\nsd: {sd}')
        logging.debug(f'vd correction: {vd_correction} - sd correction: {sd_correction}')
        logging.debug(f'vd corrected: {self.vd}\nsd corrected: {self.sd}')
        logging.debug(f'update: {update}')

        return update


# Logistic Regression

In [7]:
class LogisticRegressionClassifierJax:
    def __init__(self, params, optimizer: Optimizer):
        self.params = params
        self.optimizer = optimizer

    @staticmethod
    @jax.jit
    def log_likelihood_loss(params, x, y):
        w, b = params
        n = y.size
        p = jax.nn.sigmoid(jnp.dot(x, w) + b)
        # likelihood = jnp.product(p ** y * (1 - p) ** (1 - y)) 
        log_likelihood = (1 / n) * jnp.sum(y * jnp.log(p) + (1 - y) * jnp.log(1 - p))
        return -log_likelihood

    @partial(jax.jit, static_argnames=('self',))
    def predict(self, x):
        w, b = self.params
        p = jax.nn.sigmoid(jnp.dot(x, w) + b)
        return jnp.where(p >= .5, 1, 0)

    def fit(self, x, y):
        self.params = self.optimizer.optimize(self.params, x, y, self.log_likelihood_loss)

    @partial(jax.jit, static_argnames=('self',))
    def accuracy(self, x, y):
        yp = self.predict(x)
        accurate = jnp.count_nonzero(yp == y)
        n = y.size
        return accurate / n


#Testing

In [8]:
logging.set_verbosity(logging.INFO)
# logging.set_verbosity(logging.DEBUG)

In [9]:
jax.devices()

INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.


[GpuDevice(id=0, process_index=0)]

## Multiple initial params for testing

In [12]:
import numpy as np

f = xtr.shape[1]

params = [
    [
        # Randomly initializes weights
        np.random.normal(size=(f,)) * np.sqrt(2/f),
        # Initializes biases to 1
        np.ones((1,))
    ]
    for _ in range(10)
]

# Ordinary GD

In [50]:
lrc = LogisticRegressionClassifierJax(
    params=params[0], 
    optimizer=GradientDescent(
        epochs=5000, 
        learning_rate=1e1, 
        eps=1e-5, 
        overshoot_decrease_rate=.5, 
        stagnation_batch_size=2
    )
)

In [51]:
lrc.fit(xtr, ytrb)

INFO:absl:Overshoot! Lowering learning rate to 5.0 - Epoch: 2 - Loss: 0.8857446312904358 - Prev loss: 0.8174636960029602
INFO:absl:Overshoot! Lowering learning rate to 2.5 - Epoch: 3 - Loss: 0.9512478709220886 - Prev loss: 0.8857446312904358
INFO:absl:Overshoot! Lowering learning rate to 1.25 - Epoch: 4 - Loss: 0.9864434599876404 - Prev loss: 0.9512478709220886
INFO:absl:SGD converged in 797 epochs! - Initial loss: 0.8174636960029602, final loss: 0.4212346374988556


In [52]:
lrc.accuracy(xte, yteb).item()

0.8018018007278442

In [53]:
plot_epoch_loss(lrc.optimizer.loss_hist)

# Stochastic Gradient Descent

In [58]:
lrc_sgd = LogisticRegressionClassifierJax(
    params=params[0], 
    optimizer=SGD(
        epochs=5000, 
        learning_rate=10, 
        eps=1e-8, 
        overshoot_decrease_rate=.5, 
        stagnation_batch_size=2
    )
)

In [62]:
lrc_sgd.fit(xtr, ytrb)

INFO:absl:SGD converged in 17 epochs! - Initial loss: 0.3196088969707489, final loss: 7.39124880055897e-05


In [63]:
lrc_sgd.accuracy(xte, yteb).item()

0.5405405759811401

In [64]:
plot_epoch_loss(lrc_sgd.optimizer.loss_hist)

# Mini Batch Stochastic Gradient Descent

In [68]:
lrc_batch_sgd = LogisticRegressionClassifierJax(
    params=params[0], 
    optimizer=MiniBatchSGD(
        mini_batch_size=512,
        epochs=5000, 
        learning_rate=10, 
        eps=1e-5, 
        overshoot_decrease_rate=.5, 
        stagnation_batch_size=2
    )
)

In [69]:
lrc_batch_sgd.fit(xtr, ytrb)

INFO:absl:Overshoot! Lowering learning rate to 5.0 - Epoch: 2 - Loss: 0.8605702519416809 - Prev loss: 0.8041672706604004
INFO:absl:Overshoot! Lowering learning rate to 2.5 - Epoch: 3 - Loss: 0.934516966342926 - Prev loss: 0.8605702519416809
INFO:absl:Overshoot! Lowering learning rate to 1.25 - Epoch: 4 - Loss: 0.9570765495300293 - Prev loss: 0.934516966342926
INFO:absl:Overshoot! Lowering learning rate to 0.625 - Epoch: 5 - Loss: 0.9626386761665344 - Prev loss: 0.9570765495300293
INFO:absl:SGD converged in 771 epochs! - Initial loss: 0.8041672706604004, final loss: 0.42609649896621704


In [70]:
lrc_batch_sgd.accuracy(xte, yteb).item()

0.7567567825317383

In [71]:
plot_epoch_loss(lrc_batch_sgd.optimizer.loss_hist)

# Momentum

In [72]:
lrc_momentum = LogisticRegressionClassifierJax(
    params=params[0], 
    optimizer=SGDWithMomentum(
        momentum=0.9,
        mini_batch_size=512, 
        epochs=5000, 
        learning_rate=160, 
        eps=1e-5, 
        overshoot_decrease_rate=.5, 
        stagnation_batch_size=2
    )
)

In [None]:
lrc_momentum.fit(xtr, ytrb)

In [74]:
lrc_momentum.accuracy(xte, yteb).item()

0.7297297716140747

In [75]:
plot_epoch_loss(lrc_momentum.optimizer.loss_hist)

# RMSProp

In [76]:
lrc_rmsprop = LogisticRegressionClassifierJax(
    params=params[0], 
    optimizer=RMSProp(
        momentum=0.9,
        rms_eps=1e-8,
        mini_batch_size=512, 
        epochs=5000, 
        learning_rate=2e-2, 
        eps=1e-4, 
        overshoot_decrease_rate=.5, 
        stagnation_batch_size=2
    )
)

In [77]:
lrc_rmsprop.fit(xtr, ytrb)

INFO:absl:SGD converged in 78 epochs! - Initial loss: 0.8041672706604004, final loss: 0.4348953664302826


In [78]:
lrc_rmsprop.accuracy(xte, yteb).item()

0.7567567825317383

In [79]:
plot_epoch_loss(lrc_rmsprop.optimizer.loss_hist)

# ADAM

In [102]:
lrc_adam = LogisticRegressionClassifierJax(
    params=params[0], 
    optimizer=Adam(
        momentum=0.9,
        rms_momentum=0.999,
        rms_eps=1e-8,
        mini_batch_size=512, 
        epochs=5000, 
        learning_rate=1e-1, 
        eps=2e-4, 
        overshoot_decrease_rate=.5, 
        stagnation_batch_size=2
    )
)

In [103]:
lrc_adam.fit(xtr, ytrb)

INFO:absl:SGD converged in 96 epochs! - Initial loss: 0.7999417185783386, final loss: 0.2699057459831238


In [108]:
lrc_adam.accuracy(xte, yteb).item()

0.8828828930854797

In [109]:
plot_epoch_loss(lrc_adam.optimizer.loss_hist)

# 🎸🍣

# Benchmarking

In [None]:
%%timeit 
LogisticRegressionClassifierJax(
    params=params, 
    optimizer=StochasticGradientDescent(
        mini_batch_size=None, 
        max_iter=5000, 
        learning_rate=1e-2, 
        tol=1e-2, 
        overshoot_decrease_rate=.75, 
        stagnation_batch_size=2
    )
).fit(xtr, ytrb)

In [None]:
%%timeit 
LogisticRegressionClassifierJax(
    params=params, 
    optimizer=StochasticGradientDescent(
        mini_batch_size=None, 
        max_iter=5000, 
        learning_rate=1e-2, 
        tol=1e-2, 
        overshoot_decrease_rate=.75, 
        stagnation_batch_size=2
    )
).fit(xtr, ytrb)

### Without jit

In [None]:
%timeit lrc.predict(xte)

### jit

In [None]:
%timeit lrc.predict(xte)

# Memory Profiling using pprof

In [None]:
# This will install it
! add-apt-repository ppa:longsleep/golang-backports -y
! apt update
! apt install golang-go

%env GOPATH=/root/go

path = ! echo $PATH
%env PATH={path[0]}:/root/go/bin

! apt-get install graphviz gv
! go install github.com/google/pprof@latest

In [None]:
! pprof -top memory/100.prof

In [None]:
# jax.profiler.save_device_memory_profile(f"memory/4984.prof")

In [None]:
! pprof -top memory/4984.prof

In [None]:
! pprof -top --diff_base memory/100.prof memory/4984.prof

# Development

# Optimize JIT version - WIP

In [None]:
    @partial(jax.jit, static_argnames=('self', 'loss_f',))
    def optimize_jit(self, params, x, y, loss_f):
        mini_batch_size, max_iter, tol, learning_rate, overshoot_decrease_rate, stagnation_batch_size = self.mini_batch_size, self.max_iter, self.tol, self.learning_rate, self.overshoot_decrease_rate, self.stagnation_batch_size

        def optimize_cond(args):
            params, loss_hist, epoch = args
            curr_loss = loss_hist[-1]
            loss_is_nan = jnp.isnan(curr_loss)
            converged = jax.lax.cond(
                epoch >= stagnation_batch_size, 
                lambda _: jnp.abs(curr_loss - loss_hist[-stagnation_batch_size]) < tol,
                lambda _: False,
                operand=None
            )
            return not(loss_is_nan or converged)

        def optimize_body(args):
            params, loss_hist, epoch = args
            # Creates mini batch
            x_batch, y_batch = create_sample_batch(x, y, mini_batch_size, self.prng_key)
            # Applies SGD update
            params, loss, grads = self.sgd_update(params, x_batch, y_batch, learning_rate, loss_f)
            loss_hist.append(loss)
            overshoot = loss_hist[-1] > loss_hist[-2]
            if overshoot:
                learning_rate *= overshoot_decrease_rate
                logging.info(f'Overshoot! Lowering learning rate to {learning_rate}')
            return params, loss_hist, epoch + 1

        #grad_f = jax.value_and_grad(loss_f)
        loss_hist = [loss_f(params, x, y)]

        params, loss_hist, epoch = jax.lax.while_loop(optimize_cond, optimize_body, init_val=[params, loss_hist, 1])
        
        return params

In [None]:
jnp.count_nonzero(jnp.array([True, False, True]))

In [None]:
jnp.count_nonzero(jnp.array([0,1,2]) == jnp.array([0,0,2]))

In [None]:
jnp.all(jnp.array(jax.tree_map(lambda t: jnp.all(t < 1e-3), [jnp.array([1e-4, 1e-5, 1e-6]), jnp.array([1e-2, 1e-4, 1e-3])])))

In [None]:
jnp.array(jax.tree_leaves([jnp.array([1e-4, 1e-5, 1e-6]), jnp.array([1e-2, 1e-4, 1e-3])]))

In [None]:
jax.nn.sigmoid(jnp.zeros(5))

In [None]:
jnp.arange(10) ** jnp.ones(10)

In [None]:
3 ** 2 * 3 ** 2

In [None]:
jnp.arange(10).item()

In [None]:
a = jnp.array([1,2,3])

In [None]:
a

In [None]:
a[:, None, None]

In [None]:
a[:, None, None][0][0][0]