Navigation Menu

Skip to content

Commit

Permalink
Update Plotly
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed Jul 28, 2018
1 parent 51ddfa2 commit fa68a2f
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions test.py
@@ -1,6 +1,7 @@
import os
import plotly
from plotly.graph_objs import Scatter, Line
from plotly.graph_objs import Scatter
from plotly.graph_objs.scatter import Line
import torch

from env import Env
Expand Down Expand Up @@ -61,16 +62,16 @@ def test(args, T, dqn, val_mem, evaluate=False):

# Plots min, max and mean + standard deviation bars of a population over time
def _plot_line(xs, ys_population, title, path=''):
max_colour, mean_colour, std_colour = 'rgb(0, 132, 180)', 'rgb(0, 172, 237)', 'rgba(29, 202, 255, 0.2)'
max_colour, mean_colour, std_colour, transparent = 'rgb(0, 132, 180)', 'rgb(0, 172, 237)', 'rgba(29, 202, 255, 0.2)', 'rgba(0, 0, 0, 0)'

ys = torch.tensor(ys_population, dtype=torch.float32)
ys_min, ys_max, ys_mean, ys_std = ys.min(1)[0].squeeze(), ys.max(1)[0].squeeze(), ys.mean(1).squeeze(), ys.std(1).squeeze()
ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std

trace_max = Scatter(x=xs, y=ys_max.numpy(), line=Line(color=max_colour, dash='dash'), name='Max')
trace_upper = Scatter(x=xs, y=ys_upper.numpy(), line=Line(color='transparent'), name='+1 Std. Dev.', showlegend=False)
trace_upper = Scatter(x=xs, y=ys_upper.numpy(), line=Line(color=transparent), name='+1 Std. Dev.', showlegend=False)
trace_mean = Scatter(x=xs, y=ys_mean.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color=mean_colour), name='Mean')
trace_lower = Scatter(x=xs, y=ys_lower.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color='transparent'), name='-1 Std. Dev.', showlegend=False)
trace_lower = Scatter(x=xs, y=ys_lower.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color=transparent), name='-1 Std. Dev.', showlegend=False)
trace_min = Scatter(x=xs, y=ys_min.numpy(), line=Line(color=max_colour, dash='dash'), name='Min')

plotly.offline.plot({
Expand Down

0 comments on commit fa68a2f

Please sign in to comment.