In [15]:
# Simple Multi-Class Classification using a Neural Network (Python)

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.utils import to_categorical
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [34]:
# Load dataset
X, y = load_iris(return_X_y=True)

In [37]:
X[:5]

array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2]])

In [41]:
y

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [17]:
# One-hot encode labels (3 classes)
y = to_categorical(y)

In [18]:
# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

In [19]:
# Feature scaling
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

In [20]:
# Build neural network
model = Sequential([
    Dense(16, activation='relu', input_shape=(X.shape[1],)),
    Dense(16, activation='relu'),
    Dense(3, activation='softmax')   # number of classes
])

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [21]:
# Compile model
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

In [22]:
# Train model
model.fit(X_train, y_train, epochs=50, batch_size=8, verbose=1)

Epoch 1/50
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step - accuracy: 0.1083 - loss: 1.2028      
Epoch 2/50
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.3418 - loss: 1.0629 
Epoch 3/50
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.5674 - loss: 0.9713 
Epoch 4/50
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.6810 - loss: 0.8928 
Epoch 5/50
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.7239 - loss: 0.8338 
Epoch 6/50
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.7410 - loss: 0.7948 
Epoch 7/50
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.7843 - loss: 0.6861 
Epoch 8/50
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.7981 - loss: 0.6174 
Epoch 9/50
[1m15/15[0m [32m━━━━━━━━━━━━━

<keras.src.callbacks.history.History at 0x20c69eecec0>

In [23]:
# Evaluate model
loss, accuracy = model.evaluate(X_test, y_test)
print("Test Accuracy:", accuracy)
print("Categorical Cross entropy : ",loss)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 161ms/step - accuracy: 0.9667 - loss: 0.0969
Test Accuracy: 0.9666666388511658
Categorical Cross entropy :  0.09685187041759491


In [24]:
print(X_test[:1])

[[ 0.35451684 -0.58505976  0.55777524  0.02224751]]


In [25]:
print(y_test[:1])

[[0. 1. 0.]]


In [26]:
# Predict a sample
pred = model.predict(X_test[:1])
print("Predicted class:", pred.argmax())

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 47ms/step
Predicted class: 1


In [27]:
pred = model.predict(X_test)
print("Predicted class:", pred)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step
Predicted class: [[6.6526788e-03 9.2716050e-01 6.6186816e-02]
 [9.9134666e-01 8.6472221e-03 6.2003846e-06]
 [3.1445634e-07 2.1555249e-03 9.9784410e-01]
 [9.6852854e-03 6.7015725e-01 3.2015747e-01]
 [2.1998575e-03 8.5987598e-01 1.3792410e-01]
 [9.9015170e-01 9.8255351e-03 2.2729571e-05]
 [6.7861065e-02 9.1201568e-01 2.0123260e-02]
 [5.7304776e-05 1.6830638e-02 9.8311204e-01]
 [7.9072872e-04 6.2888485e-01 3.7032437e-01]
 [1.5603975e-02 9.4455171e-01 3.9844327e-02]
 [7.4919610e-04 1.2406178e-01 8.7518901e-01]
 [9.9600405e-01 3.9583310e-03 3.7633683e-05]
 [9.9402171e-01 5.9664794e-03 1.1852596e-05]
 [9.9619675e-01 3.7707507e-03 3.2528635e-05]
 [9.9919325e-01 8.0631761e-04 4.3393345e-07]
 [1.0074133e-02 8.1554818e-01 1.7437766e-01]
 [2.9058470e-05 5.8276565e-03 9.9414331e-01]
 [1.0808641e-02 9.6251470e-01 2.6676690e-02]
 [1.2592027e-02 8.6765712e-01 1.1975081e-01]
 [2.1605680e-05 4.4759172e-03 9.9550253e-01]
 [9.9797207