-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_old.py
149 lines (129 loc) · 7.29 KB
/
test_old.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from agent.agent import *
from utils import setup_environment
from networks.Qnetworks import setup_networks
from options.options import gather_options, print_options
from visualisation.visualizers import Visualizer
import torch, os, json
import numpy as np
# visualization
from matplotlib import pyplot as plt
import matplotlib.lines
from matplotlib.transforms import Bbox, TransformedBbox
from matplotlib.legend_handler import HandlerBase
from matplotlib.image import BboxImage
# handler class to insert images in the legend (gives insights of size of heach heart volume tested)
class HandlerLineImage(HandlerBase):
def __init__(self, img_arr, space=15, offset = 10 ):
self.space=space
self.offset=offset
self.image_data = img_arr
super(HandlerLineImage, self).__init__()
def create_artists(self, legend, orig_handle,
xdescent, ydescent, width, height, fontsize, trans):
l = matplotlib.lines.Line2D([xdescent+self.offset,xdescent+(width-self.space)/3.+self.offset],
[ydescent+height/2., ydescent+height/2.])
l.update_from(orig_handle)
l.set_clip_on(False)
l.set_transform(trans)
bb = Bbox.from_bounds(xdescent +(width+self.space)/3.+self.offset,
ydescent,
height*self.image_data.shape[1]/self.image_data.shape[0],
height)
tbb = TransformedBbox(bb, trans)
image = BboxImage(tbb)
image.set_data(self.image_data)
image.set_cmap("Greys_r")
self.update_prop(image, orig_handle, legend)
return [l,image]
if __name__ == "__main__":
# 1. gather options
parser = gather_options(phase="test")
config = parser.parse_args()
config.use_cuda = torch.cuda.is_available()
config.device = torch.device("cuda" if config.use_cuda else "cpu")
print_options(config, parser)
# 2. instanciate environment(s)
envs = setup_environment(config)
# 3. instanciate agent
agent = MultiVolumeAgent(config)
# 4. instanciate Qnetwork and set it in eval mode
qnetwork, _ = setup_networks(config)
# 5. instanciate visualizer to plot results
visualizer = Visualizer(agent.results_dir)
if not os.path.exists(os.path.join(agent.results_dir, "test")):
os.makedirs(os.path.join(agent.results_dir, "test"))
# 6. run test experiments on all given environments and generate outputs
total_rewards = {"planeDistanceReward": {}, "anatomyReward": {}}
goal_planes = {env.vol_id: env.sample_plane(env.goal_state)["plane"] for env in envs}
for key,value in goal_planes.items():
if len(value.shape)>2:
goal_planes[key] = value[0, ...]
goal_planes[key] = goal_planes[key]/goal_planes[key].max()
terminal_planes = {env.vol_id: [] for env in envs}
terminal_states = {env.vol_id: [] for env in envs}
for run in range(max(int(config.n_runs/len(envs)), 1)):
print("test run: [{}]/[{}]".format(run+1, int(config.n_runs/len(envs))))
out = agent.test_agent(config.n_steps, envs, qnetwork)
for i, (key, logs) in enumerate(out.items(), 1):
# 6.1. gather total rewards accumulated in testing episode
for reward in total_rewards:
if key not in total_rewards[reward]:
total_rewards[reward][key] = []
total_rewards[reward][key].append(logs["logs"][reward])
# 6.2. gather terminal states and planes reached
terminal_states[key].append(logs["states"][-1])
terminal_planes[key].append(logs["planes"][-1])
# 6.3. render trajectories if queried
if config.render:
print("rendering logs for: {} ([{}]/[{}])".format(key, i, len(out)))
if not os.path.exists(os.path.join(agent.results_dir, "test", key)):
os.makedirs(os.path.join(agent.results_dir, "test", key))
visualizer.render_full(logs, fname = os.path.join(agent.results_dir, "test", key, "{}_{}.gif".format(config.fname, run)))
#visualizer.render_frames(logs["planes"], logs["planes"], fname="trajectory.gif", n_rows=2, fps=10)
# 7. print quantitative metrics for this model
anatomy_rewards = {"mean": np.stack([np.array(total_rewards["anatomyReward"][key]).sum(-1).mean(-1) for key in goal_planes]),
"std": np.stack([np.array(total_rewards["anatomyReward"][key]).sum(-1).std(-1) for key in goal_planes])}
planeDistance_rewards = {"mean": np.stack([np.array(total_rewards["planeDistanceReward"][key]).sum(-1).mean(-1) for key in goal_planes]),
"std":np.stack([np.array(total_rewards["planeDistanceReward"][key]).sum(-1).std(-1) for key in goal_planes])}
plane_distances_from_goal_plane = {"mean": [], "std": []}
for env in envs:
distancesGoal = [env.rewards["planeDistanceReward"].get_distance_from_goal(env.get_plane_coefs(*terminal_states[env.vol_id][i])) for i in range(len(terminal_states[env.vol_id]))]
plane_distances_from_goal_plane["mean"].append(np.mean(distancesGoal))
plane_distances_from_goal_plane["std"].append(np.std(distancesGoal))
print("anatomy_rewards\n")
print("means: ", anatomy_rewards["mean"])
print("stds: ", anatomy_rewards["std"])
print("planeDistance_rewards\n")
print("means: ", planeDistance_rewards["mean"])
print("stds: ", planeDistance_rewards["std"])
print("plane_distances_from_goal_plane\n")
print("means: ", plane_distances_from_goal_plane["mean"])
print("stds: ", plane_distances_from_goal_plane["std"])
# # 7. re-organize logged rewards
# for reward_key, reward in total_rewards.items():
# fig = plt.figure(figsize=(15,10))
# ax = plt.gca()
# lines = {}
# last_reward = []
# for vol_id, log in reward.items():
# log = np.array(log).astype(np.float)
# means = log[:,1:].mean(0)
# stds = log[:,1:].std(0)
# last_reward.append(means[-1])
# color = next(ax._get_lines.prop_cycler)['color']
# lines[vol_id], = plt.plot(range(len(means)), means, c=color)
# plt.fill_between(range(len(means)), means-stds, means+stds ,alpha=0.3, facecolor=color)
# # add legend with a reference image to compare heart sizes
# line_values, line_keys, imgs, last_reward = zip(*sorted(zip(lines.values(), lines.keys(), goal_planes.values(), last_reward), key=lambda t: t[-1], reverse=True))
# # add legend with vol_ids
# legend1 = plt.legend(line_values, line_keys, loc="lower center", ncol=len(line_keys))
# plt.legend(line_values,
# [""]*len(lines),
# handler_map={line: HandlerLineImage(img) for line,img in zip(line_values,imgs)},
# handlelength=0.25, fontsize=80, labelspacing=0., bbox_to_anchor=(0.94,1.1), frameon=False)
# ax.add_artist(legend1)
# plt.title("average {} collected in an episode".format(reward_key))
# # 8. save figure
# if not os.path.exists(os.path.join(agent.results_dir, "test")):
# os.makedirs(os.path.join(agent.results_dir, "test"))
# plt.savefig(os.path.join(agent.results_dir, "test", "{}_test.pdf".format(reward_key)))