In [1]:
import numpy as onp
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

In [82]:
import jax 
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax.nn import relu
import optax 

from jax import random

# Generate key which is used to generate random numbers
key = random.PRNGKey(1)

In [135]:
class pcann():
    def __init__(self, layer_sizes=[5, 5, 2], key=random.PRNGKey(1)):
        self.params = self.initialize_mlp(layer_sizes, key)
        self.batch_forward = vmap(self.forward_pass, in_axes=(None, 0))
        self.initialize_opt()
        
    def train(self, xs, ys, epochs=100):
        for _ in range(epochs):
            grads = jax.grad(self.loss)(self.params, xs, ys)
            updates, self.opt_state = self.optimizer.update(grads, self.opt_state)
            self.params = optax.apply_updates(self.params, updates)
            print(f'Loss is {self.loss(self.params, xs, ys)}')
        
    def initialize_opt(self, lr=1e-1):
        self.optimizer = optax.adam(lr)
        self.opt_state = self.optimizer.init(self.params)
        
    def loss(self, params, x, y):
        preds = self.batch_forward(params, x)
        loss = np.mean(optax.l2_loss(preds, y))
        return loss
        
    def forward_pass(self, params, x):
        """ Compute the forward pass for each example individually """
        activations = x

        # Loop over the ReLU hidden layers
        for w, b in params[:-1]:
            outputs = np.dot(w, activations) + b  # apply affine transformation
            activations = relu(outputs)  #  apply nonlinear activation

        # Perform final trafo to logits
        final_w, final_b = params[-1]
        final_outputs = np.sum(np.dot(final_w, activations) + final_b)
        return final_outputs
        
    def initialize_mlp(self, sizes, key):
        """ Initialize the weights of all layers of a linear layer network """
        keys = random.split(key, len(sizes))
        # Initialize a single layer with Gaussian weights -  helper function
        def initialize_layer(m, n, key, scale=1e-2):
            w_key, b_key = random.split(key)
            return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
        return [initialize_layer(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

In [136]:
key = jax.random.PRNGKey(42)
target_params = 0.5

# Generate some data.
xs = jax.random.normal(key, (100, 5))
ys = np.sum(xs * target_params, axis=-1)

In [137]:
nn = pcann()
nn.train(xs, ys)

Loss is 0.4158563017845154
Loss is 0.17924073338508606
Loss is 0.037840694189071655
Loss is 0.13168497383594513
Loss is 0.13468925654888153
Loss is 0.04993642121553421
Loss is 0.01779324933886528
Loss is 0.04430382326245308
Loss is 0.07623517513275146
Loss is 0.07979195564985275
Loss is 0.057037800550460815
Loss is 0.027858078479766846
Loss is 0.013929243199527264
Loss is 0.02245417609810829
Loss is 0.03732893243432045
Loss is 0.0383341908454895
Loss is 0.02478916384279728
Loss is 0.011566661298274994
Loss is 0.009864678606390953
Loss is 0.016465405002236366
Loss is 0.021581336855888367
Loss is 0.02024519443511963
Loss is 0.01443567592650652
Loss is 0.009171809069812298
Loss is 0.007160389330238104
Loss is 0.007956251502037048
Loss is 0.009236621670424938
Loss is 0.00870217103511095
Loss is 0.006757930386811495
Loss is 0.00557346548885107
Loss is 0.005530070513486862
Loss is 0.0055364458821713924
Loss is 0.0046552554704248905
Loss is 0.003405258059501648
Loss is 0.0029534483328461647
L