# ME bias in the two agent setting 

## Setup the training

In [None]:
# IMPORTS 

import tensorflow as tf
import numpy as np
import sys
import os
import pickle

# sys.path.append("path to the project")
from RSA_communication_agents import RSASpeaker0, RSAListener0, RSASpeaker1, RSAListener1

In [None]:
# GENERATE INPUT STATES 

def generate_states(size, n, ME=1):
    """ Generates the state input for the speaker
        inputs: 
            size - number of data points
            n - number of states and messages in the system
            ME - number of states that are withheld from training, here 1 or 2
        outputs: 
            data - training data 
            labels - training labels
    """

    data = np.zeros((size, n, n))
    if ME == 1:
        selection = np.random.choice(n-1, size=(size))
    elif ME == 2: 
        selection = np.random.choice([i for i in range(1, n-1)], size=(size))
    for i in range(n):
        data[selection == i, i, :] = 1.
    labels = tf.one_hot(selection, depth=n)
    
    return np.float32(data), labels

In [None]:
# TRAIN A SPEAKER LISTENER PAIR AND SAVE THE PARAMETERS AS WELL AS THE REWARDS AND THE LEXICA

def run_speaker_listener_pair(n=3, reasoning=0, n_epochs = 100, ME=1, learning_rate=0.001):
    """ Trains speaker and listener in the two agent Lewis game as describe in Experiment 1. 
        inputs: 
            n - number of states and messages (in total)
            reasoning - reasoning level of the speaker and the listener, 0 for literal, 1 for pragmatic 
            ME - number of messages that are left out during training, here 1 or 2 
            learning_rate - learning rate for the Adam optimizer
        By default 100 agent combinations (speaker+listener) are trained and all their lexica and rewards for 
        every epoch are saved (combination-wise).
    """
    
    # setup the training and save the parameters 
    
    n_states = n                      # number of states 
    n_messages = n                    # number of messages
    batch_size = 32                   # batch size
    datasize = 1000                   # number of training data points
    batches = datasize // batch_size  # number of bathes per epoch
    
    runs = 100                        # number of runs: 100 speaker-listener pairs are trained 
    init_mean = 0.5                   # mean for initialization of lexicon entries
    init_std = 0.01                   # std for initialization of lexicon entries
    
    constraint = tf.keras.constraints.NonNeg() # constrains the lexica to have entries >= 0
    
    filename = 'data/communication/'+str(n)+'_states/'
    if not os.path.exists(filename):
            os.makedirs(filename)
    
    param_dict = {"n_states": n_states,"n_messages": n_messages, "n_epochs":n_epochs, "batch_size": batch_size,
                  "datasize":datasize, "initializer_truncated_normal_mean_std": [init_mean, init_std], 
                  "learning_rate":learning_rate, "runs": runs, "constraint":constraint}
    with open(filename + 'param_dict.pickle', 'wb') as handle:
        pickle.dump(param_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    
    # run the speaker-listener pairs 
    
    for run in range(1,101):
    
        # create data 
        data, labels = generate_states(datasize, n, ME=ME)
        data = tf.convert_to_tensor(data)
        
        lexica_S = []
        lexica_L = []
        all_rewards = []
        
        # create listener
        lexicon_listener = tf.Variable(tf.initializers.TruncatedNormal(mean=init_mean, stddev=init_std)
                                       ([n_states, n_messages]),
                                       name="lexicon_listener", 
                                       trainable=True, 
                                       dtype=tf.float32,
                                       constraint=tf.keras.constraints.NonNeg())
        if reasoning == 0:
            listener = RSAListener0(n_states, n_messages, lexicon_listener)
        elif reasoning == 1:
            listener = RSAListener1(n_states, n_messages, lexicon_listener, alpha=5.)
        listener.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                         loss=tf.keras.losses.CategoricalCrossentropy())
        
        # create speaker
        lexicon_speaker = tf.Variable(tf.initializers.TruncatedNormal(mean=init_mean, stddev=init_std)
                                      ([n_states, n_messages]),
                                      name="lexicon_speaker", 
                                      trainable=True, dtype=tf.float32,
                                      constraint=tf.keras.constraints.NonNeg())
        
        if reasoning == 0:
            speaker = RSASpeaker0(n_states, n_messages, lexicon_speaker)
        elif reasoning == 1:
            speaker = RSASpeaker1(n_states, n_messages, lexicon_speaker, alpha=5.)
        speaker.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                        loss=tf.keras.losses.CategoricalCrossentropy())
        
        # train
        for i in range(n_epochs):
            average_reward = []
            
            shuffle_indices = np.random.permutation(datasize)
            data = tf.gather(data, shuffle_indices)
            labels = tf.gather(labels, shuffle_indices)

            for j in range(batches):
                data_batch = data[j:j + batch_size]
                labels_batch = labels[j:j + batch_size]

                _, messages = speaker.get_messages(data_batch)
                listener_input_messages = tf.reshape(tf.tile(messages, multiples=tf.constant([1, n_states])),
                                                     shape=(batch_size, n_states, n_messages))
                listener_input_messages = tf.transpose(listener_input_messages, [0, 2, 1])
                _, states = listener.get_states(listener_input_messages)

                rewards = tf.einsum('ij,ij->i', labels_batch, states)
                average_reward.append(np.mean(rewards))
                
                # RL:
                # Note that we implemented REINFORCE with a work-around using categorical crossentropy. 
                # This can be done by setting the labels to the agent's actions, and weighting the loss
                # function by the rewards. 
                speaker.train_on_batch(data_batch, messages, sample_weight=rewards)
                listener.train_on_batch(listener_input_messages, states, sample_weight=rewards)
                
            mean_reward = np.mean(average_reward)
            all_rewards.append(mean_reward)
            lexica_L.append(np.copy(listener.lexicon[:]))
            lexica_S.append(np.copy(speaker.lexicon[:]))
            
        print('run ' + str(run), 'average reward ' +str(ME)+ ' ' + str(mean_reward))
        
        # save rewards and lexica 
        if reasoning == 0:
            filename_full = (filename + 'L' + str(reasoning) +'_'+ 'S' + str(reasoning) + '_' + str(ME) 
                             + 'missing_')
        elif reasoning == 1:
            filename_full = (filename + 'L' + str(reasoning) +'_'+ 'S' + str(reasoning) + '_' + str(ME) 
                             + 'missing_5.0alpha_')
            
        np.save(filename_full + 'lexicon_S_run' + str(run), lexica_S)
        np.save(filename_full + 'lexicon_L_run' + str(run), lexica_L)
        np.save(filename_full + 'rewards_run' + str(run), all_rewards)

## Run the training 

for different agent types, numbers of state (3 and 10) and different number of states being withheld from the training (1 and 2). 

### Literal speaker with literal listener

In [None]:
run_speaker_listener_pair(n=3, reasoning=0, n_epochs=100, ME=1)

In [None]:
run_speaker_listener_pair(n=3, reasoning=0, n_epochs=100, ME=2)

In [None]:
run_speaker_listener_pair(n=10, reasoning=0, n_epochs=1000, ME=1)

In [None]:
run_speaker_listener_pair(n=10, reasoning=0, n_epochs=1000, ME=2)

### Pragmatic speaker with pragmatic listener

In [None]:
run_speaker_listener_pair(n=3, reasoning=1, n_epochs=100, ME=1)

In [None]:
run_speaker_listener_pair(n=3, reasoning=1, n_epochs=100, ME=2)

In [None]:
run_speaker_listener_pair(n=10, reasoning=1, n_epochs=1000, ME=1)

In [None]:
run_speaker_listener_pair(n=10, reasoning=1, n_epochs=1000, ME=2)