# MNIST training

## Data loading

In [11]:
import tensorflow as tf
import tensorflow_datasets as tfds

INVERT = True

(ds_train, ds_test), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train

<_PrefetchDataset element_spec=(TensorSpec(shape=(28, 28, 1), dtype=tf.uint8, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None))>

## Training dataset setup

In [12]:
def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(normalize_img)

if INVERT:
    def invert(image, label):
        return (tf.cast(image, tf.float32) * -1.0) + 1.0, label
    inverted = ds_train.map(invert)
    ds_train = ds_train.concatenate(inverted)

ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples*(INVERT+1))
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

## Testing dataset setup

In [13]:
ds_test = ds_test.map(normalize_img)

if INVERT:
    inverted = ds_test.map(invert)
    ds_test = ds_test.concatenate(inverted)

ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

## Model creation

In [14]:
max_pool = tf.keras.layers.MaxPool2D((2, 2), (2, 2), padding='same')

model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(
        filters=16,
        kernel_size=5,
        padding="same",
        activation=tf.nn.relu),
    max_pool,
    tf.keras.layers.Conv2D(
        filters=32,
        kernel_size=5,
        padding="same",
        activation=tf.nn.relu),
    max_pool,
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dropout(0.4),
    tf.keras.layers.Dense(10, activation="softmax")
])

model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"],
)

## Model fitting

In [15]:
model.fit(
    ds_train,
    epochs=1,
    validation_data=ds_test,
)

[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m88s[0m 84ms/step - accuracy: 0.8120 - loss: 0.5856 - val_accuracy: 0.9783 - val_loss: 0.0644


<keras.src.callbacks.history.History at 0x7f13203f94b0>

## Prediction image loading

In [16]:
!wget -q https://github.com/milliams/machine_learning/archive/master.zip
!unzip -q -j -o master.zip

In [17]:
import numpy as np
from skimage.io import imread

images = []
for i in list(range(1,10)) + ["dog"]:
    images.append(np.array(imread(f"{i}.png")/255.0, dtype='float32'))
images = np.array(images)[:,:,:,np.newaxis]
test_data = tf.convert_to_tensor(images)

## Make the predictions

In [18]:
probabilities = model.predict(test_data)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 76ms/step


## Output stats on the predictions

In [19]:
truths = list(range(1, 10)) + ["dog"]

table = []
for truth, probs in zip(truths, probabilities):
    prediction = probs.argmax()
    if truth == 'dog':
        print(f"{truth}. CNN thinks it's a {prediction} ({probs[prediction]*100:.1f}%)")
    else:
        print(f"{truth} at {probs[truth]*100:4.1f}%. CNN thinks it's a {prediction} ({probs[prediction]*100:4.1f}%)")
    table.append((truth, probs))

1 at 68.9%. CNN thinks it's a 1 (68.9%)
2 at 96.2%. CNN thinks it's a 2 (96.2%)
3 at 98.9%. CNN thinks it's a 3 (98.9%)
4 at 99.9%. CNN thinks it's a 4 (99.9%)
5 at 97.8%. CNN thinks it's a 5 (97.8%)
6 at 99.9%. CNN thinks it's a 6 (99.9%)
7 at 38.5%. CNN thinks it's a 1 (41.9%)
8 at 98.9%. CNN thinks it's a 8 (98.9%)
9 at  2.1%. CNN thinks it's a 0 (60.3%)
dog. CNN thinks it's a 8 (33.7%)


In [20]:
def print_table(table):
    print("""<table cellpadding="0" style="border-collapse: collapse; border-style: hidden;">
    <thead>
    <tr>
    <td><b>Image</b></td>
    <td><b>0</b></td>
    <td><b>1</b></td>
    <td><b>2</b></td>
    <td><b>3</b></td>
    <td><b>4</b></td>
    <td><b>5</b></td>
    <td><b>6</b></td>
    <td><b>7</b></td>
    <td><b>8</b></td>
    <td><b>9</b></td>
    </tr>
    </thead>
    <tbody>""")
    for truth, l in table:
        print("<tr>")
        print(f'<td><img src="{truth}.png" style="margin: 1px 0px"></td>')
        highest_prob = l.argmax()
        for j, m in enumerate(l):
            if j == highest_prob:
                if highest_prob == truth:
                    colour = "green"
                else:
                    colour = "red"
                print(f'<td style="color:{colour};">{int(round(m*100))}%</td>')
            else:
                print(f"<td>{int(round(m*100))}%</td>")
        print("</tr>")
    print("""</tbody>
    </table>""")

print_table(table)

<table cellpadding="0" style="border-collapse: collapse; border-style: hidden;">
    <thead>
    <tr>
    <td><b>Image</b></td>
    <td><b>0</b></td>
    <td><b>1</b></td>
    <td><b>2</b></td>
    <td><b>3</b></td>
    <td><b>4</b></td>
    <td><b>5</b></td>
    <td><b>6</b></td>
    <td><b>7</b></td>
    <td><b>8</b></td>
    <td><b>9</b></td>
    </tr>
    </thead>
    <tbody>
<tr>
<td><img src="1.png" style="margin: 1px 0px"></td>
<td>2%</td>
<td style="color:green;">69%</td>
<td>5%</td>
<td>2%</td>
<td>6%</td>
<td>3%</td>
<td>1%</td>
<td>6%</td>
<td>4%</td>
<td>1%</td>
</tr>
<tr>
<td><img src="2.png" style="margin: 1px 0px"></td>
<td>0%</td>
<td>0%</td>
<td style="color:green;">96%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>4%</td>
<td>0%</td>
</tr>
<tr>
<td><img src="3.png" style="margin: 1px 0px"></td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td style="color:green;">99%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>1%</td>
<td>0%</td>
</tr>
<tr>


<table cellpadding="0" style="border-collapse: collapse; border-style: hidden;">
    <thead>
    <tr>
    <td><b>Image</b></td>
    <td><b>0</b></td>
    <td><b>1</b></td>
    <td><b>2</b></td>
    <td><b>3</b></td>
    <td><b>4</b></td>
    <td><b>5</b></td>
    <td><b>6</b></td>
    <td><b>7</b></td>
    <td><b>8</b></td>
    <td><b>9</b></td>
    </tr>
    </thead>
    <tbody>
<tr>
<td><img src="1.png" style="margin: 1px 0px"></td>
<td>2%</td>
<td style="color:green;">69%</td>
<td>5%</td>
<td>2%</td>
<td>6%</td>
<td>3%</td>
<td>1%</td>
<td>6%</td>
<td>4%</td>
<td>1%</td>
</tr>
<tr>
<td><img src="2.png" style="margin: 1px 0px"></td>
<td>0%</td>
<td>0%</td>
<td style="color:green;">96%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>4%</td>
<td>0%</td>
</tr>
<tr>
<td><img src="3.png" style="margin: 1px 0px"></td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td style="color:green;">99%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>1%</td>
<td>0%</td>
</tr>
<tr>
<td><img src="4.png" style="margin: 1px 0px"></td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td style="color:green;">100%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
</tr>
<tr>
<td><img src="5.png" style="margin: 1px 0px"></td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td style="color:green;">98%</td>
<td>0%</td>
<td>0%</td>
<td>2%</td>
<td>0%</td>
</tr>
<tr>
<td><img src="6.png" style="margin: 1px 0px"></td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td style="color:green;">100%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
</tr>
<tr>
<td><img src="7.png" style="margin: 1px 0px"></td>
<td>0%</td>
<td style="color:red;">42%</td>
<td>18%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>38%</td>
<td>1%</td>
<td>0%</td>
</tr>
<tr>
<td><img src="8.png" style="margin: 1px 0px"></td>
<td>0%</td>
<td>0%</td>
<td>1%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td>0%</td>
<td style="color:green;">99%</td>
<td>0%</td>
</tr>
<tr>
<td><img src="9.png" style="margin: 1px 0px"></td>
<td style="color:red;">60%</td>
<td>0%</td>
<td>2%</td>
<td>1%</td>
<td>1%</td>
<td>0%</td>
<td>2%</td>
<td>0%</td>
<td>31%</td>
<td>2%</td>
</tr>
<tr>
<td><img src="dog.png" style="margin: 1px 0px"></td>
<td>3%</td>
<td>4%</td>
<td>13%</td>
<td>5%</td>
<td>10%</td>
<td>6%</td>
<td>9%</td>
<td>15%</td>
<td style="color:red;">34%</td>
<td>1%</td>
</tr>
</tbody>
    </table>


=> we train the data on black-background image (write white).
=> After we using other image (white back-gound ; we get False result)
### Machine Always give answer