Resources:  
http://jmlr.org/papers/volume17/15-239/15-239.pdf  
https://github.com/michetonu/DA-RNN_manoeuver_anticipation  
https://github.com/michetonu/gradient_reversal_keras_tf

In [None]:
import os
import sys
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
import inspect
import shutil

import datetime
from sklearn.metrics import confusion_matrix

import tensorflow as tf
import keras
from keras.applications import MobileNetV2
from keras import backend as K, Model
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from keras.layers import Conv2D, pooling, Dropout, Dense, Input, Lambda
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras.callbacks import ModelCheckpoint, TensorBoard
from keras.optimizers import SGD
from keras.engine import Layer

# import stuff in this project
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
if parentdir not in sys.path:
    sys.path.insert(0,parentdir)
from colored_MNIST import colors as all_colors_rgb
    
%load_ext autoreload
%autoreload 2

cv2.__version__

# Parameters

In [None]:
colored_MNIST_output_folder = './colored_MNIST'
colors = {0: 'dark red', 
          1: 'navy',
          2: 'gold',
          3: 'aqua',
          4: 'indigo',
          5: 'deep pink',
          6: 'chocolate',
          7: 'honeydew',
          8: 'dark violet',
          9: 'beige'
         }
colors_inv = {v:k for k,v in colors.items()}
colors_rgb = {k: all_colors_rgb[v] for k, v in colors.items()}

In [None]:
INITIAL_LR = 1e-4
EPOCHS = 4
BIAS = 0.8

# Dataset creation
We will color the MNIST dataset, and will add the color label to the labels as well.

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
y_train_onehot = to_categorical(y_train)
y_test_onehot = to_categorical(y_test)

In [None]:
def color_MNIST(gray_dataset, bias=0, color_noise=False):
    ''' this function takes single-channel np.uint8 0-255 images and their labels, 
    and returns RGB np.uint8 0-255 images and their colors, optionally correlating colors to the labels.
    The goal of this function is to introduce bias into the MNIST dataset.'''
    
    x = []
    y_color = []
    
    for img, label in gray_dataset:
        if np.random.rand() < bias:
            i = label
        else:
            i = random.choice(range(10))
            
        icolor = colors[i]
        icolor_rgb = colors_rgb[i]
        cimg = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        cimg = cimg.astype(np.float32)/255.

        # convert color
        cimg[..., 0] = cimg[..., 0] * icolor_rgb[0]
        cimg[..., 1] = cimg[..., 1] * icolor_rgb[1]
        cimg[..., 2] = cimg[..., 2] * icolor_rgb[2]
        cimg = cimg.astype(np.uint8)
        x.append(cimg)
        y_color.append(icolor)
        
    return np.array(x), np.array(y_color)

In [None]:
x_train_color, y_train_color = color_MNIST(zip(x_train, y_train), bias=BIAS)
x_test_color, y_test_color = color_MNIST(zip(x_test, y_test), bias=0)  # here, our real-world testset won't have the bias
x_test_color_ref, y_test_color_ref = color_MNIST(zip(x_test, y_test), bias=BIAS)  # this to check what what if it does.


In [None]:
y_train_color_onehot = to_categorical([colors_inv[c] for c in y_train_color])
y_train_multi = {
    'color': y_train_color_onehot,
    'number': y_train_onehot
}

### inspect

In [None]:
i = np.random.choice(range(1000))

plt.imshow(x_test_color[i])
print(y_test[i], y_test_color[i])

In [None]:
plt.imshow(x_test_color_ref[i])
print(y_test[i], y_test_color_ref[i])

In [None]:
# write away some results to jpg
shutil.rmtree(colored_MNIST_output_folder)
for f in ['train', 'test']:
    [os.makedirs(os.path.join(colored_MNIST_output_folder, f, str(label)), exist_ok=True) for label in range(10)]    
    [os.makedirs(os.path.join(colored_MNIST_output_folder, f, str(label)), exist_ok=True) for label in range(10)]

N = 200

for i in range(N):
    path=os.path.join(colored_MNIST_output_folder, 'train', str(y_train[i]), str(i)+'.jpg')
    ret = cv2.imwrite(path, x_train_color[i])
    assert ret
    
for i in range(N):
    path=os.path.join(colored_MNIST_output_folder, 'test', str(y_test[i]), str(i)+'.jpg')
    ret = cv2.imwrite(path, x_test_color[i])
    assert ret

# Modeling

### Network creation

In [None]:
def simple_base_classifier_2(input_shape=(None,None,3), n_conv=3,
                             init_filter_size=10, dropout_rate=0.10,
                             conv1x1_filters=None, include_top=False, 
                             hidden_units=None, n_classes=None):
    
    x = inp = Input(shape=input_shape)
    
    for i in range(0, n_conv, 1):
        x = Conv2D(filters=init_filter_size * (2**i), kernel_size=3, activation='relu')(x)
        if dropout_rate:
            x = Dropout(dropout_rate)(x)
        if i < n_conv - 1:
            x = pooling.MaxPool2D()(x)
    if conv1x1_filters:
        x = Conv2D(filters=conv1x1_filters, kernel_size=1, activation='relu')(x)
    
    x = pooling.GlobalAvgPool2D()(x)

    if include_top:
        if isinstance(hidden_units, (list, tuple)):
            for units in hidden_units:
                x = Dense(units, activation='relu')(x)
        else:
            x = Dense(hidden_units, activation='relu')(x)
        x = Dense(n_classes, activation="sigmoid")(x)
    model = Model(inp, x)
    return model


def get_multihead_branch(inputs, num_classes, final_act, l2_norm=False,
                         branch_name=None, dense=True):
    x = inputs
    if dense:
        x = Dense(20, activation='relu')(x)
    x = Dense(num_classes, activation=final_act, name=branch_name or final_act)(x)
    return x


def get_multitask_network(backbone=MobileNetV2, num_classes=10):
    outputs = backbone.output
    number = get_multihead_branch(outputs, num_classes, final_act='softmax', branch_name='number', l2_norm=False)
    color = get_multihead_branch(outputs, num_classes, final_act='softmax', branch_name='color', dense=False)
    model = Model(backbone.input, [number, color], name='number_color_model')

    return model

def get_numberonly_network(backbone=MobileNetV2, num_classes=10):
    outputs = backbone.output

    number = get_multihead_branch(outputs, num_classes, final_act='softmax', branch_name='number', l2_norm=False)    
    model = Model(backbone.input, [number], name='number_model')

    return model

### training logistics

In [None]:
def log_save_callbacks(name, log=True, save=True):
    # set up callbacks
    callbacks = []
    save_dir_root = "../model_training/"

    if save:
        fp_modelcheckpoint = os.path.join(save_dir_root, "modelcheckpoints")
        os.makedirs(fp_modelcheckpoint, exist_ok=True)
        h5_filename = os.path.join(fp_modelcheckpoint, name + ".hdf5")
        callbacks += [ModelCheckpoint(h5_filename, save_best_only=True)]

    if log:
        dir_tensorboard = os.path.join(save_dir_root, "tensorboard", name)
        os.makedirs(dir_tensorboard, exist_ok=True)
        callbacks += [TensorBoard(dir_tensorboard)]

    return callbacks


def compile_model(model, name, loss, loss_weights=None, initial_lr=1e-5):

    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    callbacks = log_save_callbacks(name=name + '_' + timestamp, log=True, save=False)
    callbacks += [EarlyStopping(patience=20, verbose=1, restore_best_weights=True)]
    callbacks += [ReduceLROnPlateau(verbose=1, factor=0.2, patience=10)]

#     optimizer = Adam(lr=initial_lr)
    optimizer = SGD(initial_lr)
    model.compile(optimizer, loss=loss, metrics=['acc'], loss_weights=loss_weights)
    return model, callbacks

### evalutation

In [None]:
def evaluate_results(y_true, y_pred, all_labels):
    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=all_labels)
    
    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=all_labels, yticklabels=all_labels,
           ylim=(len(all_labels)-0.5, -0.5),
           ylabel='True label',
           xlabel='Predicted label')

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], 'd'),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    plt.show()
    
    # Compute accuracy
    acc = np.mean(y_true==y_pred)
    print(f'Accuracy: {acc*100.:.2f}%')
    print('Per class statistics:')
    # Compute statistics per defect type
    for i, label in enumerate(all_labels):
        totp = np.sum(cm[:, i])
        realp = np.sum(cm[i])
        tp = cm[i][i]
        fp = totp - tp
        fn = realp - tp
        prec = tp / (tp+fp)
        rec = tp / realp
        print(f'  - {label}: precision {prec*100.:.2f}%, recall {rec*100.:.2f}%')

### train: number only, confirm bias

In [None]:
simplenet = simple_base_classifier_2(input_shape=(28, 28, 3), n_conv=3, init_filter_size=10, dropout_rate=0)
model = get_numberonly_network(backbone=simplenet)

loss = 'categorical_crossentropy'
loss_weights = {"number": 1}
model, callbacks = compile_model(model, 'mnist_simplenet_number',
                                 loss, loss_weights=loss_weights,  # or loss_weights=None
                                 initial_lr=INITIAL_LR)

model.fit(x_train_color, y_train_onehot, epochs=EPOCHS, callbacks=callbacks)    

In [None]:
# consistent colored as in dataset
y_pred_ref = np.argmax(model.predict(x_test_color_ref), axis=1)
# y_pred_ref = [colors[i] for i in y_pred_ref]
evaluate_results(y_test, y_pred_ref, range(10))

In [None]:
# no consistent coloring in the testset, contrary to the trainingset
y_pred = np.argmax(model.predict(x_test_color), axis=1)
evaluate_results(y_test, y_pred, range(10))

Our model has become worse in predicting the number.

### train: both number and color, flip color loss negative

In [None]:
simplenet = simple_base_classifier_2(input_shape=(28, 28, 3), n_conv=3, init_filter_size=10, dropout_rate=0)
model_debiased = get_multitask_network(backbone=simplenet)

In [None]:
loss = 'categorical_crossentropy'
negative_color_weights = {
    "number": 1, 
    "color": -.1
}

model_debiased, callbacks = compile_model(model_debiased, 'mnist_simplenet_number_color', loss, 
                                          loss_weights=negative_color_weights, initial_lr=INITIAL_LR)
model_debiased.fit(x_train_color, y_train_multi, epochs=EPOCHS, callbacks=callbacks)

In [None]:
# consistent colored testset
y_pred_number_ref, y_pred_color_ref = model_debiased.predict(x_test_color_ref)
y_pred_number_ref = np.argmax(y_pred_number_ref, axis=1)
evaluate_results(y_test, y_pred_number_ref, range(10))

We can predict numbers pretty well on the reference dataset, except for 4 and 8, which are confused (indigo and dark violet are close to each other, and so our bias screws up results on the testset even if we color it consistently.)

Now, lets predict on the inconsistently colored testset

In [None]:
# no consistent coloring in the testset, contrary to the trainingset
y_pred_number, y_pred_color = model_debiased.predict(x_test_color)  # todo: how are we guaranteed of the order of outputs?
y_pred_number = np.argmax(y_pred_number, axis=1)
evaluate_results(y_test, y_pred_number, range(10))

Seems like a negative loss weight for 'color' is not enough... We  are probably just scrambling the crap out of our last layer

### train: Gradient Reversal
In https://github.com/feidfoe/learning-not-to-learn/blob/master/trainer.py we see the authors train with a minimax game and gradient reversal.  
  
In essence, this means the head for color still tries to extract color info from the shared embedding, but during backprop we flip the gradient between the start of the color head and the embedding layer, meaning the shared weights move away from allowing encoding color information.  
  
Let's try the gradient reversal.

In [None]:
def reverse_gradient(X, hp_lambda):
    '''Flips the sign of the incoming gradient during training.'''
    try:
        reverse_gradient.num_calls += 1
    except AttributeError:
        reverse_gradient.num_calls = 1

    grad_name = "GradientReversal%d" % reverse_gradient.num_calls

    @tf.RegisterGradient(grad_name)
    def _flip_gradients(op, grad):
        return [tf.negative(grad) * hp_lambda]

    g = K.get_session().graph
    with g.gradient_override_map({'Identity': grad_name}):
        y = tf.identity(X)

    return y

class GradientReversal(Layer):
    '''Flip the sign of gradient during training.'''
    def __init__(self, hp_lambda, **kwargs):
        super(GradientReversal, self).__init__(**kwargs)
        self.supports_masking = False
        self.hp_lambda = hp_lambda

    def build(self, input_shape):
        self.trainable_weights = []

    def call(self, x, mask=None):
        return reverse_gradient(x, self.hp_lambda)

    def get_output_shape_for(self, input_shape):
        return input_shape

    def get_config(self):
        config = {'hp_lambda': self.hp_lambda}
        base_config = super(GradientReversal, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [None]:
def get_multihead_branch_gradrev(inputs, num_classes, final_act, l2_norm=False,
                         branch_name=None, dense=True, reverse_grad=False, hp_lambda=.1):
    x = inputs
    if reverse_grad:
        # a gradient reversal layer, useful if we want the base part of a network to be penalized for 
        # encoding information on which this branch can predict its output.
        flip_layer = GradientReversal(hp_lambda)
        x = flip_layer(x)
        print(branch_name, ' reverse grad from here')
        
    if dense:
        x = Dense(20, activation='relu')(x)
    x = Dense(num_classes, activation=final_act, name=branch_name or final_act)(x)
    return x


def get_multitask_network_gradflip(backbone=MobileNetV2, num_classes=10):
    outputs = backbone.output
    number = get_multihead_branch_gradrev(outputs, num_classes, final_act='softmax', 
                                          branch_name='number', reverse_grad=False, dense=False)
    
    color = get_multihead_branch_gradrev(outputs, num_classes, final_act='softmax', 
                                         branch_name='color', reverse_grad=True, hp_lambda=.1)
    
    model = Model(backbone.input, [number, color], name='number_color_gradflip')

    return model

In [None]:
simplenet = simple_base_classifier_2(input_shape=(28, 28, 3), n_conv=3, init_filter_size=10, dropout_rate=0)
model_gradflip = get_multitask_network_gradflip(simplenet)

loss = 'categorical_crossentropy'
loss_weights = {
    "number": 1, 
    "color": 1
}

model_gradflip, callbacks = compile_model(model_gradflip, 'mnist_simplenet_number_color', loss, 
                                          loss_weights=loss_weights, initial_lr=INITIAL_LR)
model_gradflip.fit(x_train_color, y_train_multi, epochs=EPOCHS, callbacks=callbacks)

In [None]:
# consistent colored testset
y_pred_number_ref, y_pred_color_ref = model_gradflip.predict(x_test_color_ref)
y_pred_number_ref = np.argmax(y_pred_number_ref, axis=1)
evaluate_results(y_test, y_pred_number_ref, range(10))

With hp_lambda=1 we see that 4 is always predicted as 8, so there is prediction based on colors going on.

Now, lets predict on the inconsistently colored testset

In [None]:
# no consistent coloring in the testset, contrary to the trainingset
y_pred_number, y_pred_color = model_gradflip.predict(x_test_color)  # todo: how are we guaranteed of the order of outputs?
y_pred_number = np.argmax(y_pred_number, axis=1)
evaluate_results(y_test, y_pred_number, range(10))

Similar results, we are getting somewhere!

# Todo: 
* BIAS MUST CONTAIN NOISE. THE NUMBER ITSELF OFCOURSE ENCODES COLOR IF THERE IS NO NOISE, so the gradflip layer doesn't matter, we by definition have the same info as the number branch which is trained to predict number. Noisy bias ensures that there is indeed color info in the embedding, **separate** from number info.
* Confirm simplenet can train on structure alone by training it on original MNIST  
* Use a more powerful color branch (more dense layers) than the number branch so that simple hiding of color info in the embedding isn't enough.