## What this is about
The basic idea of Deep Infomax (DIM) is to have a setup that allows you to combine multiple training goals.
This is combined with the idea of maximizing MI in order to do unsupervised learning as effectively as possible.

I have tried to recreate this approach in **Tensorflow**.

In doing so, I ran into some issues that made it very difficult to show my results within an ipynb file.
Therefore I have 4 different files.

### How the Files are structured
File 0_ deals with training an encoder using the approach of DIM.

To evaluate if the results can be used, I built several classifiers that try to match the images of the mnist dataset to the correct numbers based on the features learned from the encoder.
In file 1_ I build a simple classifier without my encoder.
I got the code for it from: [Tensorflow: writing training loop tutorial](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch)
In file 2_ I have adapted the code to run the images through the encoder before passing them through the same network.

In file 3_ I present then briefly the data and refer to it.

### To run this on your own
I coded this on a Windows 10 machine with python **3.10.1**
You can install the requirements by executing `pip install -r ./requirements.txt`
To build multiple encoders, you need to re-run every cell, as I'm using `@tf.function` and was not able to find out how to reuse the function.

### About Sources:
I needed to learn tensorflow first, I used these tutorial very heavily in the beginning, and started from scratch later.
- [Install Tensorflow](https://www.tensorflow.org/install/pip)
- [Tensorflow: writing training loop tutorial](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch)
- [Tensorflow: dcgan tutorial](https://www.tensorflow.org/tutorials/generative/dcgan)
- [Tensorflow: functional api tutorial](https://www.tensorflow.org/guide/keras/functional)
- [Tensorflow: save and load models](https://www.tensorflow.org/guide/keras/save_and_serialize)
- [Tensorflow docs: KLDivergence](https://www.tensorflow.org/api_docs/python/tf/keras/losses/KLDivergence)
- [Tensorflow docs: Math](https://www.tensorflow.org/api_docs/python/tf/math)

Also I used the Paper as a Resource and some follow up works
- [Learning deep representations by mutual information estimation and maximization](https://arxiv.org/abs/1808.06670)
- [Deep InfoMax: Learning good representations through mutual information maximization](https://www.microsoft.com/en-us/research/blog/deep-infomax-learning-good-representations-through-mutual-information-maximization/)
- [Jehill Parikh: Deep InfoMax Tensorflow-Keras Implementation](https://jehillparikh.medium.com/deep-info-max-tensorflow-keras-implementation-b1faeffb0260)

I marked every cell where copied code over

In [1]:
import tensorflow as tf
import time
from IPython.display import clear_output

# Workaround for Pylance
keras = tf.keras
from keras import layers, models, losses

kl = tf.keras.losses.KLDivergence()

## Loading training data
I used the mnist dataset. Like mentioned in [Tensorflow: writing training loop tutorial](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch) I reserved some samples for validation. Which I'm doing in 1_ and 2_

In [2]:
(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype("float32")

train_images = train_images[:-10000]

DIM seperates it's training into 3 Tasks:
- Global Dim task
- Local Dim task
- Prior-Matching

For each task it needs it's own discriminator.

So I need 4 Models:
1. Encoder I want to train
2. Global-Discriminator
3. Local-Discriminator
4. Prior-Discriminator

The Encoder Needs to extract the local features of different Kernals into a MxM Feature Map
In the examples provided within paper they use 4 convolutional layers, the last one being the final local feature map.

In [3]:
"""
Factory Function returns the Encoder.

Model Output:
conv4: The local MxM feature map (M=5 in this example)
fc: Global Feature Vector
"""


def make_encoder_model():
    encoder_input = layers.Input(shape=(28, 28, 1), name="image")

    conv_1 = layers.Conv2D(filters=32, kernel_size=6, strides=5, activation="relu")(
        encoder_input
    )
    conv_2 = layers.Conv2D(filters=64, kernel_size=1, strides=1, activation="relu")(
        conv_1
    )
    conv_3 = layers.Conv2D(filters=128, kernel_size=1, strides=1, activation="relu")(
        conv_2
    )
    conv_4 = layers.Conv2D(filters=256, kernel_size=1, strides=1, activation="relu")(
        conv_3
    )

    fc = layers.Flatten()(conv_4)
    fc = layers.Dense(256, activation="relu")(fc)
    fc = layers.Dense(32)(fc)

    return models.Model(inputs=encoder_input, outputs=[conv_4, fc])

In [4]:
"""
Factory Function returns a discriminator for the global dim task.

Task of this discriminator is to give a high score if the inputs belong together

Does that by bringing both together in one 1D-Layer and densing it down to a single float

Model Output:
Score (float)
"""


def make_global_discriminator_model():
    conv_4 = layers.Input(shape=(5, 5, 256), name="local_feature_map")
    fc = layers.Input(shape=(32), name="global_feature_vector")

    flattend_map = layers.Flatten()(conv_4)
    combined = layers.Concatenate()([flattend_map, fc])

    densed = layers.Dense(128, activation="relu")(combined)
    densed = layers.Dense(16, activation="relu")(densed)

    score = layers.Dense(1)(densed)
    return models.Model(inputs=[conv_4, fc], outputs=score)

In [5]:
"""
Factory Function returns a discriminator for the local dim task.

Task of this discriminator is to give a high score for each local feature vector
if the global vector belongs to the same source

This is done by building a feature map of same dimensions like the conv_4 of the encoder.
After that both maps can be concatinated.

Also I renamed the input of the conv_4 here, as I'm using conv and wanted to avoid confusion here

Appends the global vector to each of the MxM local vectors

Model Output:
Map out of MxM scores (floats)
"""


def make_local_discriminator_model():
    f_map = layers.Input(shape=(5, 5, 256), name="local_feature_map")
    fc = layers.Input(shape=(32), name="global_feature_vector")

    # tested if it works in that way
    # it does!
    score = layers.RepeatVector(25)(fc)
    resh = layers.Reshape((5, 5, 32), input_shape=(25, 32))(score)

    combined_map = layers.Concatenate(axis=3)([f_map, resh])

    # again 4 convolutional networks to boil it down to a usable size
    conv_1 = layers.Conv2D(filters=128, kernel_size=1, strides=1, activation="relu")(
        combined_map
    )
    conv_2 = layers.Conv2D(filters=64, kernel_size=1, strides=1, activation="relu")(
        conv_1
    )
    conv_3 = layers.Conv2D(filters=32, kernel_size=1, strides=1, activation="relu")(
        conv_2
    )
    conv_4 = layers.Conv2D(filters=16, kernel_size=1, strides=1, activation="relu")(
        conv_3
    )

    score_map = layers.Conv2D(filters=1, kernel_size=1, strides=1)(conv_4)
    return models.Model(inputs=[f_map, fc], outputs=score_map)

For the prior matching I used the implementation of [Jehill Parikh](https://jehillparikh.medium.com/deep-info-max-tensorflow-keras-implementation-b1faeffb0260).

In [6]:
"""
Factory Function returns a discriminator for the prior matching part.
Task of this is to predict if the global feature input belongs to an encoded prior
High-Score = encoded prior

Model Output:
Score (float)
"""


def make_prior_matching_discriminator_model():
    fc = layers.Input(shape=(32), name="global_feature_vector")
    p1 = layers.Dense(32, use_bias=False)(fc)
    p1 = layers.BatchNormalization()(p1)
    p1 = layers.Activation("relu")(p1)
    p1 = layers.Dense(200, use_bias=False)(p1)
    p1 = layers.BatchNormalization()(p1)
    p1 = layers.Activation("relu")(p1)
    p1 = layers.Dense(1)(p1)
    return models.Model(inputs=fc, outputs=p1)

In the next part I defined the loss functions. The one for the encoder is especially important for DIM, as it should be easy to swap MI estimators to provide flexibility.
For the MI estimator, it is important that a high score defines a high MI for the encoder.

While the discriminators need only be simple discriminators.
The authors of the paper describe that they used a dcgan implementation for their implementation.
So I also just used a sample implementation of [Tensorflow: dcgan tutorial](https://www.tensorflow.org/tutorials/generative/dcgan)

In [7]:
cross_entropy = losses.BinaryCrossentropy(from_logits=True)


# from a tensorflow dcgan tutorial
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

I used the JSD based MI estimator mentioned by the authors. As it seems to provide stable results with few epochs.

In [8]:
"""
can never get positive, higher value means estimated MI is higher
If goal is to maximize MI, this needs to be negated for loss functions
"""


def tensor_jsd(joint_scored, marginals_scored):
    neg_joint_softplus = tf.math.negative(
        tf.math.softplus(tf.math.negative(joint_scored))
    )
    marginals_softplus = tf.math.softplus(marginals_scored)
    return tf.subtract(neg_joint_softplus, marginals_softplus)

In [9]:
"""
The encoder has different goals for each DIM task
It tries to fool the global and local discriminators, by maximizing the MI
So it uses the scored from the descriminators and uses a MI estimator.
The authors show in the paper that it's possible to use different Estimators.
Only thing important is, that a higher score means higher MI

We want to be able give it some weights, so that we prioritize one task over another

While we want to maximize the global and local score -> meaning minimizing the loss
As they become minimal for our wanted behaviour

For the prior_matching on the other hand, we just need to calculate the deviation between the prior and the real results
I use the KL divergence for that.
Anyhow, I need to do some calculations, to make sure the tensorflow implementation provides stable results.
If the input would be negative, the divergence would not work, so watch out.
Also, I'm only interested in the absolute diviation.
And a higher diviation means the model is less accurate, meaning it indicates a higher loss

Assuming that all scores are always negative (see JSD implementation)
global and local higher MI -> less loss
higher KL divergence with both p scores -> more loss
Wanna sum and normalise l_score to same dimension like g_score
"""


def get_encoder_loss(
    g_weight,
    g_pos_scored,
    g_neg_scored,
    l_weight,
    l_pos_scored,
    l_neg_scored,
    p_weight,
    p_pos_scored,
    p_neg_scored,
):
    global_loss = tensor_jsd(g_pos_scored, g_neg_scored) * g_weight

    local_reduced = tf.math.reduce_sum(tensor_jsd(l_pos_scored, l_neg_scored), axis=1)
    local_reduced = tf.math.reduce_sum(local_reduced, axis=1)
    local_devided = tf.math.divide(local_reduced, tf.constant([25.0]))
    local_loss = local_devided * l_weight

    prior_shape_orientation = local_loss.shape

    prior_pos_match_dimension = tf.repeat(
        p_pos_scored, prior_shape_orientation[0], axis=0
    )
    prior_pos_match_dimension = tf.reshape(
        prior_pos_match_dimension, prior_shape_orientation
    )
    prior_pos_abs = tf.math.abs(prior_pos_match_dimension)
    prior_neg_match_dimension = tf.reshape(p_neg_scored, prior_shape_orientation)
    prior_neg_abs = tf.math.abs(prior_neg_match_dimension)

    divergence = tf.math.abs(kl(prior_pos_abs, prior_neg_abs))
    prior_loss = tf.ones_like(local_loss) * divergence * p_weight
    return tf.negative(global_loss + local_loss) + prior_loss

This helper function is to get negative samples.
A sample is a combination of a local feature map and a global feature vector
The outputs of the encoder are the positive samples.
To get a negative sample, I need to bring together a global feature vector and a feature map from different sources

In [10]:
def offset_tensor_by_one(tensor):
    """Does not change the input"""
    slice = tensor[:1]
    slice2 = tensor[1:]
    concat = tf.concat([slice2, slice], axis=0)
    return concat


def create_negative_samples(real_features):
    """
    Return a pair where the global feature vector
    is not the one that belongs to the feature map
    """
    [convs, fcs] = real_features
    fake_convs = offset_tensor_by_one(convs)
    return [fake_convs, fcs]

### The training
Each task has its own goal - but each decoder receives the same encoded sources.
Prioritization of tasks is done by specifying weights.
These are used by the encoder_loss function.

It would be easily possible to adapt each step here, and to add goals or even altering how them.
But for that also the encoder_loss function must be adapted.

Here a scetch of the train step:

![DIM Train Step Sketch](./DIM_task_v3.png)

In [11]:
# I wanna be able to train multiple models with different setting, that's why I'm defining all of these here
encoder = make_encoder_model()
encoder_optimizer = tf.keras.optimizers.Adam(1e-4)
global_discriminator = make_global_discriminator_model()
global_discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
local_discriminator = make_local_discriminator_model()
local_discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
prior_discriminator = make_prior_matching_discriminator_model()
prior_discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)


@tf.function
def train_step(images, prior_image, global_weight, local_weight, prior_weight):
    # Doing training and calculating loss
    with tf.GradientTape() as enc_tape:
        positive_samples = encoder(images, training=True)
        negative_samples = create_negative_samples(positive_samples)
        prior_sampled = encoder(prior_image, training=True)

        # this is the global DIM task
        with tf.GradientTape() as global_discriminator_tape:
            global_real_output = global_discriminator(positive_samples, training=True)
            global_fake_output = global_discriminator(negative_samples, training=True)
            global_loss = discriminator_loss(global_real_output, global_fake_output)

        # this is the local DIM task
        with tf.GradientTape() as local_discriminator_tape:
            local_real_output = local_discriminator(positive_samples, training=True)
            local_fake_output = local_discriminator(negative_samples, training=True)
            local_loss = discriminator_loss(local_real_output, local_fake_output)

        # this is the prior matching task
        with tf.GradientTape() as prior_discriminator_tape:
            prior_real_output = prior_discriminator(prior_sampled[1], training=True)
            prior_fake_output = prior_discriminator(positive_samples[1], training=True)
            prior_loss = discriminator_loss(prior_real_output, prior_fake_output)

        enc_loss = get_encoder_loss(
            global_weight,
            global_real_output,
            global_fake_output,
            local_weight,
            local_real_output,
            local_fake_output,
            prior_weight,
            prior_real_output,
            prior_fake_output,
        )

    # calculating and applying gradients
    gradients_of_encoder = enc_tape.gradient(enc_loss, encoder.trainable_variables)
    gradients_of_global = global_discriminator_tape.gradient(
        global_loss, global_discriminator.trainable_variables
    )
    gradients_of_local = local_discriminator_tape.gradient(
        local_loss, local_discriminator.trainable_variables
    )
    gradients_of_prior = prior_discriminator_tape.gradient(
        prior_loss, prior_discriminator.trainable_variables
    )

    encoder_optimizer.apply_gradients(
        zip(gradients_of_encoder, encoder.trainable_variables)
    )
    global_discriminator_optimizer.apply_gradients(
        zip(gradients_of_global, global_discriminator.trainable_variables)
    )
    local_discriminator_optimizer.apply_gradients(
        zip(gradients_of_local, local_discriminator.trainable_variables)
    )
    prior_discriminator_optimizer.apply_gradients(
        zip(gradients_of_prior, prior_discriminator.trainable_variables)
    )
    return tf.math.reduce_mean(enc_loss)

### Creating multiple encoders

The training step is held flexible, with this I was able to create multiple models.
they are saved with their wheights in name, so it's easier to differentiate them.
Sadly it's not easily possible to run the training with different encoders within one run because I used `@tf.function`

So the approach to run it multiple times is to wait until the first is finished, then changing the wheights, and running all cells again.

In [12]:
def train(
    global_weight,
    local_weight,
    prior_weight,
    prior_image,
    train_images,
    batch_size,
    epochs,
):
    train_dataset = tf.data.Dataset.from_tensor_slices(train_images)
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

    for epoch in range(epochs):
        print("\nStart of epoch %d" % (epoch,))
        start_time = time.time()

        # Iterate over the batches of the dataset.
        for step, (image_batch) in enumerate(train_dataset):
            enc_loss_average = train_step(
                global_weight=global_weight,
                local_weight=local_weight,
                prior_weight=prior_weight,
                images=image_batch,
                prior_image=prior_image,
            )

            # Log every 200 batches.
            if step % 200 == 0:
                print(
                    "Training loss (for one batch) at step {}: {}".format(
                        step, float(enc_loss_average)
                    )
                )

        clear_output(wait=True)
        print("------------------------------------")
        print("End of epoch:{}".format(epoch))
        print("Time taken: {}s".format(time.time() - start_time))
        print("------------------------------------")

In [13]:
path_prefix = "./models/dim_encoder-"
saved_models_pathes = []


def save_model(g_weight, l_weight, p_weight):
    model_name = "g{}-l{}-p{}".format(g_weight, l_weight, p_weight).replace(".", "_")
    rel_path = path_prefix + model_name
    encoder.compile()
    encoder.save(rel_path)
    print("done saved as {}".format(rel_path))

In [14]:
prior_image = tf.ones_like(train_images[0:1])
prior_image = tf.math.divide(prior_image, tf.constant([28.0 * 28.0]))

I train multiple encoders, to get a better understanding in the influence of the different tasks

global_dim = g=1, l=0, p=1
local_dim = l=1, p=0.1
mixed_dim = g=0.6, l=0.4, p=0
complete_dim = g=0.6, l=0.4, p=0.2

In [15]:
g = 1.0
l = 0.0
p = 1.0

"""
For oprimizations I enabled tf.function for the train step, but because of that, it's not possible to train everything at one
So I build the dims one by one, store them and evaluate in another file
"""
train(
    global_weight=g,
    local_weight=l,
    prior_weight=p,
    prior_image=prior_image,
    train_images=train_images,
    epochs=5,
    batch_size=128,
)


Start of epoch 0
Training loss (for one batch) at step 0: 1.8513067960739136


In [None]:
save_model(g, l, p)



INFO:tensorflow:Assets written to: ./models/dim_encoder-g1_0-l0_0-p1_0\assets


INFO:tensorflow:Assets written to: ./models/dim_encoder-g1_0-l0_0-p1_0\assets


done saved as ./models/dim_encoder-g1_0-l0_0-p1_0
