Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions graph_net/test/typical_sequence_decomposer_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")

MODEL1="$GRAPH_NET_ROOT/samples/torchvision/resnet18"
MODEL2="$GRAPH_NET_ROOT/samples/torchvision/resnet34"
MODEL_LIST_FILE=$(mktemp)
echo "$MODEL1" > "$MODEL_LIST_FILE"
echo "$MODEL2" >> "$MODEL_LIST_FILE"

python3 -m graph_net.torch.typical_sequence_split_points \
--model-list "$MODEL_LIST_FILE" \
--device "cuda" \
--window-size 10 \
--output-json "$GRAPH_NET_ROOT/split_results.json"

rm -f "$MODEL_LIST_FILE"


MODEL_PATH_IN_SAMPLES=/torchvision/resnet18
MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES")

decomposer_config_json_str=$(cat <<EOF
{
"split_results_path": "$GRAPH_NET_ROOT/split_results.json",
"workspace_path": "$GRAPH_NET_ROOT/decompose_workspace",
"chain_style": "True"
}
EOF
)
DECOMPOSER_CONFIG=$(echo $decomposer_config_json_str | base64 -w 0)

python3 -m graph_net.torch.test_compiler --model-path $GRAPH_NET_ROOT/samples/$MODEL_PATH_IN_SAMPLES --compiler range_decomposer --device cuda --config=$DECOMPOSER_CONFIG


DECOMPOSE_PATH=$GRAPH_NET_ROOT/decompose_workspace
cp -r "$GRAPH_NET_ROOT/samples/$MODEL_PATH_IN_SAMPLES" "$DECOMPOSE_PATH/"

python3 -m graph_net.torch.test_compiler \
--model-path $DECOMPOSE_PATH/$MODEL_NAME \
--compiler range_decomposer_validator \
--device cuda > "$DECOMPOSE_PATH/log.log" 2>&1

python3 -m graph_net.plot_ESt \
--benchmark-path $DECOMPOSE_PATH/log.log \
--output-dir $DECOMPOSE_PATH \
94 changes: 94 additions & 0 deletions graph_net/torch/backend/range_decomposer_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import base64
import json
import subprocess
import sys
from pathlib import Path
from typing import Any, Dict

import torch
import graph_net


def convert_to_dict(config_str):
if config_str is None:
return {}
config_str = base64.b64decode(config_str).decode("utf-8")
config = json.loads(config_str)
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
return config


def encode_config(config: Dict[str, Any]) -> str:
json_str = json.dumps(config)
return base64.b64encode(json_str.encode("utf-8")).decode("utf-8")


def load_json(file_path):
with open(file_path, "r", encoding="utf-8") as file:
data_dict = json.load(file)
return data_dict


class RangeDecomposerBackend:
def __init__(self):
self.graph_net_root = Path(graph_net.__file__).parent

def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
config = convert_to_dict(self.config)
workspace_path = Path(config["workspace_path"])
chain_style = config["chain_style"]

model_file_path = Path(model.__class__.__graph_net_file_path__)
model_name = model_file_path.parent.name

model_info = load_json(config["split_results_path"])[model_name]
model_path = model_info["path"]
split_points = model_info["split_points"]

model_output_dir = workspace_path / f"{model_name}_decomposed"
model_output_dir.mkdir(parents=True, exist_ok=True)

config_dict = {
"decorator_path": str(self.graph_net_root / "torch/extractor.py"),
"decorator_config": {
"name": model_name,
"custom_extractor_path": str(
self.graph_net_root / "torch/naive_graph_decomposer.py"
),
"custom_extractor_config": {
"output_dir": str(model_output_dir),
"split_positions": split_points,
"group_head_and_tail": True,
"filter_path": str(
self.graph_net_root / "torch/naive_subgraph_filter.py"
),
"filter_config": {},
"chain_style": chain_style,
},
},
}

encoded_config = encode_config(config_dict)

cmd = [
sys.executable,
"-m",
"graph_net.torch.run_model",
"--model-path",
model_path,
"--decorator-config",
encoded_config,
]

try:
subprocess.run(cmd, check=True)
print(f"[Success] Saved to {model_output_dir}")
except subprocess.CalledProcessError as e:
print(f"[Error] Process failed: {e}")
except Exception as e:
print(f"[Error] Unexpected: {e}")
return model

def synchronize(self):
if torch.cuda.is_available():
torch.cuda.synchronize()
2 changes: 2 additions & 0 deletions graph_net/torch/rp_expr/longest_rp_expr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __call__(self, primitive_id_lists: t.List[t.List[PrimitiveId]]):
token_id2primitive_id
)
]
if not cur_primitive_id_lists:
continue
cur_lets_list_rp_expr, cur_token_id2primitive_id = rp_expr_parser(
cur_primitive_id_lists
)
Expand Down
2 changes: 1 addition & 1 deletion graph_net/torch/rp_expr/rp_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def get_range(size):
segments = [
token_tensor[start:end]
for consecutive_tensor in consecutive_tensors
for start, end in [get_range(int(consecutive_tensor.size(0)))]
for start, end in [get_range(len(consecutive_tensor))]
]

return segments
Expand Down
14 changes: 13 additions & 1 deletion graph_net/torch/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from graph_net.torch.backend.blade_disc_backend import BladeDISCBackend
from graph_net.torch.backend.nope_backend import NopeBackend
from graph_net.torch.backend.unstable_to_stable_backend import UnstableToStableBackend
from graph_net.torch.backend.range_decomposer_backend import RangeDecomposerBackend
from graph_net.torch.backend.range_decomposer_validator_backend import (
RangeDecomposerValidatorBackend,
)
Expand All @@ -39,6 +40,7 @@
"bladedisc": BladeDISCBackend(),
"nope": NopeBackend(),
"unstable_to_stable": UnstableToStableBackend(),
"range_decomposer": RangeDecomposerBackend(),
"range_decomposer_validator": RangeDecomposerValidatorBackend(),
}

Expand Down Expand Up @@ -94,7 +96,10 @@ def load_class_from_file(

def get_compiler_backend(args) -> GraphCompilerBackend:
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
return registry_backend[args.compiler]
backend = registry_backend[args.compiler]
if args.config is not None:
backend.config = args.config
return backend


def get_model(args):
Expand Down Expand Up @@ -447,5 +452,12 @@ def main(args):
default=None,
help="Path to samples list, each line contains a sample path",
)
parser.add_argument(
"--config",
type=str,
required=False,
default=None,
help="base64 encode configuration json.",
)
args = parser.parse_args()
main(args=args)
Loading