# Convergence time scales in the two agent setting

## Setup training and saving of the generated data 

In [None]:
# IMPORTS 

import tensorflow as tf
import numpy as np
import os
import sys
import pickle

sys.path.append("/net/store/cogmod/users/xenohmer/Conferences/CogSci2020/pragmatic_agents_me_bias/")
from RSA_communication_agents import RSASpeaker0, RSAListener0, RSASpeaker1, RSAListener1

In [None]:
# GENERATE INPUT STATES 

def generate_states(size, n, state_dist):
    """ Generates the state input for the speaker in the two agent setting. The occurrence frequencies of the 
        n states follow either a uniform or a Zipfian distribution. 
        inputs: 
            size - number of data points
            n - number of states and messages in the system
            message_dist -  'Zipf' (definition see paper) or 'uniform', describes the occurence frequencies of 
                            the different input states
        outputs: 
            data - training data 
            labels - training labels
    """
    
    data = np.zeros((size, n, n))
    
    if state_dist == 'uniform':
        selection = np.random.choice(n, size=(size))
    elif state_dist == 'Zipf':
        norm_factor = np.sum([1/i for i in range(1,n+1)])
        selection = np.random.choice(np.arange(n), size=1000, p=[(1/i)/norm_factor for i in range(1,n+1)])
    
    for i in range(n):
        data[selection == i, i, :] = 1.
    labels = tf.one_hot(selection, depth=n)
    return np.float32(data), labels

In [None]:
# TRAIN A SPEAKER LISTENER PAIR AND SAVE THE PARAMETERS AS WELL AS THE REWARDS 

def run_speaker_listener_pair(learning_rate=0.01, n=3, n_epochs=100, reasoning=0, state_dist='uniform'):
    
    
    # setup the training and save the parameters 
    
    n_states = n                      # number of states 
    n_messages = n                    # number of messages
    batch_size = 32                   # batch size
    datasize = 1000                   # number of training data points
    batches = datasize // batch_size  # number of bathes per epoch
    
    runs = 50                         # number of runs: 50 speaker-listener pairs are trained 
    init_mean = 0.1                   # mean for initialization of lexicon entries
    init_std = 0.01                   # std for initialization of lexicon entries 
    
    constraint = tf.keras.constraints.NonNeg() # constrains the lexica to have entries >= 0
    
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
    
    # set path to save data 
    if reasoning == 0: 
        if state_dist == 'uniform':
            filename = ('data/communication/L' + str(reasoning) + '_S' + str(reasoning) + 
                        '/SGD_' + str(n) + '_states/')
        elif state_dist == 'Zipf':
            filename = ('data/communication_Zipf/L' + str(reasoning) + '_S' + str(reasoning) + 
                        '/SGD_' + str(n) + '_states/')
    else: 
        if state_dist == 'uniform':
            filename = ('data/communication/L' + str(reasoning) + '_S' + str(reasoning) + '/SGD_' + str(n) 
                        + '_states_5.0alpha/')
        elif state_dist == 'Zipf':
            filename = ('data/communication_Zipf/L' + str(reasoning) + '_S' + str(reasoning) + '/SGD_' + str(n) 
                        + '_states_5.0alpha/')
    if not os.path.exists(filename):
            os.makedirs(filename)
    
    # generate parameter dictionary
    param_dict = {"n_states": n_states,"n_messages": n_messages, "n_epochs":n_epochs, "batch_size": batch_size,
                  "datasize":datasize, "initializer_truncated_normal_mean_std": [init_mean, init_std], 
                  "learning_rate":learning_rate, "runs": runs, "constraint":constraint, "optimizer":optimizer}

    with open(filename + 'param_dict.pickle', 'wb') as handle:
        pickle.dump(param_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    
    # run the speaker-listener pairs 
    
    for run in range(runs): 
        print(run)
        all_rewards = []
        
        # create listener
        lexicon_listener = tf.Variable(tf.initializers.TruncatedNormal(mean=init_mean, stddev=init_std)
                                       ([n_states, n_messages]),
                                       name="lexicon_listener", 
                                       trainable=True, 
                                       dtype=tf.float32,
                                       constraint=constraint)
        if reasoning == 0:
            listener = RSAListener0(n_states, n_messages, lexicon_listener)
        elif reasoning == 1:
            listener = RSAListener1(n_states, n_messages, lexicon_listener, alpha=5.)
        listener.compile(optimizer=optimizer, loss=tf.keras.losses.CategoricalCrossentropy()) 
        
        # create speaker 
        lexicon_speaker = tf.Variable(tf.initializers.TruncatedNormal(mean=init_mean, stddev=init_std)
                                      ([n_states, n_messages]),
                                      name="lexicon_speaker", 
                                      trainable=True, 
                                      dtype=tf.float32,
                                      constraint=constraint)
        if reasoning == 0:
            speaker = RSASpeaker0(n_states, n_messages, lexicon_speaker)
        elif reasoning == 1:
            speaker = RSASpeaker1(n_states, n_messages, lexicon_speaker, alpha=5.)
        speaker.compile(optimizer=optimizer, loss=tf.keras.losses.CategoricalCrossentropy())
        
        # create data 
        data, labels = generate_states(datasize, n, state_dist=state_dist)
        data = tf.convert_to_tensor(data)
        
        # train   
        all_rewards = []
        for i in range(n_epochs):
            
            shuffle_indices = np.random.permutation(datasize)
            data = tf.gather(data, shuffle_indices)
            labels = tf.gather(labels, shuffle_indices)

            average_reward = []

            for j in range(batches):
                data_batch = data[j:j+batch_size]
                labels_batch = labels[j:j+batch_size]

                _, messages = speaker.get_messages(data_batch)
                listener_input_messages = tf.reshape(tf.tile(messages, multiples=tf.constant([1,n_states])), 
                                                     shape=(batch_size, n_states, n_messages))
                listener_input_messages = tf.transpose(listener_input_messages, [0,2,1])
                _, states = listener.get_states(listener_input_messages)

                rewards = tf.einsum('ij,ij->i', labels_batch, states)
                average_reward.append(np.mean(rewards))
                
                # RL: 
                # Note that we implemented REINFORCE with a work-around using categorical crossentropy. 
                # This can be done by setting the labels to the agent's actions, and weighting the loss
                # function by the rewards. 
                speaker.train_on_batch(data_batch, messages, sample_weight=rewards)
                listener.train_on_batch(listener_input_messages, states, sample_weight=rewards)
            
            mean_reward = np.mean(average_reward)
            all_rewards.append(mean_reward)
            
            if i%25==0:
                print('average reward epoch ' + str(i), mean_reward)
        
        # save the mean rewards of all epochs 
        np.save(filename + 'rewards_'+str(run), all_rewards)

## Run the training

for different types of agents and different input distributions. 

### Uniform distribution

In [None]:
n = 10
ne = 300
lr = 0.01
alpha = 5.
state_dist = 'uniform'

In [None]:
# literal speaker-listener pair
run_speaker_listener_pair(reasoning=0, n=n, n_epochs=ne, learning_rate=lr, state_dist=state_dist)

In [None]:
# pragmatic speaker-listener pair
run_speaker_listener_pair(reasoning=1, n=n, n_epochs=ne, learning_rate=lr, state_dist=state_dist)

### Zipfian input distribution

In [None]:
n = 10
ne = 500
lr = 0.01
alpha = 5.
state_dist = 'Zipf'

In [None]:
# literal speaker-listener pair
run_speaker_listener_pair(reasoning=0, n=n, n_epochs=ne, learning_rate=lr, state_dist=state_dist)

In [None]:
# pragmatic speaker-listener pair
run_speaker_listener_pair(reasoning=1, n=n, n_epochs=ne, learning_rate=lr, state_dist=state_dist)