In [4]:
# IMPORTS

import tensorflow as tf
import numpy as np
import os
import pickle

from agents import PragmaticListener, PragmaticSpeaker, LiteralListener
from data_helpers import generate_inputs

In [9]:
# TRAIN THE AGENT AND SAVE PARAMETERS + RESULTS

def run_agent(n=100, 
              reasoning=1, 
              n_epochs=1000, 
              input_dist='Zipf',
              learning_rate=0.1,
              alpha=5.,
              iter_step=10,
              train_mode='RL', 
              batch_size = 32, 
              data_size = 1000,
              init_mean = 0.1,
              runs = 100):
    
    ### setup the training and save the parameters ###

    batches = data_size // batch_size  # number of bathes per epoch
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)

    filename = ('results_dynamic_lexicon/interval_' + str(iter_step) 
                + '_alpha_' + str(alpha) + '/')
    if not os.path.exists(filename):
            os.makedirs(filename)
            
    param_dict = {"n_states": n,
                  "n_messages": n, 
                  "n_epochs": n_epochs, 
                  "batch_size": batch_size,
                  "data_size": data_size, 
                  "init_mean": init_mean,
                  "learning_rate": learning_rate, 
                  "runs": runs, 
                  "alpha": alpha,
                  "iter_step": iter_step,
                  "reasoning": reasoning,
                  "input_dist": input_dist, 
                  "train_mode": train_mode,
                  "agent_mode": "single"}
    with open(filename + 'param_dict.pickle', 'wb') as handle:
        pickle.dump(param_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    ### iterate over runs ###
    
    n_start = 2 # initial lexicon size

    for run in range(runs):
        
        print('interval: ', iter_step, ', run: ', run)
        
        ### build the agent ###
        
        lexicon = tf.Variable(tf.initializers.Constant(init_mean)
                              ([n_start, n_start]),
                              name="lexicon", 
                              trainable=True, 
                              dtype=tf.float32,
                              constraint=tf.keras.constraints.NonNeg())
        
        listener = PragmaticListener(n_start, n_start, lexicon, alpha=alpha)
        listener.compile(optimizer=optimizer, loss=tf.keras.losses.CategoricalCrossentropy())
        
        ### train the agent ###
        
        # save the following results: 
        # rewards per epoch --> all_rewards 
        # policies in the context-dependent evaluation per epoch --> policies_all_reference 
        # policies in the context-free evaluation per epoch --> policies_all_lewis
        # number of correct, false, and potentially correct selections per epoch --> counts
        # (potentially correct is not relevant here, only for the fixed lexicon)
        
        all_rewards = []                    
        policies_all_reference = []
        policies_all_lewis = []
        number_of_states = n_start
        has_occurred = np.zeros(n)
        counts = np.zeros((n_epochs, n, 3)) # 0: correct, 1: false , 2: potentially correct
        
        for i in range(n_epochs):
            
            # add a new sample every iter_step and accordingly expand the lexicon
            if i % iter_step == 0 and i>=n_start*iter_step and number_of_states < 100: 
                number_of_states += 1     
                
                # expand lexicon
                old_lexicon = np.copy(listener.lexicon[:])
                old_lexicon_mean = np.mean(old_lexicon)
                new_lexicon = old_lexicon_mean * np.ones((number_of_states, number_of_states), 
                                                         dtype=np.float32)

                new_lexicon[0:-1,0:-1] = old_lexicon
                plt.imshow(new_lexicon)
                plt.colorbar()
                plt.show()
                new_lexicon = tf.Variable(new_lexicon, dtype=tf.float32, constraint=tf.keras.constraints.NonNeg())
                if reasoning == 1: 
                    listener = PragmaticListener(number_of_states, number_of_states, new_lexicon, alpha=alpha)
                elif reasoning == 0:
                    listener = LiteralListener(number_of_states, number_of_states, new_lexicon)
                    
                listener.compile(optimizer=optimizer, loss=tf.keras.losses.CategoricalCrossentropy())
                
                # lewis game evalution = context-free evaluation
                new_message = np.zeros((1,number_of_states, number_of_states))
                new_message[:,-1,:] = 1.
                policy, states = listener.get_states(new_message)
                policy = policy / tf.expand_dims(tf.reduce_sum(policy, axis=1), axis=1)
                policies_all_lewis.append(policy)
          
                # reference game evaluation = context-dependent evaluation
                policies_state  = np.empty(99)
                policies_state[:] = np.NaN
                for j in range(number_of_states-1):
            
                    relevant_row = old_lexicon[j,:]
                    
                    if reasoning == 1: 
                        test_speaker = PragmaticSpeaker(number_of_states-1, number_of_states-1, 
                                                   old_lexicon, alpha=alpha)
                    elif reasoning == 0:
                        test_speaker = PragmaticSpeaker(number_of_states-1, number_of_states-1, old_lexicon)
                        
                    speaker_target = np.zeros((1, number_of_states-1, number_of_states-1), dtype=np.float32)
                    speaker_target[0,j,:] = np.ones((1, number_of_states-1))
                    policy_speaker, _ = test_speaker.get_messages(speaker_target)
                    policy_speaker = policy_speaker / tf.reduce_sum(policy_speaker)
                    relevant_messages = np.unique(np.random.choice(number_of_states-1, size=25, p=np.array(policy_speaker)[0]))
                    
                    dim1 = 2
                    dim2 = len(relevant_messages)+1
                    test_lexicon = old_lexicon_mean * np.ones((dim1, dim2))
                    test_lexicon[0,0:len(relevant_messages)] = relevant_row[relevant_messages]
                    
                    if reasoning == 1: 
                        test_listener = PragmaticListener(2, len(relevant_messages)+1, test_lexicon, alpha=alpha)
                    elif reasoning == 0:
                        test_listener = LiteralListener(2, len(relevant_messages)+1, test_lexicon)
                        
                    target = np.zeros((1, dim2, dim1))
                    target[0, dim2-1, :] = 1
                    p, _ = test_listener.get_states(target)
                    p = p / tf.reduce_sum(p)
                    policies_state[j] = np.array(p)[:,-1]
                
                policies_all_reference.append(policies_state)
            
            # generate random data set
            data, selection, labels = generate_inputs(data_size, 
                                                      number_of_states, 
                                                      number_of_states, 
                                                      distribution=input_dist)
            shuffle_indices = np.random.permutation(data_size)
            data = tf.gather(data, shuffle_indices)
            labels = tf.gather(labels, shuffle_indices)
            selection = tf.gather(selection, shuffle_indices)
            
            rewards_epoch = []
            
            # training loop
            for j in range(batches):
                data_batch = data[j:j+batch_size]
                labels_batch = labels[j:j+batch_size]
                selection_batch = np.array(selection[j:j+batch_size])
                
                p, states = listener.get_states(data_batch)
                rewards = tf.einsum('ij,ij->i', tf.cast(labels_batch, tf.float32), states)
                if train_mode == 'RL':
                        loss = listener.train_on_batch(data_batch, states, sample_weight=rewards)
                elif train_mode == 'supervised':
                        loss = listener.train_on_batch(data_batch, labels_batch)
                
                rewards_epoch.append(np.mean(rewards))
                
                states_non_hot = np.argmax(states, axis=1)
                correct_states = selection_batch[states_non_hot==selection_batch]
                false_states = selection_batch[np.logical_and(states_non_hot!=selection_batch, 
                                                              has_occurred[states_non_hot]==1)]
                potentially_correct_states = selection_batch[np.logical_and(states_non_hot!=selection_batch, 
                                                                            has_occurred[states_non_hot]==0)]
                unique, occurrences = np.unique(correct_states, return_counts=True)
                counts[i, unique, 0] += occurrences 
                unique, occurrences = np.unique(false_states, return_counts=True)
                counts[i, unique, 1] += occurrences 
                unique, occurrences = np.unique(potentially_correct_states, return_counts=True)
                counts[i, unique, 2] += occurrences 
                
                has_occurred[selection_batch] = 1
            
            mean_rewards = np.mean(rewards_epoch)
            all_rewards.append(mean_rewards)
            
        
        print('final reward ' + str(all_rewards[iter_step * n - 1]))
        
        # save results 
        np.save(filename + 'policies_ref_single_' + str(run), policies_all_reference) # single because of ONE distractor
        np.save(filename + 'policies_lewis_' + str(run), policies_all_lewis)
        np.save(filename + 'rewards_' + str(run), all_rewards)
        np.save(filename + 'counts_' + str(run), counts)

In [None]:
for iter_step in [1,3,6,9,12,15]:
    if iter_step == 1: 
        run_agent(iter_step=iter_step, n_epochs=iter_step*100 + 1000, runs=500)
    else:
        run_agent(iter_step=iter_step, n_epochs=iter_step*100)