@@ -2526,44 +2526,31 @@ def quiver(self, *args,
25262526 Any additional keyword arguments are delegated to
25272527 :class:`~matplotlib.collections.LineCollection`
25282528 """
2529+
25292530 def calc_arrows (UVW , angle = 15 ):
25302531 # get unit direction vector perpendicular to (u, v, w)
25312532 x = UVW [:, 0 ]
25322533 y = UVW [:, 1 ]
25332534 norm = np .linalg .norm (UVW [:, :2 ], axis = 1 )
25342535 x_p = np .divide (y , norm , where = norm != 0 , out = np .zeros_like (x ))
25352536 y_p = np .divide (- x , norm , where = norm != 0 , out = np .ones_like (x ))
2536-
25372537 # compute the two arrowhead direction unit vectors
25382538 ra = math .radians (angle )
25392539 c = math .cos (ra )
25402540 s = math .sin (ra )
2541-
2542- # construct the rotation matrices
2541+ # construct the rotation matrices of shape (3, 3, n)
25432542 Rpos = np .array (
25442543 [[c + (x_p ** 2 ) * (1 - c ), x_p * y_p * (1 - c ), y_p * s ],
25452544 [y_p * x_p * (1 - c ), c + (y_p ** 2 ) * (1 - c ), - x_p * s ],
25462545 [- y_p * s , x_p * s , np .full_like (x_p , c )]])
2547- Rpos = Rpos .transpose (2 , 0 , 1 )
2548-
25492546 # opposite rotation negates all the sin terms
25502547 Rneg = Rpos .copy ()
2551- Rneg [:, [0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]] = \
2552- - Rneg [:, [0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]]
2553-
2554- # expand dimensions for batched matrix multiplication
2555- UVW = np .expand_dims (UVW , axis = - 1 )
2556-
2557- # multiply them to get the rotated vector
2558- Rpos_vecs = np .matmul (Rpos , UVW )
2559- Rneg_vecs = np .matmul (Rneg , UVW )
2560-
2561- # transpose for concatenation
2562- Rpos_vecs = Rpos_vecs .transpose (0 , 2 , 1 )
2563- Rneg_vecs = Rneg_vecs .transpose (0 , 2 , 1 )
2564-
2565- head_dirs = np .concatenate ([Rpos_vecs , Rneg_vecs ], axis = 1 )
2566-
2548+ Rneg [[0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]] *= - 1
2549+ # Batch n (3, 3) x (3) matrix multiplications ((3, 3, n) x (n, 3)).
2550+ Rpos_vecs = np .einsum ("ij...,...j->...i" , Rpos , UVW )
2551+ Rneg_vecs = np .einsum ("ij...,...j->...i" , Rneg , UVW )
2552+ # Stack into (n, 2, 3) result.
2553+ head_dirs = np .stack ([Rpos_vecs , Rneg_vecs ], axis = 1 )
25672554 return head_dirs
25682555
25692556 had_data = self .has_data ()
@@ -2630,7 +2617,7 @@ def calc_arrows(UVW, angle=15):
26302617 # compute all head lines at once, starting from the shaft ends
26312618 heads = shafts [:, :1 ] - np .multiply .outer (arrow_dt , head_dirs )
26322619 # stack left and right head lines together
2633- heads . shape = ( len (arrow_dt ), - 1 , 3 )
2620+ heads = heads . reshape (( len (arrow_dt ), - 1 , 3 ) )
26342621 # transpose to get a list of lines
26352622 heads = heads .swapaxes (0 , 1 )
26362623
0 commit comments