In [None]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

# figure out the correct path
machop_path = Path(".").resolve().parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity, get_logger

from chop.passes.graph.analysis import (
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
)
from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")

logger = get_logger("chop")
logger.setLevel(logging.INFO)

batch_size = 8
model_name = "jsc-tiny"
dataset_name = "jsc"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

dummy_in = {"x": next(iter(data_module.train_dataloader()))[0]}

  from .autonotebook import tqdm as notebook_tqdm
[32mINFO    [0m [34mSet logging level to info[0m


In [3]:
from torch import nn
from chop.passes.graph.utils import get_parent_name

# define a new model
class JSC_Three_Linear_Layers(nn.Module):
    def __init__(self):
        super(JSC_Three_Linear_Layers, self).__init__()
        self.seq_blocks = nn.Sequential(
            nn.BatchNorm1d(16),  # 0
            nn.ReLU(),  # 1
            nn.Linear(16, 16),  # linear seq_2
            nn.ReLU(),  # 3
            nn.Linear(16, 16),  # linear seq_4
            nn.ReLU(),  # 5
            nn.Linear(16, 5),  # linear seq_6
            nn.ReLU(5),  # 7
        )

    def forward(self, x):
        return self.seq_blocks(x)


model = JSC_Three_Linear_Layers()

# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)
mg, _ = init_metadata_analysis_pass(mg, None)

In [9]:
pass_args = {
"by": "type",
"default": {"config": {"name": None}},
"linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
},}

pass_config = {
"by": "name",
"default": {"config": {"name": None}},
"seq_blocks_2": {
    "config": {
        "name": "output_only",
        "channel_multiplier": 2,
        }
    },
"seq_blocks_4": {
    "config": {
        "name": "output_only",
        "channel_multiplier": 2,
        }
    },
"seq_blocks_6": {
    "config": {
        "name": "input_only",
        }
    },
}

import copy
# build a search space
channel_multiplier_2 = [1,2,4]
channel_multiplier_4 = [1,2,4]
channel_multiplier_6 = [1,2,4]
search_spaces = []
for c_config in channel_multiplier_2:
    for b_config in channel_multiplier_4:
           pass_config['seq_blocks_2']['config']['channel_multiplier'] = c_config
           pass_config['seq_blocks_4']['config']['channel_multiplier'] = b_config
           pass_config['seq_blocks_6']['config']['channel_multiplier'] = b_config
        # dict.copy() and dict(dict) only perform shallow copies
        # in fact, only primitive data types in python are doing implicit copy when a = b happens
           search_spaces.append(copy.deepcopy(pass_config))
           print(pass_config)

{'by': 'name', 'default': {'config': {'name': None}}, 'seq_blocks_2': {'config': {'name': 'output_only', 'channel_multiplier': 1}}, 'seq_blocks_4': {'config': {'name': 'output_only', 'channel_multiplier': 1}}, 'seq_blocks_6': {'config': {'name': 'input_only', 'channel_multiplier': 1}}}
{'by': 'name', 'default': {'config': {'name': None}}, 'seq_blocks_2': {'config': {'name': 'output_only', 'channel_multiplier': 1}}, 'seq_blocks_4': {'config': {'name': 'output_only', 'channel_multiplier': 2}}, 'seq_blocks_6': {'config': {'name': 'input_only', 'channel_multiplier': 2}}}
{'by': 'name', 'default': {'config': {'name': None}}, 'seq_blocks_2': {'config': {'name': 'output_only', 'channel_multiplier': 1}}, 'seq_blocks_4': {'config': {'name': 'output_only', 'channel_multiplier': 4}}, 'seq_blocks_6': {'config': {'name': 'input_only', 'channel_multiplier': 4}}}
{'by': 'name', 'default': {'config': {'name': None}}, 'seq_blocks_2': {'config': {'name': 'output_only', 'channel_multiplier': 2}}, 'seq_bl

In [10]:
from chop.passes.graph import report_graph_analysis_pass

def instantiate_linear(in_features, out_features, bias):
    if bias is not None:
        bias = True
    return nn.Linear(
        in_features=in_features,
        out_features=out_features,
        bias=bias)

def redefine_linear_transform_pass(graph, pass_args=None):
    main_config = pass_args.pop('config')
    default = main_config.pop('default', None)
    if default is None:
        print(default)
        raise ValueError(f"default value must be provided.")
    i = 0
    last_multi = 1
    for node in graph.fx_graph.nodes:
        i += 1
        # if node name is not matched, it won't be tracked
        config = main_config.get(node.name, default)['config']
        name = config.get("name", None)
        if name is not None:
            ori_module = graph.modules[node.target]
            in_features = ori_module.in_features
            out_features = ori_module.out_features
            bias = ori_module.bias
            if name == "output_only":
                out_features = out_features * config["channel_multiplier"]
                in_features = in_features*last_multi
                last_multi = config["channel_multiplier"]
            elif name == "both":
                in_features = in_features * config["channel_multiplier"]
                out_features = out_features * config["channel_multiplier"]
            elif name == "input_only":
                in_features = in_features * last_multi
                

            new_module = instantiate_linear(in_features, out_features, bias)
            parent_name, name = get_parent_name(node.target)
            setattr(graph.modules[parent_name], name, new_module)
    _ = report_graph_analysis_pass(mg)
    return graph, {}

In [11]:
import torch
from torchmetrics.classification import MulticlassAccuracy

mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

metric = MulticlassAccuracy(num_classes=5)
num_batchs = 5
# This first loop is basically our search strategy,
# in this case, it is a simple brute force search

recorded_accs = []
for i, config in enumerate(search_spaces):
    mg, _ = redefine_linear_transform_pass(graph=mg, pass_args={"config": config})
    j = 0

    # this is the inner loop, where we also call it as a runner.
    acc_avg, loss_avg = 0, 0
    accs, losses = [], []
    for inputs in data_module.train_dataloader():
        xs, ys = inputs
        preds = mg.model(xs)
        loss = torch.nn.functional.cross_entropy(preds, ys)
        acc = metric(preds, ys)
        accs.append(acc)
        losses.append(loss)
        if j > num_batchs:
            break
        j += 1
    acc_avg = sum(accs) / len(accs)
    loss_avg = sum(losses) / len(losses)
    recorded_accs.append(acc_avg)
    print(acc_avg)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %seq_blocks_0 : [num_users=1] = call_module[target=seq_blocks.0](args = (%x,), kwargs = {})
    %seq_blocks_1 : [num_users=1] = call_module[target=seq_blocks.1](args = (%seq_blocks_0,), kwargs = {})
    %seq_blocks_2 : [num_users=1] = call_module[target=seq_blocks.2](args = (%seq_blocks_1,), kwargs = {})
    %seq_blocks_3 : [num_users=1] = call_module[target=seq_blocks.3](args = (%seq_blocks_2,), kwargs = {})
    %seq_blocks_4 : [num_users=1] = call_module[target=seq_blocks.4](args = (%seq_blocks_3,), kwargs = {})
    %seq_blocks_5 : [num_users=1] = call_module[target=seq_blocks.5](args = (%seq_blocks_4,), kwargs = {})
    %seq_blocks_6 : [num_users=1] = call_module[target=seq_blocks.6](args = (%seq_blocks_5,), kwargs = {})
    %seq_blocks_7 : [num_users=1] = call_module[target=seq_blocks.7](args = (%seq_blocks_6,), kwargs = {})
    return seq_blocks_7Network overview:
{'placeholder': 1, 'get_attr': 0, 'call_function': 0, 'cal

: 

In [11]:
# pass_config = {
# "by": "name",
# "default": {"config": {"name": None}},

# "seq_blocks_2": {
#     "config": {
#         "name": "output_only",
#         "channel_multiplier": 1,
#         }
#     },
# "seq_blocks_4": {
#     "config": {
#         "name": "output_only",
#         "channel_multiplier": 1,
#         }
#     },
# "seq_blocks_6": {
#     "config": {
#         "name": "input_only",
#         }
#     },
# }
# this performs the architecture transformation based on the config
mg, _ = redefine_linear_transform_pass(
    graph=mg, pass_args={"config": pass_config})

{'by': 'name', 'default': {'config': {'name': None}}, 'seq_blocks_2': {'config': {'name': 'output_only', 'channel_multiplier': 8}}, 'seq_blocks_4': {'config': {'name': 'output_only', 'channel_multiplier': 8}}, 'seq_blocks_6': {'config': {'name': 'input_only', 'channel_multiplier': 8}}}
{'config': {'name': None}}
graph():
    %x : [num_users=1] = placeholder[target=x]
    %seq_blocks_0 : [num_users=1] = call_module[target=seq_blocks.0](args = (%x,), kwargs = {})
    %seq_blocks_1 : [num_users=1] = call_module[target=seq_blocks.1](args = (%seq_blocks_0,), kwargs = {})
    %seq_blocks_2 : [num_users=1] = call_module[target=seq_blocks.2](args = (%seq_blocks_1,), kwargs = {})
    %seq_blocks_3 : [num_users=1] = call_module[target=seq_blocks.3](args = (%seq_blocks_2,), kwargs = {})
    %seq_blocks_4 : [num_users=1] = call_module[target=seq_blocks.4](args = (%seq_blocks_3,), kwargs = {})
    %seq_blocks_5 : [num_users=1] = call_module[target=seq_blocks.5](args = (%seq_blocks_4,), kwargs = {})

In [40]:
# report graph is an analysis pass that shows you the detailed information in the graph
from chop.passes.graph import report_graph_analysis_pass
_ = report_graph_analysis_pass(mg)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %seq_blocks_0 : [num_users=1] = call_module[target=seq_blocks.0](args = (%x,), kwargs = {})
    %seq_blocks_1 : [num_users=1] = call_module[target=seq_blocks.1](args = (%seq_blocks_0,), kwargs = {})
    %seq_blocks_2 : [num_users=1] = call_module[target=seq_blocks.2](args = (%seq_blocks_1,), kwargs = {})
    %seq_blocks_3 : [num_users=1] = call_module[target=seq_blocks.3](args = (%seq_blocks_2,), kwargs = {})
    %seq_blocks_4 : [num_users=1] = call_module[target=seq_blocks.4](args = (%seq_blocks_3,), kwargs = {})
    %seq_blocks_5 : [num_users=1] = call_module[target=seq_blocks.5](args = (%seq_blocks_4,), kwargs = {})
    %seq_blocks_6 : [num_users=1] = call_module[target=seq_blocks.6](args = (%seq_blocks_5,), kwargs = {})
    %seq_blocks_7 : [num_users=1] = call_module[target=seq_blocks.7](args = (%seq_blocks_6,), kwargs = {})
    return seq_blocks_7Network overview:
{'placeholder': 1, 'get_attr': 0, 'call_function': 0, 'cal

In [8]:
mg.model



GraphModule(
  (seq_blocks): Module(
    (0): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=8192, bias=True)
    (3): ReLU()
    (4): Linear(in_features=8192, out_features=8192, bias=True)
    (5): ReLU()
    (6): Linear(in_features=8192, out_features=5, bias=True)
    (7): ReLU(inplace=True)
  )
)