# Debugging

## Disable JIT for Debugging

#### ```PYTORCH_JIT```

Setting the environment variable ```PYTORCH_JIT=0``` will disable all script and tracing annotations. If there is hard-to-debug error in one of the TorchScript models, you can use this flag to force everything to run native Python. Since TorchScript is disabled with this flag, we can use tools like ```pdb``` to debug the model code like this.

In [1]:
import torch

print(torch.__version__)

1.9.0+cu111


In [2]:
import pdb
    
@torch.jit.script
def scripted_fn(x: torch.Tensor):
    for i in range(12):
        x += x
    return x

def fn(x):
    x = torch.neg(x)
    pdb.set_trace()
    return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5)))
traced_fn(torch.rand(3, 4))

> [1;32m<ipython-input-2-6d7926d75035>[0m(12)[0;36mfn[1;34m()[0m
[1;32m     10 [1;33m    [0mx[0m [1;33m=[0m [0mtorch[0m[1;33m.[0m[0mneg[0m[1;33m([0m[0mx[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[0m[1;32m     11 [1;33m    [0mpdb[0m[1;33m.[0m[0mset_trace[0m[1;33m([0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[0m[1;32m---> 12 [1;33m    [1;32mreturn[0m [0mscripted_fn[0m[1;33m([0m[0mx[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[0m[1;32m     13 [1;33m[1;33m[0m[0m
[0m[1;32m     14 [1;33m[0mtraced_fn[0m [1;33m=[0m [0mtorch[0m[1;33m.[0m[0mjit[0m[1;33m.[0m[0mtrace[0m[1;33m([0m[0mfn[0m[1;33m,[0m [1;33m([0m[0mtorch[0m[1;33m.[0m[0mrand[0m[1;33m([0m[1;36m4[0m[1;33m,[0m [1;36m5[0m[1;33m)[0m[1;33m)[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[0m
ipdb> x
tensor([[-0.7684, -0.5101, -0.7906, -0.4205, -0.9304],
        [-0.9736, -0.8795, -0.9005, -0.3209, -0.9383],
        [-0.0357, -0.7428, -0.5549, -0.5360, -0.9390],
   

BdbQuit: 

Debugging this script with ```pdb``` works except for when we invoke the ```@torch.jit.script``` function. We can globally disable JIT, so that we can call the ```@torch.jit.script``` function as a normal Python function and not compile it.

To disable the TorchScript Compiler for a specific function then ```@torch.jit.ignore```

## Inspecting Code

TorchScript provides a pretty printer for all ```ScriptModule``` instances. This pprint gives an interpretation of the script method's code as a valid ```Python``` Syntax like this:

In [11]:
@torch.jit.script
def foo(len):
    # Type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv -= 1.0
        else:
            rv += 1.0
    return rv

print(foo.code)

RuntimeError: 
all inputs of range must be ints, found Tensor (inferred) in argument 0:
  File "<ipython-input-11-8c0445dd483c>", line 5
    # Type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
             ~~~~~~~~~ <--- HERE
        if i < 10:
            rv -= 1.0


A ```ScriptModule``` with a single ```forward``` method will have ```code``` attribute, which can be inspect ```ScriptModule```'s code. If the ScriptModule has more than one method, you will need to access ```.code``` on the method itself and not the module. We can inspect the code of a method named ```foo``` on a ```ScriptModule``` by accessing ```.foo.code```.

It produces a Python Syntax Output like this:

```python
def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0
```

## Interpreting Graphs

TorchScript also has a representation at a lower level than the code pprint, in form of IR graphs.

TorchScript uses a static single assignment (SSA) intermediate representation IR to represent computation. The instructions in this format consist of ```ATen``` (C++ backend of PyTorch) operators and other primitive operators, including control flow operators for loops and conditionals.

For Example:

In [12]:
@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph)

graph(%len.1 : int):
  %21 : int = prim::Constant[value=1]()
  %13 : bool = prim::Constant[value=1]() # <ipython-input-12-8a34e03747f9>:5:4
  %5 : NoneType = prim::Constant()
  %1 : int = prim::Constant[value=3]() # <ipython-input-12-8a34e03747f9>:4:21
  %2 : int = prim::Constant[value=4]() # <ipython-input-12-8a34e03747f9>:4:24
  %16 : int = prim::Constant[value=10]() # <ipython-input-12-8a34e03747f9>:6:15
  %20 : float = prim::Constant[value=1.]() # <ipython-input-12-8a34e03747f9>:7:22
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %5, %5, %5, %5) # <ipython-input-12-8a34e03747f9>:4:9
  %rv : Tensor = prim::Loop(%len.1, %13, %rv.1) # <ipython-input-12-8a34e03747f9>:5:4
    block0(%i.1 : int, %rv.27 : Tensor):
      %17 : bool = aten::lt(%i.1, %16) # <ipython-input-12-8a34e03747f9>:6:11
      %rv.25 : Tensor = prim::If(%17) # <ipython-input-12-8a34e03747f9>:6:8
        block0():
          %rv.5 : Tensor = aten::sub(%rv.27, %20, %21) # <ipython-input-12-8

Take the instruction ```rv.1 : Tensor = aten::zeros(%4, %6, %10, %12) # test.py:9:10``` 

- ```rv.1 : Tensor``` means we assign the output to a unique value named ```rv.1```, that value is of ```Tensor``` type and that we do not know its concrete shape.
- ```aten::zeros``` is the operator (equivalent to ```torch.zeros```) and the input list ```(%4, %6, %6, %10, %12)``` specifies which values in scope should be passed as inputs. The schema for built-in functions like ```aten::zeros``` can be found in <b>Built-in Functions</b>.
- ```# test.py:9:10``` is the location in the original source file that generated the instruction. In this case, it is a file named <i>test.py</i>.

Notice that operators can also have associated ```blocks```, namely the ```prim::Loop``` and ```prim::If``` operators. In the graph print-out, these operators are formatted to reflect their equivalent source code forms to facilitate easy debugging.

Graphs can be inspected as shown to confirm that the computation described by a ```ScriptModule``` is correct, in both automated and manual fashion.

## Tracer

### Tracing Edge Cases

There are some edge cases that exist where the trace of a given Python Function will not be representative of the underlying code. These can include:

- Tracing of control flow that is dependent on inputs like ```Tensor``` shapes.
- Tracing of in-place operations of tensor views like indexing on the left-side of an assignment

### Automatic Trace Checking

One way to automatically catch many errors in traces is by using ```check_inputs``` on the ```torch.jit.trace()``` API.
```check_inputs``` takes a tuple of inputs that will be used to re-trace the computation and verify the results.

In [13]:
def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result *= x[i]
    return result

inputs = (torch.rand(3, 4, 5))
check_inputs = [(torch.rand(4, 5, 6), ), (torch.rand(2, 3, 4))]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

With rtol=1e-05 and atol=1e-05, found 30 element(s) (out of 30) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.30838286876678467 (0.5647702813148499 vs. 0.2563874125480652), which occurred at index (3, 5).
  _check_trace(


TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
	Graph diff:
		  graph(%x : Tensor):
		    %1 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:2:0
		    %2 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:2:0
		    %result.1 : Tensor = aten::select(%x, %1, %2) # <ipython-input-13-6f70278e62a6>:2:0
		    %4 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:4:0
		    %5 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:4:0
		    %6 : Tensor = aten::select(%x, %4, %5) # <ipython-input-13-6f70278e62a6>:4:0
		    %result.3 : Tensor = aten::mul_(%result.1, %6) # <ipython-input-13-6f70278e62a6>:4:0
		    %8 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:4:0
		    %9 : int = prim::Constant[value=1]() # <ipython-input-13-6f70278e62a6>:4:0
		    %10 : Tensor = aten::select(%x, %8, %9) # <ipython-input-13-6f70278e62a6>:4:0
		-   %result : Tensor = aten::mul_(%result.3, %10) # <ipython-input-13-6f70278e62a6>:4:0
		+   %result.5 : Tensor = aten::mul_(%result.3, %10) # <ipython-input-13-6f70278e62a6>:4:0
		?          ++
		    %12 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:4:0
		    %13 : int = prim::Constant[value=2]() # <ipython-input-13-6f70278e62a6>:4:0
		    %14 : Tensor = aten::select(%x, %12, %13) # <ipython-input-13-6f70278e62a6>:4:0
		+   %result : Tensor = aten::mul_(%result.5, %14) # <ipython-input-13-6f70278e62a6>:4:0
		+   %16 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:4:0
		+   %17 : int = prim::Constant[value=3]() # <ipython-input-13-6f70278e62a6>:4:0
		+   %18 : Tensor = aten::select(%x, %16, %17) # <ipython-input-13-6f70278e62a6>:4:0
		-   %15 : Tensor = aten::mul_(%result, %14) # <ipython-input-13-6f70278e62a6>:4:0
		?     ^                                  ^
		+   %19 : Tensor = aten::mul_(%result, %18) # <ipython-input-13-6f70278e62a6>:4:0
		?     ^                                  ^
		-   return (%15)
		?             ^
		+   return (%19)
		?             ^
	First diverging operator:
	Node diff:
		- %result : Tensor = aten::mul_(%result.3, %10) # <ipython-input-13-6f70278e62a6>:4:0
		+ %result.5 : Tensor = aten::mul_(%result.3, %10) # <ipython-input-13-6f70278e62a6>:4:0
		?        ++
	Trace source location:
		<ipython-input-13-6f70278e62a6>(4): loop_in_traced_fn
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\torch\jit\_trace.py(780): trace
		<ipython-input-13-6f70278e62a6>(10): <module>
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\IPython\core\interactiveshell.py(3418): run_code
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\IPython\core\interactiveshell.py(3338): run_ast_nodes
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\IPython\core\interactiveshell.py(3146): run_cell_async
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\IPython\core\async_helpers.py(68): _pseudo_sync_runner
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\IPython\core\interactiveshell.py(2923): _run_cell
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\IPython\core\interactiveshell.py(2877): run_cell
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel\zmqshell.py(536): run_cell
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel\ipkernel.py(306): do_execute
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\gen.py(234): wrapper
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel\kernelbase.py(543): execute_request
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\gen.py(234): wrapper
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel\kernelbase.py(268): dispatch_shell
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\gen.py(234): wrapper
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel\kernelbase.py(365): process_one
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\gen.py(775): run
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\gen.py(814): inner
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\ioloop.py(741): _run_callback
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\ioloop.py(688): <lambda>
		C:\ProgramData\Anaconda3\envs\ML\lib\asyncio\events.py(81): _run
		C:\ProgramData\Anaconda3\envs\ML\lib\asyncio\base_events.py(1859): _run_once
		C:\ProgramData\Anaconda3\envs\ML\lib\asyncio\base_events.py(570): run_forever
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\platform\asyncio.py(199): start
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel\kernelapp.py(612): start
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\traitlets\config\application.py(845): launch_instance
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel_launcher.py(16): <module>
		C:\ProgramData\Anaconda3\envs\ML\lib\runpy.py(87): _run_code
		C:\ProgramData\Anaconda3\envs\ML\lib\runpy.py(194): _run_module_as_main
	Check source location:
		<ipython-input-13-6f70278e62a6>(4): loop_in_traced_fn
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\torch\jit\_trace.py(780): trace
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\torch\jit\_trace.py(344): _check_trace
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\torch\autograd\grad_mode.py(28): decorate_context
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\torch\jit\_trace.py(793): trace
		<ipython-input-13-6f70278e62a6>(10): <module>
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\IPython\core\interactiveshell.py(3418): run_code
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\IPython\core\interactiveshell.py(3338): run_ast_nodes
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\IPython\core\interactiveshell.py(3146): run_cell_async
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\IPython\core\async_helpers.py(68): _pseudo_sync_runner
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\IPython\core\interactiveshell.py(2923): _run_cell
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\IPython\core\interactiveshell.py(2877): run_cell
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel\zmqshell.py(536): run_cell
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel\ipkernel.py(306): do_execute
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\gen.py(234): wrapper
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel\kernelbase.py(543): execute_request
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\gen.py(234): wrapper
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel\kernelbase.py(268): dispatch_shell
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\gen.py(234): wrapper
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel\kernelbase.py(365): process_one
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\gen.py(775): run
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\gen.py(814): inner
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\ioloop.py(741): _run_callback
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\ioloop.py(688): <lambda>
		C:\ProgramData\Anaconda3\envs\ML\lib\asyncio\events.py(81): _run
		C:\ProgramData\Anaconda3\envs\ML\lib\asyncio\base_events.py(1859): _run_once
		C:\ProgramData\Anaconda3\envs\ML\lib\asyncio\base_events.py(570): run_forever
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\tornado\platform\asyncio.py(199): start
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel\kernelapp.py(612): start
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\traitlets\config\application.py(845): launch_instance
		C:\ProgramData\Anaconda3\envs\ML\lib\site-packages\ipykernel_launcher.py(16): <module>
		C:\ProgramData\Anaconda3\envs\ML\lib\runpy.py(87): _run_code
		C:\ProgramData\Anaconda3\envs\ML\lib\runpy.py(194): _run_module_as_main


As we can see, this given an error

```python
TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
	Graph diff:
		  graph(%x : Tensor):
		    %1 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:2:0
		    %2 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:2:0
		    %result.1 : Tensor = aten::select(%x, %1, %2) # <ipython-input-13-6f70278e62a6>:2:0
		    %4 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:4:0
		    %5 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:4:0
		    %6 : Tensor = aten::select(%x, %4, %5) # <ipython-input-13-6f70278e62a6>:4:0
		    %result.3 : Tensor = aten::mul_(%result.1, %6) # <ipython-input-13-6f70278e62a6>:4:0
		    %8 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:4:0
		    %9 : int = prim::Constant[value=1]() # <ipython-input-13-6f70278e62a6>:4:0
		    %10 : Tensor = aten::select(%x, %8, %9) # <ipython-input-13-6f70278e62a6>:4:0
		-   %result : Tensor = aten::mul_(%result.3, %10) # <ipython-input-13-6f70278e62a6>:4:0
		+   %result.5 : Tensor = aten::mul_(%result.3, %10) # <ipython-input-13-6f70278e62a6>:4:0
		?          ++
		    %12 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:4:0
		    %13 : int = prim::Constant[value=2]() # <ipython-input-13-6f70278e62a6>:4:0
		    %14 : Tensor = aten::select(%x, %12, %13) # <ipython-input-13-6f70278e62a6>:4:0
		+   %result : Tensor = aten::mul_(%result.5, %14) # <ipython-input-13-6f70278e62a6>:4:0
		+   %16 : int = prim::Constant[value=0]() # <ipython-input-13-6f70278e62a6>:4:0
		+   %17 : int = prim::Constant[value=3]() # <ipython-input-13-6f70278e62a6>:4:0
		+   %18 : Tensor = aten::select(%x, %16, %17) # <ipython-input-13-6f70278e62a6>:4:0
		-   %15 : Tensor = aten::mul_(%result, %14) # <ipython-input-13-6f70278e62a6>:4:0
		?     ^                                  ^
		+   %19 : Tensor = aten::mul_(%result, %18) # <ipython-input-13-6f70278e62a6>:4:0
		?     ^                                  ^
		-   return (%15)
		?             ^
		+   return (%19)
		?             ^
	First diverging operator:
	Node diff:
		- %result : Tensor = aten::mul_(%result.3, %10) # <ipython-input-13-6f70278e62a6>:4:0
		+ %result.5 : Tensor = aten::mul_(%result.3, %10) # <ipython-input-13-6f70278e62a6>:4:0
		?        ++
```

This indicates to us that the computation differed b/w when we first traced and when we traced it with the ```check_inputs```. Indeed, the loop within the body of ```loop_in_traced_fn``` depends on the shape of the input ```x```, and thus when we try another ```x``` with a different shape, the trace differs.

In this case, <b>data-dependent</b> control flow can be captured using ```torch.jit.script()``` instead.

In [17]:
def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result *= x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6), ), (torch.rand(2, 3, 4),)]

scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)

for input_tup in [inputs] + check_inputs:
    torch.testing.assert_allclose(fn(*input_tup), scripted_fn(*input_tup))

graph(%x.1 : Tensor):
  %9 : bool = prim::Constant[value=1]() # <ipython-input-17-690ccad1804c>:3:4
  %2 : int = prim::Constant[value=0]() # <ipython-input-17-690ccad1804c>:2:15
  %result.1 : Tensor = aten::select(%x.1, %2, %2) # <ipython-input-17-690ccad1804c>:2:13
  %6 : int = aten::size(%x.1, %2) # <ipython-input-17-690ccad1804c>:3:19
  %result : Tensor = prim::Loop(%6, %9, %result.1) # <ipython-input-17-690ccad1804c>:3:4
    block0(%i.1 : int, %result.11 : Tensor):
      %17 : Tensor = aten::select(%x.1, %2, %i.1) # <ipython-input-17-690ccad1804c>:4:18
      %result.5 : Tensor = aten::mul_(%result.11, %17) # <ipython-input-17-690ccad1804c>:4:8
      -> (%9, %result.5)
  return (%result)



### Tracer Warnings

The tracer produces warnings for several problematic patterns in traced computation. For Example, take a trace of a function that contains an in-place assignment on a slice (a view) of a Tensor:

In [18]:
def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4)))
print(traced.graph)

graph(%x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %4 : int = prim::Constant[value=1]() # <ipython-input-18-7024b6eddacc>:2:0
  %5 : int = aten::size(%x, %4) # <ipython-input-18-7024b6eddacc>:2:0
  %6 : Long(device=cpu) = prim::NumToTensor(%5)
  %7 : int = aten::Int(%6)
  %8 : int[] = prim::ListConstruct(%7)
  %9 : int = prim::Constant[value=6]() # <ipython-input-18-7024b6eddacc>:2:0
  %10 : NoneType = prim::Constant()
  %11 : Device = prim::Constant[value="cpu"]() # <ipython-input-18-7024b6eddacc>:2:0
  %12 : bool = prim::Constant[value=0]() # <ipython-input-18-7024b6eddacc>:2:0
  %13 : Float(4, strides=[1], requires_grad=0, device=cpu) = aten::rand(%8, %9, %10, %11, %12) # <ipython-input-18-7024b6eddacc>:2:0
  %14 : int = prim::Constant[value=0]() # <ipython-input-18-7024b6eddacc>:2:0
  %15 : int = prim::Constant[value=0]() # <ipython-input-18-7024b6eddacc>:2:0
  %16 : Float(4, strides=[1], requires_grad=0, device=cpu) = aten::select(%x, %14, %15) # <ipython-inpu

	%13 : Float(4, strides=[1], requires_grad=0, device=cpu) = aten::rand(%8, %9, %10, %11, %12) # <ipython-input-18-7024b6eddacc>:2:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
  _check_trace(
With rtol=1e-05 and atol=1e-05, found 4 element(s) (out of 12) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.9652758240699768 (0.9818273186683655 vs. 0.016551494598388672), which occurred at index (0, 1).
  _check_trace(


We can fix this by modifying the code to not use the in-place update, but rather build up the result tensor out-of-place with ```torch.cat```:

In [19]:
def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

graph(%x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %4 : int = prim::Constant[value=1]() # <ipython-input-19-a0ec8869dac9>:2:0
  %5 : int = aten::size(%x, %4) # <ipython-input-19-a0ec8869dac9>:2:0
  %6 : Long(device=cpu) = prim::NumToTensor(%5)
  %7 : int = aten::Int(%6)
  %8 : int = prim::Constant[value=1]() # <ipython-input-19-a0ec8869dac9>:2:0
  %9 : int[] = prim::ListConstruct(%8, %7)
  %10 : int = prim::Constant[value=6]() # <ipython-input-19-a0ec8869dac9>:2:0
  %11 : NoneType = prim::Constant()
  %12 : Device = prim::Constant[value="cpu"]() # <ipython-input-19-a0ec8869dac9>:2:0
  %13 : bool = prim::Constant[value=0]() # <ipython-input-19-a0ec8869dac9>:2:0
  %14 : Float(1, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::rand(%9, %10, %11, %12, %13) # <ipython-input-19-a0ec8869dac9>:2:0
  %15 : int = prim::Constant[value=0]() # <ipython-input-19-a0ec8869dac9>:2:0
  %16 : int = prim::Constant[value=1]() # <ipython-input-19-a0ec8869dac9>:2:0
  %17 : int = 

	%14 : Float(1, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::rand(%9, %10, %11, %12, %13) # <ipython-input-19-a0ec8869dac9>:2:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
  _check_trace(
With rtol=1e-05 and atol=1e-05, found 4 element(s) (out of 8) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.6998563408851624 (0.990868330001831 vs. 0.2910119891166687), which occurred at index (0, 3).
  _check_trace(
