In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from keras.models import load_model
from sklearn.model_selection import train_test_split as tt
from keras.layers import Conv2D,Dense,MaxPooling2D,Flatten
import os
import warnings
import re

warnings.filterwarnings('ignore')

In [None]:
# 0:Circle, 1:Triangle, 2:Square (28x28 pixels)
df = pd.read_csv("../MNIST Shapes/train_data.csv")
print(df['label'].value_counts())
print(df.shape)

In [None]:
x = np.asarray(df.drop(['label'],axis=1))
y = df['label']

# Data Splitting
x_train, x_test, y_train, y_test = tt(x, y, train_size=0.8, stratify=y, random_state=0)

In [None]:
x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0

In [None]:
model = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        # Conv2D(#filter,filter size)
        # padding="valid" => default-32 to 30, padding="sme" => size is minted(32 to 32)
        Conv2D(256, 3, padding="same", activation="relu"),
        MaxPooling2D(),
        Conv2D(128, 3, activation="relu"),
        MaxPooling2D(),
        Conv2D(128, 3, activation="relu"),
        Flatten(),
        Dense(64, activation="relu"),
        Dense(3,activation="softmax"),
    ]
)

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(lr=0.003),
    metrics=["accuracy"],
)

In [None]:
model.fit(x_train, y_train, batch_size=64, epochs=10, verbose=2)
model.evaluate(x_test, y_test, batch_size=64, verbose=2)

model.save("mnist_shapes_model.h5")
del model

In [None]:
model = load_model("mnist_shapes_model.h5")

In [None]:
ip = np.asarray([255,255,255,255,255,255,255,255,255,254,255,254,255,255,255,255,255,255,255,253,255,253,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,254,253,255,255,255,255,254,255,253,255,255,253,255,254,254,255,255,255,255,255,255,255,255,255,255,255,255,254,255,255,255,253,255,255,255,255,254,255,253,255,253,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,254,254,254,255,254,255,255,254,252,255,254,252,255,251,255,255,255,255,255,255,255,255,255,255,255,255,255,255,254,255,255,255,253,255,254,255,255,255,255,255,255,255,254,255,255,255,255,255,255,255,255,255,255,255,255,255,255,254,255,252,255,1,1,255,255,252,253,255,254,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,250,255,2,0,0,1,253,255,255,255,254,254,254,255,255,255,255,255,255,255,255,255,255,255,255,255,254,254,255,1,0,255,255,0,2,255,254,255,255,255,254,255,255,255,255,254,255,255,254,255,253,255,255,255,255,255,0,0,255,254,255,254,2,252,255,253,255,255,255,255,255,255,255,255,255,254,255,253,255,254,255,255,253,0,1,255,255,255,253,255,253,2,252,255,254,254,254,255,255,255,255,253,255,254,254,255,251,255,253,254,1,255,253,255,254,255,255,255,255,0,4,253,255,255,255,255,255,255,255,255,252,253,255,252,255,254,1,2,253,254,255,252,255,254,255,254,255,255,0,255,255,253,255,255,255,255,255,254,254,255,253,255,255,254,0,255,253,255,253,255,254,253,255,254,254,255,2,0,254,255,255,255,255,255,255,255,255,255,255,251,254,0,1,254,255,255,252,255,255,255,255,255,254,255,255,253,0,0,255,255,255,255,255,254,255,255,255,255,1,2,254,255,252,255,255,253,255,253,255,255,254,255,255,255,255,1,0,255,255,255,255,255,255,255,254,0,0,1,0,0,2,0,0,2,0,2,0,0,0,1,0,1,0,0,1,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255])

ip = ip.reshape((-1,28,28,1)).astype("float32") / 255.0
pred = model.predict(ip)
print(pred)

op = np.argmax(pred,axis=1)
print(op)