# ME bias in the single agent setting

## Setup the training

In [None]:
# IMPORTS 

import tensorflow as tf
import numpy as np
import sys
import os
import pickle

sys.path.append("/Users/xxAlv1Nxx/Documents/01a Masters/05 CS/CS 428B Probabilistic Models of Cognition – Language/Project/pragmatic_agents_me_bias")
from RSA_communication_agents import RSAListener0, RSAListener1

In [None]:
# GENERATE INPUT MESSAGES 

def generate_messages(size, n, ME=1):
    """ Generates the message input for the listener
        :param size:    number of data points
        :param n:       number of states and messages in the system
        :param ME:      number of messages that are withheld from training, here 1 or 2
        :return data:   training data 
        :return labels: training labels
    """
    
    data = np.zeros((size, n, n))
    if ME == 1:
        selection = np.random.choice(n-1, size=(size))
    elif ME == 2: 
        selection = np.random.choice([i for i in range(1, n-1)], size=(size))
    for i in range(n):
        data[selection == i, i, :] = 1.
    labels = tf.one_hot(selection, depth=n)
    
    return np.float32(data), labels

In [None]:
# GENERATE BILINGUAL INPUT MESSAGES 

def generate_bilingual_messages(size, n, ME=1, n_blocks=0, lang_edit=0):
    """ Generates the message input for the listener
        :param size:        number of data points
        :param n:           number of states in the system (n_messages = 2*n)
        :param ME:          number of messages that are withheld from training, here 1 or 2
        :param n_blocks:    number of language blocks (>=1), or not blocked (0, i.e. random)
        :return data:       training data 
        :return labels:     training labels
    """
    
    data = np.zeros((size, 2*n, n))

    if n_blocks == 0:
        if ME == 1:
            selection = np.random.choice([i for i in range(0, 2*n-1) if i != n-1], size=(size))
        elif ME == 2: 
            selection = np.random.choice([i for i in range(1, 2*n-1) if (i != n-1 and i != n)], size=(size))
    else:
        if ME == 1:
            selection = np.random.choice([i for i in range(0, n-1)], size=(size))
        elif ME == 2: 
            selection = np.random.choice([i for i in range(1, n-1)], size=(size))
        # switch language every other block
        divisor = size // n_blocks
        selection = np.array([v if not (i // divisor) % 2 else v + n for i, v in enumerate(selection)])

    selection = selection + n * lang_edit

    for i in range(2*n):
        data[selection == i, i, :] = 1.
    labels = tf.one_hot(tf.math.floormod(selection, n), depth=n)
    
    return np.float32(data), labels

In [None]:
# TRAIN A LISTENER AND SAVE THE PARAMETERS AS WELL AS THE REWARDS AND THE LEXICA

def run_listener(n=10, reasoning=0, n_epochs=100, ME=1, learning_rate=0.001, runs=100, blocked=0, n_blocks=1):
    """ Trains the listener on a single agent Lewis game as described in Experiment 1, saving all lexica and rewards for every epoch.
        :param n:               number of states and messages (in total)
        :param reasoning:       reasoning level of the listener, 0 for literal, 1 for pragmatic 
        :param ME:              number of messages that are left out during training, here 1 or 2
        :param learning_rate:   learning rate for the Adam optimizer
        :param runs:            number of runs
        :param blocked:         0: not blocked; 1: blocked within epochs; 2: blocked across epochs
        :param n_blocks:        number of blocks; ignored if blocked == 0
    """
    
    # setup the training and save the parameters 
    
    n_states = n                      # number of states 
    n_messages = 2*n                  # number of messages
    batch_size = 32                   # batch size
    datasize = 1000                   # number of training data points
    batches = datasize // batch_size  # number of batches per epoch
    
    runs = runs                       # number of runs: 100 speaker-listener pairs are trained 
    init_mean = 0.5                   # mean for initialization of lexicon entries
    init_std = 0.01                   # std for initialization of lexicon entries
    
    constraint = tf.keras.constraints.NonNeg() # constrains the lexica to have entries >= 0
    
    b_str = '_blocked' if blocked == 1 else '_blocked_e' if blocked == 2 else ''
    bn_str = str(n_blocks) + '_blocks/' if blocked > 0 else ''
    filename = 'data/bilingual' + b_str + '/L' + str(reasoning) + '/' + str(n) + '_states/' + bn_str
    if not os.path.exists(filename):
            os.makedirs(filename)
    
    param_dict = {
        "n_states": n_states, "n_messages": n_messages, "n_epochs": n_epochs, "batch_size": batch_size,
        "datasize": datasize, "initializer_truncated_normal_mean_std": [init_mean, init_std], 
        "learning_rate": learning_rate, "runs": runs, "constraint": constraint, 
        "blocked": blocked, "n_blocks": n_blocks}    
    with open(filename + 'param_dict.pickle', 'wb') as handle:
        pickle.dump(param_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    
    # run the listeners

    for run in range(1,runs+1):
        
        # create data 
        if blocked < 2:
            data, labels = generate_bilingual_messages(datasize, n, ME=ME, n_blocks=n_blocks)
            data = tf.convert_to_tensor(data)
        else:
            data1, labels1 = generate_bilingual_messages(datasize, n, ME=ME, n_blocks=1)
            data1 = tf.convert_to_tensor(data1)
            data2, labels2 = generate_bilingual_messages(datasize, n, ME=ME, n_blocks=1, lang_edit=1)
            data2 = tf.convert_to_tensor(data2)
            data_all = [data1, data2]
            labels_all = [labels1, labels2]
        lexica = []
        all_rewards = []
        
        # create listener
        lexicon = tf.Variable(
            tf.initializers.TruncatedNormal(mean=init_mean, stddev=init_std)
            ([n_states, n_messages]),
            name="lexicon",
            trainable=True,
            dtype=tf.float32,
            constraint=tf.keras.constraints.NonNeg())
        
        if reasoning == 0: 
            listener = RSAListener0(n_states, n_messages, lexicon)
        elif reasoning == 1: 
            listener = RSAListener1(n_states, n_messages, lexicon, alpha=5.)
                
        listener.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                      loss=tf.keras.losses.CategoricalCrossentropy())
        
        # train
        for i in range(n_epochs):
            average_reward = []

            # shuffle data and labels
            if blocked == 2:
                divisor = 100 // n_blocks
                data = data_all[(i // divisor) % 2]
                labels = labels_all[(i // divisor) % 2]
                shuffle_indices = np.random.permutation(datasize)
            elif blocked == 1:
                shuffle_indices_1 = np.random.permutation(datasize//2)
                shuffle_indices_2 = np.random.permutation(datasize//2) + datasize//2
                shuffle_indices = np.concatenate((shuffle_indices_1, shuffle_indices_2))
            else:
                shuffle_indices = np.random.permutation(datasize)
            data = tf.gather(data, shuffle_indices)
            labels = tf.gather(labels, shuffle_indices)

            for j in range(batches):
                data_batch = data[j*batch_size:(j+1)*batch_size]
                labels_batch = labels[j*batch_size:(j+1)*batch_size]
    
                _, actions = listener.get_states(data_batch)
    
                rewards = tf.einsum('ij,ij->i', labels_batch, actions)
                average_reward.append(np.mean(rewards))
                
                # RL: 
                # Note that we implemented REINFORCE with a work-around using categorical crossentropy. 
                # This can be done by setting the labels to the agent's actions, and weighting the loss
                # function by the rewards. 
                listener.train_on_batch(data_batch, actions, sample_weight=rewards)
            
            mean_reward = np.mean(average_reward)
            all_rewards.append(mean_reward)
            lexica.append(np.copy(listener.lexicon[:]))
                        
        print('run ' + str(run), 'average reward ' +str(ME)+ ' ' + str(mean_reward))
        
        # save rewards and lexica 
        if reasoning == 0:
            filename_full = filename + 'L' + str(reasoning) +'_'+ str(ME) + 'missing_'
        elif reasoning == 1:
            filename_full = filename + 'L' + str(reasoning) +'_'+ str(ME) + 'missing_5.0alpha_'
        np.save(filename_full + 'lexicon_run' + str(run), lexica)
        np.save(filename_full + 'rewards_run' + str(run), all_rewards)

## Run the training 

for different agent types, numbers of state (3 and 10) and different number of states being withheld from the training (1 and 2). 

### Literal listener

In [None]:
run_listener(n=3, reasoning=0, n_epochs=50, ME=1)

In [None]:
run_listener(learning_rate=0.001, n=3, reasoning=0, n_epochs=50, ME=2)

In [None]:
run_listener(learning_rate=0.001, n=10, reasoning=0, n_epochs=100, ME=1)

In [None]:
run_listener(learning_rate=0.001, n=10, reasoning=0, n_epochs=100, ME=2)

In [None]:
run_listener(learning_rate=0.001, n=100, reasoning=0, n_epochs=100, ME=1)

### Pragmatic listener

In [None]:
run_listener(learning_rate=0.001, n=3, reasoning=1, n_epochs=100, ME=1)

In [None]:
run_listener(learning_rate=0.001, n=3, reasoning=1, n_epochs=50, ME=2)

In [None]:
for i in range(4, 11):
    run_listener(learning_rate=0.001, n=i, reasoning=1, n_epochs=100, ME=1)

In [None]:
for i in range(20, 60, 10):
    run_listener(learning_rate=0.001, n=i, reasoning=1, n_epochs=100, ME=1)

In [None]:
for i in range(60, 120, 20):
    run_listener(learning_rate=0.001, n=i, reasoning=1, n_epochs=100, ME=1)

In [None]:
run_listener(learning_rate=0.001, n=10, reasoning=1, n_epochs=100, ME=2)

In [None]:
run_listener(learning_rate=0.001, n=10, reasoning=1, n_epochs=100, ME=1, blocked=1, n_blocks=2)

In [None]:
for i in range(1, 5):
    run_listener(learning_rate=0.001, n=10, reasoning=1, n_epochs=100, ME=1, blocked=2, n_blocks=2**i)