## Iris dataset from sklearn

In [1]:
import pandas as pd

In [2]:
from sklearn.datasets import load_iris

iris = load_iris()
df = pd.DataFrame(data=iris.data)
df["target"] = iris.target

In [3]:
df.head()

Unnamed: 0,0,1,2,3,target
0,5.1,3.5,1.4,0.2,0
1,4.9,3.0,1.4,0.2,0
2,4.7,3.2,1.3,0.2,0
3,4.6,3.1,1.5,0.2,0
4,5.0,3.6,1.4,0.2,0


In [4]:
# unique target labels
target_ids = set(iris.target)
target_ids

{0, 1, 2}

In [5]:
# get features and target from dataframe
X, y = df[[0, 1, 2, 3]], df["target"]

# convert to numpy arrays
X = X.to_numpy()
y = y.to_numpy()

In [6]:
# create test train split
# not going for validation
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

## ANN for classification

In [7]:
import numpy as np
import jax as J

This is going to be a 1-layer ANN.

In [8]:
def init_params():
    # 4 features
    W = np.random.normal(size=(4, 3))
    # 3 output classes
    b = np.random.normal(size=(1, 3))
    
    return [W, b]

In [9]:
params = init_params()

In [10]:
J.tree_map(lambda p: p.shape, params)

[(4, 3), (1, 3)]

In [11]:
def forward(params, x):
    W, b = params
    
    # affine transform
    out = x @ W
    out = out + b
    
    
    # softmax
    logits = J.nn.softmax(out) 
    
    return logits

# test
forward(params, X_train[0])

DeviceArray([[0.01013626, 0.03795071, 0.95191306]], dtype=float32)

In [12]:
# cross entropy loss
def ce_loss(params, x, y):
    logits = forward(params, x)

    # since -ln(y_actual) is what we need
    loss = -J.numpy.log(logits[0][y])
    
    return loss

In [13]:
# gradient descent
@J.jit
def update(params, x, y, lr=0.1):
    grads = J.grad(ce_loss)(params, x, y)
    return J.tree_map(
        lambda p, g: p - lr * g, params, grads 
    )

In [14]:
from tqdm import tqdm as T

# 100 epochs
for e in T(range(100)):
    for x, y in zip(X_train, y_train):
        params = update(params, x, y)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 99.58it/s]


In [15]:
@J.jit
def infer(params, x):
    logits = forward(params, x)
    # find the class (in this case index with max prob.)
    y_pred = J.numpy.argmax(logits, axis=1)
    return y_pred

In [16]:
preds = []
for _, x in enumerate(X_test):
    preds.append(infer(params, x))

## Evaluation

In [17]:
from sklearn.metrics import classification_report

print(classification_report(y_pred=preds, y_true=y_test))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        11
           1       0.93      1.00      0.96        13
           2       1.00      0.83      0.91         6

    accuracy                           0.97        30
   macro avg       0.98      0.94      0.96        30
weighted avg       0.97      0.97      0.97        30

