In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, regularizers
from tensorflow.keras.datasets import mnist
import pandas as pd
import os

In [2]:
#HYPERPARAMETERS
BATCH_SIZE = 64
WEIGHT_DECAY = 0.001
LEARNING_RATE = 0.001

physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [3]:
train_df = pd.read_csv('multi_digits_mnist/train.csv')
test_df = pd.read_csv('multi_digits_mnist/test.csv')

In [4]:
train_df.head()

Unnamed: 0,Image,first_num,second_num
0,0_00.png,0,0
1,100_00.png,0,0
2,101_00.png,0,0
3,102_00.png,0,0
4,103_00.png,0,0


In [5]:
train_df.iloc[:,0].values #all rows, index 0 column

array(['0_00.png', '100_00.png', '101_00.png', ..., '999_98.png',
       '99_98.png', '9_98.png'], dtype=object)

In [6]:
train_images = 'multi_digits_mnist/train_images/'+train_df.iloc[:,0].values
test_images = 'multi_digits_mnist/test_images/'+test_df.iloc[:,0].values

In [7]:
train_images

array(['multi_digits_mnist/train_images/0_00.png',
       'multi_digits_mnist/train_images/100_00.png',
       'multi_digits_mnist/train_images/101_00.png', ...,
       'multi_digits_mnist/train_images/999_98.png',
       'multi_digits_mnist/train_images/99_98.png',
       'multi_digits_mnist/train_images/9_98.png'], dtype=object)

In [8]:
train_labels = train_df.iloc[:,1:].values
test_labels = test_df.iloc[:,1:].values

In [9]:
train_labels

array([[0, 0],
       [0, 0],
       [0, 0],
       ...,
       [9, 8],
       [9, 8],
       [9, 8]], dtype=int64)

In [10]:
def read_image(image_path, label):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_image(image, channels=1, dtype=tf.float32)
    
    labels = {'first_num':label[0],'second_num':label[1]}
    return image, labels

In [11]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_dataset = tf.data.Dataset.from_tensor_slices((train_images,train_labels))
train_dataset

<TensorSliceDataset shapes: ((), (2,)), types: (tf.string, tf.int64)>

In [12]:
train_dataset = (train_dataset.shuffle(buffer_size=len(train_labels))
                .map(read_image)
                .batch(batch_size=BATCH_SIZE)
                .prefetch(buffer_size=AUTOTUNE))

In [13]:
test_dataset = tf.data.Dataset.from_tensor_slices((test_images,test_labels))
test_dataset = (test_dataset.map(read_image)
                .batch(batch_size=BATCH_SIZE)
                .prefetch(buffer_size=AUTOTUNE))

In [14]:
inputs = keras.Input(shape = (64,64,1))

x = layers.Conv2D(
    filters=32,
    kernel_size=3,
    padding='same',
    kernel_regularizer = regularizers.l2(WEIGHT_DECAY))(inputs)

x = layers.BatchNormalization()(x)

x = keras.activations.relu(x)

x = layers.Conv2D(
    filters=64,
    kernel_size=3,
    padding='same',
    kernel_regularizer = regularizers.l2(WEIGHT_DECAY))(x)

x = layers.BatchNormalization()(x)

x = keras.activations.relu(x)

x = layers.MaxPooling2D()(x)

x = layers.Conv2D(
    filters=64,
    kernel_size=3,
    padding='same',
    activation='relu',
    kernel_regularizer = regularizers.l2(WEIGHT_DECAY))(x)

x = layers.Conv2D(
    filters=128,
    kernel_size=3,
    padding='same',
    activation='relu')(x)

x = layers.MaxPooling2D()(x)

x = layers.Flatten()(x)

x = layers.Dense(128, activation='relu')(x)

x = layers.Dropout(0.5)(x)

x = layers.Dense(64, activation='relu')(x)

output1 = layers.Dense(10,activation='softmax', name='first_num')(x)
output2 = layers.Dense(10,activation='softmax', name='second_num')(x)

model = keras.Model(inputs=inputs, outputs=[output1, output2])

In [15]:
model.compile(
    optimizer=keras.optimizers.Adam(LEARNING_RATE),
    loss = [
        keras.losses.SparseCategoricalCrossentropy(),
        keras.losses.SparseCategoricalCrossentropy() #only one also ok if same
    ],
    metrics=['accuracy'])

In [16]:
model.fit(train_dataset,epochs=5)
model.evaluate(test_dataset,verbose=1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[1.1679807901382446,
 0.5625738501548767,
 0.5585225820541382,
 0.8358500003814697,
 0.8399500250816345]