In [1]:
### IMPORTS

import tensorflow as tf 
from tensorflow.keras import Sequential, layers, losses, optimizers, activations, Model, datasets, models
import numpy as np
import pickle
import os
from helper_functions.calculation_helpers import *
from neural_agents import LiteralListener, PragmaticListener

In [2]:
### CREATE DATA SET AND LOAD MNIST CLASSIFIER 

(train_images, train_labels_orig), (test_images, test_labels_orig) = datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.reshape(60000, 28, 28, 1)
test_images = test_images.reshape(len(test_labels_orig), 28, 28, 1)

train, test, ME = np.load('mnist_referential_dataset/train_test_ME.npy', allow_pickle=True)
(target_indices, target_labels, distractor_indices, distractor_labels) = train
(test_target_indices, test_target_labels, test_distractor_indices, test_distractor_labels) = test
(ME_indices, ME_labels) = ME

n_train = len(target_labels)

mnist_classifier = models.load_model('mnist_classifier/model.06-0.04.h5')
feature_extraction_model = tf.keras.Model(inputs=mnist_classifier.input,
                                          outputs=mnist_classifier.get_layer('dense').output)

In [3]:
# FUNCTION FOR TRAINING THE NEURAL AGENTS

def run_agent(vocab_size=20,
              activation='sigmoid',
              agent_type='pragmatic',
              alpha = 5.,
              runs = 25,
              negative_sampling=None, # None, 'words', 'objects', or 'both',
              message_length = 1,
              encoding_dim = 32,
              n_distractors = 1,
              n_classes = 10,
              n_training = n_train,
              n_epochs = 100,
              batch_size = 64,
              learning_rate = 0.0001
             ):
    
    n_batches = n_training // batch_size
    if agent_type == 'pragmatic':
        path_ending = '_alpha_' + str(alpha) + '/'
    else: 
        path_ending = '/'

    # create messages as integers 0-19
    all_messages = np.zeros((vocab_size,message_length))
    for i in range(vocab_size):
        all_messages[i] = i
    
    ### iterate over runs 
    for run in range(runs):
        
        print('run', run)
        
        # randomly assign one message to each digit 0-8
        np.random.shuffle(all_messages)
        message_code = all_messages[:9]
        ME_message_code = all_messages[9:]
        
        messages = np.zeros((n_train, message_length))
        for i in range(9):
            messages[target_labels==i] = message_code[i]
        
        
        ### set paths for different negative sampling strategies 
        
        if negative_sampling == 'both' or negative_sampling == 'objects':
            folder = ('results/' + agent_type + '_negative_sampling_' + 
                      str(negative_sampling) + path_ending + 'run' + str(run) + '/')
            
            # add examples of digit 9 to the distractors if sampling is 'both' or 'objects'
            replacement_indices = np.where(np.random.binomial(1, 1/10, size=len(distractor_indices))==1)[0]
            if len(replacement_indices) > len(ME_indices):
                replacement_indices = replacement_indices[:len(ME_indices)]
            random_ME_indices = np.random.choice(len(ME_indices), size=len(replacement_indices), replace=False)
            distractor_indices[replacement_indices] = ME_indices[random_ME_indices]
            distractor_labels[replacement_indices] = ME_labels[random_ME_indices]
        
        elif negative_sampling == 'words':
            folder = ('results/' + agent_type + '_negative_sampling_words' + path_ending + 'run' + str(run) + '/')
        
        elif negative_sampling is None:
            folder = ('results/' + agent_type + '_no_negative_sampling' + path_ending + 'run' + str(run) + '/')
        
        
        ### save the parameters 
        
        param_dict = {"VOCAB_SIZE": vocab_size, 
                      "MESSAGE_LENGTH": message_length, 
                      "ALPHA": alpha, 
                      "ENCODING_DIM": encoding_dim, 
                      "N_DISTRACTORS": n_distractors, 
                      "N_CLASSES" : n_classes, 
                      "N_TRAINING": n_training, 
                      "N_EPOCHS": n_epochs, 
                      "BATCH_SIZE": batch_size, 
                      "N_EPOCHS": n_epochs, 
                      "MESSAGE_CODE": message_code, 
                      "ME_MESSAGE_CODE": ME_message_code, 
                      "NEGATIVE_SAMPLING": negative_sampling, 
                      "STATE_ENCODER": 'one_layer_' + activation, 
                      "MESSAGE_ENCODER": 'one_layer_' + activation}
        
        if not os.path.exists(folder):
            os.makedirs(folder)
        with open(folder + 'param_dict.pickle', 'wb') as handle:
            pickle.dump(param_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
            
        
        ### build the agent 
        
        message_encoder = Sequential(
            [layers.Embedding(input_dim=vocab_size, output_dim=32, 
                              input_length=message_length), 
             layers.Flatten(),
             layers.Dense(encoding_dim, activation=activation)], 
            name='message_encoder'
        )
        
        print(message_encoder.summary())
        
        state_encoder = Sequential(
            [layers.Dense(encoding_dim, activation=activation)], 
            name='state_encoder'
        )
        
        if agent_type == 'literal':
            agent = LiteralListener(vocab_size = vocab_size, 
                                    message_length = message_length,
                                    encoding_dim = encoding_dim,
                                    n_distractors = n_distractors,
                                    messages = message_code,
                                    state_encoder = state_encoder, 
                                    message_encoder= message_encoder
                                   )
        elif agent_type == 'pragmatic':
            agent = PragmaticListener(alpha=alpha, 
                                      vocab_size = vocab_size, 
                                      message_length = message_length,
                                      encoding_dim = encoding_dim,
                                      n_distractors = n_distractors,
                                      messages = message_code,
                                      state_encoder = state_encoder, 
                                      message_encoder= message_encoder
                                     )
            
        agent.build(vision_dim=64)
        agent_optim = optimizers.Adam(lr=learning_rate)
        agent_loss = losses.CategoricalCrossentropy()
        
        
        ### training
        
        # store rewards, losses and ME evlauation results per epoch
        all_rewards, all_losses, ME_evaluation = [], [], []
        
        # create tensorflow Dataset from data
        train_dataset = tf.data.Dataset.from_tensor_slices((feature_extraction_model(train_images[target_indices]),
                                                        feature_extraction_model(train_images[distractor_indices]),
                                                        messages, target_labels, distractor_labels))
        train_dataset = train_dataset.batch(batch_size)
        
        # training loop
        for epoch in range(n_epochs): 
            print('epoch', epoch)
            
            train_dataset = train_dataset.shuffle(buffer_size=n_train)
            train_iterator = iter(train_dataset)
        
            collect_rewards, collect_losses = [], []
        
            for batch in range(n_batches):
                
                targets, distractors, messages, _, _ = train_iterator.get_next()
                labels = np.stack([np.ones(len(messages)), np.zeros(len(messages))], axis=1)         
                      
                if negative_sampling == 'both' or negative_sampling == 'words':
                    message_proposals = np.reshape(np.tile(all_messages, [len(targets), 1]), 
                                                   (len(targets), len(all_messages), message_length))
                elif negative_sampling == 'objects' or negative_sampling is None:
                    message_proposals = np.reshape(np.tile(message_code, [len(targets), 1]), 
                                                   (len(targets), len(message_code), message_length))
                
                with tf.GradientTape(persistent=True) as tape:
                    tape.watch(agent.trainable_variables)
                    
                    selections, log_policy = agent.listener_action([messages, 
                                                                    targets, 
                                                                    [distractors],  
                                                                    message_proposals, 
                                                                    [message_proposals]], 
                                                                    expand_messages = False)   
                    
                    rewards = np.array(tf.cast(tf.equal(labels, selections), tf.float32)[:,0])
                    
                    policy = tf.math.exp(log_policy)      
                    loss = - tf.reduce_mean(selections * log_policy * tf.expand_dims(rewards, axis=1))
        
                
                grads = tape.gradient(loss, agent.trainable_variables)
                agent_optim.apply_gradients(zip(grads, agent.trainable_variables))
                
                collect_rewards.append(np.mean(rewards))
                collect_losses.append(loss)
            
            all_rewards.append(np.mean(collect_rewards))
            all_losses.append(np.mean(collect_losses))
            
            
            ### ME bias evaluation
        
            n_ME = len(ME_indices)
            
            eval_messages = np.zeros((n_ME, message_length))
            shuffled_indices = np.random.choice(len(ME_message_code), replace=True, size=n_ME)
            for i in range(n_ME):
                eval_messages[i] = ME_message_code[shuffled_indices[i]]
            
            eval_targets = feature_extraction_model(train_images[ME_indices])
            eval_distractors = feature_extraction_model(test_images[0:n_ME])
            
            labels = np.stack([np.ones(len(eval_messages)), np.zeros(len(eval_messages))], axis=1)
            
            if negative_sampling == 'both' or negative_sampling == 'words': 
                message_proposals = np.reshape(np.tile(all_messages, [len(eval_targets), 1]), 
                                               (len(eval_targets), len(all_messages), message_length))
                
                selections, log_policy = agent.listener_action([eval_messages, 
                                                                eval_targets, 
                                                                [eval_distractors],  
                                                                message_proposals, 
                                                                [message_proposals]], 
                                                               expand_messages=False) 
                
            elif negative_sampling == 'objects' or negative_sampling is None: 
                message_proposals = np.reshape(np.tile(message_code, [len(eval_targets), 1]), 
                                               (len(eval_targets), len(message_code), message_length))
                
                selections, log_policy = agent.listener_action([eval_messages, 
                                                                eval_targets, 
                                                                [eval_distractors],  
                                                                message_proposals, 
                                                                [message_proposals]], 
                                                               expand_messages=True)           
                
            rewards = np.array(tf.cast(tf.equal(labels, selections), tf.float32)[:,0])
            ME_eval = np.mean(rewards)
            ME_evaluation.append(ME_eval)
        
        print('... saving models ...')
        np.save(folder + 'state_encoder_weights', agent.state_encoder.get_weights())
        np.save(folder + 'state_encoder_config', agent.state_encoder.get_config())
        np.save(folder + 'message_encoder_weights', agent.message_encoder.get_weights())
        np.save(folder + 'message_encoder_config', agent.message_encoder.get_config())
        print('... models saved ...')
        
        
        ### evaluation on the test data 
        
        eval_targets = feature_extraction_model(test_images[test_target_indices])
        eval_distractors = feature_extraction_model(test_images[test_distractor_indices])
        n_test = len(test_target_labels)
        
        eval_messages = np.zeros((n_test, message_length))
        for i in range(9):
            eval_messages[test_target_labels==i] = message_code[i]
        
        labels = np.stack([np.ones(len(eval_messages)), np.zeros(len(eval_messages))], axis=1)
        
        if negative_sampling: 
            message_proposals = np.reshape(np.tile(all_messages, [len(eval_targets), 1]), 
                                           (len(eval_targets), len(all_messages), message_length))
        else: 
            message_proposals = np.reshape(np.tile(message_code, [len(eval_targets), 1]), 
                                           (len(eval_targets), len(message_code), message_length))
                      
        selections, log_policy = agent.listener_action([eval_messages, 
                                                        eval_targets, 
                                                        [eval_distractors],  
                                                        message_proposals, 
                                                        [message_proposals]], 
                                                       expand_messages=False)          
        
        rewards = np.array(tf.cast(tf.equal(labels, selections), tf.float32)[:,0])
        test_eval = np.mean(rewards)
        
        print('test performance: ', test_eval, 'ME bias: ', ME_evaluation[-1])
        
        ### save run
        
        np.save(folder + 'rewards', all_rewards)
        np.save(folder + 'losses', all_losses)
        np.save(folder + 'ME', ME_evaluation)
        np.save(folder + 'test', test_eval)   

In [2]:
### TRAIN THE AGENTS

## run pragmatic agent without negative sampling for different values of alpha
for alpha in [5, 10, 15]:
    run_agent(alpha=alpha, activation='sigmoid', runs=25) 

## run pragmatic agent with alpha=5 and different negative sampling strategies
for negative_sampling in ['words', 'objects', 'both']:
    run_agent(alpha=5, negative_sampling=negative_sampling, activation='sigmoid', runs=25)   

## run the literal agent without negative sampling
run_agent(agent_type='literal', activation='sigmoid', runs=25)