In [1]:
import tensorflow as tf

# Workaround for Pylance
keras = tf.keras

In [2]:
model_pathes = [
    "./models/standard_classifier",
    "./models/global_dim_classifier",
    "./models/local_dim_classifier",
    "./models/mixed_dim_classifier",
    "./models/complete_dim_classifier",
]

In [3]:
standard_classifier = keras.models.load_model(model_pathes[0])
global_classifier = keras.models.load_model(model_pathes[1])
local_classifier = keras.models.load_model(model_pathes[2])
mixed_classifier = keras.models.load_model(model_pathes[3])
completed_classifier = keras.models.load_model(model_pathes[4])

In [4]:
# Prepare the training dataset.
batch_size = 64
(_, _), (x_test, y_test) = keras.datasets.mnist.load_data()
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype("float32")

In [5]:
standard_acc_metric = keras.metrics.SparseCategoricalAccuracy()
global_acc_metric = keras.metrics.SparseCategoricalAccuracy()
local_acc_metric = keras.metrics.SparseCategoricalAccuracy()
mixed_acc_metric = keras.metrics.SparseCategoricalAccuracy()
completed_acc_metric = keras.metrics.SparseCategoricalAccuracy()

In [6]:
standard_logits = standard_classifier(x_test, training=False)
standard_acc_metric.update_state(y_test, standard_logits)

global_logits = global_classifier(x_test, training=False)
global_acc_metric.update_state(y_test, global_logits)

local_logits = local_classifier(x_test, training=False)
local_acc_metric.update_state(y_test, local_logits)

mixed_logits = mixed_classifier(x_test, training=False)
mixed_acc_metric.update_state(y_test, mixed_logits)

completed_logits = completed_classifier(x_test, training=False)
completed_acc_metric.update_state(y_test, completed_logits)

<tf.Variable 'UnreadVariable' shape=() dtype=float32, numpy=10000.0>

In [7]:
accuracy = [
    standard_acc_metric.result().numpy(),
    global_acc_metric.result().numpy(),
    local_acc_metric.result().numpy(),
    mixed_acc_metric.result().numpy(),
    completed_acc_metric.result().numpy(),
]

labels = ["standard", "global", "local", "mixed", "complete"]
wheights = [
    "",
    "g=1.0 l=0.0 p=1.0",
    "g=0.0 l=1.0 p=0.1",
    "g=0.6 l=0.4 p=0.0",
    "g=0.6 l=0.4 p=0.2",
]

In [8]:
print(
    "{:<18} | {:<17} | {:<15}".format("Used Classifier", "Encoder Wheights", "Accuracy")
)
print("-" * 19 + "|" + "-" * 19 + "|" + "-" * 15)
for index in range(5):
    print(
        "{:<18} | {:<17} | {:<15}".format(
            labels[index], wheights[index], accuracy[index]
        )
    )

Used Classifier    | Encoder Wheights  | accuracity     
-------------------|-------------------|---------------
standard           |                   | 0.8999000191688538
global             | g=1.0 l=0.0 p=1.0 | 0.9215999841690063
local              | g=0.0 l=1.0 p=0.1 | 0.9631999731063843
mixed              | g=0.6 l=0.4 p=0.0 | 0.9197999835014343
complete           | g=0.6 l=0.4 p=0.2 | 0.9472000002861023


## Analysing the Data
In the results, it can be seen that all the models trained with the specialized encoder have higher accuracy. However, the result also shows that the global task seems to be less good than the local one.

The authors of the original paper assume that the local dim task helps to ignore unimportant features.
Since the images in the mnist dataset consist of a lot of "background", this could be a possible answer as to why this is the case.

It also shows that accuracy decreases as soon as we turn off prior matching, which is supposed to counteract possible mode collapse.

## Conclusion
In their paper, the authors also tried standard data sets to test their approach.
Which mostly resulted in really amazing results in which simple supervised learning was almost always surpassed.

That made me sceptical, but my results tend to back it.

But this was a very simple task.
The data within the paper, also show a slight trend, that this could discrease with more complex tasks/larger images.

A good next step would be to expand the flexibility of my implementation and try more and better measurements with more complex datasets.
In the paper they also introduced much more measurement tools, which would also be interesting to try out, as they would help to understand how well different parameters affects help with maximizing MIin the models.

In [9]:
standard_acc_metric.reset_states()
global_acc_metric.reset_states()
local_acc_metric.reset_states()
mixed_acc_metric.reset_states()
completed_acc_metric.reset_states()