In [1]:
import pandas as pd

In [2]:
from sklearn.datasets import load_iris
iris = load_iris()
df = pd.DataFrame(data=iris.data)
df["target"] = iris.target
df.head()

Unnamed: 0,0,1,2,3,target
0,5.1,3.5,1.4,0.2,0
1,4.9,3.0,1.4,0.2,0
2,4.7,3.2,1.3,0.2,0
3,4.6,3.1,1.5,0.2,0
4,5.0,3.6,1.4,0.2,0


In [3]:
X, y = df[[i for i in range(4)]], df["target"]
X = X.to_numpy()
y = y.to_numpy()

In [4]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

In [5]:
import numpy as np
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn

In [6]:
from jax.lib import xla_bridge
xla_bridge.get_backend().platform

'gpu'

In [7]:
masterkey = jax.random.PRNGKey(0)
masterkey

DeviceArray([0, 0], dtype=uint32)

In [8]:
class Ann(nn.Module):
    def setup(self):
        self.dense = nn.Dense(features=3)
        
    def __call__(self, x):
        out = self.dense(x)
        out = jax.nn.softmax(out)
        return out

In [9]:
model = Ann()
params = model.init(masterkey, X_train[0])
params

FrozenDict({
    params: {
        dense: {
            kernel: DeviceArray([[ 0.38065028,  0.5594941 ,  0.37281904],
                         [ 0.4069356 , -0.28853688,  0.18296593],
                         [-0.7483908 , -0.73975974, -0.4646266 ],
                         [ 0.5796048 ,  0.8613224 , -0.471414  ]], dtype=float32),
            bias: DeviceArray([0., 0., 0.], dtype=float32),
        },
    },
})

In [10]:
y = model.apply(params, X_train[0])
y

DeviceArray([0.50647944, 0.3207565 , 0.17276412], dtype=float32)

In [11]:
def ce_loss(params, xs, ys):
    def ce(x, y):
        logits = model.apply(params, x)
        loss = -jnp.log(logits[y])
        
        #print(loss)
    
        return loss
    
    return jnp.mean(jax.vmap(ce)(xs, ys), axis=0)

In [12]:
import optax

optim = optax.sgd(learning_rate=0.1)
optim_state = optim.init(params)
loss_grad_fn = jax.value_and_grad(ce_loss)

In [13]:
# train
for i in range(1000 + 1):
    loss_val, gradient = loss_grad_fn(params, X_train, y_train)
    update, optim_state = optim.update(gradient, optim_state)
    params = optax.apply_updates(params, update)
    
    if i % 100 == 0:
        print(f"Step {i} :: {loss_val}")

Step 0 :: 1.2861980199813843
Step 100 :: 0.44869717955589294
Step 200 :: 0.2791523039340973
Step 300 :: 0.23242047429084778
Step 400 :: 0.20332759618759155
Step 500 :: 0.18320882320404053
Step 600 :: 0.1684349775314331
Step 700 :: 0.15710097551345825
Step 800 :: 0.14811217784881592
Step 900 :: 0.14079517126083374
Step 1000 :: 0.1347130537033081


In [14]:
@jax.jit
def infer(params, xs):
    def infer_x(x):
        logits = model.apply(params, x)
        return jnp.argmax(logits, axis=0)
    
    return jax.vmap(infer_x)(xs)

In [15]:
preds = infer(params, X_test)

In [16]:
from sklearn.metrics import classification_report

print(classification_report(y_pred=preds, y_true=y_test))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        11
           1       1.00      1.00      1.00        13
           2       1.00      1.00      1.00         6

    accuracy                           1.00        30
   macro avg       1.00      1.00      1.00        30
weighted avg       1.00      1.00      1.00        30

