# Iris Dataset Classification

### $\textbf{Task: Build a logistic regression classifier for sepal petals using the Iris dataset}$

In [130]:
from typing import Tuple
import pandas as pd
import numpy as np

import jax
import jax.numpy as jnp
import flax

In [131]:
df = pd.read_csv("datasets/iris_csv.csv")
df.head(5)

Unnamed: 0,sepallength,sepalwidth,petallength,petalwidth,class
0,5.1,3.5,1.4,0.2,Iris-setosa
1,4.9,3.0,1.4,0.2,Iris-setosa
2,4.7,3.2,1.3,0.2,Iris-setosa
3,4.6,3.1,1.5,0.2,Iris-setosa
4,5.0,3.6,1.4,0.2,Iris-setosa


In [132]:
df.describe()

Unnamed: 0,sepallength,sepalwidth,petallength,petalwidth
count,150.0,150.0,150.0,150.0
mean,5.843333,3.054,3.758667,1.198667
std,0.828066,0.433594,1.76442,0.763161
min,4.3,2.0,1.0,0.1
25%,5.1,2.8,1.6,0.3
50%,5.8,3.0,4.35,1.3
75%,6.4,3.3,5.1,1.8
max,7.9,4.4,6.9,2.5


In [133]:
df = df.dropna(subset=["petallength", "petalwidth", "sepallength", "petalwidth"])

In [134]:
df['class_index'], classes = pd.factorize(df['class'])
df = df.drop('class', axis=1)
class_map = {class_name: index for (index, class_name) in enumerate(classes)}
print(class_map)

{'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}


In [135]:
def train_test(df: pd.DataFrame, num_classes, proportion = 0.1) -> Tuple[pd.DataFrame, pd.DataFrame]:
    samples = []
    for class_index in range(num_classes):
        class_df = df[df["class_index"] == class_index]
        sample = class_df.sample(n = int(len(class_df) * proportion))
        samples.append(sample)
    test_df = pd.concat(samples)
    train_df = df[~df.index.isin(test_df.index)]
    test_df.reset_index(drop=True, inplace=True)
    train_df.reset_index(drop=True, inplace=True)
    return train_df, test_df

train_df, test_df = train_test(df, 3)

In [136]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size = 0.1, stratify = df["class_index"])

print(f"Training set size: {len(train_df)}")
print(f"Test set size: {len(test_df)}")

Training set size: 135
Test set size: 15


# Model architecture

We are going to first just train a linear regression model for logits and then pass that through a softmax function. Our model will take in the $(batch, 4)$ features and pass them through a (4, 3) weight matrix, the outputs will then be softmax'd and we will have our probabilities for each class.

In [146]:
import flax.linen as nn
from jax.nn.initializers import lecun_normal
from jax.nn import softmax, one_hot
import optax

rng_key = jax.random.PRNGKey(42)
rng_key, rng_w, rng_b = jax.random.split(rng_key, 3)
initializer = jax.nn.initializers.lecun_normal()

weights = initializer(rng_w, (4, 3), jnp.float32)
bias = initializer(rng_b, (1, 3), jnp.float32)
bias = jnp.reshape(bias, (3, ))
print(weights)
print(bias)

def fwd(weights_, bias_, batch_in):
    return softmax(batch_in @ weights_ + bias_)

def cross_entropy_loss(weights_, bias_, batch_inputs, batch_outputs):
    """
    Batch outputs are one hot encoded vector of true output class
    """
    model_output = fwd(weights_, bias_, batch_inputs)
    return -jnp.mean(jnp.sum(batch_outputs * jnp.log(model_output)))

optimizer = optax.adam(learning_rate = 0.00001)
optimizer_state = optimizer.init((weights, bias))

def process_batch(weights, bias, inputs, outputs, optimizer, optimizer_state):
    loss, grad = jax.value_and_grad(cross_entropy_loss, argnums=(0, 1))(weights, bias, batch_inputs, batch_outputs)
    updates, optimizer_state = optimizer.update(grad, optimizer_state)
    (weights, biases) = optax.apply_updates((weights, bias), updates)
    return loss, (weights, biases), optimizer_state

[[-0.18463236 -0.3223093  -1.022951  ]
 [ 0.83705264  0.6459639   0.62568444]
 [ 0.03572375 -0.37918803 -0.25625595]
 [-1.0434582   0.33855742 -0.15648586]]
[-0.7610422  1.3522058 -0.516474 ]


In [149]:
num_epochs = 300
batch_size = 28

for epoch in range(num_epochs):
    batch_losses = []
    for batch_no in range(len(train_df) // batch_size):
        batch_samples = df.sample(batch_size)
        batch_inputs  = jnp.array(batch_samples.drop("class_index", axis=1))
        batch_outputs = one_hot(batch_samples["class_index"].values, num_classes = 3)
        loss, (weights, biases), optimizer_state = process_batch(weights, bias, batch_inputs, batch_outputs, optimizer, optimizer_state)
        batch_losses.append(loss)
    if ((epoch + 1) % 10 == 0):
        print(np.array(batch_losses).mean())

-9.86907
-10.297173
-10.773974
-9.232189
-10.735574
-10.526245
-10.314912
-9.883543
-11.274291
-11.200541
-9.527762
-10.671847
-9.680386
-10.203789
-10.515488
-10.549166
-9.483685
-8.82962
-10.49518
-10.404421
-10.017877
-10.159428
-11.247931
-10.854046
-10.824501
-9.740387
-11.676072
-10.15776
-10.1600485
-10.416632
