# Convergence time scales in the single agent setting

## Setup the training

In [1]:
# IMPORTS

import tensorflow as tf
import numpy as np
import os
import sys
import pickle

# sys.path.append("path to the project")
from RSA_communication_agents import RSASpeaker0, RSAListener0, RSASpeaker1, RSAListener1

In [2]:
# GENERATE INPUT MESSAGES 

def generate_messages(size, n, message_dist):
    """ Generates the message input for the listener. The occurrence frequencies of the n messages follow either 
        a uniform or a Zipfian distribution. 
        inputs: 
            size - number of data points
            n - number of states and messages in the system
            message_dist -  'Zipf' (definition see paper) or 'uniform', describes the occurence frequencies of 
                            the different input messages
        outputs: 
            data - training data 
            labels - training labels
    """
    
    data = np.zeros((size, n, n))
    
    if message_dist == 'uniform':
        selection = np.random.choice(n, size=(size))
    elif message_dist == 'Zipf':
        norm_factor = np.sum([1/i for i in range(1,n+1)])
        selection = np.random.choice(np.arange(n), size=1000, p=[(1/i)/norm_factor for i in range(1,n+1)])
        
    for i in range(n):
        data[selection == i, i, :] = 1.
    labels = tf.one_hot(selection, depth=n)
    return np.float32(data), labels

In [3]:
# TRAIN A LISTENER AND SAVE THE PARAMETERS AS WELL AS THE REWARDS 

def run_listener(n=3, reasoning=0, n_epochs=100, message_dist='uniform', learning_rate=0.01):
    """ Trains the listener on a single agent Lewis game as described in Experiment 2. 
        inputs:
            n - number of states and messages
            reasoning - reasoning level of the listener, 0 for literal, 1 for pragmatic 
            n_epochs - number of training epochs
            learning_rate - learning rate for the Adam optimizer
            message_dist -  'Zipf' (definition see paper) or 'uniform', describes the occurence frequencies of 
                            the different input messages
        By default 50 agents are trained and all their lexica and rewards for every epoch are saved (agent-wise).
    """
    
    # setup the training and save the parameters 

    n_states = n                      # number of states 
    n_messages = n                    # number of messages
    batch_size = 32                   # batch size
    datasize = 1000                   # number of training data points
    batches = datasize // batch_size  # number of bathes per epoch
    
    runs = 50                         # number of runs: 50 speaker-listener pairs are trained 
    init_mean = 0.1                   # mean for initialization of lexicon entries
    init_std = 0.01                   # std for initialization of lexicon entries 
    
    constraint = tf.keras.constraints.NonNeg() # constrains the lexica to have entries >= 0

    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
    file_ending = 'SGD/' 
    
    # set path to save data 
    if reasoning == 0: 
        if message_dist == 'uniform': 
            filename = ('data/labeling/L' + str(reasoning) + '/SGD_' + str(n) + '_states/')
        elif message_dist == 'Zipf':
            filename = ('data/labeling_Zipf/L' + str(reasoning) + '/SGD_' + str(n) + '_states/')
    elif reasoning == 1: 
        if message_dist == 'uniform':
            filename = ('data/labeling/L' + str(reasoning) + '/SGD_' + str(n) + '_states_5.0alpha/')
        elif smessage_dist == 'Zipf':
            filename = ('data/labeling_Zipf/L' + str(reasoning) + '/SGD_' + str(n) + '_states_5.0alpha/')
    if not os.path.exists(filename):
            os.makedirs(filename)
    
    # generate parameter dictionary
    param_dict = {"n_states": n_states,"n_messages": n_messages, "n_epochs":n_epochs, "batch_size": batch_size,
                  "datasize":datasize, "initializer_truncated_normal_mean_std": [init_mean, init_std], 
                  "learning_rate":learning_rate, "runs": runs, "constraint":constraint, "optimizer":optimizer}
    
    with open(filename + 'param_dict.pickle', 'wb') as handle:
        pickle.dump(param_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    
    # run the listeners
    
    for run in range(runs):
        print(run)
        
        # create listener
        lexicon = tf.Variable(tf.initializers.TruncatedNormal(mean=init_mean, stddev=init_std)
                              ([n_states, n_messages]),
                              name="lexicon", 
                              trainable=True, 
                              dtype=tf.float32,
                              constraint=tf.keras.constraints.NonNeg())
        if reasoning == 0: 
            listener = RSAListener0(n_states, n_messages, lexicon)
        elif reasoning == 1: 
            listener = RSAListener1(n_states, n_messages, lexicon, alpha=5.)
        listener.compile(optimizer=optimizer, loss=tf.keras.losses.CategoricalCrossentropy())
        
        # create data
        data, labels = generate_messages(datasize, n, message_dist=message_dist)
        data = tf.convert_to_tensor(data)
        
        # train 
        all_rewards = []
        for i in range(n_epochs):
            
            shuffle_indices = np.random.permutation(datasize)
            data = tf.gather(data, shuffle_indices)
            labels = tf.gather(labels, shuffle_indices)
            
            rewards_epoch = []

            for j in range(batches):
                data_batch = data[j:j + batch_size]
                labels_batch = labels[j:j + batch_size]

                _, states = listener.get_states(data_batch)
                
                # RL: 
                # Note that we implemented REINFORCE with a work-around using categorical crossentropy. 
                # This can be done by setting the labels to the agent's actions, and weighting the loss
                # function by the rewards. 
                rewards = tf.einsum('ij,ij->i', labels_batch, states)
                loss = listener.train_on_batch(data_batch, states, sample_weight=rewards)
                
                rewards_epoch.append(np.mean(rewards))
            
            mean_reward = np.mean(rewards_epoch)
            
            if i%25==0:
                print('average reward epoch ' + str(i), mean_reward)
            all_rewards.append(mean_reward)
        
        # save the mean rewards of all epochs 
        np.save(filename + 'rewards_'+str(run), all_rewards)

## Run the training 

### Uniform input distribution

In [None]:
n = 10 # number of states and messages
lr = 0.01 # learning rate
ne = 100 # number of epochs
message_dist = 'uniform' # message input distribution

In [None]:
# literal listener
run_listener(learning_rate=lr, n=n, n_epochs=ne, reasoning=0, message_dist=message_dist)

In [None]:
# pragmatic listener
run_listener(learning_rate=lr, n=n, n_epochs=ne, reasoning=1, message_dist=message_dist)

### Zipfian input distribution

In [None]:
n = 10 # number of states and messages
lr = 0.01 # learning rate
ne = 100 # number of epochs
message_dist = 'Zipf' # message input distribution

In [None]:
# literal listener
run_listener(learning_rate=lr, n=n, n_epochs=ne, reasoning=0, message_dist=message_dist)

In [None]:
# pragmatic listener
run_listener(learning_rate=lr, n=n, n_epochs=ne, reasoning=1, message_dist=message_dist)