In [7]:
!pip install onnx


Collecting onnx
  Downloading onnx-1.19.1-cp311-cp311-win_amd64.whl.metadata (7.2 kB)
Downloading onnx-1.19.1-cp311-cp311-win_amd64.whl (16.5 MB)
   ---------------------------------------- 0.0/16.5 MB ? eta -:--:--
   -- ------------------------------------- 1.0/16.5 MB 12.5 MB/s eta 0:00:02
   ------- -------------------------------- 3.1/16.5 MB 10.8 MB/s eta 0:00:02
   -------------- ------------------------- 5.8/16.5 MB 11.4 MB/s eta 0:00:01
   -------------------- ------------------- 8.4/16.5 MB 11.5 MB/s eta 0:00:01
   -------------------------- ------------- 10.7/16.5 MB 11.6 MB/s eta 0:00:01
   ------------------------------- -------- 13.1/16.5 MB 11.3 MB/s eta 0:00:01
   ------------------------------------ --- 14.9/16.5 MB 11.0 MB/s eta 0:00:01
   ---------------------------------------- 16.5/16.5 MB 10.6 MB/s  0:00:01
Installing collected packages: onnx
Successfully installed onnx-1.19.1


In [None]:
# ===== Pix2Pix G → ONNX (and optional TorchScript) from checkpoints folder =====
import re, sys
from pathlib import Path
import torch

# ---- Default paths (adjust to your setup if needed)
REPO_DIR   = Path("pix2pix")  # source code folder (contains models/, options/, util/, etc.)
EXP_NAME   = "wafer_pix2pix_AtoB_256_out1"
CKPT_DIR   = Path("checkpoints") / EXP_NAME
G_WEIGHTS  = CKPT_DIR / "latest_net_G.pth"
TRAIN_OPT  = CKPT_DIR / "train_opt.txt"
OUT_DIR    = CKPT_DIR  # output folder for exported files

# ---- Load network modules from the project
assert REPO_DIR.exists(), f"Source code folder not found: {REPO_DIR.resolve()}"
sys.path.insert(0, str(REPO_DIR.resolve()))
from models import networks  # from pix2pix repo

assert G_WEIGHTS.exists(), f"Generator weights not found: {G_WEIGHTS}"
assert TRAIN_OPT.exists(), f"train_opt.txt not found: {TRAIN_OPT}"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ---- Parse train_opt.txt to recover critical configurations
def parse_train_opt(path: Path):
    cfg = {}
    for line in path.read_text(encoding="utf-8", errors="ignore").splitlines():
        line = line.strip()
        if not line or line.startswith("-"):
            continue
        if ":" not in line:
            continue
        k, v = [p.strip() for p in line.split(":", 1)]
        # Smart type conversion (int/float/bool) when possible, and cleans up spaces/tabs
        v_clean = re.sub(r"\s+", " ", v).strip()
        if re.fullmatch(r"-?\d+", v_clean):
            val = int(v_clean)
        elif re.fullmatch(r"-?\d+\.\d*", v_clean):
            val = float(v_clean)
        elif v_clean.lower() in ("true", "false"):
            val = (v_clean.lower() == "true")
        else:
            val = v_clean
        cfg[k] = val
    return cfg

opt = parse_train_opt(TRAIN_OPT)

def as_int(d, key, default):
    try:
        return int(str(d.get(key, default)).strip())
    except Exception:
        return default

input_nc   = as_int(opt, "input_nc", 3)
output_nc  = as_int(opt, "output_nc", 1)
ngf        = as_int(opt, "ngf", 64)
netG       = str(opt.get("netG", "unet_256")).strip()
norm       = str(opt.get("norm", "batch")).strip()   # 'batch' or 'instance'
no_dropout = bool(opt.get("no_dropout", False))
use_dropout = not no_dropout

print("Recovered train options for G:", {
    "input_nc": input_nc, "output_nc": output_nc, "ngf": ngf,
    "netG": netG, "norm": norm, "use_dropout": use_dropout
})

# ---- Build generator exactly as in training (note: without gpu_ids)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG_model = networks.define_G(
    input_nc=input_nc,
    output_nc=output_nc,
    ngf=ngf,
    netG=netG,
    norm=norm,
    use_dropout=use_dropout,
    init_type='normal',
    init_gain=0.02,
).to(device)
netG_model.eval()

# ---- Load weights (clean common prefixes)
state = torch.load(G_WEIGHTS, map_location=device)
if isinstance(state, dict):
    # Some checkpoints wrap inside 'state_dict'
    if "state_dict" in state and isinstance(state["state_dict"], dict):
        state = state["state_dict"]
# Remove 'module.' if trained with DataParallel
if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
    state = {k.replace("module.", "", 1): v for k, v in state.items()}

try:
    netG_model.load_state_dict(state, strict=True)
except Exception as e:
    print("Strict load failed, trying non-strict:", e)
    netG_model.load_state_dict(state, strict=False)

# ---- Dummy input according to model input (H,W typically 256; can be changed)
dummy = torch.randn(1, input_nc, 256, 256, device=device)

# ---- Export to ONNX (Netron supports it perfectly)
onnx_path = OUT_DIR / f"{EXP_NAME}_G.onnx"
torch.onnx.export(
    netG_model,
    dummy,
    onnx_path.as_posix(),
    input_names=["segmentation_A"],
    output_names=["predicted_SEM_B"],
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    dynamic_axes={
        "segmentation_A": {0: "batch", 2: "height", 3: "width"},
        "predicted_SEM_B": {0: "batch", 2: "height", 3: "width"},
    },
)
print(f"✅ ONNX export completed: {onnx_path}")

# ---- (Optional) Export TorchScript — useful for deployment/debugging
ts_path = OUT_DIR / f"{EXP_NAME}_G.torchscript.pt"
traced = torch.jit.trace(netG_model, dummy)
traced.save(ts_path.as_posix())
print(f"✅ TorchScript export completed: {ts_path}")

print("\nTo view in Netron: drag", onnx_path.name, "into https://netron.app (or use File → Open).")


Recovered train options for G: {'input_nc': 3, 'output_nc': 1, 'ngf': 64, 'netG': 'unet_256', 'norm': 'batch', 'use_dropout': True}
✅ ייצוא ONNX הושלם: checkpoints\wafer_pix2pix_AtoB_256_out1\wafer_pix2pix_AtoB_256_out1_G.onnx
✅ ייצוא TorchScript הושלם: checkpoints\wafer_pix2pix_AtoB_256_out1\wafer_pix2pix_AtoB_256_out1_G.torchscript.pt

להצגה ב-Netron: גררי את wafer_pix2pix_AtoB_256_out1_G.onnx לתוך https://netron.app (או File → Open).
