-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
linalg.py
2455 lines (2024 loc) · 92.2 KB
/
linalg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype
from ..fluid.framework import in_dygraph_mode, _varbase_creator, Variable
from ..fluid.layers import transpose, cast # noqa: F401
from ..fluid import layers
import paddle
from paddle.common_ops_import import core
from paddle.common_ops_import import VarDesc
from paddle import _C_ops
import paddle
__all__ = []
def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
"""
Applies matrix multiplication to two tensors. `matmul` follows
the complete broadcast rules,
and its behavior is consistent with `np.matmul`.
Currently, the input tensors' number of dimensions can be any, `matmul` can be used to
achieve the `dot`, `matmul` and `batchmatmul`.
The actual behavior depends on the shapes of :math:`x`, :math:`y` and the
flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically:
- If a transpose flag is specified, the last two dimensions of the tensor
are transposed. If the tensor is ndim-1 of shape, the transpose is invalid. If the tensor
is ndim-1 of shape :math:`[D]`, then for :math:`x` it is treated as :math:`[1, D]`, whereas
for :math:`y` it is the opposite: It is treated as :math:`[D, 1]`.
The multiplication behavior depends on the dimensions of `x` and `y`. Specifically:
- If both tensors are 1-dimensional, the dot product result is obtained.
- If both tensors are 2-dimensional, the matrix-matrix product is obtained.
- If the `x` is 1-dimensional and the `y` is 2-dimensional,
a `1` is prepended to its dimension in order to conduct the matrix multiply.
After the matrix multiply, the prepended dimension is removed.
- If the `x` is 2-dimensional and `y` is 1-dimensional,
the matrix-vector product is obtained.
- If both arguments are at least 1-dimensional and at least one argument
is N-dimensional (where N > 2), then a batched matrix multiply is obtained.
If the first argument is 1-dimensional, a 1 is prepended to its dimension
in order to conduct the batched matrix multiply and removed after.
If the second argument is 1-dimensional, a 1 is appended to its
dimension for the purpose of the batched matrix multiple and removed after.
The non-matrix (exclude the last two dimensions) dimensions are
broadcasted according the broadcast rule.
For example, if input is a (j, 1, n, m) tensor and the other is a (k, m, p) tensor,
out will be a (j, k, n, p) tensor.
Args:
x (Tensor): The input tensor which is a Tensor.
y (Tensor): The input tensor which is a Tensor.
transpose_x (bool): Whether to transpose :math:`x` before multiplication.
transpose_y (bool): Whether to transpose :math:`y` before multiplication.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Tensor: The output Tensor.
Examples:
.. code-block:: python
import paddle
import numpy as np
# vector * vector
x_data = np.random.random([10]).astype(np.float32)
y_data = np.random.random([10]).astype(np.float32)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
z = paddle.matmul(x, y)
print(z.numpy().shape)
# [1]
# matrix * vector
x_data = np.random.random([10, 5]).astype(np.float32)
y_data = np.random.random([5]).astype(np.float32)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
z = paddle.matmul(x, y)
print(z.numpy().shape)
# [10]
# batched matrix * broadcasted vector
x_data = np.random.random([10, 5, 2]).astype(np.float32)
y_data = np.random.random([2]).astype(np.float32)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
z = paddle.matmul(x, y)
print(z.numpy().shape)
# [10, 5]
# batched matrix * batched matrix
x_data = np.random.random([10, 5, 2]).astype(np.float32)
y_data = np.random.random([10, 2, 5]).astype(np.float32)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
z = paddle.matmul(x, y)
print(z.numpy().shape)
# [10, 5, 5]
# batched matrix * broadcasted matrix
x_data = np.random.random([10, 1, 5, 2]).astype(np.float32)
y_data = np.random.random([1, 3, 2, 5]).astype(np.float32)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
z = paddle.matmul(x, y)
print(z.numpy().shape)
# [10, 3, 5, 5]
"""
op_type = 'matmul_v2'
if in_dygraph_mode():
op = getattr(_C_ops, op_type)
return op(x, y, 'trans_x', transpose_x, 'trans_y', transpose_y)
attrs = {
'trans_x': transpose_x,
'trans_y': transpose_y,
}
def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'matmul')
__check_input(x, y)
helper = LayerHelper('matmul_v2', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='matmul_v2',
inputs={'X': x,
'Y': y},
outputs={'Out': out},
attrs=attrs)
return out
def norm(x, p='fro', axis=None, keepdim=False, name=None):
"""
Returns the matrix norm (Frobenius) or vector norm (the 1-norm, the Euclidean
or 2-norm, and in general the p-norm for p > 0) of a given tensor.
.. note::
This norm API is different from `numpy.linalg.norm`.
This api supports high-order input tensors (rank >= 3), and certain axis need to be pointed out to calculate the norm.
But `numpy.linalg.norm` only supports 1-D vector or 2-D matrix as input tensor.
For p-order matrix norm, this api actually treats matrix as a flattened vector to calculate the vector norm, NOT REAL MATRIX NORM.
Args:
x (Tensor): The input tensor could be N-D tensor, and the input data
type could be float32 or float64.
p (float|string, optional): Order of the norm. Supported values are `fro`, `0`, `1`, `2`,
`inf`, `-inf` and any positive real number yielding the corresponding p-norm. Not supported: ord < 0 and nuclear norm.
Default value is `fro`.
axis (int|list|tuple, optional): The axis on which to apply norm operation. If axis is int
or list(int)/tuple(int) with only one element, the vector norm is computed over the axis.
If `axis < 0`, the dimension to norm operation is rank(input) + axis.
If axis is a list(int)/tuple(int) with two elements, the matrix norm is computed over the axis.
Defalut value is `None`.
keepdim (bool, optional): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have fewer dimension
than the :attr:`input` unless :attr:`keepdim` is true, default
value is False.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: results of norm operation on the specified axis of input tensor,
it's data type is the same as input's Tensor.
Examples:
.. code-block:: python
import paddle
import numpy as np
shape=[2, 3, 4]
np_input = np.arange(24).astype('float32') - 12
np_input = np_input.reshape(shape)
x = paddle.to_tensor(np_input)
#[[[-12. -11. -10. -9.] [ -8. -7. -6. -5.] [ -4. -3. -2. -1.]]
# [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.]]]
# compute frobenius norm along last two dimensions.
out_fro = paddle.norm(x, p='fro', axis=[0,1])
# out_fro.numpy() [17.435596 16.911535 16.7332 16.911535]
# compute 2-order vector norm along last dimension.
out_pnorm = paddle.norm(x, p=2, axis=-1)
#out_pnorm.numpy(): [[21.118711 13.190906 5.477226]
# [ 3.7416575 11.224972 19.131126]]
# compute 2-order norm along [0,1] dimension.
out_pnorm = paddle.norm(x, p=2, axis=[0,1])
#out_pnorm.numpy(): [17.435596 16.911535 16.7332 16.911535]
# compute inf-order norm
out_pnorm = paddle.norm(x, p=np.inf)
#out_pnorm.numpy() = [12.]
out_pnorm = paddle.norm(x, p=np.inf, axis=0)
#out_pnorm.numpy(): [[12. 11. 10. 9.] [8. 7. 6. 7.] [8. 9. 10. 11.]]
# compute -inf-order norm
out_pnorm = paddle.norm(x, p=-np.inf)
#out_pnorm.numpy(): [0.]
out_pnorm = paddle.norm(x, p=-np.inf, axis=0)
#out_pnorm.numpy(): [[0. 1. 2. 3.] [4. 5. 6. 5.] [4. 3. 2. 1.]]
"""
def frobenius_norm(input, dim=None, keepdim=False, name=None):
"""
The frobenius norm OP is to calculate the frobenius norm of certain two dimensions of Tensor `input`.
Args:
input (Variable): Tensor, data type float32, float64.
dim (list, optional): None for last two dimensions.
keepdim (bool, optional): Whether keep the dimensions as the `input`, Default False.
"""
if dim is not None and not (isinstance(dim, list) and len(dim) == 2):
raise ValueError(
"The dim of frobenius norm op should be None or two elements list!"
)
if in_dygraph_mode():
if dim is None:
return _C_ops.frobenius_norm(input, 'keep_dim', keepdim,
'reduce_all', True)
return _C_ops.frobenius_norm(input, 'dim', dim, 'keep_dim', keepdim,
'reduce_all', False)
attrs = {'dim': dim, 'keep_dim': keepdim, 'reduce_all': False}
if dim is None:
attrs['reduce_all'] = True
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'frobenius_norm')
helper = LayerHelper('frobenius_norm', **locals())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
helper.append_op(
type='frobenius_norm',
inputs={'X': input},
outputs={'Out': out},
attrs=attrs)
return out
def vector_norm(input,
porder=None,
axis=None,
keepdim=False,
asvector=False,
name=None):
"""
Calculate the p-order vector norm for certain dimension of Tensor `input`.
Args:
input (Variable): Tensor, data type float32, float64.
porder (float, optional): None for porder=2.0.
axis (int, optional): None for last dimension.
keepdim (bool, optional): Whether keep the dimensions as the `input`, Default False.
"""
if in_dygraph_mode():
if axis is None: axis = -1
return _C_ops.p_norm(input, 'porder', porder, 'axis', axis,
'keepdim', keepdim, 'asvector', asvector)
if porder is not None:
check_type(porder, 'porder', (float, int), 'p_norm')
if axis is not None:
check_type(axis, 'axis', (int), 'p_norm')
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'p_norm')
attrs = {
'axis': axis if axis is not None else -1,
'porder': float(porder) if porder is not None else 2.0,
'keepdim': keepdim,
'asvector': asvector,
'epsilon': 1e-12,
}
helper = LayerHelper('p_norm', **locals())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
helper.append_op(
type='p_norm',
inputs={'X': input},
outputs={'Out': out},
attrs=attrs)
return out
def inf_norm(input,
porder=None,
axis=axis,
keepdim=False,
asvector=False,
name=None):
helper = LayerHelper('frobenius_norm', **locals())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
helper.append_op(type='abs', inputs={'X': input}, outputs={'Out': out})
reduce_out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
reduce_all = True if axis == None or axis == [] or asvector == True else False
axis = axis if axis != None and axis != [] else [0]
reduce_type = 'reduce_max' if porder == np.float(
'inf') else 'reduce_min'
helper.append_op(
type=reduce_type,
inputs={'X': out},
outputs={'Out': reduce_out},
attrs={'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all})
return reduce_out
def p_matrix_norm(input, porder=1., axis=axis, keepdim=False, name=None):
"""
NOTE:
This function actually treats the matrix as flattened vector to calculate vector norm instead of matrix norm.
"""
block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
abs_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='abs', inputs={'X': input}, outputs={'Out': abs_out})
pow_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='pow',
inputs={'X': abs_out},
outputs={'Out': pow_out},
attrs={'factor': porder})
sum_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='reduce_sum',
inputs={'X': pow_out},
outputs={'Out': sum_out},
attrs={
'dim': axis,
'keep_dim': keepdim,
'reduce_all': True if axis is None else False
})
porder
block.append_op(
type='pow',
inputs={'X': sum_out},
outputs={'Out': out},
attrs={'factor': float(1. / porder)})
return out
if axis is None and p is not None:
if isinstance(p, str):
if p == "fro":
return frobenius_norm(x, dim=axis, keepdim=keepdim, name=name)
else:
raise ValueError(
"only valid string values are 'fro', found {}".format(p))
elif isinstance(p, (int, float)):
return vector_norm(
x,
porder=p,
axis=axis,
keepdim=keepdim,
asvector=True,
name=name)
else:
raise ValueError("only valid p type is string or float, found {}".
format(type(p)))
if isinstance(axis, tuple):
axis = list(axis)
if isinstance(axis, list) and len(axis) == 1:
axis = axis[0]
#calculate vector norm, where axis is int or list with only one integer
if isinstance(axis, int):
if isinstance(p, str):
if p == "fro":
return vector_norm(
x,
porder=2,
axis=axis,
keepdim=keepdim,
asvector=False,
name=name)
else:
raise ValueError(
"only valid string values are 'fro', found {}".format(p))
elif isinstance(p, (int, float)):
return vector_norm(
x,
axis=axis,
porder=p,
keepdim=keepdim,
asvector=False,
name=name)
else:
raise ValueError(
"unspport p for p-order vector norm. except float, found {}".
format(p))
#calculate matrix norm, where axis is list with two integers
elif isinstance(axis, list) and len(axis) == 2:
if p == "fro":
return frobenius_norm(x, dim=axis, keepdim=keepdim, name=name)
elif p == np.inf or p == -np.inf:
return inf_norm(x, porder=p, axis=axis, keepdim=keepdim, name=name)
elif p == 0:
raise ValueError(
"just suport axis type int or list (length of list <=1) if p = 0, found {}".
format(axis))
else:
return p_matrix_norm(
x, porder=p, axis=axis, keepdim=keepdim, name=name)
else:
raise ValueError(
"except axis type int or list (length of list <=2), found {}".
format(axis))
def dist(x, y, p=2, name=None):
r"""
This OP returns the p-norm of (x - y). It is not a norm in a strict sense, only as a measure
of distance. The shapes of x and y must be broadcastable. The definition is as follows, for
details, please refer to the `numpy's broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_:
- Each input has at least one dimension.
- Match the two input dimensions from back to front, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.
Where, z = x - y, the shapes of x and y are broadcastable, then the shape of z can be
obtained as follows:
1. If the number of dimensions of x and y are not equal, prepend 1 to the dimensions of the
tensor with fewer dimensions.
For example, The shape of x is [8, 1, 6, 1], the shape of y is [7, 1, 5], prepend 1 to the
dimension of y.
x (4-D Tensor): 8 x 1 x 6 x 1
y (4-D Tensor): 1 x 7 x 1 x 5
2. Determine the size of each dimension of the output z: choose the maximum value from the
two input dimensions.
z (4-D Tensor): 8 x 7 x 6 x 5
If the number of dimensions of the two inputs are the same, the size of the output can be
directly determined in step 2. When p takes different values, the norm formula is as follows:
When p = 0, defining $0^0=0$, the zero-norm of z is simply the number of non-zero elements of z.
.. math::
||z||_{0}=\lim_{p \\rightarrow 0}\sum_{i=1}^{m}|z_i|^{p}
When p = inf, the inf-norm of z is the maximum element of z.
.. math::
||z||_\infty=\max_i |z_i|
When p = -inf, the negative-inf-norm of z is the minimum element of z.
.. math::
||z||_{-\infty}=\min_i |z_i|
Otherwise, the p-norm of z follows the formula,
.. math::
||z||_{p}=(\sum_{i=1}^{m}|z_i|^p)^{\\frac{1}{p}}
Args:
x (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
p (float, optional): The norm to be computed, its data type is float32 or float64. Default: 2.
Returns:
Tensor: Tensor that is the p-norm of (x - y).
Examples:
.. code-block:: python
import paddle
import numpy as np
x = paddle.to_tensor(np.array([[3, 3],[3, 3]]), "float32")
y = paddle.to_tensor(np.array([[3, 3],[3, 1]]), "float32")
out = paddle.dist(x, y, 0)
print(out) # out = [1.]
out = paddle.dist(x, y, 2)
print(out) # out = [2.]
out = paddle.dist(x, y, float("inf"))
print(out) # out = [2.]
out = paddle.dist(x, y, float("-inf"))
print(out) # out = [0.]
"""
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'dist')
check_variable_and_dtype(y, 'dtype', ['float32', 'float64'], 'dist')
check_type(p, 'p', (float, int), 'dist')
helper = LayerHelper("dist", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
inputs = {"X": [x], "Y": [y]}
outputs = {'Out': [out]}
attrs = {"p": float(p)}
helper.append_op(
type='dist', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return out
def cond(x, p=None, name=None):
"""
Computes the condition number of a matrix or batches of matrices with respect to a matrix norm ``p``.
Args:
x (Tensor): The input tensor could be tensor of shape ``(*, m, n)`` where ``*`` is zero or more batch dimensions
for ``p`` in ``(2, -2)``, or of shape ``(*, n, n)`` where every matrix is invertible for any supported ``p``.
And the input data type could be ``float32`` or ``float64``.
p (float|string, optional): Order of the norm. Supported values are `fro`, `nuc`, `1`, `-1`, `2`, `-2`,
`inf`, `-inf`. Default value is `None`, meaning that the order of the norm is `2`.
name (str, optional): The default value is `None`. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: computing results of condition number, its data type is the same as input Tensor ``x``.
Examples:
.. code-block:: python
import paddle
import numpy as np
x = paddle.to_tensor([[1., 0, -1], [0, 1, 0], [1, 0, 1]])
# compute conditional number when p is None
out = paddle.linalg.cond(x)
# out.numpy() [1.4142135]
# compute conditional number when order of the norm is 'fro'
out_fro = paddle.linalg.cond(x, p='fro')
# out_fro.numpy() [3.1622777]
# compute conditional number when order of the norm is 'nuc'
out_nuc = paddle.linalg.cond(x, p='nuc')
# out_nuc.numpy() [9.2426405]
# compute conditional number when order of the norm is 1
out_1 = paddle.linalg.cond(x, p=1)
# out_1.numpy() [2.]
# compute conditional number when order of the norm is -1
out_minus_1 = paddle.linalg.cond(x, p=-1)
# out_minus_1.numpy() [1.]
# compute conditional number when order of the norm is 2
out_2 = paddle.linalg.cond(x, p=2)
# out_2.numpy() [1.4142135]
# compute conditional number when order of the norm is -1
out_minus_2 = paddle.linalg.cond(x, p=-2)
# out_minus_2.numpy() [0.70710677]
# compute conditional number when order of the norm is inf
out_inf = paddle.linalg.cond(x, p=np.inf)
# out_inf.numpy() [2.]
# compute conditional number when order of the norm is -inf
out_minus_inf = paddle.linalg.cond(x, p=-np.inf)
# out_minus_inf.numpy() [1.]
a = paddle.to_tensor(np.random.randn(2, 4, 4).astype('float32'))
# a.numpy()
# [[[ 0.14063153 -0.996288 0.7996131 -0.02571543]
# [-0.16303636 1.5534962 -0.49919784 -0.04402903]
# [-1.1341571 -0.6022629 0.5445269 0.29154757]
# [-0.16816919 -0.30972657 1.7521842 -0.5402487 ]]
# [[-0.58081484 0.12402827 0.7229862 -0.55046535]
# [-0.15178485 -1.1604939 0.75810957 0.30971205]
# [-0.9669573 1.0940945 -0.27363303 -0.35416734]
# [-1.216529 2.0018666 -0.7773689 -0.17556527]]]
a_cond_fro = paddle.linalg.cond(a, p='fro')
# a_cond_fro.numpy() [31.572273 28.120834]
b = paddle.to_tensor(np.random.randn(2, 3, 4).astype('float64'))
# b.numpy()
# [[[ 1.61707487 0.46829144 0.38130416 0.82546736]
# [-1.72710298 0.08866375 -0.62518804 0.16128892]
# [-0.02822879 -1.67764516 0.11141444 0.3220113 ]]
# [[ 0.22524372 0.62474921 -0.85503233 -1.03960523]
# [-0.76620689 0.56673047 0.85064753 -0.45158196]
# [ 1.47595418 2.23646462 1.5701758 0.10497519]]]
b_cond_2 = paddle.linalg.cond(b, p=2)
# b_cond_2.numpy() [3.30064451 2.51976252]
"""
def mat_norm(input, porder=1., axis=None):
"""
NOTE:
Calculate the matrix norm of a square matrix or batches of square matrices,
when porder is in (1, -1, inf, -inf)
"""
reduce_all = True if axis is None or axis == [] else False
axis = axis if axis != None and axis != [] else [0]
keepdim = False
if in_dygraph_mode():
abs_out = _C_ops.abs(input)
sum_out = _C_ops.reduce_sum(abs_out, 'dim', axis, 'keepdim',
keepdim, 'reduce_all', reduce_all)
if porder == 1 or porder == np.inf:
return _C_ops.reduce_max(sum_out, 'dim', [-1], 'keepdim',
keepdim, 'reduce_all', reduce_all)
if porder == -1 or porder == -np.inf:
return _C_ops.reduce_min(sum_out, 'dim', [-1], 'keepdim',
keepdim, 'reduce_all', reduce_all)
block = LayerHelper('norm', **locals())
abs_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
sum_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='abs', inputs={'X': input}, outputs={'Out': abs_out})
block.append_op(
type='reduce_sum',
inputs={'X': abs_out},
outputs={'Out': sum_out},
attrs={'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all})
if porder == 1 or porder == np.inf:
block.append_op(
type='reduce_max',
inputs={'X': sum_out},
outputs={'Out': out},
attrs={
'dim': [-1],
'keep_dim': keepdim,
'reduce_all': reduce_all
})
if porder == -1 or porder == -np.inf:
block.append_op(
type='reduce_min',
inputs={'X': sum_out},
outputs={'Out': out},
attrs={
'dim': [-1],
'keep_dim': keepdim,
'reduce_all': reduce_all
})
return out
def fro_norm(input, porder=2, axis=[-1]):
"""
NOTE:
Calculate the frobenius norm of a square matrix or batches of square matrices.
"""
reduce_all = True if axis is None or axis == [] else False
keepdim = False
if in_dygraph_mode():
pow_out = _C_ops.pow(input, 'factor', porder)
sum_out_1 = _C_ops.reduce_sum(pow_out, 'dim', axis, 'keepdim',
keepdim, 'reduce_all', reduce_all)
sum_out_2 = _C_ops.reduce_sum(sum_out_1, 'dim', axis, 'keepdim',
keepdim, 'reduce_all', reduce_all)
return _C_ops.pow(sum_out_2, 'factor', float(1. / porder))
block = LayerHelper('norm', **locals())
pow_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
sum_out_1 = block.create_variable_for_type_inference(
dtype=block.input_dtype())
sum_out_2 = block.create_variable_for_type_inference(
dtype=block.input_dtype())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='pow',
inputs={'X': input},
outputs={'Out': pow_out},
attrs={'factor': porder})
block.append_op(
type='reduce_sum',
inputs={'X': pow_out},
outputs={'Out': sum_out_1},
attrs={'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all})
block.append_op(
type='reduce_sum',
inputs={'X': sum_out_1},
outputs={'Out': sum_out_2},
attrs={'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all})
block.append_op(
type='pow',
inputs={'X': sum_out_2},
outputs={'Out': out},
attrs={'factor': float(1. / porder)})
return out
def svd_norm(input, porder, axis=[-1]):
"""
NOTE:
Calculate the matrix norm, which is related to singular values, of a matrix
or batches of matrices, including nuclear norm, 2-norm and (-2)-norm.
"""
reduce_all = True if axis is None or axis == [] else False
keepdim = False
u, s, vh = svd(input, full_matrices=False)
if in_dygraph_mode():
if porder == "nuc":
return _C_ops.reduce_sum(s, 'dim', axis, 'keepdim', keepdim,
'reduce_all', reduce_all)
max_out = _C_ops.reduce_max(s, 'dim', axis, 'keepdim', keepdim,
'reduce_all', reduce_all)
min_out = _C_ops.reduce_min(s, 'dim', axis, 'keepdim', keepdim,
'reduce_all', reduce_all)
if porder == 2:
return _C_ops.elementwise_div(max_out, min_out, 'aixs', axis,
'use_mkldnn', False)
if porder == -2:
return _C_ops.elementwise_div(min_out, max_out, 'aixs', axis,
'use_mkldnn', False)
block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
if porder == "nuc":
block.append_op(
type='reduce_sum',
inputs={'X': s},
outputs={'Out': out},
attrs={
'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all
})
return out
max_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
min_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='reduce_max',
inputs={'X': s},
outputs={'Out': max_out},
attrs={'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all})
block.append_op(
type='reduce_min',
inputs={'X': s},
outputs={'Out': min_out},
attrs={'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all})
if porder == 2:
block.append_op(
type='elementwise_div',
inputs={'X': max_out,
'Y': min_out},
outputs={'Out': out},
attrs={'aixs': axis,
'use_mkldnn': False})
return out
if porder == -2:
block.append_op(
type='elementwise_div',
inputs={'X': min_out,
'Y': max_out},
outputs={'Out': out},
attrs={'aixs': axis,
'use_mkldnn': False})
return out
def empty_tensor(input, shape):
if in_dygraph_mode():
return input.reshape(shape)
raise ValueError("only support x is nonempty tensor in static mode")
x_shape = list(x.shape)
if not len(x_shape) >= 2:
raise ValueError("input should be a matrix or batches of matrices, " +
"but the dimention of received input is {}".format(
len(x_shape)))
if p == None:
p = 2
x_size = 0 if (0 in x_shape) else 1
if p in ("fro", "nuc", 1, -1, np.inf, -np.inf):
if x_shape[len(x_shape) - 1] == x_shape[len(x_shape) - 2]:
if x_size == 0:
return empty_tensor(x, x_shape[:-2])
x_inv = x.inverse()
if p == "fro":
return fro_norm(x) * fro_norm(x_inv)
if p == "nuc":
return svd_norm(x, p) * svd_norm(x_inv, p)
if p in (1, -1):
return mat_norm(
x, porder=p, axis=[-2]) * mat_norm(
x_inv, porder=p, axis=[-2])
if p in (np.inf, -np.inf):
return mat_norm(
x, porder=p, axis=[-1]) * mat_norm(
x_inv, porder=p, axis=[-1])
else:
raise ValueError("only support p is {} when input is a ".format(p) +
"square matrix or batches of square matrices")
elif p in (2, -2):
if x_size == 0:
return empty_tensor(x, x_shape[:-2])
return svd_norm(x, porder=p)
else:
raise ValueError(
"unsupported {} for p, only supporting ('fro', 'nuc', ".format(
p) + "1, -1, 2, -2, inf, -inf) or none")
def dot(x, y, name=None):
"""
This operator calculates inner product for vectors.
.. note::
Support 1-d and 2-d Tensor. When it is 2d, the first dimension of this matrix
is the batch dimension, which means that the vectors of multiple batches are dotted.
Parameters:
x(Tensor): 1-D or 2-D ``Tensor``. Its dtype should be ``float32``, ``float64``, ``int32``, ``int64``
y(Tensor): 1-D or 2-D ``Tensor``. Its dtype soulde be ``float32``, ``float64``, ``int32``, ``int64``
name(str, optional): Name of the output. Default is None. It's used to print debug info for developers. Details: :ref:`api_guide_Name`
Returns:
Tensor: the calculated result Tensor.
Examples:
.. code-block:: python
import paddle
import numpy as np
x_data = np.random.uniform(0.1, 1, [10]).astype(np.float32)
y_data = np.random.uniform(1, 3, [10]).astype(np.float32)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
z = paddle.dot(x, y)
print(z)
"""
op_type = 'dot'
# skip var type check in dygraph mode to improve efficiency
if in_dygraph_mode():
op = getattr(_C_ops, op_type)
return op(x, y)
assert x is not None, 'x cannot be None in {}'.format(op_type)
assert y is not None, 'y cannot be None in {}'.format(op_type)
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
op_type)
check_variable_and_dtype(y, 'y', ['float32', 'float64', 'int32', 'int64'],
op_type)
helper = LayerHelper(op_type, **locals())
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="dot", inputs={'X': x,
'Y': y}, attrs={}, outputs={"Out": out})
return out
def t(input, name=None):
"""
Transpose <=2-D tensor.
0-D and 1-D tensors are returned as it is and 2-D tensor is equal to
the paddle.transpose function which perm dimensions set 0 and 1.
Args:
input (Tensor): The input Tensor. It is a N-D (N<=2) Tensor of data types float16, float32, float64, int32.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Tensor: A transposed n-D Tensor, with data type being float16, float32, float64, int32, int64.
For Example:
.. code-block:: text
# Example 1 (0-D tensor)
x = tensor([0.79])
paddle.t(x) = tensor([0.79])
# Example 2 (1-D tensor)
x = tensor([0.79, 0.84, 0.32])
paddle.t(x) = tensor([0.79, 0.84, 0.32])
# Example 3 (2-D tensor)
x = tensor([0.79, 0.84, 0.32],
[0.64, 0.14, 0.57])
paddle.t(x) = tensor([0.79, 0.64],
[0.84, 0.14],
[0.32, 0.57])
Examples:
.. code-block:: python
import paddle
x = paddle.ones(shape=[2, 3], dtype='int32')
x_transposed = paddle.t(x)
print(x_transposed.shape)
# [3, 2]
"""
if len(input.shape) > 2:
raise ValueError(
"Input(input) only support N-D (N<=2) tensor, but received "
"length of Input(input) is %s. Perhaps you can use paddle."
"tensor.transpose() instead." % len(input.shape))
if in_dygraph_mode():
if len(input.shape) == 1:
return input
# 2-D tensor
perm = [1, 0]
out, _ = _C_ops.transpose2(input, 'axis', perm)
return out
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64', 'int32',
'int64'], 'transpose')
helper = LayerHelper('t', **locals())
out = helper.create_variable_for_type_inference(input.dtype)
input_shape = helper.create_variable_for_type_inference(input.dtype)
if len(input.shape) == 1:
out = input
else:
helper.append_op(
type='transpose2',
inputs={'X': [input]},
outputs={'Out': [out],
'XShape': [input_shape]},
attrs={'axis': [1, 0]})
return out
def cross(x, y, axis=None, name=None):
"""
Computes the cross product between two tensors along an axis.