In [10]:
import torch.nn as nn
import torch

In [11]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h

In [12]:
my_cell = MyCell()

In [13]:
x, h = torch.rand(4, 4), torch.rand(4, 4)

In [5]:
traced_cell = torch.jit.trace(my_cell, (x, h))

In [14]:
traced_cell

TracedModule[MyCell](
  original_name=MyCell
  (linear): TracedModule[Linear](original_name=Linear)
)

In [7]:
traced_cell(x, h)

tensor([[ 0.4238, -0.0524,  0.5719,  0.4747],
        [-0.0059, -0.3625,  0.2658,  0.7130],
        [ 0.4532,  0.6390,  0.6385,  0.6584]],
       grad_fn=<DifferentiableGraphBackward>)

In [8]:
traced_cell.graph

graph(%self : ClassType<MyCell>,
      %input : Float(3, 4),
      %h : Float(3, 4)):
  %1 : ClassType<Linear> = prim::GetAttr[name="linear"](%self)
  %weight : Tensor = prim::GetAttr[name="weight"](%1)
  %bias : Tensor = prim::GetAttr[name="bias"](%1)
  %6 : Float(4, 4) = aten::t(%weight), scope: MyCell/Linear[linear] # /home/jibin/.local/lib/python3.6/site-packages/torch/nn/functional.py:1370:0
  %7 : int = prim::Constant[value=1](), scope: MyCell/Linear[linear] # /home/jibin/.local/lib/python3.6/site-packages/torch/nn/functional.py:1370:0
  %8 : int = prim::Constant[value=1](), scope: MyCell/Linear[linear] # /home/jibin/.local/lib/python3.6/site-packages/torch/nn/functional.py:1370:0
  %9 : Float(3, 4) = aten::addmm(%bias, %input, %6, %7, %8), scope: MyCell/Linear[linear] # /home/jibin/.local/lib/python3.6/site-packages/torch/nn/functional.py:1370:0
  %10 : int = prim::Constant[value=1](), scope: MyCell # <ipython-input-2-c6e2cd8665ee>:7:0
  %11 : Float(3, 4) = aten::add(%9, %h, %10

In [9]:
traced_cell.code

'import __torch__\nimport __torch__.torch.nn.modules.linear\ndef forward(self,\n    input: Tensor,\n    h: Tensor) -> Tensor:\n  _0 = self.linear\n  weight = _0.weight\n  bias = _0.bias\n  _1 = torch.addmm(bias, input, torch.t(weight), beta=1, alpha=1)\n  return torch.tanh(torch.add(_1, h, alpha=1))\n'

In [15]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

In [16]:
class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h

In [18]:
my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

  This is separate from the ipykernel package so we can avoid doing imports until


In [21]:
print(traced_cell.code)

import __torch__.___torch_mangle_0
import __torch__
import __torch__.torch.nn.modules.linear.___torch_mangle_1
def forward(self,
    input: Tensor,
    h: Tensor) -> Tensor:
  _0 = self.linear
  weight = _0.weight
  bias = _0.bias
  x = torch.addmm(bias, input, torch.t(weight), beta=1, alpha=1)
  _1 = torch.tanh(torch.add(torch.neg(x), h, alpha=1))
  return _1



In [22]:
scripted_gate = torch.jit.script(MyDecisionGate())

In [23]:
my_cell = MyCell(scripted_gate)

In [24]:
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)

import __torch__.___torch_mangle_3
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_4
def forward(self,
    x: Tensor,
    h: Tensor) -> Tensor:
  _0 = self.linear
  _1 = _0.weight
  _2 = _0.bias
  if torch.eq(torch.dim(x), 2):
    _3 = torch.__isnot__(_2, None)
  else:
    _3 = False
  if _3:
    bias = ops.prim.unchecked_unwrap_optional(_2)
    ret = torch.addmm(bias, x, torch.t(_1), beta=1, alpha=1)
  else:
    output = torch.matmul(x, torch.t(_1))
    if torch.__isnot__(_2, None):
      bias0 = ops.prim.unchecked_unwrap_optional(_2)
      output0 = torch.add_(output, bias0, alpha=1)
    else:
      output0 = output
    ret = output0
  _4 = torch.gt(torch.sum(ret, dtype=None), 0)
  if bool(_4):
    _5 = ret
  else:
    _5 = torch.neg(ret)
  return torch.tanh(torch.add(_5, h, alpha=1))

