In [None]:
import numpy as np
import random
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten


In [None]:
data_dir = "./Dataset" # Root folder containing subfolders of each class

# Load dataset from directory
dataset = tf.keras.utils.image_dataset_from_directory(
    data_dir,image_size=(256,256),batch_size=32,label_mode="categorical"
)

class_names = dataset.class_names


In [None]:
train_size = 0.8
train_ds = dataset.take(int(len(dataset) * train_size))
test_ds = dataset.skip(int(len(dataset) * train_size))

In [None]:
# Prefetch for performance
train_ds = train_ds.prefetch(buffer_size=tf.data.AUTOTUNE)
test_ds = test_ds.prefetch(buffer_size=tf.data.AUTOTUNE)

def dataset_to_numpy(dataset):
    images, labels = [], []
    for batch_images, batch_labels in dataset:
        images.append(batch_images.numpy())
        labels.append(batch_labels.numpy())
    return np.concatenate(images), np.concatenate(labels)


In [None]:
# Extract training and testing data
x_train, y_train = dataset_to_numpy(train_ds)
x_test, y_test = dataset_to_numpy(test_ds)

x_train=x_train/255.0
x_test=x_test/255.0

In [None]:
#model architecture
model=Sequential([
    Conv2D(32,(3,3),activation='relu',input_shape=(256,256,3)),
    MaxPooling2D((2,2)),

    Conv2D(32,(3,3),activation='relu'),
    MaxPooling2D((2,2)),

    Flatten(),
    Dense(64,activation='relu'),
    Dense(20,activation='softmax')
])


In [None]:
optim=tf.keras.optimizers.Adam(learning_rate=0.001) 
model.compile(
    loss='categorical_crossentropy',optimizer=optim,metrics=['accuracy']
)

In [None]:
model.fit(x_train,y_train,epochs=5,batch_size=64)

In [None]:
model.evaluate(x_test,y_test)

In [None]:
idx = random.randint(0, len(y_test))
plt.imshow(x_test[idx, :])
plt.axis('off')
plt.show()

y_pred = model.predict(x_test[idx, :].reshape(1,256,256,3))

predicted_class_index = np.argmax(y_pred)
predicted_class_name = class_names[predicted_class_index]

print(f"Predicted class: {predicted_class_name}")