## Thunder as torch.compile backend (ThunderFX) Tutorial

In this tutorial, we’ll explore how to use Thunder as a backend for `torch.compile`, and demonstrate the tools to inspect the compiling process.

#### Introduction

Starting with PyTorch2.0, the `torch.compile` feature introduces a powerful way to optimize and accelerate the PyTorch models. As its core, `torch.compile` relies on the following key components:
1. TorchDynamo - A Python-level tracing tool that transforms Python function calls into an intermediate representation(IR)
2. Backends - Systems that further process the IR, optimizing and executing the computational graph for better performance.

While PyTorch provides several built-in backends such as "inductor" and "cudagraphs", it also supports custom backends that allow users to define their own optimization strategies. Thunder as a deep learning compiler can either be used on its own to accelerate model performance (see the [Thunder overview](https://lightning-thunder.readthedocs.io/en/latest/basic/overview.html) and other tutorials for more details) or also integrate with `torch.compile` as a backend. This is possible because TorchDynamo transforms the original Python code into new, optimized Python code that represents the same computation, which Thunder can directly process.

For more information on `torch.compile`, we recommend reading PyTorch documentation and tutorials:

1. Introduction to torch.compile - [Link](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)
2. Docs of torch.compile - [Link](https://pytorch.org/docs/stable/generated/torch.compile.html)

#### Example Usage

By simply specifying the `backend` argument as `ThunderCompiler`, we can seamlessly use `torch.compile` with Thunder as the backend.

In [1]:
import torch
from thunder.dynamo import ThunderCompiler

def foo(x, y):
    a = torch.sin(x)
    return a + torch.sinc(a) + torch.cos(y)

# Create the ThunderCompiler backend
backend = ThunderCompiler()
# Pass the ThunderCompiler backend to torch.compile by using the backend argument.
opt_foo1 = torch.compile(foo, backend=backend)
# Run the compiled model as you normally would
print(opt_foo1(torch.randn(4, 4, requires_grad=True, device="cuda"), torch.randn(4, 4, requires_grad=True, device="cuda")))



tensor([[ 2.0454,  0.9373,  1.3031,  1.6171],
        [ 2.0270,  1.5227,  1.4768,  0.2534],
        [-0.4196,  1.7928,  1.9140,  0.5584],
        [ 1.9205, -0.8348, -0.0268,  2.0556]], device='cuda:0',
       grad_fn=<ThunderFunctionBackward>)


#### Implementation Details and Debugging

Now Let’s dive into the [FX graphs](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph) generated by TorchDynamo and explore how Thunder processes them.

##### Exploring FX Graphs Generated by TorchDynamo

TorchDynamo transforms Python functions into FX graphs. It can segment computations into smaller subgraphs to handle dynamic behavior or unsupported operations, allowing parts of the code to fall back to native execution while optimizing supported segments. 

In our example, all operators in the `foo` function are supported, resulting in a single FX graph.

**NOTE**: For more information about TorchDynamo, refer to the official [Dynamo overview](https://pytorch.org/docs/stable/torch.compiler_dynamo_overview.html)

In [2]:
subgraph_infos = backend.subgraph_infos
print(f"TorchDynamo extracts {len(subgraph_infos)} FX graphs")
for graph_id, subgraph_info in enumerate(subgraph_infos):
    print(f"Graph {graph_id}:\n{subgraph_info.original_graph_module}\n")

TorchDynamo extracts 1 FX graphs
Graph 0:
GraphModule()



def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
    l_x_ = L_x_
    l_y_ = L_y_
    a = torch.sin(l_x_);  l_x_ = None
    sinc = torch.sinc(a)
    add = a + sinc;  a = sinc = None
    cos = torch.cos(l_y_);  l_y_ = None
    add_1 = add + cos;  add = cos = None
    return (add_1,)
    
# To see more debug info, please use `graph_module.print_readable()`



##### How `ThunderCompiler` Handles FX Graphs

The `ThunderCompiler` serves as the backend for torch.compile, processing the FX graph generated by TorchDynamo. If the graph contains regions unsupported by Thunder, ThunderCompiler splits the FX graph into smaller subgraphs. To achieve this, it leverages the [split module pass](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/split_module.py) provided by `torch.fx` to customize the rules of how to split the FX graph. `ThunderCompiler` implements its own [callback function](https://github.com/Lightning-AI/lightning-thunder/blob/75ba590708178bfe61b7ec2ed2d579d9edb7daa9/thunder/dynamo/splitter.py#L101-L135) to:
1. Split the FX graph into supported subgraph that is compiled and executed by Thunder
2. Send unsupported subgraphs to alternative execution path -- PyTorch’s Inductor.

Some common causes for graph splitting include:
1. Unsupported operators: when encounter operators that are not supported by Thunder
2. Compilation Errors: when exceptions occur while attempting to compile operators using Thunder. 

You can inspect the split reasons and review how the FX graph was split by accessing the `TorchCompiler.subgraph_infos` attribute.

Note that ThunderCompiler accepts `thunder.jit` options as keyword arguments to customize the compilation of subgraphs executed by Thunder. Similarly, `torch_inductor_options` options can be specified for subgraphs executed by Inductor.

In this example, the `sinc` operator is not yet supported by Thunder. As a result, the original FX graph is split into three parts. The first and third part is executed by Thunder. The second part contains the unsupported `sinc` operator and is executed by Inductor.

In [3]:
subgraph_info = subgraph_infos[0]
num_of_submodules = len(subgraph_info.submodule_to_compiled_functions)
num_of_thunder_modules = len(subgraph_info.thunder_compiled_fns)
print(f"Thunder spliter splits the graph into {num_of_submodules} subgraphs, in which {num_of_thunder_modules} subgraphs are run by Thunder")
print("The structure of the split graph:\n")
print(subgraph_info.split_graph_module)

for subgraph_id, (original_graph, compiled_graph)  in enumerate(subgraph_info.submodule_to_compiled_functions.items()):
    print(f"Subgraph {subgraph_id}:\n{original_graph}\n")

Thunder spliter splits the graph into 3 subgraphs, in which 2 subgraphs are run by Thunder
The structure of the split graph:

GraphModule(
  (thunder_0): ThunderModule(
    (_model): GraphModule()
  )
  (inductor_1): OptimizedModule(
    (_orig_mod): GraphModule()
  )
  (thunder_2): ThunderModule(
    (_model): GraphModule()
  )
)



def forward(self, l_x_ : torch.Tensor, l_y_ : torch.Tensor):
    thunder_0 = self.thunder_0(l_x_);  l_x_ = None
    inductor_1 = self.inductor_1(thunder_0)
    thunder_2 = self.thunder_2(thunder_0, inductor_1, l_y_);  thunder_0 = inductor_1 = l_y_ = None
    return (thunder_2,)
    
# To see more debug info, please use `graph_module.print_readable()`
Subgraph 0:
GraphModule()



def forward(self, l_x_ : torch.Tensor):
    a = torch.sin(l_x_);  l_x_ = None
    return a
    
# To see more debug info, please use `graph_module.print_readable()`

Subgraph 1:
GraphModule()



def forward(self, a):
    sinc = torch.sinc(a);  a = None
    return sinc
    
# To see

To inspect why the original graph is split, we can print the split reasons:

In [4]:
for reason_id, split_reason in enumerate(subgraph_info.split_reasons):
    print(f"Split reason {reason_id}:\n{split_reason}\n")

Split reason 0:
SplitReason(reason_type=<SplitReasonType.MISSING_OP_SUPPORT: 2>, info='node with name: sinc and target: <built-in method sinc of type object at 0x78d5c7f853a0> only has an automatic torch fallback in thunder.', exception=None)



To inspect the Thunder trace of each subgraph, we can use the `thunder.last_traces` and `thunder.last_backward_traces` on the compiled module as usual:

In [None]:
import thunder

for subgraph_id, thunder_module  in enumerate(subgraph_info.thunder_compiled_fns):
    print(f"Subgraph {subgraph_id}:")
    print(f"Forward trace:\n{thunder.last_traces(thunder_module)[-1]}\n")
    print(f"Backward trace:\n{thunder.last_backward_traces(thunder_module)[-1]}\n")