<a href="https://colab.research.google.com/github/WillieCubed/ai-gym-experiments/blob/master/notebooks/knowledge_graphs_and_nlp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# An Exploration of Knowledge Graphs and Natural Language Processing in Reinforcement Learning

Based off "[How to Avoid Being Eaten by a Grue: Structured Exploration Strategies for Textual Worlds](https://arxiv.org/pdf/2006.07409.pdf)" by Prithviraj Ammanabrolu, Ethan Tien, Matthew Hausknech, and Mark O. Riedl.

## Overview
Text-adventure games like *Zorkl* pose some challenges:
1. **Partial observability**: An agent reasoning about the world solely through incomplete textual descriptions
2. **Commonsense reasoning**: An agent using common sense to more intelligently interact with objects in its surroundings
3. **A combinatorial state-action space**: Most games have action spaces exeeding a billion possible actions per step

Some text-adventure agents like KG-A2C, TDQN, and DRRN use simple exploration strategies like epsilon-greedy or sampling from the distribution of possible actions.

The paper focuses on detecting and overcoming bottleneck states. Most text-adventure games have linear plots that involve solving sequences of puzzles to advance a story. To solve puzzles, players explore unlocked areas of the game, collect clues, and acquire tools required to solve them. Puzzles can be viewed as bottlenecks that partition different regions of the state space. Existing RL agents are poorly equipped to solve these types of problems.

Quests in text games (and any such sequential descition-making problem requiring long-term dependencies) can be modeled as directed acyclic graphs (DAGs) in which verticies indicate rewards that can be collected or dependencies that must be satisfied to progress. Text-adventure games have two types of dependencies

## Related Work

## Q*BERT
Q*BERT is a reinforcement learning algorithm based on KG-A2C. It uses a knowledge graph to represent its understanding of the world state.

A knowledge graph is a set of relations 
$\langle s, r, o \rangle$ such that $s$ is a subject, $r$ is a relationship, and $o$ is some object.

Q*BERT uses a variant of the BERT language transformer to answer questions about the current state and populate the knowledge graph from the answers.

### Knowlege Graph State Representation


- The language model ALBERT was used.
  - Fine-tuned for QA on the SQuAD dataset.
  - Fine-tuned again on Jericho-QA, a dataset created by making question answering pairs about text-games
  - Jericho is a framework for RL in text-games.
  

# TODO List
- Set up Jericho
- Set up Q-BERT

## KG-A2C
### Overview

From "[Graph Constrained Reinforcement Learning for Natural Language Action Spaces](https://openreview.net/pdf?id=B1x6w0EtwH)"
KG-A2C is an agent that builds a knowledge graph while exploring and generates actions using a template-based action space.




## KG-A2C Architecuture
KG-A2C consists of three units:
- Input representation containing:
  - Observation encoder
  - Score encoder
  - Knowledge graph
- Action decoder
- Critic

### Input Representation
At every step, an observation of the room description, game feedback, inventory, and previous action $o_t = ({o_t}_{desc}, {o_t}_{game}, {o_t}_{inv}, a_{t-1})$ is received along with a total score $R_t$.

- ${o_t}_{desc}$ is a textual description of the agent's location, corresponds to command "look"
- ${o_t}_{game}$ is the simulator's response to the agent's previous action and consists of narrative text
- ${o_t}_{inv}$ and ${a_{t-1}}$ inform the agent about the contents of its inventory and previous action, respectively




In [None]:
#@title Setting up Depdendencies


## Training
Every step, a knowledge graph $\mathit{G_{t}}$

$\mathit{H}=\{\mathbf{h_1}, \mathbf{h_2}, ... \mathbf{h_N} | \mathbf{h}_i\}$ where $\mathit{N}$ is the number of nodes and $\mathit{F}$ is the number of features in each node consisting of the average subword embeddings of the entity.(?)


$\mathit{e_{ij}} = \mathit{LeakyReLU}(\mathbf{p} \mathit{W}(\mathbf{h_i} \oplus \mathbf{h_j})$

In [None]:
def edge(start, end, features):
    pass

In [None]:
#@title Environment setup {display-mode: "form"}

import spacy.cli
import torch.nn as nn

# Q-BERT setup
spacy.cli.download('en_core_web_sm')

!wget https://github.com/BYU-PCCL/z-machine-games/archive/master.zip
!unzip master.zip

!pip install jericho
!java

# Implementing KG-A2C


# Implementing the Grue Paper
Now we begin using what we know about KG-2AC to create an agent that can play Zork I.

In [None]:
#@title Training parameters
#@markdown Things that affect model training.

MAX_TIMESTEPS = 10000 #@param {type: "number", min: 1}
MAX_PATIENCE = 100 #@param {type: "number", min: 1}
#@markdown ---

In [None]:
# Main training loop

def train():
    pass


In [None]:
a2c = ...

def qbert_update(state_t, policy):
    """Update after each step.
    
    Params:
        state_t: The state at a given time step
        policy:
    """
    next_state, reward_g_t = env.step(state_t, policy)
    reward = calculate_reward(next_state, reward_g_t)
    a2c.update(policy, reward)
    return next_state, reward, policy

In [None]:
def backtrack(backtrack_policy, backtrack_buffer):
    """Attempt to overcome a bottleneck.
    Params:
        backtrack_policy:
        backtrack_buffer: 
    """
    

In [None]:
env = ...

def structured_exploration():
    chained_policy, backtrack_policy, current_policy = ...
    backtrack_state_buffer = []
    current_state_buffer = []

    state, reward = env.reset()

    def return(policy):
        return 0

    max_return = reward
    patience = 0
    for t in range(MAX_TIMESTEPS):
        next_state, reward, current_policy = qbert_update(state, current_policy)
        current_state_buffer.append(next_state)
        patience += 1
        if max_return(policy) <= max_reward:
            if patience >= MAX_PATIENCE:
                state, max_reward, current_policy = backtrack(backtrack_policy, backtrack_buffer)
                chained_policy += current_policy
        current_return = return(current_policy)
        if current_return > max_return:  # New high score found
            max_return = current_reward
            backtrack_policy = current_policy
            backtrack_state_buffer = current_state_buffer
            patience = 0
    return chained_policy

In [None]:
class QBERT:
    def choose_action(state):
        pass