In [1]:
import torch

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()

    def forward(self, x, h):
        new_h = torch.tanh(x + h)
        return new_h, new_h

my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))


(tensor([[0.9258, 0.8624, 0.2548, 0.7526],
        [0.9196, 0.7145, 0.6951, 0.7513],
        [0.8356, 0.8861, 0.9181, 0.7395]]), tensor([[0.9258, 0.8624, 0.2548, 0.7526],
        [0.9196, 0.7145, 0.6951, 0.7513],
        [0.8356, 0.8861, 0.9181, 0.7395]]))


In [5]:
import torch


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = torch.nn.Conv3d(5, 1, 3, padding=1, bias=False)

    def forward(self, x):
        new_x = self.conv(x)
        up_x = torch.nn.functional.interpolate(
            new_x, scale_factor=2, mode="bicubic", align_corners=True)
        return up_x


inp_5 = torch.randn( 5, 5, 5, 5)
inp_10 = torch.randn( 5, 10, 10, 10)
inp_15 = torch.randn( 5, 15, 15, 15)

model = Model()
model.eval()
trace = torch.jit.trace(model, inp_10)
trace.save("trace.pth")

result_model_5 = model(inp_5)
result_model_10 = model(inp_10)
result_model_15 = model(inp_15)
print("Shape  5, {} ||| {}".format(result_model_5.shape, result_model_5.shape))
print("Shape 10, {} ||| {}".format(result_model_10.shape, result_model_10.shape))
print("Shape 15, {} ||| {}".format(result_model_15.shape, result_model_15.shape))
t_model = torch.jit.load("trace.pth")
result_t_model_5 = t_model(inp_5)
result_t_model_10 = t_model(inp_10)
result_t_model_15 = t_model(inp_15)

print("Shape  5, {} ||| {}".format(result_model_5.shape, result_t_model_5.shape))
print("Shape 10, {} ||| {}".format(result_model_10.shape, result_t_model_10.shape))
print("Shape 15, {} ||| {}".format(result_model_15.shape, result_t_model_15.shape))
torch.allclose(result_model_5, result_t_model_5)
torch.allclose(result_model_10, result_t_model_10)
torch.allclose(result_model_15, result_t_model_15)


Shape  5, torch.Size([1, 5, 10, 10]) ||| torch.Size([1, 5, 10, 10])
Shape 10, torch.Size([1, 10, 20, 20]) ||| torch.Size([1, 10, 20, 20])
Shape 15, torch.Size([1, 15, 30, 30]) ||| torch.Size([1, 15, 30, 30])
Shape  5, torch.Size([1, 5, 10, 10]) ||| torch.Size([1, 5, 10, 10])
Shape 10, torch.Size([1, 10, 20, 20]) ||| torch.Size([1, 10, 20, 20])
Shape 15, torch.Size([1, 15, 30, 30]) ||| torch.Size([1, 15, 30, 30])


True

In [15]:
import torch.onnx
import torch.nn.functional as F
import torch

# Define the model
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = torch.nn.Conv3d(5, 1, 3, padding=1, bias=False)

    def forward(self, x):
        new_x = self.conv(x)
        up_x = F.interpolate(
            new_x, scale_factor=2, mode="bicubic", align_corners=True)
        return up_x

# Instantiate the model
model = Model()

# Prepare dummy input data
dummy_input = torch.randn(5, 32, 32, 32)  # Example shape, adjust according to your actual input shape

# Export the model to ONNX
torch.onnx.export(model, dummy_input, "bi_model.onnx", verbose=True)

Exported graph: graph(%input : Float(5, 32, 32, 32, strides=[32768, 1024, 32, 1], requires_grad=0, device=cpu),
      %conv.weight : Float(1, 5, 3, 3, 3, strides=[135, 27, 9, 3, 1], requires_grad=1, device=cpu)):
  %/conv/Constant_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/conv/Constant"](), scope: __main__.Model::/torch.nn.modules.conv.Conv3d::conv # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:608:0
  %/conv/Unsqueeze_output_0 : Float(1, 5, 32, 32, 32, strides=[163840, 32768, 1024, 32, 1], requires_grad=0, device=cpu) = onnx::Unsqueeze[onnx_name="/conv/Unsqueeze"](%input, %/conv/Constant_output_0), scope: __main__.Model::/torch.nn.modules.conv.Conv3d::conv # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:608:0
  %/conv/Conv_output_0 : Float(1, 1, 32, 32, 32, strides=[32768, 32768, 1024, 32, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1, 1], group=1, kernel_shape=[3, 3, 3], pads=[1, 1, 1, 1, 1

In [16]:
import onnxruntime
import numpy as np
import onnx

# Load the ONNX model
onnx_model = onnx.load("bi_model.onnx")

# Prepare input data
input_data = np.random.randn(5, 32, 32, 32).astype(np.float32)  # Example input, adjust according to your actual input

# Run the ONNX model
ort_session = onnxruntime.InferenceSession("bi_model.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: input_data}
ort_outs = ort_session.run(None, ort_inputs)

# Convert ONNX output to PyTorch tensor
onnx_output = torch.tensor(ort_outs[0])



pytorch_model = Model()

# Run the PyTorch model
pytorch_output = pytorch_model(torch.tensor(input_data))

# Compare outputs
print(torch.allclose(pytorch_output, onnx_output, atol=1e-4))  # Check if outputs are approximately equal

False


In [19]:
import torch
import torch.onnx
import onnx
import numpy as np

# Define the model as scripted
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = torch.nn.Conv3d(5, 1, 3, padding=1, bias=False)

    def forward(self, x):
        new_x = self.conv(x)
        up_x = torch.nn.functional.interpolate(
            new_x, scale_factor=2, mode="trilinear", align_corners=True)
        return up_x

# Script the model
model_scripted = torch.jit.script(Model())

# Prepare dummy input data
dummy_input = torch.randn(5, 32, 32, 32)  # Example shape, adjust according to your actual input shape

# Export the scripted model to ONNX
torch.onnx.export(model_scripted, dummy_input, "model_scripted.onnx", verbose=True)


RuntimeError: 
Arguments for call are not valid.
The following variants are available:
  
  interpolate(Tensor input, int? size=None, float[]? scale_factor=None, str mode="nearest", bool? align_corners=None, bool? recompute_scale_factor=None, bool antialias=False) -> Tensor:
  Expected a value of type 'Optional[List[float]]' for argument 'scale_factor' but instead found type 'int'.
  
  interpolate(Tensor input, int[]? size=None, float[]? scale_factor=None, str mode="nearest", bool? align_corners=None, bool? recompute_scale_factor=None, bool antialias=False) -> Tensor:
  Expected a value of type 'Optional[List[float]]' for argument 'scale_factor' but instead found type 'int'.
  
  interpolate(Tensor input, int? size=None, float? scale_factor=None, str mode="nearest", bool? align_corners=None, bool? recompute_scale_factor=None, bool antialias=False) -> Tensor:
  Expected a value of type 'Optional[float]' for argument 'scale_factor' but instead found type 'int'.
  
  interpolate(Tensor input, int[]? size=None, float? scale_factor=None, str mode="nearest", bool? align_corners=None, bool? recompute_scale_factor=None, bool antialias=False) -> Tensor:
  Expected a value of type 'Optional[float]' for argument 'scale_factor' but instead found type 'int'.

The original call is:
  File "/tmp/ipykernel_207/4245390919.py", line 14
    def forward(self, x):
        new_x = self.conv(x)
        up_x = torch.nn.functional.interpolate(
               ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            new_x, scale_factor=2, mode="trilinear", align_corners=True)
        return up_x


In [17]:
pytorch_output

tensor([[[[-0.2227, -0.3328, -0.3188,  ..., -0.3511, -0.2137, -0.0189],
          [-0.2008, -0.4195, -0.4958,  ...,  0.0945, -0.1694, -0.3552],
          [-0.1794, -0.3575, -0.4425,  ...,  0.4274, -0.1076, -0.5864],
          ...,
          [ 0.1714,  0.6273,  0.8638,  ..., -0.0519,  0.1328,  0.3054],
          [ 0.3390,  0.5549,  0.5967,  ..., -0.1187,  0.0978,  0.2802],
          [ 0.4682,  0.3761,  0.1764,  ..., -0.1827,  0.0200,  0.1770]],

         [[ 0.8741,  0.5230,  0.1654,  ..., -0.6323, -0.3302,  0.1075],
          [ 0.2444, -0.0916, -0.3434,  ..., -0.8905, -0.4764,  0.1316],
          [-0.2961, -0.5792, -0.7157,  ..., -0.8524, -0.4298,  0.1662],
          ...,
          [-0.3688, -0.3246, -0.2625,  ..., -0.8969, -0.6655, -0.2234],
          [-0.2457, -0.0348,  0.1770,  ...,  0.0212, -0.0255,  0.0124],
          [ 0.0385,  0.3449,  0.6517,  ...,  1.0559,  0.8050,  0.4756]],

         [[-0.2796, -0.0995,  0.0552,  ...,  0.0903,  0.1187,  0.0740],
          [-0.4286, -0.0534,  

In [11]:
onnx_output

tensor([[[[-2.4487e-01, -7.6984e-02,  3.9708e-02,  ...,  2.9346e-01,
            1.7076e-01,  7.4563e-02],
          [-3.4842e-01, -4.8136e-02,  2.1705e-01,  ..., -2.4384e-01,
           -1.1211e-01,  4.2700e-02],
          [-2.4128e-01,  4.8916e-02,  3.1589e-01,  ..., -6.4900e-01,
           -3.2812e-01,  3.2954e-02],
          ...,
          [-3.7661e-01, -2.4080e-02,  1.2430e-01,  ...,  5.8830e-01,
            5.3890e-01,  5.2177e-01],
          [-2.6550e-01, -2.6247e-02,  1.5702e-01,  ...,  3.9481e-01,
            3.7874e-01,  3.6470e-01],
          [-1.2041e-01,  6.5828e-03,  2.3583e-01,  ...,  4.9067e-02,
            1.1513e-01,  1.5927e-01]],

         [[-1.6589e-01, -7.2124e-01, -1.1899e+00,  ..., -3.5382e-01,
           -4.4075e-01, -3.8103e-01],
          [-2.5858e-01, -4.9317e-01, -7.4639e-01,  ..., -2.1517e-01,
           -3.2313e-01, -3.0102e-01],
          [-2.9973e-01, -2.7149e-01, -3.1751e-01,  ..., -1.0852e-01,
           -2.3290e-01, -2.8148e-01],
          ...,
     

In [2]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

MyCell(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.5909,  0.5538,  0.7614,  0.0835],
        [ 0.0673,  0.1451,  0.7798,  0.0643],
        [ 0.7146, -0.0390,  0.3463, -0.2745]], grad_fn=<TanhBackward0>), tensor([[ 0.5909,  0.5538,  0.7614,  0.0835],
        [ 0.0673,  0.1451,  0.7798,  0.0643],
        [ 0.7146, -0.0390,  0.3463, -0.2745]], grad_fn=<TanhBackward0>))


In [3]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

MyCell(
  (dg): MyDecisionGate()
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.8818,  0.4849,  0.7097,  0.0984],
        [ 0.7700,  0.2646,  0.8844, -0.0608],
        [ 0.8700,  0.0560,  0.2545, -0.3650]], grad_fn=<TanhBackward0>), tensor([[ 0.8818,  0.4849,  0.7097,  0.0984],
        [ 0.7700,  0.2646,  0.8844, -0.0608],
        [ 0.8700,  0.0560,  0.2545, -0.3650]], grad_fn=<TanhBackward0>))


## Tracing scripts

In [4]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)


(tensor([[0.6391, 0.6892, 0.8042, 0.5016],
         [0.0427, 0.8648, 0.7200, 0.5320],
         [0.3587, 0.8346, 0.6829, 0.4864]], grad_fn=<TanhBackward0>),
 tensor([[0.6391, 0.6892, 0.8042, 0.5016],
         [0.0427, 0.8648, 0.7200, 0.5320],
         [0.3587, 0.8346, 0.6829, 0.4864]], grad_fn=<TanhBackward0>))

In [5]:
print(traced_cell.graph)

graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # /tmp/ipykernel_1218/260609686.py:7:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /tmp/ipykernel_1218/260609686.py:7:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /tmp/ipykernel_1218/260609686.py:7:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)



In [6]:
print(traced_cell.code)


def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)



In [7]:
print(my_cell(x, h))
print(traced_cell(x, h))

(tensor([[0.6391, 0.6892, 0.8042, 0.5016],
        [0.0427, 0.8648, 0.7200, 0.5320],
        [0.3587, 0.8346, 0.6829, 0.4864]], grad_fn=<TanhBackward0>), tensor([[0.6391, 0.6892, 0.8042, 0.5016],
        [0.0427, 0.8648, 0.7200, 0.5320],
        [0.3587, 0.8346, 0.6829, 0.4864]], grad_fn=<TanhBackward0>))
(tensor([[0.6391, 0.6892, 0.8042, 0.5016],
        [0.0427, 0.8648, 0.7200, 0.5320],
        [0.3587, 0.8346, 0.6829, 0.4864]], grad_fn=<TanhBackward0>), tensor([[0.6391, 0.6892, 0.8042, 0.5016],
        [0.0427, 0.8648, 0.7200, 0.5320],
        [0.3587, 0.8346, 0.6829, 0.4864]], grad_fn=<TanhBackward0>))


In [8]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

print(traced_cell.dg.code)
print(traced_cell.code)

def forward(self,
    argument_1: Tensor) -> Tensor:
  return torch.neg(argument_1)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  _1 = torch.tanh(_0)
  return (_1, _1)



  if x.sum() > 0:


In [9]:
scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)

print(scripted_gate.code)
print(scripted_cell.code)

def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)



In [10]:
# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
print(scripted_cell(x, h))

(tensor([[ 0.6794,  0.8707,  0.4962,  0.2333],
        [ 0.5400,  0.3965,  0.6403,  0.6984],
        [ 0.6505, -0.0650,  0.3179,  0.2074]], grad_fn=<TanhBackward0>), tensor([[ 0.6794,  0.8707,  0.4962,  0.2333],
        [ 0.5400,  0.3965,  0.6403,  0.6984],
        [ 0.6505, -0.0650,  0.3179,  0.2074]], grad_fn=<TanhBackward0>))
