-
Notifications
You must be signed in to change notification settings - Fork 60
/
__init__.py
984 lines (809 loc) · 36.6 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
from functools import wraps
from typing import Any
from collections import defaultdict, namedtuple
from collections.abc import Callable
from collections.abc import Sequence
from contextvars import ContextVar
import os
import dis
import time
import warnings
from looseversion import LooseVersion
from thunder.core.module import ThunderModule
from thunder.core.interpreter import InterpreterLogItem
from thunder.core.options import (
INTERPRETATION_OPTIONS,
resolve_interpretation_option,
resolve_sharp_edges_option,
CACHE_OPTIONS,
SHARP_EDGES_OPTIONS,
)
from thunder.core.trace import (
TraceResults,
TraceCtx,
from_trace,
set_tracectx,
reset_tracectx,
is_tracing,
)
from thunder import functional as functional
import thunder.core.prims as prims
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.transform_common import dce, EarlyTransform, AdditionalTransform, PostOptimizationTransform
from thunder.common import (
CompileData,
CompileStats,
_create_callable,
trace,
transform_for_execution,
)
import thunder.extend as extend
from thunder.extend import Executor, add_default_executor
from thunder.core.compile_data import compile_data_and_stats, get_compile_data
from thunder.core.langctxs import LanguageContext
import thunder.core.langctxs as langctxs
from thunder.core.baseutils import run_once, check
from thunder.core.codeutils import Positions
from thunder.core.proxies import (
Proxy,
TensorProxy,
NumberProxy,
StringProxy,
IntegerProxy,
FloatProxy,
ComplexProxy,
TupleProxy,
ListProxy,
DictProxy,
AnyProxy,
)
from thunder.core.interpreter import print_interpreter_log, print_to_log
from thunder.core.jit_ext import thunder_general_jit
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction
from thunder.cudagraphs import CUDAGraphExecutor
# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
import torch as pytorch
import thunder.clang as clang
# Imports executors (to populate default executors and make them accessible)
import thunder.executors.pythonex
import thunder.executors.torchex
import thunder.executors.nvfuserex
pythonex = extend.get_executor("python")
assert pythonex is not None
_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
# TODO RC1 Review exposed names
__all__ = [
# dtype aliases
"bool8",
"uint8",
"int8",
"int16",
"int32",
"int64",
"bfloat16",
"float8_e5m2",
"float8_e5m2fnuz",
"float8_e4m3fn",
"float8_e4m3fnuz",
"float16",
"float32",
"float64",
"complex32",
"complex64",
"complex128",
# language aliases
"torch",
"numpy",
"prims",
# interface functions
# TODO Extend this
# TODO Add device aliases
# TODO Add executor aliases
"cudnn_executor",
"sdpa_executor",
"nvfuser_executor",
"pytorch_executor",
# debugging functions
"set_execution_callback_file",
]
def __version__():
return LooseVersion("0.0.1")
# TODO maybe move these aliases to the core language?
#
# dtype aliases
#
bool8 = dtypes.bool8
uint8 = dtypes.uint8
int8 = dtypes.int8
int16 = dtypes.int16
int32 = dtypes.int32
int64 = dtypes.int64
bfloat16 = dtypes.bfloat16
float8_e5m2 = dtypes.float8_e5m2
float8_e5m2fnuz = dtypes.float8_e5m2fnuz
float8_e4m3fn = dtypes.float8_e4m3fn
float8_e4m3fnuz = dtypes.float8_e4m3fnuz
float16 = dtypes.float16
float32 = dtypes.float32
float64 = dtypes.float64
complex32 = dtypes.complex32
complex64 = dtypes.complex64
complex128 = dtypes.complex128
#
# Module aliases
#
# NOTE this allows clang.foo() to be called directly as thunder.foo()
from thunder.clang import *
#
# Promoted executor-related functions and objects
#
# TODO Add more of these functions
resolve_executors = extend.resolve_executors
add_executor_lists = extend.add_executor_lists
get_executor = extend.get_executor
get_all_executors = extend.get_all_executors
get_default_executors = extend.get_default_executors
get_always_executors = extend.get_always_executors
cudnn_executor: None | extend.Executor = extend.get_executor("cudnn")
sdpa_executor: None | extend.Executor = extend.get_executor("sdpa")
nvfuser_executor: None | extend.Executor = extend.get_executor("nvfuser")
pytorch_executor: None | extend.Executor = extend.get_executor("torch")
# Default executor list is [cudnn -> sdpa -> nvfuser -> torch -> python]
# Note that add_default_executor inserts executor at start of list, hence the reverse order below.
if nvfuser_executor:
add_default_executor(nvfuser_executor)
if sdpa_executor:
add_default_executor(sdpa_executor)
if cudnn_executor:
add_default_executor(cudnn_executor)
#
# Promoted debugging functions
#
# If set, Python programs will be written to this file before being executed, and if the
# the file is modified then the modified version of the program will be compiled and executed, instead.
from thunder.core.trace import _set_execution_file
set_execution_callback_file = _set_execution_file
# Translates the Python function to a thunder program using the thunder interpreter
def _general_frontend(
fn: Callable,
args: tuple[Any, ...],
kwargs: dict[str, Any],
/,
*,
record_history: bool,
sharp_edges: SHARP_EDGES_OPTIONS,
) -> TraceResults:
return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
# this captures the information needed to decide whether a cached function
# matches (e.g. ddp and autocast state)
_cache_info_ctx = ContextVar("cache_info_ctx")
def _with_cache_info_ctx(fn):
def cache_info_wrapper(*args, **kwargs):
tok = _cache_info_ctx.set({})
try:
res = fn(*args, **kwargs)
finally:
_cache_info_ctx.reset(tok)
return res
return cache_info_wrapper
def _get_cache_info():
return _cache_info_ctx.get()
def add_executor_lists(
exc_list: None | Sequence[Executor | str], other_exc_list: None | Sequence[Executor | str]
) -> Sequence[Executor]:
new_exc_list = []
exc_list = resolve_executors(exc_list)
other_exc_list = resolve_executors(other_exc_list)
for exc in itertools.chain(exc_list, other_exc_list):
if not exc in new_exc_list:
new_exc_list.append(exc)
return new_exc_list
@run_once
def _recursive_jit_call_warning() -> None:
warnings.warn(
"Calling a jitted function from a jitted function currently uses all settings from the caller. In the future this behavior may change."
)
CacheEntry = namedtuple(
"CacheEntry",
[
"prologue_fn",
"prologue_traces",
"computation_fn",
"computation_traces",
"epilogue_fn",
"epilogue_traces",
"backward_fn",
"backward_traces",
"return_none_instead_of_grads",
],
)
# This function will replace compile() (below) before RC1
# TODO RC1 Consider adding a debug_log parameter to control debug printing
# TODO RC1 Consider renaming compile_options to additional_compile_options
def jit(
fn: Callable,
/,
*,
langctx: None | str | Any | LanguageContext = None,
executors: None | Sequence[Executor | str] = None,
sharp_edges: None | SHARP_EDGES_OPTIONS | str = None,
interpretation: None | INTERPRETATION_OPTIONS | str = None,
cache: None | CACHE_OPTIONS | str = None,
disable_torch_autograd: bool = False, # TODO Revisit this UX for RC1
early_transforms: list[EarlyTransform] | None = None,
additional_transforms: list[AdditionalTransform] | None = None,
post_optimization_transforms: list[PostOptimizationTransform] | None = None,
record_history: bool = False,
**compile_options, # TODO RC1 Make this explicit -- dict of options
) -> Callable:
"""Just-in-time compile a callable (function or model).
Args:
fn: A :class:`~torch.nn.Module` or a function to compile.
Keyword Args:
langctx: the language context, which language / library to emulate. default: "torch" for PyTorch compatibility.
executors: list of executors to use. Defaults to the executors returned by :func:`thunder.extend.get_default_executors` and always amended by :func:`thunder.extend.get_always_executors`.
You can get a list of all available executors with :func:`thunder.get_all_executors`. You can also pass the name of an executor that's been registered, and it will be resolved with :func:`thunder.extend.get_executor`.
sharp_edges: sharp edge detection action. What to do when thunder detects a construct that is likely to lead to errors. Can be ``"allow"``, ``"warn"``, ``"error"``. Defaults to ``"allow"``.
cache: caching mode. default: ``"constant values"```
- ``"no caching"`` - disable caching and always recompute,
- ``"constant values"`` - require Tensors to be of the same shape, device, dtype etc., and integers and strings to match exactly,
- ``"same input"`` - don't check, but just assume that a cached function works if it exists.
interpretation: (deprecated: don't use this, use the thunder.functional.jit entry point to get the functional jit)
early_transforms: List of transforms to be applied to prologue, computation, and epilogue traces before executing the prologue. It should be an instance :class:`thunder.core.transforms.EarlyTransform`. Default: ``None``
transforms: List of transforms to be applied to the computation trace. It should be an instance :class:`thunder.core.transforms.AdditionalTransform`. Default: ``None``
post_optimization_transforms: List of transforms to be applied to the optimized computation traces i.e. forward and backward traces. It should be an instance :class:`thunder.core.transforms.PostOptimizationTransform`. Default: ``None``
"""
if "executors_list" in compile_options:
warnings.warn("outdated argument executors_list= in call, please use executors=")
if executors is None:
executors = compile_options.pop("executors_list")
# Resolves interpreter option
interpretation = resolve_interpretation_option(interpretation)
interpreter: Callable
if interpretation is INTERPRETATION_OPTIONS.PYTHON_INTERPRETER:
interpreter = functional._python_interpreter
elif interpretation is INTERPRETATION_OPTIONS.TRANSLATE_FUNCTIONS:
interpreter = functional._translate_functions_interpreter
elif interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON:
interpreter = _general_frontend
if early_transforms is None:
early_transforms = []
if additional_transforms is None:
additional_transforms = []
if post_optimization_transforms is None:
post_optimization_transforms = []
# Resolve names of executors
executors = resolve_executors(executors)
# TODO: verify that tutorials don't have false positives and enable warning by default
# # Make sharp_edges == warn default if not supplied and if in the general jit
# if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON and sharp_edges is None:
# sharp_edges = SHARP_EDGES_OPTIONS.WARN
executor_lookasides = {}
for ex in executors:
# TODO: sharp edge if lookasides are shadowed?
executor_lookasides.update(ex._lookasides)
assert type(record_history) is bool
# TODO RC1 Refine the compile data option to remove unused options
# TODO: refine options
# NOTE(fixme): use_cudagraphs is being absorbed into compile_options
use_cudagraphs = compile_options.get("use_cudagraphs", False)
cd = CompileData(
fn=fn,
langctx=langctx,
executors_list=executors,
cache_option=cache,
sharp_edges=sharp_edges,
using_jit=True,
use_cudagraphs=use_cudagraphs,
disable_torch_autograd_support=disable_torch_autograd,
use_rematerialization=False,
only_execute_prims=False,
disable_preprocessing=True,
compile_options=compile_options,
executor_lookasides=executor_lookasides,
)
cs = CompileStats()
@_with_cache_info_ctx
def get_computation_and_inputs(*args, **kwargs):
# set up a record of things in the current environment that impact caching / prologues
# this could be replaced by the respective querying in the prologues
cache_info = _get_cache_info()
# autocast related operations
is_autocast_enabled = False
if pytorch.is_autocast_enabled() or pytorch.is_autocast_cpu_enabled():
if pytorch.is_autocast_enabled() and pytorch.is_autocast_cpu_enabled():
raise NotImplementedError(
"thunder.autocast does not support torch.is_autocast_enabled() and torch.is_autocast_cpu_enabled() simultaneously at this moment."
)
is_autocast_enabled = True
autocast_gpu_dtype = dtypes.to_dtype(pytorch.get_autocast_gpu_dtype())
autocast_cpu_dtype = dtypes.to_dtype(pytorch.get_autocast_cpu_dtype())
cache_info.update(
autocast_config_torch_enabled=pytorch.is_autocast_enabled(),
autocast_config_torch_cpu_enabled=pytorch.is_autocast_cpu_enabled(),
autocast_gpu_dtype=str(autocast_gpu_dtype),
autocast_cpu_dtype=str(autocast_cpu_dtype),
)
autocast_thunder_dtype = autocast_cpu_dtype if pytorch.is_autocast_cpu_enabled() else autocast_gpu_dtype
cache_info["is_autocast_enabled"] = is_autocast_enabled
is_ddp_enabled = getattr(fn, "use_ddp", False)
is_fsdp_enabled = getattr(fn, "use_fsdp", False)
no_grad_sync = False
if is_ddp_enabled or is_fsdp_enabled:
from thunder.distributed import get_skip_data_parallel_grad_sync
no_grad_sync = get_skip_data_parallel_grad_sync()
cache_info["no_grad_sync"] = no_grad_sync
return_none_instead_of_grads = is_fsdp_enabled and no_grad_sync
# TODO RC1 Add module and function checks to prologue (make it a compile option)
# Checks cache
cs.last_trace_cache_start = time.time_ns()
if (cd.cache_option is CACHE_OPTIONS.CONSTANT_VALUES) or (cd.cache_option is CACHE_OPTIONS.SYMBOLIC_VALUES):
for cache_entry in reversed(cs.interpreter_cache):
with compile_data_and_stats(cd, cs):
(
pro,
pro_traces,
comp,
comp_traces,
epilogue,
epilogue_traces,
backward_fn,
backward_traces,
_return_none_instead_of_grads,
) = cache_entry
try:
cs.last_prologue_execution_start = time.time_ns()
if epilogue:
inps, pro_to_epi = pro(*args, **kwargs)
else:
inps = pro(*args, **kwargs)
pro_to_epi = None
cs.last_prologue_execution_stop = time.time_ns()
except Exception as _:
continue
cs.last_trace_host_tracing_start = time.time_ns()
cs.last_trace_host_tracing_stop = time.time_ns()
# Updates cache statistics
cs.cache_hits += 1
cs.last_traces = comp_traces
cs.last_interpreted_instructions = None
cs.last_interpreter_log = None
cs.last_prologue_traces = pro_traces
cs.last_prologue = pro
cs.last_prologue_transformation_start = 0
cs.last_prologue_transformation_stop = 0
cs.last_computation_transformation_start = 0
cs.last_computation_transformation_stop = 0
return cache_entry, inps, pro_to_epi
if cd.cache_option is CACHE_OPTIONS.SAME_INPUT:
if len(cs.interpreter_cache):
cache_entry = cs.interpreter_cache[0]
(
pro,
pro_traces,
comp,
comp_traces,
epilogue,
epilogue_traces,
backward_fn,
backward_traces,
) = cache_entry
cs.last_prologue_execution_start = time.time_ns()
if epilogue:
inps, pro_to_epi = pro(*args, **kwargs)
else:
inps = pro(*args, **kwargs)
pro_to_epi = None
cs.last_prologue_execution_stop = time.time_ns()
cs.last_trace_host_tracing_start = time.time_ns()
cs.last_trace_host_tracing_stop = time.time_ns()
# Updates cache statistics
cs.cache_hits += 1
cs.last_traces = comp_traces
cs.last_interpreted_instructions = None
cs.last_interpreter_log = None
cs.last_prologue_traces = pro_traces
cs.last_prologue = pro
return cache_entry, inps, pro_to_epi
cs.cache_misses += 1
cs.last_trace_cache_stop = time.time_ns()
# Resets use of compile flags
cs.last_compile_reasons = defaultdict(list)
with compile_data_and_stats(cd, cs):
# Acquires the trace OR inlines the trace into an existing trace and
# returns the (proxied) result of the operation
cs.last_trace_tracing_start = time.time_ns()
with langctxs.langctx(cd.langctx):
prologue_trc: TraceCtx
computation_trc: TraceCtx
jit_results: TraceResults = interpreter(
fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
)
prologue_trc = jit_results.prologue_trace
computation_trc = jit_results.computation_trace
epilogue_trc = jit_results.epilogue_trace
last_interpreter_log = jit_results.interpreter_log
prologue_traces = [prologue_trc]
computation_traces = [computation_trc]
if epilogue_trc is not None:
epilogue_traces = [epilogue_trc]
else:
epilogue_traces = None
cs.last_trace_tracing_stop = time.time_ns()
# Makes the prologue callable
cs.last_prologue_transformation_start = time.time_ns()
transform: Callable
for transform in early_transforms:
thunder.core.utils.check_type(transform, EarlyTransform)
prologue_trc, computation_trc, epilogue_trc = transform.transform_traces(
prologue_trc, computation_trc, epilogue_trc, executors_list=cd.executors_list
)
prologue_traces.append(prologue_trc)
computation_traces.append(computation_trc)
if epilogue_trc is not None:
epilogue_traces.append(epilogue_trc)
prologue_traces += transform_for_execution(
prologue_trc,
executors_list=(pythonex,),
use_del_last_used=False,
)
protrace = prologue_traces[-1]
pro = protrace.python_callable()
if epilogue_trc is not None:
epilogue = epilogue_trc.python_callable()
else:
epilogue = None
cs.last_prologue_transformation_stop = time.time_ns()
cs.last_prologue_execution_start = time.time_ns()
if epilogue:
inps, pro_to_epi = pro(*args, **kwargs)
else:
inps = pro(*args, **kwargs)
pro_to_epi = None
cs.last_prologue_execution_stop = time.time_ns()
cs.last_traces = computation_traces
backward_traces = []
cs.last_backward_traces = backward_traces
cs.last_interpreter_log = last_interpreter_log
cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction))
computation_trc = dce(computation_trc)
computation_traces.append(computation_trc)
if is_autocast_enabled:
from thunder.core.transforms import autocast
computation_trc = trace(compile_data=cd)(
autocast(computation_trc.python_callable(), dtype=autocast_thunder_dtype), *inps
)
computation_traces.append(computation_trc)
backward_trc = None
if not cd.disable_torch_autograd_support:
tensor_cls = (pytorch.Tensor, TensorProxy)
requires_grad = any(isinstance(arg, tensor_cls) and arg.requires_grad for arg in inps)
if requires_grad:
# Currently split_forward_backward also includes
# transform_for_execution and various sorting of symbols,
# applying transform_for_execution after this would be
# breaking the order of operations
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
# by split_forward_backward
extraces = cs.last_traces
if backward_trc is None:
## EPILOGUE and TRANSFORMS should not mix...
# applies transforms
cs.last_computation_transformation_start = time.time_ns()
for transform in additional_transforms:
thunder.core.utils.check_type(transform, AdditionalTransform)
computation_trc = transform.transform_trace(computation_trc, executors_list=cd.executors_list)
computation_traces.append(computation_trc)
cs.last_computation_transformation_stop = time.time_ns()
with langctxs.langctx(cd.langctx):
extraces = transform_for_execution(
computation_trc,
executors_list=cd.executors_list,
)
computation_trc = extraces[-1]
if not compile_options.get("disable_inplace_copy_check", False):
thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
for transform in post_optimization_transforms:
# NOTE: `backward_trc` could be None.
thunder.core.utils.check_type(transform, PostOptimizationTransform)
computation_trc = transform.transform_trace(computation_trc, executors_list=cd.executors_list)
extraces.append(computation_trc)
if backward_trc is not None:
backward_trc = transform.transform_trace(backward_trc, executors_list=cd.executors_list)
backward_traces.append(backward_trc)
comp = computation_trc.python_callable()
if backward_trc is not None:
backward_fn = backward_trc.python_callable()
else:
backward_fn = None
# TODO: using vanilla CUDAGraphExecutor is not safe unless the graph is always static!
# (fixme): inspect torch.cuda.make_graph_callables and/or use it instead!
# See https://github.com/Lightning-AI/lightning-thunder/issues/433
if cd.use_cudagraphs:
comp = CUDAGraphExecutor(comp)
if backward_fn is not None:
backward_fn = CUDAGraphExecutor(backward_fn, num_constant_args=len(backward_trc.args[0][0]))
# TODO RC1 Update the cache
cache_entry = CacheEntry(
pro,
prologue_traces,
comp,
extraces,
epilogue,
epilogue_traces,
backward_fn,
backward_traces,
return_none_instead_of_grads,
)
if cd.cache_option is not CACHE_OPTIONS.NO_CACHING:
cs.interpreter_cache.append(cache_entry)
cs.last_traces += extraces
cs.last_prologue_traces = [prologue_trc] + prologue_traces
cs.last_prologue = pro
return cache_entry, inps, pro_to_epi
cd.get_computation_and_inputs = get_computation_and_inputs
@wraps(fn)
def fn_(*args, **kwargs) -> Any:
if is_tracing():
_recursive_jit_call_warning()
return fn(*args, **kwargs)
# Updats call statistics
cs.last_trace_host_start = time.time_ns()
cs.calls += 1
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
cs.last_trace_host_execution_start = time.time_ns()
result = cache_entry.computation_fn(*inps)
if cache_entry.backward_fn:
# Run the compiled forward function
data_for_autograd, (saved_tensors, saved_other) = result
# Connect produced tensors with PyTorch's autograd graph
ThunderFunction.apply(
cache_entry.return_none_instead_of_grads,
cache_entry.backward_fn,
saved_tensors,
saved_other,
data_for_autograd["flat_output"],
*data_for_autograd["flat_args"],
)
result = data_for_autograd["output"]
if cache_entry.epilogue_fn:
result, comp_to_epi = result
cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi)
cs.last_trace_host_execution_stop = time.time_ns()
cs.last_computation_execution_stop = cs.last_trace_host_execution_stop
cs.last_executed = cache_entry.computation_fn
cs.last_trace_cache_stop = time.time_ns()
cs.last_trace_host_stop = time.time_ns()
return result
if isinstance(fn, pytorch.nn.Module):
fn_ = ThunderModule(fn, fn_)
cd._thunder_module_map[id(fn)] = fn_
# Sets compile options and statistics attributes
cd._get_computation_and_inputs = get_computation_and_inputs
fn_._lc_cd = cd
fn_._lc_cs = cs
fn_._lc_early_transforms = early_transforms[:] ## transforms
fn_._lc_transforms = additional_transforms[:] ## transforms
fn_._lc_post_optimization_transforms = post_optimization_transforms[:] ## post_optimization_transforms
return fn_
def compile(
fn: Callable,
*,
langctx: None | Any = None,
executors_list: None | Sequence[Executor] = None,
cache_mode: None | str | CACHE_OPTIONS = None,
use_cudagraphs: bool = False,
disable_torch_autograd_support: bool = False,
use_rematerialization: bool = False,
only_execute_prims: bool = False,
disable_preprocessing: bool = False,
**kwargs,
) -> Callable:
cd = CompileData(
fn=fn,
langctx=langctx,
executors_list=executors_list,
cache_option=cache_mode,
use_cudagraphs=use_cudagraphs,
disable_torch_autograd_support=disable_torch_autograd_support,
use_rematerialization=use_rematerialization,
only_execute_prims=only_execute_prims,
disable_preprocessing=disable_preprocessing,
compile_options=kwargs,
)
cs = CompileStats()
_fn = _create_callable(cd, cs)
return _fn
def compile_data(fn) -> CompileData | None:
"""Obtains the compilation data from a JITed function.
The compile data (:class:`thunder.common.CompileData`) contains information about how the JIT has been configured
for compilation (including referencing the function or module that is being compiled).
"""
return getattr(fn, "_lc_cd", None)
def compile_stats(fn) -> CompileStats | None:
"""Obtains the compilation statistics from a JITed function.
The compilation statistics (:class:`thunder.common.CompileStats`) contain information about each compilation run -
collected when a JITed function is called for the first time or with previously unseen state.
This includes the cache of traces (pologues, computation, possibly backward and epilogue) and
how they have been transformed and information about cache hits and misses and timings.
"""
return getattr(fn, "_lc_cs", None)
def last_traces(fn) -> list[TraceCtx]:
"""Obtains the list of computation traces that have been produced for the last run of the function. This is a list
of traces mirroring the progression of transformations being applied to the trace (at index 0) that has
been acquired from interpreting the user program.
If the function has forward and backward, the forward is returned.
"""
cs = compile_stats(fn)
if cs is None:
raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.")
if cs.last_traces is None:
raise TypeError(f"{fn} doesn't seem to have been called yet.")
return cs.last_traces
def last_backward_traces(fn) -> list[TraceCtx]:
"""Obtains the list of backward traces that have been produced for the last run of the function and the selected prologue."""
cs = compile_stats(fn)
if cs is None:
raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.")
if cs.last_backward_traces is None:
raise TypeError(f"{fn} doesn't seem to have been called yet.")
return cs.last_backward_traces
def last_prologue_traces(fn) -> TraceCtx:
"""Obtains the list of prologue traces that have been produced for the last run of the function and the selected prologue."""
cs = compile_stats(fn)
if cs is None:
raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.")
if cs.last_prologue_traces is None:
raise TypeError(f"{fn} doesn't seem to have been called yet.")
return cs.last_prologue_traces
def cache_option(fn) -> CACHE_OPTIONS:
"""Returns the cache options set when JITting the function."""
cd = compile_data(fn)
if cd is None:
raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.")
return cd.cache_option
def cache_hits(fn) -> int:
"""Returns the number of cache hits we found when running the function."""
cs = compile_stats(fn)
if cs is None:
raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.")
return cs.cache_hits
def cache_misses(fn) -> int:
"""Returns the number of cache misses we found when running the function."""
cs = compile_stats(fn)
if cs is None:
raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.")
return cs.cache_misses
def list_transforms(fn) -> list:
"""Returns the list of (explicit) transforms applied to the JITed function."""
return fn._lc_transforms
def last_interpreter_log(fn: Callable) -> list[InterpreterLogItem]:
"""Returns the list of instructions and other information the interpreter encountered while tracing through the
user program (on the last cache miss).
"""
cs = compile_stats(fn)
if cs is None:
raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.")
if cs.last_interpreter_log is None:
raise TypeError(f"{fn} doesn't seem to have been called yet.")
return cs.last_interpreter_log
def last_interpreted_instructions(fn: Callable) -> list[dis.Instruction]:
"""Returns the list of instructions the interpreter encountered while tracing through the
user program (on the last cache miss).
"""
cs = compile_stats(fn)
if cs is None:
raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.")
if cs.last_interpreted_instructions is None:
raise TypeError(f"{fn} doesn't seem to have been called yet.")
return list(cs.last_interpreted_instructions)
def print_last_interpreter_log(
fn: Callable,
/,
print_fn: Callable = print,
use_colors: bool = True,
indent: bool = True,
max_depth: int | None = None,
color_internals: bool = False,
print_source_code: bool = True,
) -> None:
"""Prints a log of the last run of the interpreter for the given function.
Args:
fn: The function returned by :func:`thunder.jit` to print the last interpreter run log for. The function must have been called at least once first.
print_fn: The function to use for printing. Defaults to builtin `print`.
use_colors: Whether to use colors in the output. Defaults to `None`, which attempts to autodetect if the terminal supports ANSI color.
indent: Whether to indent the output with function scope. Defaults to :obj:`True`.
max_depth: The maximum indentation depth of the output. Doesn't print log items nested deeper than the max depth. Defaults to :obj:`None`, which means no limit.
color_internals: Whether to color instructions implicitly interpreted by other instructions. Defaults to :obj:`False`, so that only the instructions in the user's code are highlighted in color.
print_source_code: Whether to print the source line below each LineLogItem in the log. Defaults to :obj:`True`.
"""
log = last_interpreter_log(fn)
print_interpreter_log(
log,
print_fn=print_fn,
use_colors=use_colors,
indent=indent,
max_depth=max_depth,
color_internals=color_internals,
print_source_code=print_source_code,
)
def last_compile_options(fn: Callable, /) -> None:
"""Prints how compiled options were used (or not)"""
cd = compile_data(fn)
cs = compile_stats(fn)
# NOTE Different categories of compile options
# Specified and Queried --- in cs.last_compile_reasons and cd.compile_options
# Queried but not Specified --- in cs.last_compile_reasons but not in cd.compile_options (not printed)
# Specified but not Queried --- in cd.compile_options but not in cs.last_compile_reasons
specified: set = set(cd.compile_options.keys())
queried: set = set(cs.last_compile_reasons.keys())
# Prints used options
print("Used compile options:")
used = specified & queried
if len(used) == 0:
print("\tNo used options")
for option in used:
reasons = set(cs.last_compile_reasons[option])
for reason in reasons:
print(f"\t{option}. {reason}")
# Prints unused options
print("Unused compile options:")
unused: set = specified - queried
if len(unused) == 0:
print("\tNo unused options")
for option in unused:
print(f"\t{option}")
# TODO (mruberry) Update this
def _grad_transform(trace):
grad_fwd_trace = from_trace(trace)
trace_tok = set_tracectx(grad_fwd_trace)
all_residuals = []
# Constructs grad fwd and records info
# TODO: make recursive (or iterative, whatever)
current_inputs = grad_fwd_trace.args
for bsym in trace.bound_symbols:
grad_defined = bsym.sym.grad_defined
grad_ignored = bsym.sym.grad_ignored
grad_fwd, grad_bwd = bsym.sym.grad_fwd, bsym.sym.grad_bwd
if not grad_defined:
raise NotImplementedError
# Constructs the new grad_fwd symbol, which returns the primals and residuals
if grad_fwd is None:
fw_result = bsym.sym(*current_inputs)
residuals = None
all_residuals.append(residuals)
current_inputs = fw_result if isinstance(fw_result, Sequence) else (fw_result,)
continue
fw_result, residuals = grad_fwd(*current_inputs)
all_residuals.append(residuals)
current_inputs = fw_result if isinstance(fw_result, Sequence) else (fw_result,)
# Constructs bwd part of the program
current_grads = (prims.full(o.shape, 1.0, device=o.device, dtype=o.dtype) for o in fw_result)
for bsym, residuals in zip(reversed(trace.bound_symbols), reversed(all_residuals)):
grad_fwd = bsym.sym.grad_fwd
grad_bwd = bsym.sym.grad_bwd
grad_defined = bsym.sym.grad_defined
if not grad_defined:
raise NotImplementedError(f"grad_bwd not defined for {bsym.sym}")
if grad_fwd is None:
continue
current_grads = grad_bwd(*current_grads, *residuals)
current_grads = (current_grads,) if not isinstance(current_grads, Sequence) else current_grads
grad_fwd_trace.output = current_grads
# Resets tracing context
reset_tracectx(trace_tok)
return grad_fwd_trace
# TODO Test nesting of grad and grad and grad and grad
# TODO Test nesting of a regular function + calling grad
def grad(fn):
cfn = compile(fn)
@wraps(cfn)
def _fn(*args, **kwargs):
original_result, original_trace = cfn(*args, **kwargs)
original_trace = last_traces(cfn)
gradir = _grad_transform(original_trace)
return original_result, original_trace
return _fn