# Edit Draft

The Edit module is designed for dynamically editing models during runtime. Under the hood, it uses TorchDynamo to recompile PyTorch code into Torch FX graphs. 

**TorchDynamo** is a Python-level Just-In-Time (JIT) compiler designed to make unmodified PyTorch programs faster. TorchDynamo hooks into the frame evaluation API in CPython ([PEP 523](https://peps.python.org/pep-0523/)) to dynamically modify Python bytecode right before it is executed. It rewrites Python bytecode to extract sequences of PyTorch operations into an [FX Graph](https://pytorch.org/docs/stable/fx.html) which is then compiled with a customizable backend. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends to get the best of both worlds — usability and performance. (*[TorchDynamo Deep Dive](https://pytorch.org/docs/stable/torch.compiler_deepdive.html)*)

This notebook will walk through the basic elements of the Edit module.

## Setup (Ignore)

In [1]:
from typing import List

import torch
import torch.nn as nn

from nnsight import LanguageModel
from nnsight.util import WrapperModule
from nnsight.edit import print_gm, Edit

  from .autonotebook import tqdm as notebook_tqdm


## 1 - Simple Example

We'll start with a (relatively) simple torch model to demonstrate how operations are translated into a Torch FX graph. Two things to take note of. 

First, we declare a simple module `WrappedLayer` to observe what TorchDynamo does when it runs into user defined Torch modules. 

Second, we declare a couple functions and methods `split` and `x * 100` to illustrate the different operations in a Torch FX graph.

In [2]:
class WrappedLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1, 1)

    def forward(self, x):
        x = self.layer1(x)
        x = x * 100
        return x

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1, 1)
        self.wrapped = WrappedLayer()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = self.layer1(x)
        x = self.wrapped(x)
        x = self.dropout(x)
        x = x.split(1, dim=-1)
        return x

mod = M()

input_tensor = torch.tensor([[1.0]])
output = mod(input_tensor)
print(output)

(tensor([[-195.7624]], grad_fn=<SplitBackward0>),)


While Dynamo is the backbone of our Edit module, we'll largely interface with it through the Torch Compile method. It uses Dynamo under the hood to JIT compile arbitrary Python code. Passing a custom backend is an easy interface for editing and viewing FX graphs. 

In [3]:
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    
    # Example inputs is a list of torch tensors that represent the input to the model
    # print(example_inputs)

    gm.graph.print_tabular()

    return gm.forward

torch._dynamo.reset()

opt_model = torch.compile(mod, backend=custom_backend, dynamic=True)
gm = opt_model(torch.tensor([[1.0]]))

opcode         name     target                       args           kwargs
-------------  -------  ---------------------------  -------------  -----------
placeholder    l_x_     L_x_                         ()             {}
call_module    x        L__self___layer1             (l_x_,)        {}
call_module    x_1      L__self___wrapped_layer1     (x,)           {}
call_function  x_3      <built-in function mul>      (x_1, 100)     {}
call_module    x_4      L__self___dropout            (x_3,)         {}
call_method    split    split                        (x_4, 1)       {'dim': -1}
call_function  getitem  <built-in function getitem>  (split, 0)     {}
output         output   output                       ((getitem,),)  {}


Notice how the functions and components we declared in the module are translated into nodes and their respective operations in an FX graph. Dynamo will trace through user defined modules such as WrappedLayer, breaking apart the operations on its forward pass into separate nodes on the FX graph.

There are several elementary operations. 
- `placeholder` represents a function input.
- `get_attr` retrieves a parameter from the module hierarchy. 
- `call_function` applies a free function to some values.
- `call_module` applies a module in the module hierarchy’s `forward()` method to given arguments. 
- `call_method` calls a method on a value.
- `output` contains the output of the traced function in its `args[0]` attribute.

Now, let's see what happens if we load this module into NNsight and trace it. Note how we call `torch._dynamo.reset()` to clear existing backends and optimizations on this module. 

In [4]:
from nnsight import NNsight

nn_model = NNsight(mod)

torch._dynamo.reset()

opt_model = torch.compile(nn_model._model, backend=custom_backend, dynamic=True)
gm = opt_model(torch.tensor([[1.0]]))

opcode         name    target                   args       kwargs
-------------  ------  -----------------------  ---------  --------
placeholder    x       L_stack0_                ()         {}
call_function  x_1     <built-in function mul>  (x, 100)   {}
output         output  output                   ((x_1,),)  {}
opcode         name     target                       args           kwargs
-------------  -------  ---------------------------  -------------  -----------
placeholder    x        L_stack0_                    ()             {}
call_method    split    split                        (x, 1)         {'dim': -1}
call_function  getitem  <built-in function getitem>  (split, 0)     {}
output         output   output                       ((getitem,),)  {}


After loading our model with NNsight, we find that our Dynamo has produced two separate graphs. When TorchDynamo encounters unsupported Python features, such as data-dependent control flow, it breaks the computation graph, lets the default Python interpreter handle the unsupported code, then resumes capturing the graph. (*From [TorchDynamo Deep Dive](https://pytorch.org/docs/stable/torch.compiler_deepdive.html)*)

We can see where TorchDynamo breaks the graph by using `torch._dynamo.explain`:

In [5]:
torch._dynamo.reset()
explain_output = torch._dynamo.explain(nn_model._model)(torch.tensor([[1.0]]))
print(explain_output)

Graph Count: 2
Graph Break Count: 1
Op Count: 2
Break Reasons:
Ops per Graph:
  Ops 1:
    <built-in function mul>
  Ops 2:
    <built-in function getitem>
Out Guards:
  Guard 1:
    Name: "L['self'].layer1"
    Source: local_nn_module
    Create Function: TYPE_MATCH
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 2:
    Name: ''
    Source: global
    Create Function: BACKEND_MATCH
    Guard Types: ['BACKEND_MATCH']
    Code List: ['(___skip_backend_check() or ___current_backend() == ___lookup_backend(140109342776896))']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 3:
    Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 4:
    Name: "L['x']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['x'], '_dynamo_dynamic_indices') == Fals

It looks like the graph broke once on the multiply `x = x * 100`. We already knew this by looking at the broken graph. 

We can force TorchDynamo to raise an error upon the first graph break encountered by using `fullgraph=True`. The stack trace will provide more details on exactly what is breaking our graph.

In [6]:
import traceback as tb

opt_bar = torch.compile(nn_model._model, fullgraph=True)
try:
    opt_bar(torch.tensor([[1.0]]))
except:
    tb.print_exc()

Traceback (most recent call last):
  File "/tmp/ipykernel_2903989/1892653539.py", line 5, in <module>
    opt_bar(torch.tensor([[1.0]]))
  File "/share/u/caden/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/u/caden/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/u/caden/.local/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/share/u/caden/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/u/caden/.local/lib/python3.11/site-packages/torch/nn/modules/module

Expanding the error trace reveals this line toward the end. 

```
torch._dynamo.exc.Unsupported: call_function UserDefinedObjectVariable(_hook) [UnspecializedNNModuleVariable(Linear), TupleVariable(), ConstDictVariable(), TensorVariable()] {}
```

This message indicates Dynamo ran into an unsupported Python feature - some forward_hook - and broke the graph. 

We can remove NNsight hooks by accessing the underlying `._envoy` and clearing the hooks with `.clear_hooks(propagate=True)`. Propagate tells NNsight to remove the hooks of an envoy's sub_envoys too.

In [7]:
nn_model._envoy.clear_hooks(propagate=True)

torch._dynamo.reset()

opt_model = torch.compile(nn_model._model, backend=custom_backend, dynamic=True)
gm = opt_model(torch.tensor([[1.0]]))

opcode         name                             target                                args                                                                 kwargs
-------------  -------------------------------  ------------------------------------  -------------------------------------------------------------------  -----------
placeholder    l_x_                             L_x_                                  ()                                                                   {}
get_attr       l__self___layer1_weight          L__self___layer1_weight               ()                                                                   {}
get_attr       l__self___layer1_bias            L__self___layer1_bias                 ()                                                                   {}
call_function  x                                <built-in function linear>            (l_x_, l__self___layer1_weight, l__self___layer1_bias)               {}
get_attr       l__self___wrapped_layer1

## 2 - Intervening on the FX Graph

TorchDynamo is a powerful tool for compiling torch modules to improve performance and efficiency at scale. 

https://depyf.readthedocs.io/en/latest/walk_through.html


What if we used TorchCompile to attach arbitrary modules at any point in an existing module's computation? There are a couple obvious benefits: 

1. Edit models to access arbitrary attributes that aren't normally available.
2. Add modules such as dictionaries or lora weights and access the hidden states of those modules - on a forward or backward pass - with hooks. 
3. We can just host one module on NDIF and use Torch compile to recompile existing modules. Compile simply returns an optimized module wrapper over the existing module, so we don't have to host multiple models.

Let's declare a simple model to see how we can wrap one of its attributes below. 

In [8]:
class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1, 1)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = self.layer1(x)

        # We want to wrap `value`
        # So we can intervene on it with NNsight hooks
        value = x[:,0]
        x = x * value

        x = self.dropout(x)
        x = x.split(1, dim=-1)
        return x

mod = M()

input_tensor = torch.tensor([[1.0]])
output = mod(input_tensor)
print(output)

(tensor([[1.8184]], grad_fn=<SplitBackward0>),)


Suppose we'd like to access the `value` attribute. We wouldn't normally be able to do this with hooks because its not declared as a module attribute.

We create the `WrapperModule` class which just passes an input through itself. By setting it as an attribute of the parent module, we can access the input and output of this wrapper with hooks.

In [9]:
class WrapperModule(torch.nn.Module):
    def forward(self, *args, **kwargs):
        if len(args) == 1:
            args = args[0]

        return args
    
wrapper_module = WrapperModule()
wrapper_name = 'value_wrapper'

setattr(mod, wrapper_name, wrapper_module)
print(mod)

M(
  (layer1): Linear(in_features=1, out_features=1, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (value_wrapper): WrapperModule()
)


To figure out at what node to insert the model, we can compile and print the graph module. This returns the recompiled bytecode from TorchDynamo.

In [10]:
def custom_backend(gm: torch.fx.GraphModule, _: List[torch.Tensor]):
    print(gm)
    return gm.forward

torch._dynamo.reset()
opt_model = torch.compile(mod, backend=custom_backend, dynamic=True, fullgraph=True)
gm = opt_model(torch.tensor([[1.0]]))

GraphModule()



def forward(self, L_x_ : torch.Tensor):
    l_x_ = L_x_
    l__self___layer1_weight = self.L__self___layer1_weight
    l__self___layer1_bias = self.L__self___layer1_bias
    x = torch._C._nn.linear(l_x_, l__self___layer1_weight, l__self___layer1_bias);  l_x_ = l__self___layer1_weight = l__self___layer1_bias = None
    value = x[(slice(None, None, None), 0)]
    x_1 = x * value;  x = value = None
    x_2 = torch.nn.functional.dropout(x_1, 0.1, True, False);  x_1 = None
    split = x_2.split(1, dim = -1);  x_2 = None
    getitem_1 = split[0];  split = None
    return (getitem_1,)
    
# To see more debug info, please use `graph_module.print_readable()`


```
value = x[(slice(None, None, None), 0)]
x_1 = x * value;  x = value = None
```

From these lines, we see that value is:
1. A node, with args `x` and `slice(...)` representing some `call_method` operation.
2. An argument to the node `x_1`. 

Let's try wrapping `value` as it's passeed as an arg into `x_1`.

In [11]:
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):

    if wrapper_name not in gm._modules:
        gm.add_submodule(wrapper_name, wrapper_module)

    for node in gm.graph.nodes:    
        if node.name == "value":
            with gm.graph.inserting_before(node):
                new = gm.graph.create_node(node.op, node.target, args=node.args, kwargs=node.kwargs)
                wrapper_node = gm.graph.call_module(wrapper_name, args=(new,))
                node.replace_all_uses_with(wrapper_node)
                gm.graph.erase_node(node)
                  
    gm.recompile()
    return gm.forward


torch._dynamo.reset()
opt_model = torch.compile(mod, backend=custom_backend, dynamic=True)
gm = opt_model(torch.tensor([[1.0]]))

A lot happened above, so let's go through it step by step. 

## 3 - Editing in NNsight

Behind the scenes, the Edits passed into an NNsight model are loaded into an NNsight Editor context manager and compiled with TorchDynamo.

In [12]:
model = LanguageModel("openai-community/gpt2", device_map="cuda:0", dispatch=True)

In [13]:
class EditModule(torch.nn.Module):

    def forward(self, *args, **kwargs):
        if len(args) == 1:
            args = args[0]

        value = args * 1000
        
        return value
    
edit = Edit(
    model._envoy.transformer.h[3].attn._module_path, 
    "value", 
    "value_wrapper",
    EditModule()
)

class WrapperModule(torch.nn.Module):
    """Simple torch module which passes it's input through. Useful for hooking.
    If there is only one argument, returns the first element.
    """

    def forward(self, *args, **kwargs):
        if len(args) == 1:
            args = args[0]

        return args
    
wrapper_edit = Edit(
    model._envoy.transformer.h[3].attn._module_path, 
    "query", 
    "query_wrapper",
    WrapperModule()
)



In [14]:
edits = [wrapper_edit]

model.load_edits(edits)

In [15]:
with model.trace("empty", scan=False, validate=False):
    query = model.transformer._orig_mod.h[3].attn.query_wrapper.output.save()

print(query)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


tensor([[[ 2.9474e-01, -2.9457e-01, -1.7623e-01, -4.0702e-01, -5.0194e-01,
           5.3164e-01, -1.0857e+00, -7.0488e-01, -1.0685e+00,  6.2466e-01,
          -2.2498e-01,  4.4690e-01, -9.4095e-01,  4.2125e-01, -6.0461e-01,
          -6.7794e-01, -1.6197e-01,  1.0716e+00,  1.4526e-01,  4.6085e-01,
           6.4771e-02,  5.5589e-01, -1.0751e-01,  1.2204e-02, -5.1981e-01,
          -5.8530e-01, -1.0644e-01,  6.0577e-01, -2.9360e-01,  8.3008e-01,
           3.0922e-01, -1.0348e-01,  6.6434e-01,  3.5581e-01, -1.0595e+00,
           5.0450e-01,  6.8544e-01, -2.0407e-01,  7.2883e-02,  1.1289e+00,
          -1.7641e-01,  4.3492e-01, -6.4675e-01,  3.5470e-01, -3.1740e-01,
          -8.8479e-01,  7.0754e-01,  1.5963e-01, -2.4453e-02,  1.4798e-01,
          -7.1829e-01, -4.3326e-01, -1.2205e-01,  7.7209e-01, -3.1324e-01,
          -1.5592e+00, -2.2029e-01, -5.4575e-01, -7.1221e-01, -7.0549e-01,
           6.1355e-01,  1.0617e-01,  1.1756e+00, -1.5728e-01, -3.3079e-01,
           5.2387e-01, -7

## 4 - Other Stuff

In [16]:
model = LanguageModel("openai-community/gpt2", device_map="cuda:0", dispatch=True)

with model.trace("empty"):
    pass

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [17]:
attn_envoy = model._envoy.transformer.h[3].mlp.c_proj

print_gm(attn_envoy)