<a href="https://colab.research.google.com/github/HagerDakroury/handwritten-digits-classification/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

import matplotlib.pyplot as plt
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
import numpy as np





#mounting my google drive where the dataset is saved
#can be done like that or directly from google colab UI

# from google.colab import drive
# drive.mount('/content/drive')

#extracting the dataset
!tar xvzf /content/drive/MyDrive/trainingSet.tar.gz




In [22]:
def Load_data(directory):
  #fitting the data into a dataset
  #image_size=(28,28)
  #training to validation ratio 8:2
  #seed is random

  #the training dataset
  t_dataset = tf.keras.preprocessing.image_dataset_from_directory(
      directory, labels='inferred',label_mode='int',color_mode='rgb',
      batch_size=32, image_size=(28,28), validation_split=0.2,
      subset="training",seed=387
  )

  #the validation dataset 
  v_dataset = tf.keras.preprocessing.image_dataset_from_directory(
      directory, labels='inferred',label_mode='int', color_mode='rgb',
      batch_size=32, image_size=(28,28), validation_split=0.2,
      subset="validation",seed=387
  )

  #preprocessing the dataset

  #1.reshaping to a single


  return t_dataset,v_dataset


In [13]:
def train(t_dataset,v_dataset):
  #keeping the data into memory after they're loaded to avoid I/O blocking
  AUTOTUNE = tf.data.AUTOTUNE
  t_ds = t_dataset.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
  v_ds = v_dataset.cache().prefetch(buffer_size=AUTOTUNE)

  #resclaing the dataset
  normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

  t_ds = t_ds.map(lambda x, y: (normalization_layer(x), y))
  v_ds = v_ds.map(lambda x, y: (normalization_layer(x), y))


  #constructing the model
  #layer1 -> 
  model = Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', input_shape=(28, 28, 3)),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(100, activation='relu', kernel_initializer='he_uniform'),
    layers.Dense(10)
  ])

  # compile model
  model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

  #fitting the model
  epochs=10
  history = model.fit(
    t_ds,
    validation_data=v_ds,
    epochs=epochs
  )

  #checking accuracy
  val_acc = history.history['val_accuracy']
  print('> %.3f' % (val_acc[9] * 100.0))

  #saving the model
  # path = F"/content/gdrive/My Drive/{model_save_name}" 
  # torch.save(model.state_dict(), path)

  model.save('/content/My Drive/cse440_project/model.h5')

  class_names = t_dataset.class_names 
  plt.figure(figsize=(9, 9))
  for images, labels in t_dataset.take(1):
      for i in range(9):
          ax = plt.subplot(3, 3, i + 1)
          plt.imshow(images[i].numpy().astype("uint8"))
          #predicting the image 
          predict=Predict(isPath=False,image=image)
          plt.title("true:"+"["+class_names[labels[i]]+"]  "+"predicted:"+np.array_str(predict))

          plt.axis("off")



In [23]:
from keras.models import load_model

#2 options, pass image path or image object directly
def Predict(isPath,image_path=None,image=None):
  if isPath:
    img=load_img(image_path,target_size=(28,28))
    
  else:
    img=image
  
  #some preprocessing so the image is ready
  img = img_to_array(img)
  img = img.reshape(1, 28, 28, 3)
  img = img.astype('float32')
  img = img / 255.0

  #loading the model
  model=load_model('/content/My Drive/cse440_project/model.h5')


  #the prediction
  predict=np.argmax(model.predict(img),axis=-1)

  return predict

  




In [None]:
print(Predict(isPath=True,image_path="/content/trainingSet/7/img_10028.jpg"))

# t_dataset,v_dataset=Load_data("/content/trainingSet")
# train(t_dataset,v_dataset)