In [9]:
%load_ext autoreload
%autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
import os, sys

ROOT = os.path.join(
        os.getcwd(),
        "..",
        "..",
        "..",
        "..",
        "..",
    )

sys.path.append(ROOT + "/machop")

from chop.passes.graph.mase_graph import MaseGraph
from chop.models import get_model, get_model_info, get_tokenizer
from chop.tools.get_input import get_cf_args, get_dummy_input
from chop.dataset import get_dataset_info, MaseDataModule
import chop.passes as passes

from chop.passes.transforms import (
    emit_verilog_top_transform_pass,
    emit_mlir_hls_transform_pass,
    emit_internal_rtl_transform_pass,
    emit_bram_transform_pass,
    emit_verilog_tb_transform_pass,
    quantize_transform_pass,
)

Reload modules after changing code

In [11]:
wikitext_info = get_dataset_info("wikitext2")

opt = get_model(
    "facebook/opt-125m:patched",
    task="lm",
    dataset_info=wikitext_info,
    pretrained=True,
)

In [12]:
opt_tokenizer = get_tokenizer("facebook/opt-125m:patched")

data_module = MaseDataModule(
    name="wikitext2",
    batch_size=1,
    num_workers=os.cpu_count(),
    max_token_len=128,
    tokenizer=opt_tokenizer,
    load_from_cache_file=True,
    model_name="facebook/opt-125m@patched",
)
data_module.prepare_data()
data_module.setup()

In [15]:
%load_ext autoreload
%autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Generate graph, initialize metadata and draw diagram

In [20]:
# Generate graph and initialize metadata
model_info = get_model_info("facebook/opt-125m:patched")
cf_args = get_cf_args(model_info=model_info, task="lm", model=opt)
graph = MaseGraph(model=opt, cf_args=cf_args)
graph = passes.PASSES["init_metadata"](graph, pass_args=None)

# Generate dummy input
dummy_in = get_dummy_input(model_info, data_module=data_module, task="lm")
if len(graph.model.additional_inputs) > 0:
    dummy_in = dummy_in | graph.model.additional_inputs

# Add common metadata - infer input and output shape for each node
print(f"ADD COMMON METADATA")
graph = passes.PASSES["add_common_metadata"](graph, pass_args=dummy_in)
# graph = passes.PASSES["verify_common_metadata"](graph)

config_file = ROOT + "/machop/configs/tests/quantize/fixed.toml"
with open(config_file, "r") as f:
    import toml
    quan_args = toml.load(f)["passes"]["quantize"]
graph = quantize_transform_pass(graph, quan_args)

print(f"ADD HW METADATA")
graph = passes.PASSES["add_hardware_metadata"](graph, pass_args=None)

# Remove add/add_1/output as input nodes
# TO DO: temporary solution
print(f"before removal {graph.nodes_in}")
graph.nodes_in = graph.nodes_in[:-3]
print(f"after removal {graph.nodes_in}")

print(f"UPDATE ATTENTION ARGS")
# Rename attention node inputs etc
for node in graph.fx_graph.nodes:
    if ("self_attn" in node.name and "layer_norm" not in node.name):
        node.meta["mase"].parameters["common"]["args"]["bias_q"] = node.meta["mase"].parameters["common"]["args"].pop("q_proj.bias")
        node.meta["mase"].parameters["common"]["args"]["bias_k"] = node.meta["mase"].parameters["common"]["args"].pop("k_proj.bias")
        node.meta["mase"].parameters["common"]["args"]["bias_v"] = node.meta["mase"].parameters["common"]["args"].pop("v_proj.bias")
        
        node.meta["mase"].parameters["common"]["args"]["weight_q"] = node.meta["mase"].parameters["common"]["args"].pop("q_proj.weight")
        node.meta["mase"].parameters["common"]["args"]["weight_k"] = node.meta["mase"].parameters["common"]["args"].pop("k_proj.weight")
        node.meta["mase"].parameters["common"]["args"]["weight_v"] = node.meta["mase"].parameters["common"]["args"].pop("v_proj.weight")

        # Pop out attention_mask and output_attentions
        node.meta["mase"].parameters["common"]["args"].pop("data_in_2")
        node.meta["mase"].parameters["common"]["args"].pop("data_in_4")

        # Pop output projection weight/bias
        node.meta["mase"].parameters["common"]["args"].pop("out_proj.weight")
        node.meta["mase"].parameters["common"]["args"].pop("out_proj.bias")

ADD COMMON METADATA
ADD HW METADATA
before removal [model_decoder_layers_0_self_attn, model_decoder_layers_1_self_attn, model_decoder_layers_2_self_attn, model_decoder_layers_3_self_attn, model_decoder_layers_4_self_attn, model_decoder_layers_5_self_attn, model_decoder_layers_6_self_attn, model_decoder_layers_7_self_attn, model_decoder_layers_8_self_attn, model_decoder_layers_9_self_attn, model_decoder_layers_10_self_attn, model_decoder_layers_11_self_attn, add_1, add, output]
after removal [model_decoder_layers_0_self_attn, model_decoder_layers_1_self_attn, model_decoder_layers_2_self_attn, model_decoder_layers_3_self_attn, model_decoder_layers_4_self_attn, model_decoder_layers_5_self_attn, model_decoder_layers_6_self_attn, model_decoder_layers_7_self_attn, model_decoder_layers_8_self_attn, model_decoder_layers_9_self_attn, model_decoder_layers_10_self_attn, model_decoder_layers_11_self_attn]
UPDATE ATTENTION ARGS


In [17]:
# graph.fx_graph.print_tabular()

from chop.passes.analysis.report.report_node import report_node_type_analysis_pass
report_node_type_analysis_pass(graph)

<chop.passes.graph.mase_graph.MaseGraph at 0x7fe17bb41de0>

In [18]:
graph = emit_verilog_top_transform_pass(graph)
# graph = emit_bram_transform_pass(graph)
# graph = emit_internal_rtl_transform_pass(graph)