In [1]:
import torch
import torch.nn as nn

In [2]:
class ConvReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(ConvReLU, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2)

    def forward(self, x):
        return self.relu(self.conv2(self.conv1(x)))

In [3]:
in_channels = 3
out_channels = 3
kernel_size = 3
stride = 1
padding = 1

input = torch.arange(1, 17, dtype=torch.float32)
input = input.reshape(4, 4)
input = torch.stack((input, input, input), dim=0)

input = input.unsqueeze(0)

print(input.shape)
print(input)

torch.Size([1, 3, 4, 4])
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]],

         [[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]],

         [[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]]]])


In [4]:
conv_relu_layer = ConvReLU(in_channels, out_channels, kernel_size, stride, padding)

output = conv_relu_layer(input)

print(output.shape)
print(output)

torch.Size([1, 3, 4, 4])
tensor([[[[0.0000, 0.0000, 0.2213, 1.5795],
          [0.0000, 0.0000, 0.9237, 2.2873],
          [0.0046, 0.1470, 2.1175, 2.2916],
          [2.1326, 1.4126, 2.2399, 0.6411]],

         [[0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.8971, 0.6681]],

         [[0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000]]]], grad_fn=<ReluBackward0>)


In [5]:
model = torch.jit.trace(conv_relu_layer, input)

In [6]:
model

ConvReLU(
  original_name=ConvReLU
  (conv1): Conv2d(original_name=Conv2d)
  (conv2): Conv2d(original_name=Conv2d)
  (relu): ReLU(original_name=ReLU)
)

In [7]:
model.save("./conv_relu_layer.pt")

In [8]:
loaded_model = torch.jit.load("./conv_relu_layer.pt")

In [9]:
loaded_model

RecursiveScriptModule(
  original_name=ConvReLU
  (conv1): RecursiveScriptModule(original_name=Conv2d)
  (conv2): RecursiveScriptModule(original_name=Conv2d)
  (relu): RecursiveScriptModule(original_name=ReLU)
)

In [10]:
test_output = loaded_model(input)
print(test_output.shape)

torch.Size([1, 3, 4, 4])


In [11]:
print(test_output)

tensor([[[[0.0000, 0.0000, 0.2213, 1.5795],
          [0.0000, 0.0000, 0.9237, 2.2873],
          [0.0046, 0.1470, 2.1175, 2.2916],
          [2.1326, 1.4126, 2.2399, 0.6411]],

         [[0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.8971, 0.6681]],

         [[0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000]]]], grad_fn=<ReluBackward0>)


In [15]:
class Expr(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(Expr, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2)

    def forward(self, x):
        add_layr = self.relu(self.conv2(self.conv1(x)) + x)
        return self.maxpool(add_layr)

In [16]:
in_channels = 3
out_channels = 3
kernel_size = 3
stride = 1
padding = 1

input = torch.arange(1, 17, dtype=torch.float32)
input = input.reshape(4, 4)
input = torch.stack((input, input, input), dim=0)

input = input.unsqueeze(0)

print(input.shape)
print(input)

torch.Size([1, 3, 4, 4])
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]],

         [[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]],

         [[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]]]])


In [17]:
expr_layer = Expr(in_channels, out_channels, kernel_size, stride, padding)

output = expr_layer(input)

print(output.shape)
print(output)

torch.Size([1, 3, 4, 4])
tensor([[[[ 0.0881,  0.0000,  0.0000,  0.0000],
          [ 1.3674,  0.0000,  0.0000,  0.1356],
          [ 2.0599,  0.0000,  0.0000,  1.7272],
          [ 5.4115,  3.4820,  6.6368, 11.2847]],

         [[ 1.6182,  3.0373,  5.1693,  3.0251],
          [ 6.0341,  6.9820, 10.0318,  5.9754],
          [ 9.4111,  8.8240, 11.5910,  8.2507],
          [11.8416, 11.8962, 14.1935, 12.8298]],

         [[ 1.6758,  3.1666,  4.0112,  3.4768],
          [ 7.3629,  9.5257,  9.8637,  8.8083],
          [11.4725, 13.5979, 13.8550, 13.4482],
          [16.7963, 18.7408, 16.8579, 19.0960]]]], grad_fn=<ReluBackward0>)


In [18]:
expr_model =torch.jit.trace(expr_layer, input)

In [19]:
expr_model

Expr(
  original_name=Expr
  (conv1): Conv2d(original_name=Conv2d)
  (conv2): Conv2d(original_name=Conv2d)
  (relu): ReLU(original_name=ReLU)
  (maxpool): MaxPool2d(original_name=MaxPool2d)
)