-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
base.py
936 lines (737 loc) · 30.5 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import decorator
import contextlib
import functools
import inspect
import sys
import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.framework import global_var
from paddle.fluid.multiprocess_utils import CleanupFuncRegistrar
from .tracer import Tracer
import logging
from ..data_feeder import convert_dtype
import warnings
from ..framework import _get_paddle_place
import paddle
import warnings
__all__ = [
'no_grad',
'no_grad_',
'grad',
'guard',
'enable_dygraph',
'disable_dygraph',
'enabled',
'to_variable',
]
NON_PERSISTABLE_VAR_NAME_SUFFIX = "__non_persistable"
def in_declarative_mode():
"""
Return a bool value that indicates whether running code under `@to_static`
"""
return global_var._in_declarative_mode_
def declarative_unsupport_argument_warning(
func_name, input_names, inputs, support_values
):
"""
Warning if inputs do not elementwisely equals to support_values.
It's a utility function for dy2static when dygraph interface have
more inputs than static interface such as paddle.grad.
"""
for name, inp, sup in zip(input_names, inputs, support_values):
if inp != sup:
warnings.warn(
f"{func_name} has unsupported parameter in jit: "
+ f"{name}, jit will discard it"
)
def _switch_to_static_graph_(func):
def __impl__(*args, **kwargs):
with framework._dygraph_guard(None):
return func(*args, **kwargs)
return __impl__
switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)
@signature_safe_contextmanager
def _switch_declarative_mode_guard_(is_declarative=True):
global global_var
original_val = global_var._in_declarative_mode_
global_var._in_declarative_mode_ = is_declarative
yield
global_var._in_declarative_mode_ = original_val
@signature_safe_contextmanager
def program_desc_tracing_guard(enable):
tracer = framework._dygraph_tracer()
if tracer:
original_val = tracer._enable_program_desc_tracing
tracer._enable_program_desc_tracing = enable
try:
yield
finally:
if tracer:
tracer._enable_program_desc_tracing = original_val
@signature_safe_contextmanager
def param_guard(parameters):
# Note: parameters is a reference of self._parameters or self._buffers
if in_declarative_mode() and not paddle.in_dynamic_mode() and parameters:
origin_parameters = parameters.copy()
for name, var_base in parameters.items():
if isinstance(var_base, list):
new_var = [_convert_into_variable(var) for var in var_base]
else:
new_var = _convert_into_variable(var_base)
parameters[name] = new_var
yield
parameters.update(origin_parameters)
else:
yield
def _convert_into_variable(tensor):
"""
Convert Tensor into Variable.
"""
if isinstance(tensor, core.eager.Tensor):
# Check whether has been created before.
new_var = tensor.block._find_var_recursive(tensor.name)
if new_var is not None:
assert isinstance(new_var, framework.Variable)
# Convert EagerParamBase into Parameter with same attributes in dy2stat.
elif isinstance(tensor, framework.EagerParamBase):
new_var = tensor._to_static_var(to_parameter=True)
else:
# Note(Aurelius84): Convert Tensor in self._buffers into Variable with
# same attributes and set persistable=True to allow saving this var.
# Because users can create a Tensor in `__init__` like a
# `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter
# and necessary for inferring. It will be pruned if it's not necessary for inferring.
# But if its shape is empty while created from `create_variable()`, we consider this buffer
# non-persistable. See case of `dropout_state` in lstm api.
is_persistable = True
if tensor.name.endswith(NON_PERSISTABLE_VAR_NAME_SUFFIX):
is_persistable = False
new_var = tensor._to_static_var(
to_parameter=False, persistable=is_persistable
)
# add param into parameter recorder to collect all the params used in this program.
if new_var.persistable is True:
from paddle.jit.dy2static.program_translator import (
ProgramTranslator,
)
ProgramTranslator.get_instance()._params_recorder.add(
tensor.block.program, tensor
)
return new_var
else:
return tensor
def enabled():
"""
This function checks whether the program runs in dynamic graph mode or not.
You can enter dynamic graph mode with :ref:`api_fluid_dygraph_guard` api,
or enable and disable dynamic graph mode with :ref:`api_fluid_dygraph_enable_dygraph`
and :ref:`api_fluid_dygraph_disable_dygraph` api .
**Note**:
``fluid.dygraph.enabled`` is the alias of ``fluid.in_dygraph_mode``, and
``fluid.in_dygraph_mode`` is recommended to use for now.
Returns:
bool: Whether the program is running in dynamic graph mode.
Examples:
.. code-block:: python
import paddle.fluid as fluid
fluid.enable_dygraph() # Now we are in dygragh mode
print(fluid.dygraph.enabled()) # True
fluid.disable_dygraph()
print(fluid.dygraph.enabled()) # False
"""
# TODO(jiabin): Make this check as in_dygraph_mode when we support default eager mode.
return framework.in_dygraph_mode()
def enable_dygraph(place=None):
"""
.. note::
Dynamic graph mode is turn ON by default since paddle 2.0.0
This API turn OFF static graph mode. You can turn ON static graph mode by `enable_static <./disable_dygraph_en.html>`_ .
Parameters:
place(paddle.CPUPlace|paddle.CUDAPlace|str, optional): Place to run dynamic graph. Default: None. Which means that the running place will be
determined according to the way of paddle compilation. If ``place`` is string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the
index of the GPUs.
return:
None
Examples:
.. code-block:: python
import paddle
print(paddle.in_dynamic_mode()) # True, dynamic mode is turn ON by default since paddle 2.0.0
paddle.enable_static()
print(paddle.in_dynamic_mode()) # False, Now we are in static graph mode
paddle.disable_static()
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
"""
global global_var
if global_var._functional_dygraph_context_manager is None:
global_var._functional_dygraph_context_manager = guard(
place=_get_paddle_place(place)
)
global_var._functional_dygraph_context_manager.__enter__()
# call disable_dygraph when Python exit
CleanupFuncRegistrar.register(disable_dygraph)
def disable_dygraph():
"""
.. note::
Dynamic graph mode is turn ON by default since paddle 2.0.0
This API turn ON static graph mode. You can turn ON static graph mode by `disable_static <./enable_dygraph_en.html>`_ .
return:
None
Examples:
.. code-block:: python
import paddle
print(paddle.in_dynamic_mode()) # True, dynamic mode is turn ON by default since paddle 2.0.0
paddle.enable_static()
print(paddle.in_dynamic_mode()) # False, Now we are in static graph mode
paddle.disable_static()
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
"""
global global_var
if global_var._functional_dygraph_context_manager is not None:
global_var._functional_dygraph_context_manager.__exit__(*sys.exc_info())
global_var._functional_dygraph_context_manager = None
@signature_safe_contextmanager
def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer()
if tracer:
has_grad = tracer._has_grad
tracer._has_grad = is_train
try:
yield
finally:
tracer._has_grad = has_grad
else:
yield
def no_grad(func=None):
"""
:api_attr: imperative
Create a context which disables dygraph gradient calculation.
In this mode, the result of every computation will have `stop_gradient=True`.
Also functions as a decorator. (Make sure to instantiate without parenthesis.)
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
# use as generator
data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard():
l0 = fluid.Linear(2, 2) # l0.weight.gradient() is None
l1 = fluid.Linear(2, 2)
with fluid.dygraph.no_grad():
# l1.weight.stop_gradient is False
tmp = l1.weight * 2 # tmp.stop_gradient is True
x = fluid.dygraph.to_variable(data)
y = l0(x) + tmp
o = l1(y)
o.backward()
print(tmp.gradient() is None) # True
print(l0.weight.gradient() is None) # False
# use as decorator
@fluid.dygraph.no_grad
def test_layer():
with fluid.dygraph.guard():
inp = np.ones([3, 1024], dtype='float32')
t = fluid.dygraph.base.to_variable(inp)
linear1 = fluid.Linear(1024, 4, bias_attr=False)
linear2 = fluid.Linear(4, 4)
ret = linear1(t)
dy_ret = linear2(ret)
test_layer()
"""
if in_declarative_mode():
warnings.warn(
"paddle.no_grad is only supported for inference model, and not supported for training under @to_static."
)
if func is None:
return _switch_tracer_mode_guard_(is_train=False)
else:
@decorator.decorator
def __impl__(func, *args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs)
return __impl__(func)
class _DecoratorContextManager:
"""Allow a context manager to be used as a decorator"""
def __call__(self, func):
@decorator.decorator
def _decorate_function(func, *args, **kwargs):
with self:
return func(*args, **kwargs)
@decorator.decorator
def _decorate_generator(func, *args, **kwargs):
gen = func(*args, **kwargs)
with self:
for x in gen:
yield x
if inspect.isgeneratorfunction(func):
return _decorate_generator(func)
else:
return _decorate_function(func)
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_value, traceback):
raise NotImplementedError
def clone(self):
# override this method if your children class takes __init__ parameters
return self.__class__()
def is_grad_enabled():
"""
Returns whether current dygraph gradient calculation mode is enabled.
Returns:
bool: True if current dygraph gradient calculation mode is enabled, otherwise false.
Examples:
.. code-block:: python
import paddle
# Dygraph gradient calculation mode is enabled by default.
paddle.is_grad_enabled() # True
with paddle.set_grad_enabled(False):
paddle.is_grad_enabled() # False
paddle.enable_static()
paddle.is_grad_enabled() # False
"""
tracer = framework._dygraph_tracer()
return tracer._has_grad if tracer else False
def _set_grad_enabled(mode):
tracer = framework._dygraph_tracer()
if tracer:
tracer._has_grad = mode
class set_grad_enabled(_DecoratorContextManager):
"""
Create a context which enables or disables dygraph gradient calculation.
Args:
mode(bool): whether to enable (`True`), or disable (`False`) grad.
Returns:
None.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1.], stop_gradient=False)
is_train = False
with paddle.set_grad_enabled(is_train):
y = x * 2
assert(y.stop_gradient == True)
paddle.set_grad_enabled(True)
y = x * 2
assert(y.stop_gradient == False)
paddle.set_grad_enabled(False)
y = x * 2
assert(y.stop_gradient == True)
"""
def __init__(self, mode):
self.prev = is_grad_enabled()
_set_grad_enabled(mode)
self.mode = mode
def __enter__(self):
...
def __exit__(self, *args):
_set_grad_enabled(self.prev)
def clone(self):
return self.__class__(self.mode)
class no_grad_(_DecoratorContextManager):
"""
:api_attr: imperative
Create a context which disables dygraph gradient calculation.
In this mode, the result of every computation will have `stop_gradient` set
to `True`.
Also functions as a decorator. (Make sure to use an instance.)
Examples:
.. code-block:: python
import numpy as np
import paddle
# use as generator
data = np.array([[2, 3], [4, 5]]).astype('float32')
l0 = paddle.nn.Linear(2, 2) # l0.weight.gradient() is None
l1 = paddle.nn.Linear(2, 2)
with paddle.no_grad():
# l1.weight.stop_gradient is False
tmp = l1.weight * 2 # tmp.stop_gradient is True
x = paddle.to_tensor(data)
y = l0(x) + tmp
o = l1(y)
o.backward()
print(tmp.gradient() is None) # True
print(l0.weight.gradient() is None) # False
# use as decorator
@paddle.no_grad()
def test_layer():
inp = np.ones([3, 1024], dtype='float32')
t = paddle.to_tensor(inp)
linear1 = paddle.nn.Linear(1024, 4, bias_attr=False)
linear2 = paddle.nn.Linear(4, 4)
ret = linear1(t)
dy_ret = linear2(ret)
test_layer()
"""
def __enter__(self):
self.prev = is_grad_enabled()
_set_grad_enabled(False)
def __exit__(self, *args):
_set_grad_enabled(self.prev)
class enable_grad(_DecoratorContextManager):
"""
:api_attr: imperative
Create a context which enable dygraph gradient calculation,
if it has been disabled by `no_grad` or `set_grad_enabled`.
In this mode, the result of every computation will have `stop_gradient` set
to `False`.
Also functions as a decorator. (Make sure to use an instance.)
Examples:
.. code-block:: python
import paddle
# use as generator
x = paddle.to_tensor([1.], stop_gradient=False)
with paddle.no_grad():
with paddle.enable_grad():
y = x * 2
assert(y.stop_gradient == False)
y.backward()
assert(x.grad is not None)
# use as decorator
@paddle.enable_grad()
def double(x):
return x * 2
with paddle.no_grad():
z = double(x)
assert(z.stop_gradient == False)
"""
def __enter__(self):
self.prev = is_grad_enabled()
_set_grad_enabled(True)
def __exit__(self, *args):
_set_grad_enabled(self.prev)
@signature_safe_contextmanager
def guard(place=None):
"""
:api_attr: imperative
This context will create a dygraph context for dygraph to run, using python ``with`` statement.
Parameters:
place(fluid.CPUPlace| fluid.CUDAPlace|str, optional): Place to execute dygraph.
If None, the running place will be determined according to the way of paddle compilation.
If ``place`` is string, It can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
index of the GPUs or XPUs. Default: None
return:
None
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
with fluid.dygraph.guard():
inp = np.ones([3, 1024], dtype='float32')
t = fluid.dygraph.base.to_variable(inp)
linear1 = fluid.Linear(1024, 4, bias_attr=False)
linear2 = fluid.Linear(4, 4)
ret = linear1(t)
dy_ret = linear2(ret)
"""
train = framework.Program()
startup = framework.Program()
tracer = Tracer()
if place is not None:
expected_place = _get_paddle_place(place)
else:
expected_place = framework._current_expected_place()
with framework.program_guard(train, startup):
with framework.unique_name.guard():
with framework._dygraph_guard(tracer):
with framework._dygraph_place_guard(expected_place):
yield
@framework.non_static_only
def grad(
outputs,
inputs,
grad_outputs=None,
retain_graph=None,
create_graph=False,
only_inputs=True,
allow_unused=False,
no_grad_vars=None,
):
'''
.. note::
**This API is ONLY available in imperative mode.**
This API computes the sum of gradients of `outputs` with respect to each `inputs` .
Parameters:
outputs (Tensor|list(Tensor)|tuple(Tensor)): the output Tensor or
Tensor list/tuple of the graph to compute gradients.
inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
Tensor list/tuple of the graph to compute gradients. The returned
values of this API are the gradients of `inputs` .
grad_outputs (Tensor|list(Tensor|None)|tuple(Tensor|None), optional):
initial gradient values of `outputs` . If `grad_outputs` is None,
the initial gradient values of `outputs` would be Tensors filled with 1;
if `grad_outputs` is not None, it must have the same length as `outputs` ,
and in this case, the initial gradient value of the i-th `outputs` would
be: (1) a Tensor filled with 1 when the i-th element of `grad_outputs`
is None; (2) the i-th element of `grad_outputs` when the i-th element of
`grad_outputs` is a Tensor. Default None.
retain_graph (bool, optional): whether to retain the forward graph which
is used to calculate the gradient. When it is True, the graph would
be retained, in which way users can calculate backward twice for the
same graph. When it is False, the graph would be freed. Default None,
which means it is equal to `create_graph` .
create_graph (bool, optional): whether to create the gradient graphs of
the computing process. When it is True, higher order derivatives are
supported to compute; when it is False, the gradient graphs of the
computing process would be discarded. Default False.
only_inputs (bool, optional): whether to only compute the gradients of
`inputs` . If it is False, the gradients of all remaining leaf
Tensors in the graph would be also computed and accumulated.
If it is True, only the gradients of `inputs` would be computed.
Default True. only_inputs=False is under development, and it is
not supported yet.
allow_unused (bool, optional): whether to raise error or return None if some
Tensors of `inputs` are unreachable in the graph. If some Tensors of
`inputs` are unreachable in the graph (i.e., their gradients are None),
error would be raised if allow_unused=False, or None would be returned as
their gradients if allow_unused=True. Default False.
no_grad_vars (Tensor|list(Tensor)|tuple(Tensor)|set(Tensor), optional):
the Tensors whose gradients are not needed to compute. Default None.
Returns:
list: a list of Tensors, whose length is the same as the Tensor number
inside `inputs`, and the i-th returned Tensor is the sum of gradients of
`outputs` with respect to the i-th `inputs`.
Examples:
.. code-block:: python
:name: code-example-1
import paddle
def test_dygraph_grad(create_graph):
x = paddle.ones(shape=[1], dtype='float32')
x.stop_gradient = False
y = x * x
# Since y = x * x, dx = 2 * x
dx = paddle.grad(
outputs=[y],
inputs=[x],
create_graph=create_graph,
retain_graph=True)[0]
z = y + dx
# If create_graph = False, the gradient of dx
# would not be backpropagated. Therefore,
# z = x * x + dx, and x.gradient() = 2 * x = 2.0
# If create_graph = True, the gradient of dx
# would be backpropagated. Therefore,
# z = x * x + dx = x * x + 2 * x, and
# x.gradient() = 2 * x + 2 = 4.0
z.backward()
return x.gradient()
print(test_dygraph_grad(create_graph=False)) # [2.]
print(test_dygraph_grad(create_graph=True)) # [4.]
.. code-block:: python
:name: code-example-2
import paddle
def test_dygraph_grad(grad_outputs=None):
x = paddle.to_tensor(2.0)
x.stop_gradient = False
y1 = x * x
y2 = x * 3
# If grad_outputs=None, dy1 = [1], dy2 = [1].
# If grad_outputs=[g1, g2], then:
# - dy1 = [1] if g1 is None else g1
# - dy2 = [1] if g2 is None else g2
# Since y1 = x * x, dx = 2 * x * dy1.
# Since y2 = x * 3, dx = 3 * dy2.
# Therefore, the final result would be:
# dx = 2 * x * dy1 + 3 * dy2 = 4 * dy1 + 3 * dy2.
dx = paddle.grad(
outputs=[y1, y2],
inputs=[x],
grad_outputs=grad_outputs)[0]
return dx.numpy()
grad_value = paddle.to_tensor(4.0)
# dy1 = [1], dy2 = [1]
print(test_dygraph_grad(None)) # [7.]
# dy1 = [1], dy2 = [4]
print(test_dygraph_grad([None, grad_value])) # [16.]
# dy1 = [4], dy2 = [1]
print(test_dygraph_grad([grad_value, None])) # [19.]
# dy1 = [3], dy2 = [4]
grad_y1 = paddle.to_tensor(3.0)
print(test_dygraph_grad([grad_y1, grad_value])) # [24.]
'''
if in_declarative_mode():
# In dy2static context, we call static interface `gradients`
# to calculate grads.
from paddle.static import gradients
declarative_unsupport_argument_warning(
"paddle.grad",
["retain_graph", "create_grad", "only_inputs", "allow_unused"],
[retain_graph, create_graph, only_inputs, allow_unused],
[None, False, True, False],
)
return gradients(outputs, inputs, grad_outputs, no_grad_vars)
def check_in_out(in_out_list, name):
assert in_out_list is not None, "{} should not be None".format(name)
if isinstance(in_out_list, (list, tuple)):
assert len(in_out_list) > 0, "{} cannot be empty".format(name)
for each_var in in_out_list:
assert isinstance(
each_var, core.eager.Tensor
), "Elements of {} must be Tensor".format(name)
return in_out_list
else:
assert isinstance(
in_out_list, core.eager.Tensor
), "{} must be Tensor or list of Tensor".format(name)
return [in_out_list]
outputs = check_in_out(outputs, 'outputs')
inputs = check_in_out(inputs, 'inputs')
if grad_outputs is not None:
if not isinstance(grad_outputs, (list, tuple)):
grad_outputs = [grad_outputs]
for each_var in grad_outputs:
if each_var is not None:
assert isinstance(
each_var, core.eager.Tensor
), "grad_outputs must be None, a Variable or a list containing None or Variables"
else:
grad_outputs = []
if len(grad_outputs) > 0:
assert len(grad_outputs) == len(
outputs
), "The length of grad_outputs must be equal to outputs"
if no_grad_vars is None:
no_grad_vars = []
elif isinstance(no_grad_vars, core.eager.Tensor):
no_grad_vars = [no_grad_vars]
elif isinstance(no_grad_vars, core.eager.Tensor):
no_grad_vars = [no_grad_vars]
elif isinstance(no_grad_vars, (list, tuple, set)):
no_grad_vars = list(no_grad_vars)
for var in no_grad_vars:
assert isinstance(
var, core.eager.Tensor
), "no_grad_vars can only contains Tensor"
else:
raise AssertionError(
"no_grad_vars must be None, Tensor or list/tuple/set of Tensors"
)
assert isinstance(create_graph, bool), "create_graph must be True or False"
if retain_graph is None:
retain_graph = create_graph
assert isinstance(
retain_graph, bool
), "retain_graph must be None, True or False"
assert isinstance(allow_unused, bool), "allow_unused must be True or False"
assert isinstance(only_inputs, bool), "only_inputs must be True or False"
assert only_inputs, "only_inputs=False is not supported yet"
return core.eager.run_partial_grad(
outputs,
inputs,
grad_outputs,
retain_graph,
create_graph,
only_inputs,
allow_unused,
no_grad_vars,
)
@framework.dygraph_only
def to_variable(value, name=None, zero_copy=None, dtype=None):
r"""
:api_attr: imperative
The API will create a ``Variable`` object from
tuple, list, numpy\.ndarray or Variable object.
Parameters:
value(tuple|list|ndarray|Variable|Tensor): Initial data.
Can be a list, tuple, NumPy ndarray, Variable, Tensor.
The shape can be multi-dimensional. The data type is one of
numpy\.{float16, float32, float64, int16, int32, int64,
uint8, uint16, complex64, complex128}.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name` .
zero_copy(bool, optional): Whether to share memory with the input numpy
array. This parameter only works with CPUPlace and will be set to
True when it is None. Default: None. (Note: zero_copy is discarded temporally for some reason.)
dtype(str, optional): The desired data type of returned ``Variable`` .
Can be 'bool' , 'float16' , 'float32' , 'float64' , 'int8' , 'int16' ,
'int32' , 'int64' , 'uint8' . Default: None.
Returns:
Variable : If ``value`` is a tuple/list/numpy\.ndarray object,
return ``Tensor`` created from the corresponding numpy\.ndarray object, which has
same data type and shape with ``value``.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
with fluid.dygraph.guard(fluid.CPUPlace()):
x = np.ones([2, 2], np.float32)
y = fluid.dygraph.to_variable(x, zero_copy=False)
x[0][0] = -1
y[0][0].numpy() # array([1.], dtype=float32)
y = fluid.dygraph.to_variable(x)
x[0][0] = 0
y[0][0].numpy() # array([0.], dtype=float32)
c = np.array([2+1j, 2])
z = fluid.dygraph.to_variable(c)
z.numpy() # array([2.+1.j, 2.+0.j])
z.dtype # 'complex128'
y = fluid.dygraph.to_variable([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])
y.shape # [3L, 2L]
y = fluid.dygraph.to_variable(((0.1, 1.2), (2.2, 3.1), (4.9, 5.2)), dtype='int32')
y.shape # [3L, 2L]
"""
support_type = (
list,
tuple,
np.ndarray,
core.eager.Tensor,
framework.Variable,
core.Tensor,
core.LoDTensor,
)
if not isinstance(value, support_type):
raise TypeError(
"The type of 'value' in fluid.dygraph.to_variable must be %s, but received %s."
% (support_type, type(value))
)
if isinstance(value, (core.eager.Tensor, framework.Variable)):
return value
elif isinstance(value, (core.Tensor, core.LoDTensor)):
return core.eager.Tensor(value)
else:
if isinstance(
framework._current_expected_place(), framework.core.CPUPlace
):
# TODO(zhiqiu): we found two problems when enable zero_copy on CPUPlace.
# (1): eigen requires 16-bytes alignments, but the data of numpy array may not statisfy.
# Details: https://eigen.tuxfamily.org/dox/group__TopicUnalignedArrayAssert.html
# (2): when used in flask framework, it may result in hang.
# Details: https://github.com/PaddlePaddle/Paddle/issues/26635
# So, we temporally diable the zero_copy strategy.
if zero_copy == True:
warnings.warn(
"Currently, zero_copy is not supported, and it will be discarded."
)
zero_copy = False
else:
assert (
not zero_copy
), "zero_copy mode can only be used with CPUPlace"
if not isinstance(value, np.ndarray):
value = np.array(value)
if dtype is not None:
dtype = convert_dtype(dtype)
if value.dtype != dtype:
value = value.astype(dtype)
return core.eager.Tensor(
value,
framework._current_expected_place(),
False,
zero_copy,
name if name else None,
True,
)