# Task & Talk
Full reimplementation of the task in [Kottur, et al. (2017)](https://arxiv.org/pdf/1706.08502.pdf). Using tabular Q-learning for now.

## Initialization and Housekeeping
Set parameters, initialize the Q-tables, and describe any helper functions we might need for post-analysis.
### Parameters

In [181]:
num_episodes = 500000
eta = 0.8
gamma = 0.95

epsilon = 1.0
max_epsilon = 1.0
min_epsilon = 0.01
decay_rate = 0.0001;

### Setting up Q-Table

In [182]:
q_vocab = 3
a_vocab = 12

num_tasks = 6
num_attributes = 4;

In [183]:
# A-State: [att * att, q_vocab, a_vocab+1, q_vocab+1] (+1 for empty vocab)
# Q-State (Utterance): [task, q_vocab+1, a_vocab+1] (+1 for empty vocab)
# Q-State (Guess): [task, q_vocab, a_vocab, q_vocab, a_vocab]
a_table = zeros(num_attributes, num_attributes, num_attributes, q_vocab, a_vocab+1, q_vocab+1, a_vocab)
q_table_utt = zeros(num_tasks, q_vocab+1, a_vocab+1, q_vocab)
q_table_guess = zeros(num_tasks, q_vocab, a_vocab, q_vocab, a_vocab, num_attributes, num_attributes)

a_visited = falses(num_attributes, num_attributes, num_attributes, q_vocab, a_vocab+1, q_vocab+1)
q_utt_visited = falses(num_tasks, q_vocab+1, a_vocab+1)
q_guess_visited = falses(num_tasks, q_vocab, a_vocab, q_vocab, a_vocab)

num_correct = 0;

## Accuracy measurements

In [184]:
num_exploits = 0;

### Helper Functions

In [185]:
function get_reward(a_state, guess, num_task)
    reward = -1
    if (num_task == 1)
        if (a_state[1:2] == guess)
            reward = 1
        end
    elseif (num_task == 2)
        if (a_state[2:3] == guess)
            reward = 1
        end
    elseif (num_task == 3)
        if ([a_state[1], a_state[3]] == guess)
            reward = 1
        end
    elseif (num_task == 4)
        if ([a_state[2], a_state[1]] == guess)
            reward = 1
        end
    elseif (num_task == 5)
        if ([a_state[3], a_state[2]] == guess)
            reward = 1
        end
    else
        if ([a_state[3], a_state[1]] == guess)
            reward = 1
        end
    end
    
    # Return 1 if bool
    return reward
end

get_reward (generic function with 2 methods)

## Q-Learning
The core loop of the program:

In [186]:
total_rewards = 0
for episode in 1:num_episodes
    tradeoff = rand() # exploration-exploitation
    explore = (tradeoff < epsilon)
    
    # Generate random task:
    q_state = [rand(1:num_tasks), q_vocab+1, a_vocab+1]
    q_states = [copy(q_state)]
    
    
    # Q-BOT
    #  -> TURN 1
    if explore | ~q_utt_visited[q_state...]
        # Explore:
        q_state[2] = rand(1:q_vocab)
    else
        # Exploit:
        options = q_table_utt[q_state[1], q_state[2], q_state[3], :]
        q_state[2] = argmax(options)
    end
    
    # Generate random object:
    a_state = [rand(1:num_attributes), rand(1:num_attributes), rand(1:num_attributes), q_state[2], a_vocab+1, q_vocab+1]
    a_states = [copy(a_state)]
    
    
    # A-BOT
    #  -> TURN 1
    if explore | ~a_visited[a_state...]
        # Explore:
        a_state[5] = rand(1:a_vocab)
    else
        # Exploit:
        options = a_table[a_state[1], a_state[2], a_state[3], a_state[4], a_state[5], a_state[6], :]
        a_state[5] = argmax(options)
    end
    
    q_state[3] = a_state[5]
    push!(q_states, copy(q_state))
    
    
    # Q-BOT
    #  -> TURN 2
    if explore | ~q_utt_visited[q_state...]
        # Explore:
        a_state[6] = rand(1:q_vocab)
    else
        # Exploit:
        options = q_table_utt[q_state[1], q_state[2], q_state[3], :]
        a_state[6] = argmax(options)
    end
    push!(a_states, copy(a_state))
    
    # Update Q-State for guessing attributes:
    q_state = [q_state[1], q_state[2], q_state[3], a_state[6], a_vocab+1]
    
    
    # A-BOT
    #  -> TURN 2
    if explore | ~a_visited[a_state...]
        # Explore:
        q_state[5] = rand(1:a_vocab)
    else
        # Exploit:
        options = a_table[a_state[1], a_state[2], a_state[3], a_state[4], a_state[5], a_state[6], :]
        q_state[5] = argmax(options)
    end
    
    
    # Q-BOT
    #  -> GUESSING PHASE
    guess = []
    if explore | ~q_guess_visited[q_state...]
        # Explore:
        guess = [rand(1:num_attributes), rand(1:num_attributes)]
    else
        # Exploit:
        options = q_table_guess[q_state[1], q_state[2], q_state[3], q_state[4], q_state[5], :, :]
        optimal_first = num_attributes + 1 # Running best attribute
        optimal_second = num_attributes + 1
        max_val = -99999
        for att1 in 1:num_attributes
            for att2 in 1:num_attributes
                if (options[att1, att2] > max_val)
                    max_val = options[att1, att2]
                    optimal_first = att1
                    optimal_second = att2
                end
            end
        end
        guess = [optimal_first, optimal_second]
    end
    
    # Update Reward Tables:
    reward = get_reward(a_state, guess, q_state[1])
    
    a_table[a_state[1],a_state[2],a_state[3],a_state[4],a_vocab+1, q_vocab+1, q_state[3]] += reward
    a_visited[a_state[1],a_state[2],a_state[3],a_state[4],a_vocab+1, q_vocab+1] = true
    a_table[a_state[1],a_state[2],a_state[3],a_state[4],a_state[5],a_state[6], q_state[5]] += reward
    a_visited[a_state[1],a_state[2],a_state[3],a_state[4],a_state[5],a_state[6]] = true
    
    q_table_utt[q_state[1],q_vocab+1, a_vocab+1, q_state[2]] += reward
    q_utt_visited[q_state[1],q_vocab+1, a_vocab+1] = true
    q_table_utt[q_state[1],q_state[2],q_state[3], q_state[4]] += reward
    q_utt_visited[q_state[1],q_state[2],q_state[3]] = true
    
    q_table_guess[q_state[1],q_state[2],q_state[3],q_state[4],q_state[5], guess[1], guess[2]] += reward
    q_guess_visited[q_state[1],q_state[2],q_state[3],q_state[4],q_state[5]] = true

    epsilon = min_epsilon + (max_epsilon - min_epsilon)*exp(-decay_rate*episode)
    
#     if (episode % 1000 == 0)
#         println(episode, " | ", epsilon)
#     end
    
    if (~explore) & (episode > 250000)
        total_rewards += reward
        num_exploits += 1
    end
end
println("Accuracy: ", total_rewards/num_exploits)

Accuracy: 0.9764066142279428
