# 2 Assignment: MNIST math
We are revisiting MNIST handwritten digits, but with a twist: We are not classifying handwritten digits, but rather asking our network to do some simple math on them: In subtask (1) we want to know whether a + b ≥ 5, in subtask (2) we want to predict y = a − b, where a and b are respective MNIST digits!
    In the following we ask you to write code (specifically your model and training data preparation) in a parameterized way, which allows you to reuse nearly all of the code for both subtasks, to minimize the effort required for this homework! We will try and give some hints on how to easily achieve this in the following. If there are any open questions, make sure to ask away in the QnA
sessions.

## 2.1 Preparing MNIST math dataset
Remember from last week:he MNIST dataset consists of seventy thousand labelled images, each depicting a single handwritten digit. This may sound like alot of data, but as you’ll see for yourselves in a bit, the images are rather small.
    The MNIST dataset is included in TensorFlow, so you getting access to it is actually pretty easy. You can load it directly into your code like this:

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from keras.layers import Dense
import numpy as np
import matplotlib.pyplot as plt
import datetime
import tqdm

# Magic line only needed in jupyter notebooks!
%load_ext tensorboard

# Load mnist from tensorflow_datasets
train_ds, test_ds = tfds.load('mnist', split=['train', 'test'], as_supervised=True)

## 2.2 Two MNIST math datasets
Remember that for the following you want to create a preprocessing function.
As subtask (1) an substask (2) require very similar inputs, we will create a parameterized preprocessing function - mapping a dataset and the ’subtraction’ or ’larger than five’ condition onto respective datasets. We will generally have three different steps to our preprocssing - general MNIST preprocessing (datatype, normalization, etc.), pairing data tuples onto respective parameterized targets, and finally batching and prefetching.
    As a first step, for general preprocessing steps we follow last weeks ideas closely. You may follow the example code from the lecture (i.e. L03) again, but to summarize the steps:
    The MNIST handwritten digits images come in uint8 datatype. This refers to unsigned 8-bit integers (think numbers 0-255). As the network requires float values (think continuous variables) as input rather than integers (whole numbers), we need to change the datatype: (map in combination with lambda expressions can be really useful here). In your first lambda mapping you want to change the datatype from uint8 to tf.float values. To feed your network the 28x28 images also need to be flattened. Check out the reshape function, and if you want to minimize your work, try and understand how it interacts with size elements set to the value -1 (infering the remainder shape). In order to improve the performance you should also normalize your image values. Generally this means bringing the input close to the standart normal (gaussian) distribution with µ = 0 and σ = 1, however we can make a quick approximation as that: Knowing the inputs are in the 0-255 interval, we can simply divide all numbers by 128 (bringing them between 0-2), and finally subtracting one (bringing them into -1 to 1 range).
    Second Step is figuring out the parameterized targets: Think about the following: What kind of supervised learning task and respective loss-function and target type are required for our two subtasks respectively? It makes sense to work yourself through this backwards: First understand what loss function is required for each. Then check the respective TensorFlow documentation to find out how your targets should look. Once you know what target data-representation is needed for each, we can move to the second step: Each datapoint fed into the network requires two MNIST images as input, with one target being created from their combination. We will discuss many approaches to combining dataset elements - but the dataset zip functionality 5 probably is the most direct solution (as alternatives you might check out the respective window, scan or even
batch functions!). After pairing two dataset elements, you can now make sure each dataset element gets two input mnist images, but you still have to calculate the respective target as the third element.
    Before finalizing the dataset in the third step we recommend checking your dataset at this point - you expect each element to be a triple (tuple with three elements), the first two of which are images, the last one the respective target. Finally we shuffle and (mini-)batch the dataset, and can use the apply method to create datasets for both the a + b ≥ 5 and a − b = y problem!

In [2]:
# Create our dataset
def prepare_math_minst_data(math_mnist, batch_size, subtask):
    
    # math_mnist = math_mnist.map(lambda img_x, target: (tf.cast(img_x, tf.float32) / 128. - 1, target))
    # flatten the images into vector
    math_mnist = math_mnist.map(lambda img_x, target: (tf.reshape(img_x, (-1,)), target))
    # Convert data from uint8 to float32
    math_mnist = math_mnist.map(lambda img_x, target: (tf.cast(img_x, tf.float32), target))
    # Sloppy input normalization, just bringing image values from range [0, 255] to [-1, 1]
    math_mnist = math_mnist.map(lambda img_x, target: ((img_x/128.)-1., target))
    # Shuffle data
    # we want to have two mnist images in each example
    # this leads to a single example being ((img_x1,img_y1),(img_x2,img_y2))
    zipped_dataset = tf.data.Dataset.zip((math_mnist.shuffle(2000),
                                        math_mnist.shuffle(2000)))
    # Subtask 1: a + b ≥ 5
    if subtask == 1:
        # Create targets
        zipped_dataset = zipped_dataset.map(lambda img_x1, img_x2: (img_x1[0], img_x2[0],
                                        tf.cast((img_x1[1] + img_x2[1] >= 5), tf.int32)))
    # Subtask 2: y == a - b 
    elif subtask == 2:
        # Create targets
        zipped_dataset = zipped_dataset.map(lambda img_x1, img_x2: (img_x1[0], img_x2[0],
                                        tf.cast((img_x1[1] - img_x2[1]), tf.int32)))
        zipped_dataset = zipped_dataset.map(lambda img_x1, img_x2, target: (img_x1, img_x2, 
                                        tf.one_hot(target, depth=19)))
    # Cache this progress in memory
    zipped_dataset.cache()
    # Shuffle data
    zipped_dataset = zipped_dataset.shuffle(2000)
    # Batch data
    zipped_dataset = zipped_dataset.batch(batch_size)
    # Prefetch data
    zipped_dataset = zipped_dataset.prefetch(tf.data.AUTOTUNE)
    # Return preprocessed dataset
    return zipped_dataset

# 3 Building shared weight models
In this step we have to build a network, that can solve the math task: This requires a network that takes two inputs (i.e. two input images), and outputs, i.e. predicts the math result. Technically you could just combine (e.g. concatenate) the two input images into one vector, however this would be rather inefficient. Instead we want you to learn about weight-sharing in a hands-on way: One intuitive way to solve this task would be to feed both inputs seperatedly into a layer, then combine (e.g. concatenate) the results from this layer. But both of these input layers would have to solve essentially the same problem for their respective input images: Extract the information, what digit is depicted! As both are basically tasked with the same problem, and have the same input and output shapes, we can improve our network by using the same layer for both inputs.  Finally you have to make sure to parameterize your model to use the correct activation functions for subtasks (1) and subtask (2) respectively.

In [3]:
class MyNNmodel(tf.keras.Model):
    def __init__(self, optimizer, subtask):
        # Inherit functionality from parent class
        super(MyNNmodel, self).__init__()
        # self.flatten = tf.keras.layers.Flatten()
        # Optimizer
        self.optimizer = optimizer
        self.subtask = subtask
        # layers to encode the images (both layers used for both images)
        self.dense1 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        self.dense3 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        
        if self.subtask == 1:
            # loss function
            self.loss_function = tf.keras.losses.BinaryCrossentropy()
            # Metrix
            self.metrics_list = [
                tf.keras.metrics.BinaryAccuracy(name='accuracy'), 
                tf.keras.metrics.Mean(name='loss')
                ]
            self.out_layer = tf.keras.layers.Dense(1,activation=tf.nn.sigmoid)

        elif self.subtask == 2:
            # loss function
            self.loss_function = tf.keras.losses.CategoricalCrossentropy()
            # Metrix
            self.metrics_list = [
                tf.keras.metrics.CategoricalAccuracy(name='accuracy'), 
                tf.keras.metrics.Mean(name='loss')
                ]
            self.out_layer = tf.keras.layers.Dense(19,activation=tf.nn.softmax)
            
    # Call method(forward computation)
    @tf.function
    def __call__(self, images, training=False):
        img1, img2 = images
        
        # img1_x = self.flatten(img1)
        # img2_x = self.flatten(img2)
        
        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)
        
        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)
        
        combined_x = tf.concat([img1_x, img2_x ], axis=1)
        combined_x = self.dense3(combined_x)
        return self.out_layer(combined_x)


    # Metrics property
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model

    # Reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()

    # Train step method
    @tf.function
    def train_step(self, data):
        img1, img2, label = data

        with tf.GradientTape() as tape:
            output = self((img1, img2), training=True)
            loss = self.loss_function(label, output)

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        # Update the state of the metrics according to loss
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)

        # Return a dictionary with metric names as keys and metric results as values
        return {m.name: m.result() for m in self.metrics}

    # Test step method
    @tf.function
    def test_step(self, data):
        img1, img2, label = data

        # The same as train step(without parameter updates)
        output = self((img1, img2), training=True)
        loss = self.loss_function(label, output)

        # Update the state of the metrics according to loss
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)

        # Return a dictionary with metric names as keys and metric results as values
        return {m.name: m.result() for m in self.metrics}

# 4 Training the networks
Create a training function (with an internal training loop function), which is able to run either of the two subtasks. The function takes two inputs: The subtask to solve and the optimizer to use for it, both specified via inputs: Write the
function, such that it creates respective models and datasets based on the parameterized preprocessing function you have written before. Choose the correct
loss-function for the task and run training.

In [4]:
def create_summary_writers(config_name):
    # Define where to save the logs
    # along with this, you may want to save a config file with the same name so you know what the hyperparameters were used
    # Alternatively make a copy of the code that is used for later reference
    current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
    train_log_path = f"logs/{config_name}/{current_time}/train"
    test_log_path = f"logs/{config_name}/{current_time}/train"

    # Log writer for training metrics
    train_summary_writer = tf.summary.create_file_writer(train_log_path)
    # log writer for validation metrics
    test_summary_writer = tf.summary.create_file_writer(test_log_path)
    
    return train_summary_writer, test_summary_writer

In [5]:
def train_loop(model, train_ds, test_ds, start_epoch, 
                epochs, train_summary_writer, 
                test_summary_writer, save_path):

    #1. Iterate over epochs
    for epoch in range(start_epoch, epochs):
        # 2. Train steps on all batches in the training data
        for data in tqdm.tqdm(train_ds, position=0, leave=True):
            metrics = model.train_step(data)
        
        # Print the Epoch number
        print(f'Epoch: {epoch}')
        
        # 3.log and print training metrics 
        with train_summary_writer.as_default():
            # for scalar metrics:
            for metric in model.metrics:
                tf.summary.scalar(f'{metric.name}', metric.result(), step=epoch)
            # Alternatively, log metrics individually (allows for non-scalar metrics such as tf.keras.metrics.MeanTensor)
            # e.g. tf.summary.image(name="mean_activation_layer3", data = metrics["mean_activation_layer3"],step=e)
        
        # Print the metrics
        print([f'{key}: {value.numpy()}' for (key, value) in metrics.items()])

        # 4. Reset metric objects
        model.reset_metrics()

        # 5. Evaluate on validation data
        for data in test_ds:
            metrics = model.test_step(data)

        # 6. Log validationmetric
        with test_summary_writer.as_default():
            # for scalar metrics:
            for metric in model.metrics:
                tf.summary.scalar(f'{metric.name}', metric.result(), step=epoch)
            # Alternatively, log metrics individually (allows for non-scalar metrics such as tf.keras.metrics.MeanTensor)

        print([f'test_{key}: {value.numpy()}' for (key, value) in metrics.items()])
        # 7. Reset metric objects
        model.reset_metrics()

    # 8. Save the model weghts if save_path is given
    if save_path:
        model.save_weights(save_path)

# 5  Experiments
Run training with a classic SGD optimizer (without momentum) and an Adam Optimizer. For an outstanding submission, you are required to also add SGD with Momentum, RMSProp and AdaGrad optimizers and compare the training results by plotting them cleanly side-by-side.

In [6]:
def train(optimizer, save_path, batch_size, subtask):
    
    train_summary_writer, test_summary_writer = create_summary_writers(config_name='RUN1')

    train_dataset = prepare_math_minst_data(train_ds, batch_size, subtask)
    test_dataset = prepare_math_minst_data(test_ds, batch_size, subtask)
    
    # Сheck the contents of the dataset
    for img1, img2, label in train_dataset:
        print(img1.shape, img2.shape, label.shape)
        break

    print('\n\n')

    # Instance model with Adam optimizer
    model = MyNNmodel(optimizer, subtask)

    # Pass arguments to training loop function
    train_loop(model=model,
                train_ds=train_dataset,
                test_ds=test_dataset,
                start_epoch=0,
                epochs=10,
                train_summary_writer=train_summary_writer,
                test_summary_writer=test_summary_writer,
                save_path=save_path)


In [7]:
def main():
    # Instance optimizers
    optimizer_Adam = tf.keras.optimizers.Adam(learning_rate=0.01)
    optimizer_SGD = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)
    optimizer_SGD_no_momentum = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0) 
    optimizer_RMSprop = tf.keras.optimizers.RMSprop(learning_rate=0.01)
    optimizer_AdaGrad = tf.keras.optimizers.Adagrad(learning_rate=0.01)
    # Choose a path to save the weights
    save_path = "../Homework4/"

    subtask_number = np.array([1,2])
    batch_size = 32

    optimizers_list = [
        optimizer_Adam, 
        optimizer_SGD, 
        optimizer_SGD_no_momentum,
        optimizer_RMSprop,
        optimizer_AdaGrad
    ]

    for optimizer in optimizers_list:
        # Train model for the first subtask
        train(optimizer, save_path, batch_size, subtask_number[0])
        # Train model for the second subtask
        train(optimizer, save_path, batch_size, subtask_number[1])

    
if __name__ == '__main__':
    main()    


(32, 784) (32, 784) (32,)





100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 103.75it/s]


Epoch: 0
['accuracy: 0.8417999744415283', 'loss: 2.4104807376861572']
['test_accuracy: 0.8388000130653381', 'test_loss: 2.4588167667388916']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:12<00:00, 153.31it/s]


Epoch: 1
['accuracy: 0.8427666425704956', 'loss: 2.3977105617523193']
['test_accuracy: 0.8385000228881836', 'test_loss: 2.4633843898773193']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:11<00:00, 160.21it/s]


Epoch: 2
['accuracy: 0.8395500183105469', 'loss: 2.4467613697052']
['test_accuracy: 0.8413000106811523', 'test_loss: 2.420754909515381']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 114.29it/s]


Epoch: 3
['accuracy: 0.8413500189781189', 'loss: 2.4193131923675537']
['test_accuracy: 0.8418999910354614', 'test_loss: 2.410097599029541']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 115.48it/s]


Epoch: 4
['accuracy: 0.8421333432197571', 'loss: 2.4073686599731445']
['test_accuracy: 0.8360000252723694', 'test_loss: 2.502969741821289']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 114.99it/s]


Epoch: 5
['accuracy: 0.8430333137512207', 'loss: 2.393643856048584']
['test_accuracy: 0.8424999713897705', 'test_loss: 2.4040074348449707']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 116.23it/s]


Epoch: 6
['accuracy: 0.8425166606903076', 'loss: 2.401521682739258']
['test_accuracy: 0.8381999731063843', 'test_loss: 2.472519636154175']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 116.28it/s]


Epoch: 7
['accuracy: 0.8399333357810974', 'loss: 2.440917730331421']
['test_accuracy: 0.8373000025749207', 'test_loss: 2.4770870208740234']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 111.99it/s]


Epoch: 8
['accuracy: 0.8397499918937683', 'loss: 2.4437127113342285']
['test_accuracy: 0.8360000252723694', 'test_loss: 2.501446485519409']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 111.97it/s]


Epoch: 9
['accuracy: 0.8414499759674072', 'loss: 2.4177873134613037']
['test_accuracy: 0.8379999995231628', 'test_loss: 2.4664289951324463']
(32, 784) (32, 784) (32, 19)





100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:21<00:00, 88.06it/s]


Epoch: 0
['accuracy: 0.5507166385650635', 'loss: 7.235142707824707']
['test_accuracy: 0.5515000224113464', 'test_loss: 7.227071762084961']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 106.35it/s]


Epoch: 1
['accuracy: 0.5504500269889832', 'loss: 7.245877265930176']
['test_accuracy: 0.5530999898910522', 'test_loss: 7.202935695648193']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 109.33it/s]


Epoch: 2
['accuracy: 0.5503000020980835', 'loss: 7.248290061950684']
['test_accuracy: 0.5478000044822693', 'test_loss: 7.288224697113037']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 105.58it/s]


Epoch: 3
['accuracy: 0.553600013256073', 'loss: 7.19509744644165']
['test_accuracy: 0.5493000149726868', 'test_loss: 7.2640862464904785']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 105.57it/s]


Epoch: 4
['accuracy: 0.5512499809265137', 'loss: 7.232974529266357']
['test_accuracy: 0.553600013256073', 'test_loss: 7.194888114929199']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 107.92it/s]


Epoch: 5
['accuracy: 0.5494833588600159', 'loss: 7.261453628540039']
['test_accuracy: 0.5501000285148621', 'test_loss: 7.252821445465088']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 100.39it/s]


Epoch: 6
['accuracy: 0.5512166619300842', 'loss: 7.233517646789551']
['test_accuracy: 0.5461999773979187', 'test_loss: 7.310751914978027']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 103.82it/s]


Epoch: 7
['accuracy: 0.5490999817848206', 'loss: 7.267635345458984']
['test_accuracy: 0.5439000129699707', 'test_loss: 7.354203701019287']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 102.60it/s]


Epoch: 8
['accuracy: 0.5511833429336548', 'loss: 7.234059810638428']
['test_accuracy: 0.5476999878883362', 'test_loss: 7.296270370483398']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 105.21it/s]


Epoch: 9
['accuracy: 0.5502166748046875', 'loss: 7.24963903427124']
['test_accuracy: 0.5497000217437744', 'test_loss: 7.256041049957275']
(32, 784) (32, 784) (32,)





100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 113.83it/s]


Epoch: 0
['accuracy: 0.9330833554267883', 'loss: 0.1711537390947342']
['test_accuracy: 0.9621000289916992', 'test_loss: 0.10583990812301636']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 115.36it/s]


Epoch: 1
['accuracy: 0.9689666628837585', 'loss: 0.08643074333667755']
['test_accuracy: 0.9729999899864197', 'test_loss: 0.07259371131658554']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 107.78it/s]


Epoch: 2
['accuracy: 0.9773833155632019', 'loss: 0.06312954425811768']
['test_accuracy: 0.9736999869346619', 'test_loss: 0.07058718800544739']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 113.97it/s]


Epoch: 3
['accuracy: 0.9804333448410034', 'loss: 0.05541922152042389']
['test_accuracy: 0.9825000166893005', 'test_loss: 0.05327495187520981']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 112.12it/s]


Epoch: 4
['accuracy: 0.9837833046913147', 'loss: 0.04497397318482399']
['test_accuracy: 0.9836999773979187', 'test_loss: 0.05066705867648125']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 97.80it/s]


Epoch: 5
['accuracy: 0.9851333498954773', 'loss: 0.04359132796525955']
['test_accuracy: 0.9842000007629395', 'test_loss: 0.048253558576107025']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 101.17it/s]


Epoch: 6
['accuracy: 0.9861999750137329', 'loss: 0.04099387302994728']
['test_accuracy: 0.9836999773979187', 'test_loss: 0.048742517828941345']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 113.46it/s]


Epoch: 7
['accuracy: 0.9866999983787537', 'loss: 0.03751760348677635']
['test_accuracy: 0.9861999750137329', 'test_loss: 0.04232737421989441']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 112.53it/s]


Epoch: 8
['accuracy: 0.9879166483879089', 'loss: 0.03367907181382179']
['test_accuracy: 0.9854999780654907', 'test_loss: 0.04603441059589386']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 106.15it/s]


Epoch: 9
['accuracy: 0.9891166687011719', 'loss: 0.0313931368291378']
['test_accuracy: 0.987500011920929', 'test_loss: 0.039195358753204346']
(32, 784) (32, 784) (32, 19)





100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 99.53it/s]


Epoch: 0
['accuracy: 0.5501000285148621', 'loss: 0.5825722813606262']
['test_accuracy: 0.5989000201225281', 'test_loss: 0.24784448742866516']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 105.78it/s]


Epoch: 1
['accuracy: 0.673383355140686', 'loss: 0.2108975201845169']
['test_accuracy: 0.6363000273704529', 'test_loss: 0.17516496777534485']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:20<00:00, 92.14it/s]


Epoch: 2
['accuracy: 0.683733344078064', 'loss: 0.15745672583580017']
['test_accuracy: 0.6977999806404114', 'test_loss: 0.1371011734008789']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 103.30it/s]


Epoch: 3
['accuracy: 0.6783166527748108', 'loss: 0.13102123141288757']
['test_accuracy: 0.6406000256538391', 'test_loss: 0.15283028781414032']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:21<00:00, 88.70it/s]


Epoch: 4
['accuracy: 0.6814833283424377', 'loss: 0.11169968545436859']
['test_accuracy: 0.6466000080108643', 'test_loss: 0.10558338463306427']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 98.68it/s]


Epoch: 5
['accuracy: 0.6875', 'loss: 0.09817928820848465']
['test_accuracy: 0.6708999872207642', 'test_loss: 0.14764678478240967']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 99.41it/s]


Epoch: 6
['accuracy: 0.6855000257492065', 'loss: 0.09073080867528915']
['test_accuracy: 0.6841999888420105', 'test_loss: 0.09866274148225784']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 94.57it/s]


Epoch: 7
['accuracy: 0.6819999814033508', 'loss: 0.08466527611017227']
['test_accuracy: 0.6657000184059143', 'test_loss: 0.09648886322975159']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 101.51it/s]


Epoch: 8
['accuracy: 0.6994166374206543', 'loss: 0.0754498690366745']
['test_accuracy: 0.6973000168800354', 'test_loss: 0.09504928439855576']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 97.69it/s]


Epoch: 9
['accuracy: 0.6896499991416931', 'loss: 0.06924464553594589']
['test_accuracy: 0.6920999884605408', 'test_loss: 0.0870642140507698']
(32, 784) (32, 784) (32,)





100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 102.20it/s]


Epoch: 0
['accuracy: 0.9004499912261963', 'loss: 0.24527877569198608']
['test_accuracy: 0.928600013256073', 'test_loss: 0.1796085089445114']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 106.12it/s]


Epoch: 1
['accuracy: 0.9330833554267883', 'loss: 0.1685965657234192']
['test_accuracy: 0.946399986743927', 'test_loss: 0.14036235213279724']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 106.00it/s]


Epoch: 2
['accuracy: 0.9495499730110168', 'loss: 0.13129392266273499']
['test_accuracy: 0.9598000049591064', 'test_loss: 0.10591281205415726']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 96.24it/s]


Epoch: 3
['accuracy: 0.9612500071525574', 'loss: 0.10627786815166473']
['test_accuracy: 0.9563000202178955', 'test_loss: 0.12077032774686813']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 105.59it/s]


Epoch: 4
['accuracy: 0.9674999713897705', 'loss: 0.09156633168458939']
['test_accuracy: 0.9717000126838684', 'test_loss: 0.08356288075447083']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 114.66it/s]


Epoch: 5
['accuracy: 0.9713666439056396', 'loss: 0.08112399280071259']
['test_accuracy: 0.9731000065803528', 'test_loss: 0.07524095475673676']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 100.90it/s]


Epoch: 6
['accuracy: 0.9744333624839783', 'loss: 0.07334161549806595']
['test_accuracy: 0.9742000102996826', 'test_loss: 0.07397184520959854']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 106.31it/s]


Epoch: 7
['accuracy: 0.9765666723251343', 'loss: 0.06729013472795486']
['test_accuracy: 0.9746000170707703', 'test_loss: 0.06940219551324844']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 111.69it/s]


Epoch: 8
['accuracy: 0.9776999950408936', 'loss: 0.06453868001699448']
['test_accuracy: 0.9781000018119812', 'test_loss: 0.06556100398302078']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 109.15it/s]


Epoch: 9
['accuracy: 0.9805166721343994', 'loss: 0.05725764483213425']
['test_accuracy: 0.973800003528595', 'test_loss: 0.0745934545993805']
(32, 784) (32, 784) (32, 19)





100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:22<00:00, 81.71it/s]


Epoch: 0
['accuracy: 0.4363666772842407', 'loss: 1.0760776996612549']
['test_accuracy: 0.4731999933719635', 'test_loss: 0.9283393621444702']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:23<00:00, 79.39it/s]


Epoch: 1
['accuracy: 0.5156999826431274', 'loss: 0.7377575635910034']
['test_accuracy: 0.5857999920845032', 'test_loss: 0.5397356152534485']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:22<00:00, 83.50it/s]


Epoch: 2
['accuracy: 0.5755666494369507', 'loss: 0.4423416256904602']
['test_accuracy: 0.5895000100135803', 'test_loss: 0.3638933598995209']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:21<00:00, 86.62it/s]


Epoch: 3
['accuracy: 0.6159499883651733', 'loss: 0.33663642406463623']
['test_accuracy: 0.6380000114440918', 'test_loss: 0.3017280399799347']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:20<00:00, 90.51it/s]


Epoch: 4
['accuracy: 0.6360999941825867', 'loss: 0.2795460522174835']
['test_accuracy: 0.656000018119812', 'test_loss: 0.25545284152030945']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 95.33it/s]


Epoch: 5
['accuracy: 0.652733325958252', 'loss: 0.23712243139743805']
['test_accuracy: 0.664900004863739', 'test_loss: 0.2237178087234497']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:23<00:00, 78.37it/s]


Epoch: 6
['accuracy: 0.6614000201225281', 'loss: 0.2150939255952835']
['test_accuracy: 0.6304000020027161', 'test_loss: 0.20436984300613403']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:23<00:00, 80.50it/s]


Epoch: 7
['accuracy: 0.6689333319664001', 'loss: 0.1929323822259903']
['test_accuracy: 0.6872000098228455', 'test_loss: 0.17247003316879272']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:21<00:00, 87.39it/s]


Epoch: 8
['accuracy: 0.6758833527565002', 'loss: 0.17131829261779785']
['test_accuracy: 0.6593000292778015', 'test_loss: 0.16055749356746674']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 94.15it/s]


Epoch: 9
['accuracy: 0.6762166619300842', 'loss: 0.15728099644184113']
['test_accuracy: 0.6807000041007996', 'test_loss: 0.15656085312366486']
(32, 784) (32, 784) (32,)





100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:27<00:00, 68.70it/s]


Epoch: 0
['accuracy: 0.8413333296775818', 'loss: 2.4146828651428223']
['test_accuracy: 0.8416000008583069', 'test_loss: 2.4146647453308105']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 94.55it/s]


Epoch: 1
['accuracy: 0.8427666425704956', 'loss: 2.397709369659424']
['test_accuracy: 0.8367000222206116', 'test_loss: 2.487743377685547']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:20<00:00, 92.39it/s]


Epoch: 2
['accuracy: 0.8404333591461182', 'loss: 2.433290719985962']
['test_accuracy: 0.8421000242233276', 'test_loss: 2.4085750579833984']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:21<00:00, 86.89it/s]


Epoch: 3
['accuracy: 0.8394333124160767', 'loss: 2.4485397338867188']
['test_accuracy: 0.8364999890327454', 'test_loss: 2.4923112392425537']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 94.23it/s]


Epoch: 4
['accuracy: 0.840399980545044', 'loss: 2.433800220489502']
['test_accuracy: 0.8416000008583069', 'test_loss: 2.41923189163208']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:20<00:00, 92.89it/s]


Epoch: 5
['accuracy: 0.8430166840553284', 'loss: 2.393895149230957']
['test_accuracy: 0.8382999897003174', 'test_loss: 2.4618616104125977']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:21<00:00, 87.68it/s]


Epoch: 6
['accuracy: 0.8392333388328552', 'loss: 2.4515907764434814']
['test_accuracy: 0.8324999809265137', 'test_loss: 2.5532116889953613']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:21<00:00, 87.26it/s]


Epoch: 7
['accuracy: 0.8430166840553284', 'loss: 2.3938965797424316']
['test_accuracy: 0.8416000008583069', 'test_loss: 2.417710542678833']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:20<00:00, 91.64it/s]


Epoch: 8
['accuracy: 0.8397166728973389', 'loss: 2.4442203044891357']
['test_accuracy: 0.8389999866485596', 'test_loss: 2.454249382019043']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:20<00:00, 91.06it/s]


Epoch: 9
['accuracy: 0.8421000242233276', 'loss: 2.4078762531280518']
['test_accuracy: 0.8407999873161316', 'test_loss: 2.4237985610961914']
(32, 784) (32, 784) (32, 19)





100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:25<00:00, 74.06it/s]


Epoch: 0
['accuracy: 0.0687333345413208', 'loss: 7.745543956756592']
['test_accuracy: 0.06520000100135803', 'test_loss: 7.830536842346191']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:22<00:00, 85.05it/s]


Epoch: 1
['accuracy: 0.06963333487510681', 'loss: 7.742843151092529']
['test_accuracy: 0.07000000029802322', 'test_loss: 7.6953606605529785']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:21<00:00, 86.95it/s]


Epoch: 2
['accuracy: 0.07053333520889282', 'loss: 7.783666133880615']
['test_accuracy: 0.0674000009894371', 'test_loss: 7.7854766845703125']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:21<00:00, 87.49it/s]


Epoch: 3
['accuracy: 0.06975000351667404', 'loss: 7.754667282104492']
['test_accuracy: 0.0681999996304512', 'test_loss: 7.880422592163086']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:21<00:00, 86.88it/s]


Epoch: 4
['accuracy: 0.0681999996304512', 'loss: 7.779109477996826']
['test_accuracy: 0.06989999860525131', 'test_loss: 7.8096184730529785']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:22<00:00, 83.59it/s]


Epoch: 5
['accuracy: 0.06724999845027924', 'loss: 7.7785725593566895']
['test_accuracy: 0.06629999727010727', 'test_loss: 7.803175926208496']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:22<00:00, 85.20it/s]


Epoch: 6
['accuracy: 0.06814999878406525', 'loss: 7.767555236816406']
['test_accuracy: 0.06599999964237213', 'test_loss: 7.812832832336426']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:23<00:00, 78.21it/s]


Epoch: 7
['accuracy: 0.06726666539907455', 'loss: 7.77829122543335']
['test_accuracy: 0.06729999929666519', 'test_loss: 7.746856689453125']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:23<00:00, 79.54it/s]


Epoch: 8
['accuracy: 0.06896666437387466', 'loss: 7.762460231781006']
['test_accuracy: 0.06499999761581421', 'test_loss: 7.796741485595703']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:22<00:00, 85.22it/s]


Epoch: 9
['accuracy: 0.06893333047628403', 'loss: 7.736397743225098']
['test_accuracy: 0.0689999982714653', 'test_loss: 7.742030620574951']
(32, 784) (32, 784) (32,)





100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 103.20it/s]


Epoch: 0
['accuracy: 0.9174333214759827', 'loss: 0.20474468171596527']
['test_accuracy: 0.9171000123023987', 'test_loss: 0.19058795273303986']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 107.02it/s]


Epoch: 1
['accuracy: 0.9542499780654907', 'loss: 0.12069132179021835']
['test_accuracy: 0.9648000001907349', 'test_loss: 0.09921427816152573']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 109.67it/s]


Epoch: 2
['accuracy: 0.9676833152770996', 'loss: 0.09203530848026276']
['test_accuracy: 0.9715999960899353', 'test_loss: 0.08163541555404663']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 114.25it/s]


Epoch: 3
['accuracy: 0.9710500240325928', 'loss: 0.07980525493621826']
['test_accuracy: 0.9736999869346619', 'test_loss: 0.07457786053419113']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 110.15it/s]


Epoch: 4
['accuracy: 0.9765499830245972', 'loss: 0.06827632337808609']
['test_accuracy: 0.9728000164031982', 'test_loss: 0.07394083589315414']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 104.49it/s]


Epoch: 5
['accuracy: 0.9789000153541565', 'loss: 0.06071745976805687']
['test_accuracy: 0.9793000221252441', 'test_loss: 0.06427133083343506']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 105.62it/s]


Epoch: 6
['accuracy: 0.98089998960495', 'loss: 0.05591895803809166']
['test_accuracy: 0.9812999963760376', 'test_loss: 0.05390532687306404']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:17<00:00, 104.50it/s]


Epoch: 7
['accuracy: 0.9809499979019165', 'loss: 0.055124107748270035']
['test_accuracy: 0.9807999730110168', 'test_loss: 0.058696068823337555']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 103.33it/s]


Epoch: 8
['accuracy: 0.9832333326339722', 'loss: 0.05065905302762985']
['test_accuracy: 0.983299970626831', 'test_loss: 0.050210993736982346']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 98.60it/s]


Epoch: 9
['accuracy: 0.9839833378791809', 'loss: 0.04591310769319534']
['test_accuracy: 0.9815000295639038', 'test_loss: 0.054033003747463226']
(32, 784) (32, 784) (32, 19)





100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 102.49it/s]


Epoch: 0
['accuracy: 0.49281665682792664', 'loss: 0.8581560254096985']
['test_accuracy: 0.574999988079071', 'test_loss: 0.4854243993759155']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 100.70it/s]


Epoch: 1
['accuracy: 0.6211666464805603', 'loss: 0.3758014440536499']
['test_accuracy: 0.6284000277519226', 'test_loss: 0.2930162847042084']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:22<00:00, 84.32it/s]


Epoch: 2
['accuracy: 0.6541666388511658', 'loss: 0.26223254203796387']
['test_accuracy: 0.6420000195503235', 'test_loss: 0.23537114262580872']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 98.15it/s]


Epoch: 3
['accuracy: 0.6666666865348816', 'loss: 0.2154625803232193']
['test_accuracy: 0.6632999777793884', 'test_loss: 0.19142764806747437']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:22<00:00, 85.10it/s]


Epoch: 4
['accuracy: 0.6678666472434998', 'loss: 0.18663738667964935']
['test_accuracy: 0.6879000067710876', 'test_loss: 0.16768456995487213']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:41<00:00, 45.23it/s]


Epoch: 5
['accuracy: 0.6802833080291748', 'loss: 0.1593310385942459']
['test_accuracy: 0.6873000264167786', 'test_loss: 0.15134327113628387']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:23<00:00, 78.53it/s]


Epoch: 6
['accuracy: 0.6857166886329651', 'loss: 0.1458113044500351']
['test_accuracy: 0.6761999726295471', 'test_loss: 0.14838922023773193']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:24<00:00, 77.04it/s]


Epoch: 7
['accuracy: 0.6888333559036255', 'loss: 0.13331729173660278']
['test_accuracy: 0.6859999895095825', 'test_loss: 0.13199804723262787']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:24<00:00, 76.16it/s]


Epoch: 8
['accuracy: 0.6964499950408936', 'loss: 0.1201329454779625']
['test_accuracy: 0.6823999881744385', 'test_loss: 0.12201423943042755']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:20<00:00, 89.98it/s]


Epoch: 9
['accuracy: 0.6996666789054871', 'loss: 0.11450948566198349']
['test_accuracy: 0.6894999742507935', 'test_loss: 0.11077160388231277']
