Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions graph_net/torch/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import json
import shutil
import glob
from graph_net.torch import utils
from graph_net.torch.fx_graph_serialize_util import serialize_graph_module_to_str

Expand Down Expand Up @@ -72,12 +73,28 @@ def move_files(self, source_dir, target_dir):
target_path = os.path.join(target_dir, item)
shutil.move(source_path, target_path)

def _cleanup_stale_data(self, model_path):
for stale_dir in glob.glob(os.path.join(model_path, "subgraph_*")):
shutil.rmtree(stale_dir)
for stale_file_name in (
"model.py",
"graph_net.json",
"input_meta.py",
"weight_meta.py",
"input_tensor_constraints.py",
"graph_hash.txt",
):
stale_file = os.path.join(model_path, stale_file_name)
if os.path.isfile(stale_file):
os.remove(stale_file)

def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
# 1. Get model path
model_path = os.path.join(self.workspace_path, self.name)
os.makedirs(model_path, exist_ok=True)

if self.subgraph_counter == 0:
self._cleanup_stale_data(model_path)
subgraph_path = model_path
else:
if self.subgraph_counter == 1:
Expand Down Expand Up @@ -124,17 +141,30 @@ def try_rename_placeholder(node):
gm.graph.erase_node(node)

assert input_idx == len(sample_inputs)

# 3. Serialize graph
base_code = serialize_graph_module_to_str(gm)

if self.mut_graph_codes is not None:
assert isinstance(self.mut_graph_codes, list)
self.mut_graph_codes.append(serialize_graph_module_to_str(gm))
# 3. Generate and save model code
base_code = serialize_graph_module_to_str(gm)
# gm.graph.print_tabular()
self.mut_graph_codes.append(base_code)

# 4. Save tensor metadata
converted = utils.convert_state_and_inputs(params, [])
utils.save_converted_to_text(converted, file_path=subgraph_path)
utils.save_constraints_text(
converted,
file_path=os.path.join(subgraph_path, "input_tensor_constraints.py"),
)

# 5. Save model code
write_code = utils.apply_templates(base_code)
with open(os.path.join(subgraph_path, "model.py"), "w") as fp:
fp.write(write_code)

# 4. Save metadata
# 6. Save metadata LAST — graph_net.json serves as the
# completion marker: if it exists, all other files are guaranteed
# to be fully written.
metadata = {
"framework": "torch",
"num_devices_required": 1,
Expand All @@ -145,15 +175,6 @@ def try_rename_placeholder(node):
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
json.dump(metadata, f, indent=4)

# 5. Save tensor metadata
# Adapt to different input structures (e.g., single tensor vs. dict/tuple of tensors)
converted = utils.convert_state_and_inputs(params, [])
utils.save_converted_to_text(converted, file_path=subgraph_path)
utils.save_constraints_text(
converted,
file_path=os.path.join(subgraph_path, "input_tensor_constraints.py"),
)

print(
f"Graph and tensors for '{self.name}' extracted successfully to: {model_path}"
)
Expand Down
Loading