Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions graph_net/sample_pass/agent_unittest_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
{%- endif -%}
{{"\n"}}
import torch
from torch import device
from torch import device, inf


{% macro get_input_tensor_instance(tensor_meta, device) -%}
Expand All @@ -36,7 +36,7 @@
{%- if data is not none -%}
torch.tensor({{data}}, dtype={{dtype}}).reshape({{shape}}).to(device='{{device}}')
{%- elif dtype == "torch.bool" -%}
torch.rand({{shape}}, device={{device}}) > 0.5
torch.rand({{shape}}, device='{{device}}') > 0.5
{%- elif dtype in ["torch.int8", "torch.int16", "torch.int32", "torch.int64"] -%}
torch.randint({{min_val}}, {{max_val}} + 1, size={{shape}}, dtype={{dtype}}).to(device='{{device}}')
{%- elif dtype in ["torch.float16", "torch.bfloat16", "torch.float32", "torch.float64"] -%}
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_main(self):
{%- if data is not none -%}
paddle.to_tensor({{data}}, dtype='{{dtype}}', shape={{shape}}).to(device='{{device}}')
{%- elif dtype == "bool" -%}
paddle.randint(low=0, high=2, shape={{shape}}, dtype='{{dtype}}')
paddle.randint(low=0, high=2, shape={{shape}}, dtype='{{dtype}}').to(device='{{device}}')
{%- elif dtype in ["int8", "int16", "int32", "int64"] -%}
paddle.randint(low={{min_val}}, high={{max_val}} + 1, shape={{shape}}, dtype='{{dtype}}').to(device='{{device}}')
{%- elif dtype in ["float16", "bfloat16", "float32", "float64"] -%}
Expand Down Expand Up @@ -456,13 +456,8 @@ def __call__(self, rel_model_path: str):
self.resumable_handle_sample(rel_model_path)

def sample_handled(self, rel_model_path: str) -> bool:
dst_model_path = Path(self.config["output_dir"]) / rel_model_path
if not dst_model_path.exists():
return False
output_name = self._get_output_name(rel_model_path)
num_model_py_files = len(list(dst_model_path.rglob(output_name)))
assert num_model_py_files <= 1
return num_model_py_files == 1
return self.naive_sample_handled(rel_model_path, search_file_name=output_name)

def _get_output_name(self, rel_model_path: str):
return f"{Path(rel_model_path).name}_test.py"
Expand Down
4 changes: 2 additions & 2 deletions graph_net/torch/graph_variable_renamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def __call__(self, rel_model_path):
src_model_path, temp_model_path, rename_map
)
self._update_input_meta_py_file(src_model_path, temp_model_path, rename_map)
# print("Try to run renamed model...")
# self._try_run(temp_model_path)
self._try_run(temp_model_path)
shutil.copytree(temp_model_path, dst_model_path)

def _try_run(self, model_path):
print(f"[GraphVariableRenamer] Try to run {model_path}")
assert self.model_runnable_predicator(
model_path
), f"{model_path} is not a runnable model"
Expand Down