diff --git a/test.py b/test.py index 5b8b74a..8184e56 100644 --- a/test.py +++ b/test.py @@ -1,6 +1,7 @@ import os import plotly -from plotly.graph_objs import Scatter, Line +from plotly.graph_objs import Scatter +from plotly.graph_objs.scatter import Line import torch from env import Env @@ -61,16 +62,16 @@ def test(args, T, dqn, val_mem, evaluate=False): # Plots min, max and mean + standard deviation bars of a population over time def _plot_line(xs, ys_population, title, path=''): - max_colour, mean_colour, std_colour = 'rgb(0, 132, 180)', 'rgb(0, 172, 237)', 'rgba(29, 202, 255, 0.2)' + max_colour, mean_colour, std_colour, transparent = 'rgb(0, 132, 180)', 'rgb(0, 172, 237)', 'rgba(29, 202, 255, 0.2)', 'rgba(0, 0, 0, 0)' ys = torch.tensor(ys_population, dtype=torch.float32) ys_min, ys_max, ys_mean, ys_std = ys.min(1)[0].squeeze(), ys.max(1)[0].squeeze(), ys.mean(1).squeeze(), ys.std(1).squeeze() ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std trace_max = Scatter(x=xs, y=ys_max.numpy(), line=Line(color=max_colour, dash='dash'), name='Max') - trace_upper = Scatter(x=xs, y=ys_upper.numpy(), line=Line(color='transparent'), name='+1 Std. Dev.', showlegend=False) + trace_upper = Scatter(x=xs, y=ys_upper.numpy(), line=Line(color=transparent), name='+1 Std. Dev.', showlegend=False) trace_mean = Scatter(x=xs, y=ys_mean.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color=mean_colour), name='Mean') - trace_lower = Scatter(x=xs, y=ys_lower.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color='transparent'), name='-1 Std. Dev.', showlegend=False) + trace_lower = Scatter(x=xs, y=ys_lower.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color=transparent), name='-1 Std. Dev.', showlegend=False) trace_min = Scatter(x=xs, y=ys_min.numpy(), line=Line(color=max_colour, dash='dash'), name='Min') plotly.offline.plot({