Skip to content

Commit

Permalink
Merge pull request #189 from MarcCote/enh_visualize_graph
Browse files Browse the repository at this point in the history
Add graph visualization tool.
  • Loading branch information
MarcCote committed Feb 13, 2020
2 parents a820f38 + d99df57 commit 51225ec
Show file tree
Hide file tree
Showing 13 changed files with 462 additions and 118 deletions.
5 changes: 5 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,10 @@ matrix:
install:
- ./.travis/setup.sh

before_script:
- if [ "${TRAVIS_OS_NAME}" = "linux" ]; then ( sh -e /etc/init.d/xvfb start ) & fi
- if [ "${TRAVIS_OS_NAME}" = "osx" ]; then ( sudo Xvfb :99 -ac -screen 0 1024x768x8; echo ok ) & fi
- sleep 3 # give xvfb some time to start

script:
- ./.travis/test.sh
1 change: 1 addition & 0 deletions .travis/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ if [[ $TRAVIS_OS_NAME == "osx" ]]; then
fi

pip install .
pip install .[vis]
pip install nose coverage codecov
17 changes: 14 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@ Or, after cloning the repo, go inside the root folder of the project (i.e. along

pip install .

#### Extras
#### Visualization

In order to use the `take_screenshot` or `visualize` functions in `textworld.render`, you'll need to install either the [Chrome](https://sites.google.com/a/chromium.org/chromedriver/) or [Firefox](https://github.com/mozilla/geckodriver) webdriver (depending on which browser you have installed).
TextWorld comes with some tools to visualize game states. Make sure all dependencies are installed by running

pip install textworld[vis]

Then, you will need to install either the [Chrome](https://sites.google.com/a/chromium.org/chromedriver/) or [Firefox](https://github.com/mozilla/geckodriver) webdriver (depending on which browser you have currently installed).
If you have Chrome already installed you can use the following command to install chromedriver

pip install chromedriver_installer

Current visualization tools include: `take_screenshot`, `visualize` and `show_graph` from [`textworld.render`](https://textworld.readthedocs.io/en/latest/textworld.render.html).

## Usage

### Generating a game
Expand All @@ -47,7 +53,6 @@ TextWorld provides an easy way of generating simple text-based games via the `tw

where `custom` indicates we want to customize the game using the following options: `--world-size` controls the number of rooms in the world, `--nb-objects` controls the number of objects that can be interacted with (excluding doors) and `--quest-length` controls the minimum number of commands that is required to type in order to win the game. Once done, the game `custom_game.ulx` will be saved in the `tw_games/` folder.


### Playing a game (terminal)

To play a game, one can use the `tw-play` script. For instance, the command to play the game generated in the previous section would be
Expand All @@ -56,6 +61,12 @@ To play a game, one can use the `tw-play` script. For instance, the command to p

> **Note:** Only Z-machine's games (*.z1 through *.z8) and Glulx's games (*.ulx) are supported.
To visualize the game state while playing, use the `--viewer [port]` option.

tw-play tw_games/custom_game.ulx --viewer

A new browser tab should open and track your progress in the game.

### Playing a game (Python + [Gym](https://github.com/openai/gym))

Here's how you can interact with a text-based game from within Python using OpenAI's Gym framework.
Expand Down
11 changes: 1 addition & 10 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,8 @@ hashids>=1.2.0
jericho>=2.2.0
mementos>=1.3.1

# For visualization
pybars3>=0.9.3
flask>=1.0.2
selenium>=3.12.0
greenlet==0.4.13
gevent==1.3.5
pillow>=5.1.0
pydot>=1.2.4

# For advanced prompt
prompt_toolkit<2.1.0,>=2.0.0
prompt_toolkit

# For gym support
gym>=0.10.11
14 changes: 14 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,18 @@ def run(self):
tests_require=[
'nose==1.3.7',
],
extras_require={
'vis': [
'pybars3>=0.9.3',
'flask>=1.0.2',
'selenium>=3.12.0',
'greenlet==0.4.13',
'gevent==1.3.5',
'pillow>=5.1.0',
'plotly>=4.0.0',
'pydot>=1.2.4',
'psutil',
'matplotlib',
],
},
)
3 changes: 2 additions & 1 deletion textworld/envs/wrappers/tests/test_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import textworld

from textworld.utils import make_temp_directory, get_webdriver
from textworld.utils import make_temp_directory
from textworld.generator import compile_game
from textworld.envs.wrappers import HtmlViewer
from textworld.render import get_webdriver


def test_html_viewer():
Expand Down
7 changes: 4 additions & 3 deletions textworld/envs/wrappers/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Tuple

from textworld.core import Environment, GameState, Wrapper
from textworld.render.serve import VisualizationService
from textworld.render import WebdriverNotFoundError


class HtmlViewer(Wrapper):
Expand Down Expand Up @@ -85,11 +87,10 @@ def reset(self) -> GameState:

self._stop_server() # In case it is still running.
try:
from textworld.render.serve import VisualizationService
self._server = VisualizationService(game_state, self.open_automatically)
self._server.start(threading.current_thread(), port=self._port)
except ModuleNotFoundError as e:
print("Importing HtmlViewer without installed dependencies. Try re-installing textworld.")
except WebdriverNotFoundError as e:
print("Missing dependencies for using HtmlViewer. See 'Visualization' section of TextWorld's README.md")
raise e

return game_state
Expand Down
4 changes: 4 additions & 0 deletions textworld/render/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.


from textworld.render.render import WebdriverNotFoundError
from textworld.render.render import get_webdriver
from textworld.render.render import load_state, load_state_from_game_state, visualize
from textworld.render.graph import show_graph
218 changes: 218 additions & 0 deletions textworld/render/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from typing import Iterable, Optional

import numpy as np
import networkx as nx

from textworld.utils import check_modules
from textworld.logic import Proposition

missing_modules = []
try:
import plotly
import plotly.graph_objects as go
except ImportError:
missing_modules.append("plotly")

try:
import matplotlib.pylab as plt
except ImportError:
missing_modules.append("matplotlib")


def build_graph_from_facts(facts: Iterable[Proposition]) -> nx.DiGraph:
""" Builds a graph from a collection of facts.
Arguments:
facts: Collection of facts representing a state of a game.
Returns:
The underlying graph representation.
"""
G = nx.DiGraph()
labels = {}
for fact in facts:
# Extract relation triplet from fact (subject, object, relation)
triplet = (*fact.names, fact.name)
triplet = triplet if len(triplet) >= 3 else triplet + ("is",)

src = triplet[0]
dest = triplet[1]
relation = triplet[-1]
if relation in {"is"}:
# For entity properties and states, we artificially
# add unique node for better visualization.
dest = src + "-" + dest

labels[src] = triplet[0]
labels[dest] = triplet[1]
G.add_edge(src, dest, type=triplet[-1])

nx.set_node_attributes(G, labels, 'label')
return G


def show_graph(facts: Iterable[Proposition],
title: str = "Knowledge Graph",
renderer: Optional[str] = None,
save: Optional[str] = None) -> "plotly.graph_objs._figure.Figure":

r""" Visualizes the graph made from a collection of facts.
Arguments:
facts: Collection of facts representing a state of a game.
title: Title for the figure
renderer:
Which Plotly's renderer to use (e.g., 'browser').
save:
If provided, path where to save a PNG version of the graph.
Returns:
The Plotly's figure representing the graph.
Example:
>>> import textworld
>>> options = textworld.GameOptions()
>>> options.seeds = 1234
>>> game_file, game = textworld.make(options)
>>> import gym
>>> import textworld.gym
>>> from textworld import EnvInfos
>>> request_infos = EnvInfos(facts=True)
>>> env_id = textworld.gym.register_game(game_file, request_infos)
>>> env = gym.make(env_id)
>>> _, infos = env.reset()
>>> textworld.render.show_graph(infos["facts"])
"""
check_modules(["matplotlib", "plotly"], missing_modules)
G = build_graph_from_facts(facts)

plt.figure(figsize=(16, 9))
pos = nx.drawing.nx_pydot.pydot_layout(G, prog="fdp")

edge_labels_pos = {}
trace3_list = []
for edge in G.edges(data=True):
trace3 = go.Scatter(
x=[],
y=[],
mode='lines',
line=dict(width=0.5, color='#888', shape='spline', smoothing=1),
hoverinfo='none'
)
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
rvec = (x0 - x1, y0 - y1) # Vector from dest -> src.
length = np.sqrt(rvec[0] ** 2 + rvec[1] ** 2)
mid = ((x0 + x1) / 2., (y0 + y1) / 2.)
orthogonal = (rvec[1] / length, -rvec[0] / length)

trace3['x'] += (x0, mid[0] + 0 * orthogonal[0], x1, None)
trace3['y'] += (y0, mid[1] + 0 * orthogonal[1], y1, None)
trace3_list.append(trace3)

offset_ = 5
edge_labels_pos[(pos[edge[0]], pos[edge[1]])] = (mid[0] + offset_ * orthogonal[0],
mid[1] + offset_ * orthogonal[1])

node_x = []
node_y = []
node_labels = []
for node, data in G.nodes(data=True):
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_labels.append("<b>{}</b>".format(data['label'].replace(" ", "<br>")))

node_trace = go.Scatter(
x=node_x,
y=node_y,
mode='text',
text=node_labels,
textfont=dict(
family="sans serif",
size=12,
color="black"
),
hoverinfo='none',
marker=dict(
showscale=True,
color=[],
size=10,
line_width=2
)
)

fig = go.Figure(
data=[*trace3_list, node_trace],
layout=go.Layout(
title=title,
titlefont_size=16,
showlegend=False,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
)
)

def _get_angle(p0, p1):
x0, y0 = p0
x1, y1 = p1
if x1 == x0:
return 0

angle = -np.rad2deg(np.arctan((y1 - y0) / (x1 - x0) / (16 / 9)))
return angle

def _calc_arrow_standoff(angle, label):
return 5 + np.log(90 / abs(angle)) * max(map(len, label.split()))

# Add relation names and relation arrows.
annotations = []
for edge in G.edges(data=True):
p0, p1 = pos[edge[0]], pos[edge[1]]
x0, y0 = p0
x1, y1 = p1
angle = _get_angle(p0, p1)
annotations.append(
go.layout.Annotation(
x=x1,
y=y1,
ax=(x0 + x1) / 2,
ay=(y0 + y1) / 2,
axref="x",
ayref="y",
showarrow=True,
arrowhead=2,
arrowsize=3,
arrowwidth=0.5,
arrowcolor="#888",
standoff=_calc_arrow_standoff(angle, G.nodes[edge[1]]['label']),
)
)
annotations.append(
go.layout.Annotation(
x=edge_labels_pos[(p0, p1)][0],
y=edge_labels_pos[(p0, p1)][1],
showarrow=False,
text="<i>{}</i>".format(edge[2]['type']),
textangle=angle,
font=dict(
family="sans serif",
size=12,
color="blue"
),
)
)

fig.update_layout(annotations=annotations)

if renderer:
fig.show(renderer=renderer)

if save:
fig.write_image(save, width=1920, height=1080, scale=4)

return fig

0 comments on commit 51225ec

Please sign in to comment.