# Debugging

## Disable JIT for Debugging

#### ```PYTORCH_JIT```

Setting the environment variable ```PYTORCH_JIT=0``` will disable all script and tracing annotations. If there is hard-to-debug error in one of the TorchScript models, you can use this flag to force everything to run native Python. Since TorchScript is disabled with this flag, we can use tools like ```pdb``` to debug the model code like this.

In [1]:
import torch

print(torch.__version__)

1.9.0+cu111


In [2]:
import pdb
    
@torch.jit.script
def scripted_fn(x: torch.Tensor):
    for i in range(12):
        x += x
    return x

def fn(x):
    x = torch.neg(x)
    pdb.set_trace()
    return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5)))
traced_fn(torch.rand(3, 4))

> [1;32m<ipython-input-2-6d7926d75035>[0m(12)[0;36mfn[1;34m()[0m
[1;32m     10 [1;33m    [0mx[0m [1;33m=[0m [0mtorch[0m[1;33m.[0m[0mneg[0m[1;33m([0m[0mx[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[0m[1;32m     11 [1;33m    [0mpdb[0m[1;33m.[0m[0mset_trace[0m[1;33m([0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[0m[1;32m---> 12 [1;33m    [1;32mreturn[0m [0mscripted_fn[0m[1;33m([0m[0mx[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[0m[1;32m     13 [1;33m[1;33m[0m[0m
[0m[1;32m     14 [1;33m[0mtraced_fn[0m [1;33m=[0m [0mtorch[0m[1;33m.[0m[0mjit[0m[1;33m.[0m[0mtrace[0m[1;33m([0m[0mfn[0m[1;33m,[0m [1;33m([0m[0mtorch[0m[1;33m.[0m[0mrand[0m[1;33m([0m[1;36m4[0m[1;33m,[0m [1;36m5[0m[1;33m)[0m[1;33m)[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[0m
ipdb> x
tensor([[-0.7684, -0.5101, -0.7906, -0.4205, -0.9304],
        [-0.9736, -0.8795, -0.9005, -0.3209, -0.9383],
        [-0.0357, -0.7428, -0.5549, -0.5360, -0.9390],
   

BdbQuit: 

Debugging this script with ```pdb``` works except for when we invoke the ```@torch.jit.script``` function. We can globally disable JIT, so that we can call the ```@torch.jit.script``` function as a normal Python function and not compile it.

To disable the TorchScript Compiler for a specific function then ```@torch.jit.ignore```

## Inspecting Code

TorchScript provides a pretty printer for all ```ScriptModule``` instances. This pprint gives an interpretation of the script method's code as a valid ```Python``` Syntax like this:

In [11]:
@torch.jit.script
def foo(len):
    # Type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv -= 1.0
        else:
            rv += 1.0
    return rv

print(foo.code)

RuntimeError: 
all inputs of range must be ints, found Tensor (inferred) in argument 0:
  File "<ipython-input-11-8c0445dd483c>", line 5
    # Type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
             ~~~~~~~~~ <--- HERE
        if i < 10:
            rv -= 1.0


A ```ScriptModule``` with a single ```forward``` method will have ```code``` attribute, which can be inspect ```ScriptModule```'s code. If the ScriptModule has more than one method, you will need to access ```.code``` on the method itself and not the module. We can inspect the code of a method named ```foo``` on a ```ScriptModule``` by accessing ```.foo.code```.

It produces a Python Syntax Output like this:

```python
def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0
```

## Interpreting Graphs

TorchScript also has a representation at a lower level than the code pprint, in form of IR graphs.

TorchScript uses a static single assignment (SSA) intermediate representation IR to represent computation. The instructions in this format consist of ```ATen``` (C++ backend of PyTorch) operators and other primitive operators, including control flow operators for loops and conditionals.

For Example:

In [12]:
@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph)

graph(%len.1 : int):
  %21 : int = prim::Constant[value=1]()
  %13 : bool = prim::Constant[value=1]() # <ipython-input-12-8a34e03747f9>:5:4
  %5 : NoneType = prim::Constant()
  %1 : int = prim::Constant[value=3]() # <ipython-input-12-8a34e03747f9>:4:21
  %2 : int = prim::Constant[value=4]() # <ipython-input-12-8a34e03747f9>:4:24
  %16 : int = prim::Constant[value=10]() # <ipython-input-12-8a34e03747f9>:6:15
  %20 : float = prim::Constant[value=1.]() # <ipython-input-12-8a34e03747f9>:7:22
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %5, %5, %5, %5) # <ipython-input-12-8a34e03747f9>:4:9
  %rv : Tensor = prim::Loop(%len.1, %13, %rv.1) # <ipython-input-12-8a34e03747f9>:5:4
    block0(%i.1 : int, %rv.27 : Tensor):
      %17 : bool = aten::lt(%i.1, %16) # <ipython-input-12-8a34e03747f9>:6:11
      %rv.25 : Tensor = prim::If(%17) # <ipython-input-12-8a34e03747f9>:6:8
        block0():
          %rv.5 : Tensor = aten::sub(%rv.27, %20, %21) # <ipython-input-12-8

Take the instruction ```rv.1 : Tensor = aten::zeros(%4, %6, %10, %12) # test.py:9:10``` 

- ```rv.1 : Tensor``` means we assign the output to a unique value named ```rv.1```, that value is of ```Tensor``` type and that we do not know its concrete shape.
- ```aten::zeros``` is the operator (equivalent to ```torch.zeros```) and the input list ```(%4, %6, %6, %10, %12)``` specifies which values in scope should be passed as inputs. The schema for built-in functions like ```aten::zeros``` can be found in <b>Built-in Functions</b>.
- ```# test.py:9:10``` is the location in the original source file that generated the instruction. In this case, it is a file named <i>test.py</i>.

Notice that operators can also have associated ```blocks```, namely the ```prim::Loop``` and ```prim::If``` operators. In the graph print-out, these operators are formatted to reflect their equivalent source code forms to facilitate easy debugging.

Graphs can be inspected as shown to confirm that the computation described by a ```ScriptModule``` is correct, in both automated and manual fashion.

## Tracer

### Tracing Edge Cases

There are some edge cases that exist where the trace of a given Python Function will not be representative of the underlying code. These can include:

- Tracing of control flow that is dependent on inputs like ```Tensor``` shapes.
- Tracing of in-place operations of tensor views like indexing on the left-side of an assignment

### Automatic Trace Checking

One way to automatically catch many errors in traces is by using ```check_inputs``` on the ```torch.jit.trace()``` API.
```check_inputs``` takes a tuple of input