In [2]:
import tensorflow as tf
import numpy as np
#%load_ext tensorboard

In [3]:
#@title Mutual Information Functions

# These functions take LOG-probabilities

# Samples have shape (n, C) where n is the number of samples and C it the number of classes
def joint(samples_x, samples_y, samples_z):
    return tf.math.reduce_logsumexp(  
                samples_x[:, :, tf.newaxis, tf.newaxis]
              + samples_y[:, tf.newaxis, :, tf.newaxis]
              + samples_z[:, tf.newaxis, tf.newaxis, :],
           axis=0) - tf.math.log(tf.cast(samples_x.shape[0], tf.float32))

# Computes I(X ; Y)
# p_xy has shape (C, C)
def I_XY(p_xy):
    p_x = tf.math.reduce_logsumexp(p_xy, 1, keepdims=True)
    p_y = tf.math.reduce_logsumexp(p_xy, 0, keepdims=True)
    return tf.reduce_sum(tf.math.exp(p_xy) 
                         * tf.clip_by_value(p_xy - p_x - p_y, -1e20, 1e20)
                         / np.log(2))

# Computes I(X ; Y | Z)
# p_xyz has shape (C, C, C)
def I_XY_Z(p_xyz):
    p_z = tf.math.reduce_logsumexp(p_xyz, [0, 1], keepdims=True)
    p_xy_z = p_xyz - p_z
    p_x_z = tf.math.reduce_logsumexp(p_xyz, 1, keepdims=True) - p_z
    p_y_z = tf.math.reduce_logsumexp(p_xyz, 0, keepdims=True) - p_z
    return tf.reduce_sum(tf.math.exp(p_xyz) 
                         * tf.clip_by_value(p_xy_z - p_x_z - p_y_z, -1e10, 1e10)
                         / np.log(2))
    

In [4]:
#@title Network Architectures

# Pre-activation version of ResNet
def resnetBlock(inputs, filters, stride):
    
    x = tf.keras.layers.BatchNormalization(axis=1)(inputs)
    x = tf.keras.layers.Activation('relu')(x)
    if stride != 1:
        shortcut = tf.keras.layers.Conv2D(filters, 1, stride, 'same')(x)
    else:
        shortcut = inputs
    x = tf.keras.layers.Conv2D(filters, 3, stride, 'same')(x)
    x = tf.keras.layers.BatchNormalization(axis=1)(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.Conv2D(filters, 3, 1, 'same')(x)
    x = tf.keras.layers.add([x, shortcut])
    return x

def resnet18(inputs, C, n_per_block=2, n_layers=4):
    
    x = tf.keras.layers.Conv2D(64, 3, 1, 'same')(inputs)
    for downsize, n_blocks, filters in [(1, n_per_block, 64), 
                                        (2, n_per_block, 128), 
                                        (2, n_per_block, 256), 
                                        (2, n_per_block, 512)][:n_layers]:
        strides = [downsize] + [1] * (n_blocks - 1)
        for stride in strides:
            x = resnetBlock(x, filters, stride)
    if n_layers == 0:
        x = tf.keras.layers.Activation('relu')(x)
    #x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.AveragePooling2D(4, padding='valid')(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(C)(x)
    return x  
    

def simpleCNN(inputs, C, n_layers=3):
    for _ in range(n_layers):
        x = tf.keras.layers.Conv2D(32, 3, 1, 'same')(inputs)
        x = tf.keras.layers.BatchNormalization(axis=1)(x)
        x = tf.keras.layers.Activation('relu')(x)
        x = tf.keras.layers.Conv2D(64, 3, 2, 'same')(x)
        x = tf.keras.layers.BatchNormalization(axis=1)(x)
        x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.AveragePooling2D(4, padding='valid')(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(C)(x)
    return x


def simpleCNNGal(inputs, C):
    x = tf.keras.layers.Conv2D(32, 3, 2, 'same', activation='relu')(inputs)
    x = tf.keras.layers.Conv2D(32, 3, 2, 'same', activation='relu')(x)
    x = tf.keras.layers.Conv2D(32, 3, 1, 'same', activation='relu')(x)
    x = tf.keras.layers.Conv2D(32, 3, 1, 'same', activation='relu')(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    x = tf.keras.layers.Dense(C)(x)
    return x

def linear(x, C):
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(C)(x)
    return x

def small(x, C):
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(10, activation='relu')(x)
    x = tf.keras.layers.Dense(C)(x)
    return x
        
        

In [5]:
#@title Training and Evaluation Loops

def transform(x, training=True):
    x = tf.transpose(x, perm=[0, 2, 3, 1])
    if training:
        x = tf.image.resize_image_with_crop_or_pad(x, 40, 40)
        x = tf.image.random_crop(x, [x.shape[0], 32, 32, 3])
        x = tf.image.random_flip_left_right(x)
    x = tf.image.per_image_standardization(x)
    x = tf.transpose(x, perm=[0, 3, 1, 2])
    return x

# n is number of samples in the dataset, C is number of classes

def evaluate_simple(C, data, model, model_simple, hard=True):
    
    print('Testing mutual information')
    
    batches = data.batch(1000)
    
    y_true = []
    y_model = []
    y_simple = []
    for x, y in batches:
        
        x = transform(x, training=False)
        
        y_true += [tf.math.log(tf.one_hot(tf.squeeze(y, axis=1), C))]
        if hard:
            y_model += [tf.math.log(tf.one_hot(tf.math.argmax(model(x, training=False), axis=1), C))]
            y_simple += [tf.math.log(tf.one_hot(tf.math.argmax(model_simple(x), axis=1), C))]
        else:
            y_model += [tf.nn.log_softmax(model(x))]
            y_simple += [tf.nn.log_softmax(model_simple(x))]
            
    y_true = tf.concat(y_true, 0)
    y_model = tf.concat(y_model, 0)
    y_simple = tf.concat(y_simple, 0)
    
    p_joint = joint(y_true, y_model, y_simple)
    I_model_true_simple = I_XY_Z(p_joint)
        
    I_model_true = I_XY(tf.math.reduce_logsumexp(p_joint, axis=2))
    I_model_simple = I_XY_Z(tf.transpose(p_joint, perm=[2, 0, 1]))
    
    I_simple_true = I_XY(tf.math.reduce_logsumexp(p_joint, axis=1))
    print(f'I(L ; Y) = {I_simple_true}')
    
    print(f'I(M ; Y) = {I_model_true}, '
          f'I(M ; Y | L) = {I_model_true_simple}')
        
    print('Done')
        
    return I_model_true_simple, I_model_true, I_model_simple
        
def train_simple(n, C, data, model, model_simple, epochs, batch_size, learning_rate, savename=None, save_frequency=100):
    
    #optimizer = tf.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
    optimizer = tf.optimizers.Adam(learning_rate=1e-4)
    
    avg_loss = tf.metrics.Mean()
    
    for epoch in range(epochs):
        
        if savename is not None and epoch % save_frequency == 0:
            model_simple.save_weights(savename)
       
        print(f'Training simple model: epoch {epoch}')
        batches = data.shuffle(n).batch(batch_size)
        
        for i, (x, y) in enumerate(batches):
            
            x = transform(x, training=True)
            
            y_true = tf.math.log(tf.one_hot(tf.squeeze(y, axis=1), C))
            y_model = tf.nn.log_softmax(model(x, training=False))
            with tf.GradientTape() as tape:
                y_simple = tf.nn.log_softmax(model_simple(x))
                p_joint = joint(y_true, y_model, y_simple)
                #p_joint = joint(y_simple, y_true, y_model)
                loss = I_XY_Z(p_joint)
                #loss += I_XY_Z(tf.transpose(p_joint, perm=[0, 2, 1]))
            gradients = tape.gradient(loss, model_simple.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model_simple.trainable_variables))
            
            avg_loss(loss)
            
            if (i + 1) % 10 == 0:
                tf.print(f'Epoch {epoch}, '
                         f'{i * batch_size/n * 100:.1f}%, '
                         f'Loss {avg_loss.result():.6f}')
                
        avg_loss.reset_states()
              
    print('Done')
    
def train_imitate(n, C, data_train, data_test, model, model_simple,
                 epochs=200, batch_size=128, learning_rate=0.1):
    
    criterion = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = tf.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
    
    avg_loss = tf.metrics.Mean()
    accuracy = tf.metrics.SparseCategoricalAccuracy()
    
    for epoch in range(epochs):
    
        batches = data_train.shuffle(n).batch(batch_size)
        
        for i, (x, y) in enumerate(batches):
            
            x = transform(x, training=True)
            y_model = tf.argmax(model(x), axis=1)
            
            with tf.GradientTape() as tape:
                y_simple = model_simple(x)
                loss = criterion(y_model, y_simple)
            gradients = tape.gradient(loss, model_simple.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model_simple.trainable_variables))
            
            avg_loss(loss)
            accuracy(y_model, y_simple)
              
            if (i + 1) % 10 == 0:
                tf.print(f'Epoch {epoch}, '
                         f'{i * batch_size/n * 100:.1f}%, '
                         f'Loss {avg_loss.result():.3f}, '
                         f'Accuracy {accuracy.result() * 100:.2f}')
                
        test_accuracy = evaluate(data_test, model_simple)
        print(f'Test accuracy {test_accuracy * 100:.2f}%')
    
    

def evaluate(data, model):
    
    print('Testing accuracy')
    
    batches = data.batch(1000)
    accuracy = tf.metrics.SparseCategoricalAccuracy()
    
    for x, y in batches:
        x = transform(x, training=False)
        y_model = model(x, training=False)
        accuracy(y, y_model)
        
    print('Done')
        
    return accuracy.result()
    
import datetime            

def train(n, C, data_train, data_test, model, model_simple, 
          epochs=200, batch_size=128, learning_rate=0.1,
          save_dir='checkpoints',
          simple_epochs=2, simple_batch_size=1000, simple_learning_rate=0.1,
          simple_frequency=20,
          save_frequency=100):
    
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = 'logs/' + current_time
    summary_I_model_true_simple = tf.summary.create_file_writer(log_dir + '/I_model_true_simple')
    summary_I_model_true = tf.summary.create_file_writer(log_dir + '/I_model_true')
    summary_I_model_simple = tf.summary.create_file_writer(log_dir + '/I_model_simple')
    
    criterion = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = tf.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
    #optimizer = tf.optimizers.SGD(learning_rate=learning_rate)
    
    avg_loss = tf.metrics.Mean()
    accuracy = tf.metrics.SparseCategoricalAccuracy()
    
    steps = 0
    for epoch in range(epochs):
    
        batches = data_train.shuffle(n).batch(batch_size)
        
        for i, (x, y) in enumerate(batches):
            
            x = transform(x, training=True)
              
            with tf.GradientTape() as tape:
                y_model = model(x)
                loss = criterion(y, y_model)
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            
            avg_loss(loss)
            accuracy(y, y_model)
              
            if (i + 1) % 10 == 0:
                tf.print(f'Epoch {epoch}, '
                         f'{i * batch_size/n * 100:.1f}%, '
                         f'Loss {avg_loss.result():.3f}, '
                         f'Accuracy {accuracy.result() * 100:.2f}')
                
            if simple_frequency is not None and (i + 1) % simple_frequency == 0:
                
                train_simple(n, C, data_train, model, model_simple, 
                             epochs=simple_epochs, batch_size=simple_batch_size, 
                             learning_rate=simple_learning_rate)
                
                I_model_true_simple, I_model_true, I_model_simple \
                            = evaluate_simple(C, data_test, model, model_simple, hard=True)
                with summary_I_model_true_simple.as_default():
                        tf.summary.scalar(f'Mutual Information', 
                                          I_model_true_simple, step=steps)
                with summary_I_model_true.as_default():
                        tf.summary.scalar(f'Mutual Information', 
                                          I_model_true, step=steps)
                with summary_I_model_simple.as_default():
                        tf.summary.scalar(f'Mutual Information', 
                                          I_model_simple, step=steps)
            if steps % save_frequency == 0 and save_dir is not None:
                model.save_weights(f'{save_dir}/{steps}.h5')
            
            steps += 1
        
        test_accuracy = evaluate(data_test, model)
        print(f'Test accuracy {test_accuracy * 100:.2f}%')
        
        avg_loss.reset_states()
        accuracy.reset_states()

def train_from_checkpoints(n, C, data_train, data_test, model, models_simple, 
          checkpoint_dir=['checkpoints'],
          simple_epochs=0, simple_batch_size=1000, simple_learning_rate=0.1,
          simple_frequency=20,
          save_frequency=100,
          hard=True):
    
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = 'logs/' + current_time
    summary_I_model_true_simple = tf.summary.create_file_writer(log_dir + '/I_model_true_simple')
    summary_I_model_true = tf.summary.create_file_writer(log_dir + '/I_model_true')
    summary_I_model_simple = tf.summary.create_file_writer(log_dir + '/I_model_simple')
    
    for steps in range(0, 20000, save_frequency):
        
        if simple_epochs > 0:
            for i, model_simple in enumerate(models_simple):
                model.load_weights(f'{checkpoint_dir[i]}/{steps}.h5')
                train_simple(n, C, data_train, model, model_simple, 
                             epochs=simple_epochs, batch_size=simple_batch_size, 
                             learning_rate=simple_learning_rate)
                
        for i, model_simple in enumerate(models_simple):
            
            model.load_weights(f'{checkpoint_dir[i]}/{steps}.h5')
                
            I_model_true_simple, I_model_true, I_model_simple \
                    = evaluate_simple(C, data_test, model, model_simple, hard=hard)
            with summary_I_model_true_simple.as_default():
                tf.summary.scalar(f'Mutual Information (simple model {i})', 
                                    I_model_true_simple, step=steps)
            with summary_I_model_true.as_default():
                tf.summary.scalar(f'Mutual Information (simple model {i})', 
                                    I_model_true, step=steps)
            with summary_I_model_simple.as_default():
                tf.summary.scalar(f'Mutual Information (simple model {i})', 
                                    I_model_simple, step=steps)
            #model_simple.save_weights(f'{checkpoint_dir}/simple{i}/{steps}.h5')
    

In [6]:
#@title CIFAR Processing

C = 2

def change_classes(x, y):
    for i in [4, 6]:
        index = y.flatten() != i
        x, y = x[index], y[index]
    for i in range(len(y)):
        if y[i, 0] in [2, 3, 5, 7]:
            y[i, 0] = 0
        elif y[i, 0] in [0, 1, 8, 9]:
            y[i, 0] = 1
    return x, y

# planes, cars, birds, cats, deer vs
# ships, trucks, frogs, dogs, horses
def change_classes_hard(x, y):
    for i in range(len(y)):
        if y[i, 0] < 5:
            y[i, 0] = 0
        else:
            y[i, 0] = 1
    return x, y

tf.keras.backend.set_image_data_format('channels_first')
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, y_train = change_classes_hard(x_train, y_train)
x_test, y_test = change_classes_hard(x_test, y_test)
n = len(x_train)
x_train = np.divide(x_train, 255.0, dtype=np.float32)
x_test = np.divide(x_test, 255.0, dtype=np.float32)
print(x_test.shape)
print(y_train.shape)
data_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
data_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))

(10000, 3, 32, 32)
(50000, 1)


In [7]:
# Full model training hyperparameters:
EPOCHS = 60
BATCH_SIZE = 128
LEARNING_RATE = 0.01

# Simple model training hyperparameters
SIMPLE_EPOCHS = 2
SIMPLE_BATCH_SIZE = 512
SIMPLE_LEARNING_RATE = 0.1

# How often to evaluate mutual information in units of gradient steps
SIMPLE_FREQUENCY = 2


tf.keras.backend.clear_session()

shape = (3, 32, 32)
inputs = tf.keras.Input(shape=shape)

model = tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C))
#model = tf.keras.applications.ResNet50(include_top=False, input_tensor=inputs, weights=None)
#model = tf.keras.Model(inputs=inputs, outputs=dense(model.output, C))
#model_simple = tf.keras.Model(inputs=inputs, outputs=linear(inputs, C))
#model_simple = tf.keras.Model(inputs=inputs, outputs=small(inputs, C))
#models_simple = [tf.keras.Model(inputs=inputs, outputs=linear(inputs, C))] + \
#                [tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C, n_layers=i)) for i in range(1, 4)]

#model = tf.keras.Model(inputs=inputs, outputs=simpleCNNGal(inputs, C))
#model_simple = tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C, n_layers=1))
#model_simple.load_weights('resnet6-0/23400.h5')
#model_simple = tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C, n_layers=0, n_per_block=1))
#model_simple = tf.keras.Model(inputs=inputs, outputs=linear(inputs, C))
model_simple = tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C))
#model_simple.load_weights('cnn6-10.h5')
#model.load_weights('resnet18-0/23400.h5')
#model.load_weights('resnet18-0/234000.h5')
model_simple.load_weights('resnet18-0/23400.h5')

#train_imitate(n, C, data_train, data_test, model, model_simple, 
#              epochs=10, batch_size=128, learning_rate=0.01)
#model_simple.load_weights('linear1/23400.h5')
#train_simple(n, C, data_train, model, model_simple, 
#               epochs=100, batch_size=512, learning_rate=0.1,
#               savename=None, save_frequency=None)

#model_simple.save_weights('linear-10-test.h5')


'''
models_simple = [tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C, n_layers=1)) for _ in range(4)] + \
                [tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C, n_layers=1, n_per_block=1)) for _ in range(4)] + \
                [tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C, n_layers=0)) for _ in range(4)]
    
names = [f'cnn6-{i}' for i in range(4)] + \
        [f'cnn4-{i}' for i in range(4)] + \
        [f'cnn2-{i}' for i in range(4)]
for model_simple, name in zip(models_simple, names):
    model_simple.load_weights(f'{name}/23400.h5')
    
#for model_simple in models_simple:
#    evaluate_simple(C, data_test, model, model_simple, hard=True)

models = [tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C)) for _ in range(4)]

checkpoints = ['resnet18-0', 'resnet18-1', 'resnet18-2', 'resnet18-3'] * 3
'''

'''
n_runs = 4
models = [tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C, n_layers=3)) for _ in range(n_runs)] + \
         [tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C, n_layers=2)) for _ in range(n_runs)] + \
         [tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C, n_layers=1)) for _ in range(n_runs)]

names = [f'resnet14-{i}' for i in range(n_runs)] + \
        [f'resnet10-{i}' for i in range(n_runs)] + \
        [f'resnet6-{i}' for i in range(n_runs)]

import os
for name in names:
    os.makedirs(name, exist_ok=True)

for model, name in zip(models, names):
    train(n, C, data_train, data_test, model, None,
      epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE,
      save_dir=name,
      simple_frequency=None,
      save_frequency=100)
'''

#model.load_weights('resnet18-0/23400.h5')

#train_imitate(n, C, data_train, data_test, model, model_simple, 
#              epochs=10, batch_size=128, learning_rate=0.01)
#train_simple(n, C, data_train, model, model_simple, 
#              epochs=10, batch_size=1000, learning_rate=0.1)

#model = tf.keras.Model(inputs=inputs, outputs=resnet18(inputs, C))
#model.load_weights('resnet18-0/0.h5')

'''
train(n, C, data_train, data_test, model, model_simple,
      epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE,
      save_dir='simpleCNN-fine',
      simple_epochs=0, simple_batch_size=SIMPLE_BATCH_SIZE, 
      simple_learning_rate=SIMPLE_LEARNING_RATE, 
      simple_frequency=None,
      save_frequency=10)

#'''
#'''
train_from_checkpoints(n, C, data_train, data_test, model, [model_simple],
      checkpoint_dir=['resnet18-1'],
      simple_epochs=0, simple_batch_size=SIMPLE_BATCH_SIZE, 
      simple_learning_rate=SIMPLE_LEARNING_RATE, 
      simple_frequency=SIMPLE_FREQUENCY,
      save_frequency=500,
      hard=True)
#'''





Testing mutual information


W0521 19:19:23.914018 140424419575552 deprecation.py:323] From /home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/ops/image_ops_impl.py:1444: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.


I(L ; Y) = 0.5341752767562866
I(M ; Y) = -8.599133138886828e-08, I(M ; Y | L) = 0.0
Done
Testing mutual information
I(L ; Y) = 0.5341753959655762
I(M ; Y) = 0.048815175890922546, I(M ; Y | L) = 0.0029210764914751053
Done
Testing mutual information
I(L ; Y) = 0.5341752767562866
I(M ; Y) = 0.07309546321630478, I(M ; Y | L) = 0.004723785445094109
Done
Testing mutual information
I(L ; Y) = 0.5341755151748657
I(M ; Y) = 0.08379804342985153, I(M ; Y | L) = 0.004938877187669277
Done
Testing mutual information
I(L ; Y) = 0.5341755151748657
I(M ; Y) = 0.10661014914512634, I(M ; Y | L) = 0.0073980167508125305
Done
Testing mutual information
I(L ; Y) = 0.5341752767562866
I(M ; Y) = 0.14322562515735626, I(M ; Y | L) = 0.012084584683179855
Done
Testing mutual information
I(L ; Y) = 0.5341755151748657
I(M ; Y) = 0.14822795987129211, I(M ; Y | L) = 0.012810492888092995
Done
Testing mutual information
I(L ; Y) = 0.534175455570221
I(M ; Y) = 0.1581883728504181, I(M ; Y | L) = 0.011463411152362823
Done
