In [1]:
# IMPORTS

import tensorflow as tf
import numpy as np
import os
import pickle

from agents import PragmaticListener, PragmaticSpeaker, LiteralListener
from data_helpers import generate_inputs

In [2]:
# TRAIN THE AGENT AND SAVE PARAMETERS + RESULTS

def run_agent(n=100, 
              reasoning=1, # 0 = literal learning, 1=pragmatic learning
              n_epochs=1000, 
              input_dist='Zipf',
              learning_rate=0.1,
              alpha=5.,
              iter_step=10,
              train_mode='RL', 
              batch_size = 32, 
              data_size = 1000,
              init_mean = 0.01,
              runs = 100, 
              samples = 25,
              inference = 'pragmatic' # 'literal'=literal inference, 'pragmatic'=pragmatic inference
            ):
    
    ### setup the training and save the parameters ###

    batches = data_size // batch_size  # number of bathes per epoch
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
    
    if inference == pragmatic and reasoning == 1: 
        filename = ('results_fixed_lexicon/interval_' + str(iter_step) + '_alpha_' + str(alpha) + '/')
    else: 
        filename = ('results_fixed_lexicon/' + ['literal', 'pragmatic'][reasoning] + '-learning_' +
                    inference + '-inference/interval_' + str(iter_step) + '_alpha_' + str(alpha) + '/')
              
    if not os.path.exists(filename):
            os.makedirs(filename)
            
    param_dict = {"n_states": n,
                  "n_messages": n, 
                  "n_epochs": n_epochs, 
                  "batch_size": batch_size,
                  "data_size": data_size, 
                  "init_mean": init_mean,
                  "learning_rate": learning_rate, 
                  "runs": runs, 
                  "alpha": alpha,
                  "iter_step": iter_step,
                  "reasoning": reasoning,
                  "input_dist": input_dist, 
                  "train_mode": train_mode,
                  "agent_mode": "single", 
                  "samples": samples, 
                  "inference": inference
                 }
    with open(filename + 'param_dict.pickle', 'wb') as handle:
        pickle.dump(param_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    ### iterate over runs ###
    
    for run in range(runs):
        
        print('interval: ', iter_step, ', run: ', run)
        
        counts = np.zeros((n_epochs, n, 3)) # 0: correct, 1: false , 2: potentially correct
        
        ### build the agent ###
        
        lexicon = tf.Variable(tf.initializers.Constant(init_mean)([n, n]),
                              name="lexicon", 
                              trainable=True, 
                              dtype=tf.float32,
                              constraint=tf.keras.constraints.NonNeg())
        
        if reasoning == 0:
            listener = LiteralListener(n, n, lexicon)
        elif reasoning == 1:
            listener = PragmaticListener(n, n, lexicon, alpha=alpha)
        listener.compile(optimizer=optimizer, loss=tf.keras.losses.CategoricalCrossentropy())
        
        ### train the agent ###
        
        # save the following results:
        # rewards per epoch --> all_rewards 
        # policies in the context-dependent evaluation per epoch --> policies_all_reference 
        # policies in the context-free evaluation per epoch --> policies_all_lewis
        # number of correct, false, and potentially correct selections per epoch --> counts
        # (potentially correct means any selecting any of the unfamiliar objects, not necessarily the correct one)
        
        all_rewards = []
        all_policies_lewis = []
        all_policies_reference = []
        number_of_states = 0
        has_occurred = np.zeros(n)
        
        for i in range(n_epochs):
            
            # add a new sample every iter_step 
            if i % iter_step == 0 and number_of_states<n: 
                number_of_states += 1  
                
                # lewis game evaluation = context-free evaluation
                new_message = np.zeros((1, n, n))
                new_message[:, number_of_states-1, :] = 1.
                
                if inference == 'pragmatic': 
                    pragmatic_listener = PragmaticListener(n, n, np.copy(listener.lexicon[:]), alpha=alpha)
                    policy, states = pragmatic_listener.get_states(new_message)
                elif inference == 'literal': 
                    literal_listener = LiteralListener(n, n, np.copy(listener.lexicon[:]))
                    policy, states = literal_listener.get_states(new_message)
                    
                policy = policy / tf.expand_dims(tf.reduce_sum(policy, axis=1), axis=1)
                all_policies_lewis.append(policy)
                
                # reference game evaluation = context-dependent evaluation
                if number_of_states >= 2:
                    policies_state  = np.empty(99)
                    policies_state[:] = np.NaN
                    
                    old_lexicon = np.copy(listener.lexicon[:])[0:number_of_states-1, 0:number_of_states-1]
                    current_lexicon = np.copy(listener.lexicon[:])[0:number_of_states, 0:number_of_states]
                    
                    for j in range(number_of_states-1):
                        
                        if inference == 'pragmatic':
                            
                            test_speaker = PragmaticSpeaker(number_of_states-1, number_of_states-1, old_lexicon, alpha=alpha)
                            speaker_target = np.zeros((1, number_of_states-1, number_of_states-1), dtype=np.float32)
                            speaker_target[0,j,:] = np.ones((1, number_of_states-1))
                            policy_speaker, _ = test_speaker.get_messages(speaker_target)

                        elif inference == 'literal':
                            
                            policy_speaker = [old_lexicon[j,:]]
                        
                        policy_speaker = policy_speaker / tf.reduce_sum(policy_speaker)
                        relevant_messages = np.unique(np.random.choice(number_of_states-1, size=samples, 
                                                                           p=np.array(policy_speaker)[0]))
                        relevant_messages = np.append(relevant_messages, number_of_states-1)
                            
                        dim2 = len(relevant_messages)
                        test_lexicon = np.zeros((2, len(relevant_messages)))
                        test_lexicon[0,:] = current_lexicon[j, relevant_messages]
                        test_lexicon[1,:] = current_lexicon[number_of_states-1, relevant_messages]
                        
                        if inference == 'pragmatic':
                            test_listener = PragmaticListener(2, len(relevant_messages)+1, test_lexicon, alpha=alpha)
                            
                        elif inference == 'literal':
                            test_listener = LiteralListener(2, len(relevant_messages)+1, test_lexicon)
                        
                        target = np.zeros((1, dim2, 2))
                        target[0, dim2-1, :] = 1
                        p, _ = test_listener.get_states(target)
                        p = p / tf.reduce_sum(p)
                        policies_state[j] = np.array(p)[:,-1]
                    
                    all_policies_reference.append(policies_state) 
                
            # generate random data set   
            data, selections, labels = generate_inputs(data_size, 
                                                       n, 
                                                       number_of_states, 
                                                       distribution=input_dist)
            shuffle_indices = np.random.permutation(data_size)
            data = tf.gather(data, shuffle_indices)
            labels = tf.gather(labels, shuffle_indices)
            selections = tf.gather(selections, shuffle_indices)
            
            rewards_epoch = []
            
            # training loop
            for j in range(batches):
                data_batch = data[j:j+batch_size]
                labels_batch = labels[j:j+batch_size]
                selection_batch = np.array(selections[j:j+batch_size])
        
                _, states = listener.get_states(data_batch)
                states_non_hot = np.argmax(states, axis=1)
                rewards = tf.einsum('ij,ij->i', tf.cast(labels_batch, dtype=tf.float32), states)
                
                if train_mode == 'RL':
                    loss = listener.train_on_batch(data_batch, states, sample_weight=rewards)
                elif train_mode == 'supervised':
                    loss = listener.train_on_batch(data_batch, labels_batch)
    
                rewards_epoch.append(np.mean(rewards))
                
                correct_states = selection_batch[states_non_hot==selection_batch]
                false_states = selection_batch[np.logical_and(states_non_hot!=selection_batch, 
                                                              has_occurred[states_non_hot]==1)]
                potentially_correct_states = selection_batch[np.logical_and(states_non_hot!=selection_batch, 
                                                                            has_occurred[states_non_hot]==0)]
                unique, occurrences = np.unique(correct_states, return_counts=True)
                counts[i, unique, 0] += occurrences 
                unique, occurrences = np.unique(false_states, return_counts=True)
                counts[i, unique, 1] += occurrences 
                unique, occurrences = np.unique(potentially_correct_states, return_counts=True)
                counts[i, unique, 2] += occurrences 
                
                has_occurred[selection_batch] = 1
            
            mean_reward = np.mean(rewards_epoch)
            all_rewards.append(mean_reward)        
        
        print('final reward ' + str(all_rewards[iter_step * n - 1]))
        
        # save results 
        np.save(filename + 'counts_' + str(run), counts)
        np.save(filename + 'rewards_'+ str(run), all_rewards)
        np.save(filename + 'policies_lewis_' + str(run), all_policies_lewis)
        np.save(filename + 'policies_reference_' + str(run), all_policies_reference)

In [None]:
# main results 

for iter_step in [1,3,6,9,12,15]:
    if iter_step == 1: 
        run_agent(iter_step=iter_step, n_epochs=iter_step*100 + 1000, runs=500)
    else:
        run_agent(iter_step=iter_step, n_epochs=iter_step*100)

In [None]:
# different learning - inference combinations

for iter_step in [1,3,6,9,12,15]:
    for reasoning in [0,1]:
        if reasoning == 0: 
            inference = 'pragmatic'
        if reasoning == 1: 
            inference = 'literal'
        if iter_step == 1: 
            run_agent(iter_step=iter_step, n_epochs=iter_step*100 + 1000, runs=500, inference=inference)
        else:
            run_agent(iter_step=iter_step, n_epochs=iter_step*100, inference=inference)