-
Notifications
You must be signed in to change notification settings - Fork 141
/
PIV_3D_plotting.py
355 lines (290 loc) · 11.1 KB
/
PIV_3D_plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
"""
functions to plot 3D-deformation fields and simple 3D-structures
"""
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from itertools import chain
from mpl_toolkits.mplot3d import Axes3D
def set_axes_equal(ax):
"""
Following https://stackoverflow.com/questions/13685386/matplotlib-equal-unit-length-with-equal-aspect-ratio-z-axis-is-not-equal-to
Make axes of 3D plot have equal scale so that spheres appear as spheres,
cubes as cubes, etc.. This is one possible solution to Matplotlib's
ax.set_aspect('equal') and ax.axis('equal') not working for 3D.
Parameters
----------
ax: matplotlib.axes object
"""
x_limits = ax.get_xlim3d()
y_limits = ax.get_ylim3d()
z_limits = ax.get_zlim3d()
x_range = abs(x_limits[1] - x_limits[0])
x_middle = np.mean(x_limits)
y_range = abs(y_limits[1] - y_limits[0])
y_middle = np.mean(y_limits)
z_range = abs(z_limits[1] - z_limits[0])
z_middle = np.mean(z_limits)
# The plot bounding box is a sphere in the sense of the infinity
# norm, hence I call half the max range the plot radius.
plot_radius = 0.5 * max([x_range, y_range, z_range])
ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
def scatter_3D(a, cmap="jet", sca_args=None, control="color", size=60):
# default arguments for the quiver plot. can be overwritten by quiv_args
if not isinstance(sca_args, dict):
sca_args = {}
scatter_args = {"alpha": 1}
scatter_args.update(sca_args)
x, y, z = np.indices(a.shape)
x = x.flatten()
y = y.flatten()
z = z.flatten()
fig = plt.figure()
ax = fig.gca(projection="3d", rasterized=True)
if control == "color":
# make cmap
cbound = [np.nanmin(a), np.nanmax(a)]
# create normalized color map for arrows
norm = matplotlib.colors.Normalize(
vmin=cbound[0], vmax=cbound[1]
) # 10 ) #cbound[1] ) #)
sm = matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
# different option
cm = matplotlib.cm.get_cmap(cmap)
colors = cm(norm(a)).reshape(a.shape[0] * a.shape[1] * a.shape[2], 4) #
# plotting
nan_filter = ~np.isnan(a.flatten())
ax.scatter(
x[nan_filter],
y[nan_filter],
z[nan_filter],
c=colors[nan_filter],
s=size,
**scatter_args
)
plt.colorbar(sm)
if control == "alpha":
# untested #
colors = [(0, 0, 1, x / np.max(z)) for x in np.ravel(z)]
ax.scatter(x, y, z, c=colors, s=size, **scatter_args)
plt.show()
if control == "size":
sizes = (a - a.min()) * size / a.ptp()
ax.scatter(x, y, z, a, s=sizes, **scatter_args)
ax_scale = plt.axes([0.88, 0.1, 0.05, 0.7])
# ax_scale.set_ylim((0.1,1.2))
nm = 5
ax_scale.scatter(
[0] * nm,
np.linspace(a.min(), a.max(), nm),
s=sizes.max() * np.linspace(0, 1, nm),
)
ax_scale.spines["left"].set_visible(False)
ax_scale.spines["right"].set_visible(True)
ax_scale.spines["bottom"].set_visible(False)
ax_scale.spines["top"].set_visible(False)
ax_scale.tick_params(
axis="both",
which="both",
labelbottom=False,
labelleft=False,
labelright=True,
bottom=False,
left=False,
right=True,
)
ax.set_xlim(0, a.shape[0])
ax.set_ylim(0, a.shape[1])
ax.set_zlim(0, a.shape[2])
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
return fig
def explode(data):
# following "https://matplotlib.org/3.1.1/gallery/mplot3d/voxels_numpy_logo.html"
if len(data.shape) == 3:
size = np.array(data.shape) * 2
data_e = np.zeros(size - 1, dtype=data.dtype)
data_e[::2, ::2, ::2] = data
if len(data.shape) == 4: ## color data
size = np.array(data.shape)[:3] * 2
data_e = np.zeros(np.concatenate([size - 1, np.array([4])]), dtype=data.dtype)
data_e[::2, ::2, ::2, :] = data
return data_e
def plot_3D_alpha(data):
# plotting each voxel as a slightly smaller block with transparency depending
# on the data value
# following "https://matplotlib.org/3.1.1/gallery/mplot3d/voxels_numpy_logo.html"
col = np.zeros((data.shape[0], data.shape[1], data.shape[2], 4))
data_fil = data.copy()
data_fil[(data == np.inf)] = np.nanmax(data[~(data == np.inf)])
data_fil = (data_fil - np.nanmin(data_fil)) / (
np.nanmax(data_fil) - np.nanmin(data_fil)
)
data_fil[np.isnan(data_fil)] = 0
col[:, :, :, 2] = 1
col[:, :, :, 3] = data_fil
col_exp = explode(col)
fill = explode(np.ones(data.shape))
x, y, z = np.indices(np.array(fill.shape) + 1).astype(float) // 2
x[0::2, :, :] += 0.05
y[:, 0::2, :] += 0.05
z[:, :, 0::2] += 0.05
x[1::2, :, :] += 0.95
y[:, 1::2, :] += 0.95
z[:, :, 1::2] += 0.95
fig = plt.figure()
ax = fig.gca(projection="3d")
ax.voxels(x, y, z, fill, facecolors=col_exp, edgecolors=col_exp)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
plt.show()
def quiver_3D(
u,
v,
w,
x=None,
y=None,
z=None,
mask_filtered=None,
filter_def=0,
filter_reg=(1, 1, 1),
cmap="jet",
quiv_args=None,
vmin=None,
vmax=None,
arrow_scale=0.15,
equal_ax=True,
):
"""
Displaying 3D deformation fields vector arrows
Parameters
----------
u,v,w: 3d ndarray or lists
arrays or list with deformation in x,y and z direction
x,y,z: 3d ndarray or lists
Arrays or list with deformation the coordinates of the deformations.
Must match the dimensions of the u,v qnd w. If not provided x,y and z are created
with np.indices(u.shape)
mask_filtered, boolean 3d ndarray or 1d ndarray
Array, or list with same dimensions as the deformations. Defines the area where deformations are drawn
filter_def: float
Filter that prevents the display of deformations arrows with length < filter_def
filter_reg: tuple,list or int
Filter that prevents the display of every i-th deformations arrows separatly alon each axis.
filter_reg=(2,2,2) means that only every second arrow along x,y z axis is displayed leading to
a total reduction of displayed arrows by a factor of 8. filter_reg=3 is interpreted
as (3,3,3).
cmap: string
matplotlib colorbar that defines the coloring of the arrow
quiv_args: dict
Dictionary with kwargs passed on to the matplotlib quiver function.
vmin,vmax: float
Upper and lower bounds for the colormap. Works like vmin and vmax in plt.imshow().
arrow_scale: float
Automatic scaling of the quiver arrows so that the longest arrow has the
length axis length * arrow_scale. Arrow length can alternatively be set by
passing a "lenght" argument in quiv_args.
equal_axes: bool
resize the figure axis so that they are have equal scaling.
Returns
-------
fig: matploltib figure object
ax: mattplotlib axes object
the holding the main 3D quiver plot
"""
# default arguments for the quiver plot. can be overwritten by quiv_args
quiver_args = {
"normalize": False,
"alpha": 0.8,
"pivot": "tail",
"linewidth": 1,
"length": 1,
}
if isinstance(quiv_args, dict):
quiver_args.update(quiv_args)
# overwriting length if an arrow scale and a "length" argument in quiv_args
# is provided at the same
if arrow_scale is not None:
quiver_args["length"] = 1
# convert filter ot list if proveided as int
if not isinstance(filter_reg, (tuple, list)):
filter_reg = [filter_reg] * 3
# generating coordinates if not provided
if x is None:
# if you provide deformations as a list
if len(u.shape) == 1:
x, y, z = [np.indices(u.shape)[0] for i in range(3)]
# if you provide deformations as an array
elif len(u.shape) == 3:
x, y, z = np.indices(u.shape)
else:
raise ValueError(
"displacement data has wrong number of dimensions (%s). Use 1d array, list, or 3d array."
% str(len(u.shape))
)
# conversion to array
x, y, z = np.array([x, y, z])
deformation = np.sqrt(u ** 2 + v ** 2 + w ** 2)
if not isinstance(mask_filtered, np.ndarray):
mask_filtered = deformation > filter_def
if isinstance(filter_reg, list):
show_only = np.zeros(u.shape).astype(bool)
# filtering out every x-th
show_only[:: filter_reg[0], :: filter_reg[1], :: filter_reg[2]] = True
mask_filtered = np.logical_and(mask_filtered, show_only)
xf = x[mask_filtered]
yf = y[mask_filtered]
zf = z[mask_filtered]
uf = u[mask_filtered]
vf = v[mask_filtered]
wf = w[mask_filtered]
df = deformation[mask_filtered]
# make cmap
# create normalized color map for arrows
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) # 10 ) #cbound[1] ) #)
sm = matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
# different option
colors = matplotlib.cm.jet(norm(df)) #
colors = [c for c, d in zip(colors, df) if d > 0] + list(
chain(*[[c, c] for c, d in zip(colors, df) if d > 0])
)
# colors in ax.quiver 3d is really fucked up/ will probably change with updates:
# requires list with: first len(u) entries define the colors of the shaft, then the next len(u)*2 entries define
# the color ofleft and right arrow head side in alternating order. Try for example:
# colors = ["red" for i in range(len(cf))] + list(chain(*[["blue", "yellow"] for i in range(len(cf))]))
# to see this effect.
# BUT WAIT THERS MORE: zeor length arrows are apparently filtered out in the matplolib with out filtering
# the color list appropriately so we have to do this our selfs as well
# scale arrows to axis dimensions:
ax_dims = [(x.min(), x.max()), (y.min(), y.max()), (z.min(), z.max())]
if arrow_scale is not None:
max_length = df.max()
max_dim_length = np.max([(d[1] - d[0] + 1) for d in ax_dims])
scale = max_dim_length * arrow_scale / max_length
else:
scale = 1
# plotting
fig = plt.figure()
ax = fig.gca(projection="3d", rasterized=True)
ax.quiver(
xf, yf, zf, vf * scale, uf * scale, wf * scale, colors=colors, **quiver_args
)
plt.colorbar(sm)
ax.set_xlim(ax_dims[0])
ax.set_ylim(ax_dims[1])
ax.set_zlim(ax_dims[2])
if equal_ax:
set_axes_equal(ax)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
ax.w_xaxis.set_pane_color((0.2, 0.2, 0.2, 1.0))
ax.w_yaxis.set_pane_color((0.2, 0.2, 0.2, 1.0))
ax.w_zaxis.set_pane_color((0.2, 0.2, 0.2, 1.0))
return fig