In [1]:
import brt
import brt.nn as nn
import torch
from brt.prim import is_netlet
# @basic_unit
@brt.netlet
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 10)
        self.linear2 = nn.Linear(10, 10)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

simple_net = SimpleNet()
print(is_netlet(simple_net))
simple_net.forward = simple_net.pt_forward

x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x = simple_net(x)
print(x)
script_simple_net = torch.jit.script(simple_net)
simple_net_inlined_graph = script_simple_net.inlined_graph
print(simple_net_inlined_graph)

True
tensor([ 0.6682, -1.4162, -2.7463,  1.8377, -1.8727, -1.7920, -1.0083,  1.0476,
        -2.7022, -0.4607], grad_fn=<AddBackward0>)
graph(%self : __torch__.SimpleNet,
      %x.1 : Tensor):
  %linear1 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear1"](%self)
  %6 : Function = prim::Constant[name="linear"]()
  %weight.1 : Tensor = prim::GetAttr[name="weight"](%linear1)
  %bias.1 : Tensor = prim::GetAttr[name="bias"](%linear1)
  %x.5 : Tensor = aten::linear(%x.1, %weight.1, %bias.1) # /state/partition/whcui/tools/pyenv/versions/miniconda3-4.7.12/lib/python3.7/site-packages/torch/nn/functional.py:1848:11
  %linear2 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear2"](%self)
  %10 : Function = prim::Constant[name="linear"]()
  %weight : Tensor = prim::GetAttr[name="weight"](%linear2)
  %bias : Tensor = prim::GetAttr[name="bias"](%linear2)
  %x.9 : Tensor = aten::linear(%x.5, %weight, %bias) # /state/partition/whcui/tools/pyenv/versions/minic

In [2]:
from brt.graphgen import convert_to_graph
from brt.codegen import model_to_pytorch_script
model_ir = convert_to_graph(script_simple_net, simple_net)

for node in model_ir.get_nodes():
    print(node)

print("-----------------")

for cell_node in model_ir.get_cell_nodes():
    print(cell_node)

model_code = model_to_pytorch_script(model_ir)
print(model_code)

Node(id=-1, name=_inputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_inputs"))
Node(id=-2, name=_outputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_outputs"))
Node(id=2, name=_model__linear1, python_name=linear1, label=None, operation=ModuleOperator(type="__torch__.torch.nn.modules.linear.Linear", in_features=10, out_features=10))
Node(id=4, name=_model__linear2, python_name=linear2, label=None, operation=ModuleOperator(type="__torch__.torch.nn.modules.linear.Linear", in_features=10, out_features=10))
-----------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import brt.nn

import torch


class _model(nn.Module):
    def __init__(self):
        super().__init__()
        self._linear1 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._linear2 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._mapping_ = {'_linear1': 'linear1', '

In [6]:
import torch
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, RandomGatherRouter


@brt.netlet
class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)

    def forward(self, x):
        route_results, reverse_indice, reverse_shape = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1], reverse_indice, reverse_shape)
        return x_0, x_1


@brt.netlet
class MoEModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe_model = MoEModel()

    def forward(self, x):
        return self.moe_model(x)


moe_model = Model()

x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
# model.cuda()
y = 10
z = moe_model(x)
print(z)


# model.brt_script(True)
script_moe_model = torch.jit.script(moe_model)
print(script_moe_model.inlined_graph)


[tensor([[-0.4458,  2.5542, -0.3188, -2.9574,  4.2428,  2.7199, -2.5572,  0.0227,
         -6.4788,  0.2923]], grad_fn=<AddmmBackward0>), tensor([[-4.3528, -2.3810,  1.5321,  2.8667,  4.0990, -1.4264,  9.8560, -1.3717,
         -3.5174, -0.0418]], grad_fn=<AddmmBackward0>)]
(tensor([[-0.4458,  2.5542, -0.3188, -2.9574,  4.2428,  2.7199, -2.5572,  0.0227,
         -6.4788,  0.2923]], grad_fn=<AddmmBackward0>), tensor([[-4.3528, -2.3810,  1.5321,  2.8667,  4.0990, -1.4264,  9.8560, -1.3717,
         -3.5174, -0.0418]], grad_fn=<AddmmBackward0>))
graph(%self : __torch__.Model,
      %x.1 : Tensor):
  %moe_model : __torch__.MoEModel = prim::GetAttr[name="moe_model"](%self)
  %4 : int = prim::Constant[value=0]() # /tmp/ipykernel_2957284/4277649058.py:18:41
  %5 : int = prim::Constant[value=1]() # /tmp/ipykernel_2957284/4277649058.py:19:41
  %6 : Function = prim::Constant[name="linear"]()
  %moe : __torch__.MoE = prim::GetAttr[name="moe"](%moe_model)
  %scatter_router : __torch__.brt.router.

In [2]:
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
model_ir = convert_to_graph(script_moe_model, moe_model)

for node in model_ir.get_nodes():
    print(node)

model_code = model_to_pytorch_script(model_ir)


Node(id=-1, name=_inputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_inputs"))
Node(id=-2, name=_outputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_outputs"))
Node(id=3, name=_model__moe_model__moe__Constant3, python_name=None, label=None, operation=PrimConstant(type="prim::Constant", type='int', value=0))
Node(id=4, name=_model__moe_model__moe__Constant4, python_name=None, label=None, operation=PrimConstant(type="prim::Constant", type='int', value=1))
Node(id=5, name=_model__moe_model__moe__Attr5, python_name=None, label=None, operation=PrimGetAttr(type="prim::GetAttr", name='scatter_router', input='self', value=None))
Node(id=6, name=_model__moe_model__moe__prim__PythonOp6, python_name=None, label=None, operation=PyTorchOperation(type="prim::PythonOp"))
Node(id=7, name=_model__moe_model__moe__TupleUnpack7, python_name=None, label=None, operation=PrimTupleUnpack(type="prim::TupleUnpack"))
Node(id=9, name=_model__moe_model__moe__aten____ge

RuntimeError: unsupported operation type: prim::PythonOp ? None

In [None]:
import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self, use_memory_efficient):
        super(MyModule, self).__init__()
        self.use_memory_efficient = use_memory_efficient

    @torch.jit.ignore(drop=True)
    def memory_efficient(self, x):
        import pdb

        pdb.set_trace()
        return x + 10

    def forward(self, x):
        # Use not-yet-scriptable memory efficient mode
        if self.use_memory_efficient:
            return self.memory_efficient(x)
        else:
            return x + 10


m = torch.jit.script(MyModule(use_memory_efficient=False))
# m.save("m.pt")

m = torch.jit.script(MyModule(use_memory_efficient=True))
# exception raised
print(m.inlined_graph)
# m.save("m.pt")
# m(torch.rand(100))


In [1]:
import brt.nn as nn
import torch
from brt import netlet, top_graph
from brt.prim import unwrap_netlet, unwrap_redundant_netlet


@top_graph
@netlet
class RedundantModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)


redundant_model = RedundantModel()
x = torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
x = redundant_model(x)
print(x)

# assert redundant_model._netlet_tag == True, "netlet_tag is not set"
# assert redundant_model._top_graph == True, "top_graph is not set"

redundant_model = unwrap_redundant_netlet(redundant_model)
x = redundant_model(x)
print(x)


tensor([-1.6597,  1.9257, -1.5097, -4.6707, -0.4496,  3.1321,  1.4132,  1.1851,
         0.9164,  0.7567], grad_fn=<AddBackward0>)
tensor([-1.3143, -1.6946, -0.8486,  0.4667,  0.3905, -0.5602,  0.1592, -0.5632,
         1.1540, -1.0938], grad_fn=<AddBackward0>)


In [5]:
import os
import sys
import unittest
from typing import Dict

import numpy as np
import torch
import torch.nn.functional as F
import torchvision

import nni.retiarii.nn.pytorch as nn
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from nni.retiarii.converter.graph_gen import convert_to_graph, GraphConverterWithShape
from nni.retiarii.serializer import basic_unit

class ConvertMixin:
    @staticmethod
    def _convert_model(model, input):
        script_module = torch.jit.script(model)
        model_ir = convert_to_graph(script_module, model)
        return model_ir


class TestModels(unittest.TestCase, ConvertMixin):
    def run_test(self, model, input, check_value=True):
        model_ir = self._convert_model(model, input)
        model_code = model_to_pytorch_script(model_ir)
        print(model_code)

        exec_vars = {}
        exec(model_code + "\n\nconverted_model = _model()", exec_vars)
        converted_model = exec_vars["converted_model"]

        with original_state_dict_hooks(converted_model):
            converted_model.load_state_dict(model.state_dict())

        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        if check_value:
            try:
                self.assertEqual(len(converted_output), len(expected_output))
                for a, b in zip(converted_output, expected_output):
                    torch.eq(a, b)
            except:
                self.assertEqual(converted_output, expected_output)
        return converted_model

    def test_nested_modulelist(self):
        class Net(nn.Module):
            def __init__(self, num_nodes, num_ops_per_node):
                super().__init__()
                self.ops = nn.ModuleList()
                self.num_nodes = num_nodes
                self.num_ops_per_node = num_ops_per_node
                for _ in range(num_nodes):
                    self.ops.append(
                        nn.ModuleList(
                            [nn.Linear(16, 16) for __ in range(num_ops_per_node)]
                        )
                    )

            def forward(self, x):
                state = x
                for ops in self.ops:
                    for op in ops:
                        state = op(state)
                return state

        model = Net(4, 2)
        x = torch.rand((16, 16), dtype=torch.float)
        self.run_test(model, (x,))

    def test_append_input_tensor(self):
        from typing import List

        class Net(nn.Module):
            def __init__(self, num_nodes):
                super().__init__()
                self.ops = nn.ModuleList()
                self.num_nodes = num_nodes
                for _ in range(num_nodes):
                    self.ops.append(nn.Linear(16, 16))

            def forward(self, x: List[torch.Tensor]):
                state = x
                for ops in self.ops:
                    state.append(ops(state[-1]))
                return state[-1]

        model = Net(4)
        x = torch.rand((1, 16), dtype=torch.float)
        self.run_test(model, ([x],))

    def test_channels_shuffle(self):
        class Net(nn.Module):
            def forward(self, x):
                bs, num_channels, height, width = x.size()
                x = x.reshape(bs * num_channels // 2, 2, height * width)
                x = x.permute(1, 0, 2)
                x = x.reshape(2, -1, num_channels // 2, height, width)
                return x[0], x[1]

        model = Net()
        x = torch.rand((1, 64, 224, 224), dtype=torch.float)
        self.run_test(model, (x,))

    def test_identity_node(self):
        class Net(nn.Module):
            def forward(self, x):
                return x

        model = Net()
        x = torch.rand((1, 64, 224, 224), dtype=torch.float)
        self.run_test(model, (x,))

    def test_nn_sequential_inherit(self):
        class ConvBNReLU(nn.Sequential):
            def __init__(self):
                super().__init__(
                    nn.Conv2d(3, 3, 1, 1, bias=False),
                    nn.BatchNorm2d(3),
                    nn.ReLU(inplace=False),
                )

        # @basic_unit
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv_bn_relu = ConvBNReLU()

            def forward(self, x):
                return self.conv_bn_relu(x)

        model = Net()
        x = torch.rand((1, 3, 224, 224), dtype=torch.float)
        self.run_test(model, (x,))


test_models = TestModels()
test_models.test_nn_sequential_inherit()


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import nni.retiarii.nn.pytorch


NameError: name '_model' is not defined