<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [8]:
import tensorflow
import os
from tensorflow.keras.optimizers import Adam
import sklearn
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from tensorflow.keras.datasets import mnist
from models import sudokunet
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

In [9]:
args={"model": "models/digit_classifier1.h5"}

In [None]:
mod_path=args["model"]

((trainX,trainY),(testX,testY))=mnist.load_data()

print(trainX.shape[0])

# reshape data from (h,w) to (h,w,depth)..here depth is 1(grayscale)

trainX=trainX.reshape(trainX.shape[0],28,28,1)
testX=testX.reshape(testX.shape[0],28,28,1)

trainX=trainX.astype("float32")/255.0
testX=testX.astype("float32")/255.0

le=LabelBinarizer()
trainY=le.fit_transform(trainY)
testY=le.fit_transform(testY)

print("Building model....")
aug=ImageDataGenerator(rotation_range=30,width_shift_range=0.2,height_shift_range=0.2,zoom_range=0.3,fill_mode="nearest",shear_range=0.25)

model=sudokunet.SudokuNet.build(28,28,1,10)

EPOCHS=50
BATCH_SIZE=128

print("Compiling model...")
model.compile(optimizer=Adam(),
              metrics=["accuracy"],
              loss="categorical_crossentropy")

save_mod=tensorflow.keras.callbacks.ModelCheckpoint("models/digit_classifier1_bestval.h5",save_best_only=True,monitor="val_accuracy")
#lr_scheduler=tensorflow.keras.callbacks.LearningRateScheduler(lambda epoch: 1e-5*10**(epoch/25))

print("Fitting model...")

history=model.fit(aug.flow(trainX,trainY,batch_size=BATCH_SIZE),
                  validation_data=(testX,testY),
                  epochs=EPOCHS,
                  verbose=1,
                  callbacks=[save_mod]
                  )
predictions=model.predict(testX)

print("Evaluating model")
print(classification_report(testY.argmax(axis=1),
                            predictions.argmax(axis=1)
                            ))
print("Serializing model...")
model.save(mod_path)

plt.plot(range(1,EPOCHS+1),history.history["loss"])
plt.plot(range(1,EPOCHS+1),history.history["val_loss"])
plt.savefig("plot_1.png")

plt.plot(range(1,EPOCHS+1),history.history["accuracy"])
plt.plot(range(1,EPOCHS+1),history.history["val_accuracy"])
plt.savefig("plot_2.png")






