This notebook was executed in the Kaggle environment  
using `/kaggle/input` datasets and saving outputs to `/kaggle/working`.

In [None]:
# ============================================================
# Pix2Pix on Kaggle with Auto-Resume via Private Kaggle Dataset (FIXED)
# - Handles initial 404 by polling for dataset readiness
# - Removes deprecated --display_id flag
# ============================================================
import os, sys, json, time, shlex, subprocess, threading, zipfile, glob, re
from pathlib import Path

# ---------- USER CONFIG ----------
GIT_URL          = "https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git"
REPO             = Path("/kaggle/working/pix2pix")
DATAROOT         = Path("/kaggle/input/processed-images")
CKPT_ROOT        = Path("/kaggle/working/checkpoints")
EXPERIMENT       = "wafer_pix2pix_AtoB_256_out1"
EXP_DIR          = CKPT_ROOT / EXPERIMENT

# Training
LOAD_SIZE        = 286
CROP_SIZE        = 256
INPUT_NC         = 3
OUTPUT_NC        = 1
BATCH_SIZE       = 4
LR               = 0.0002
BETA1            = 0.5
N_EPOCHS         = 100
N_EPOCHS_DECAY   = 50
SAVE_EPOCH_FREQ  = 5
SAVE_LATEST_FREQ = 2000
DIRECTION        = "AtoB"

# Kaggle Dataset autosync
KDS_SLUG         = "wafer-pix2pix-checkpoints"
KDS_POLL_MIN     = 8
KDS_UPLOAD_DIR   = Path("/kaggle/working/_kds_upload")
KDS_DL_DIR       = Path("/kaggle/working/_kds_download")

def run(cmd, check=True, capture=False):
    print("[$]", cmd if isinstance(cmd,str) else " ".join(shlex.quote(c) for c in cmd))
    if capture:
        return subprocess.run(cmd, check=check, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True).stdout
    else:
        return subprocess.run(cmd, check=check)

# --- Data sanity ---
assert DATAROOT.exists(), f"Dataset folder not found: {DATAROOT}"
for split in ["train","val","test"]:
    p = DATAROOT/split
    assert p.exists(), f"Missing subfolder {split}"
    assert any(p.glob('**/*.*')), f"No images found in {p}"

# --- Repo ---
if not REPO.exists():
    run(["git","clone","--depth","1",GIT_URL,str(REPO)])

# --- Kaggle CLI present ---
out = run([sys.executable,"-m","pip","show","kaggle"], check=False, capture=True)
if "Version:" not in out:
    run([sys.executable,"-m","pip","install","-q","kaggle"])

# --- Token ---
KAGGLE_DIR = Path.home()/".kaggle"
KAGGLE_JSON = KAGGLE_DIR/"kaggle.json"
if KAGGLE_JSON.exists():
    os.chmod(KAGGLE_JSON, 0o600)

def get_kaggle_username():
    if KAGGLE_JSON.exists():
        try:
            cfg = json.loads(KAGGLE_JSON.read_text())
            if cfg.get("username"):
                return cfg["username"]
        except Exception:
            pass
    return os.environ.get("KAGGLE_USERNAME")

KAGGLE_USER = get_kaggle_username()
KDS_ID = f"{KAGGLE_USER}/{KDS_SLUG}" if KAGGLE_USER else None

def ensure_kds_exists():
    if not (KAGGLE_JSON.exists() and KAGGLE_USER):
        return False
    exists = False
    try:
        info = run(["kaggle","datasets","list","-u",KAGGLE_USER,"-p","50"], check=False, capture=True)
        exists = any(f"{KAGGLE_USER}/{KDS_SLUG}" in line for line in info.splitlines())
    except Exception:
        pass
    if exists:
        return True
    # create minimal private dataset
    meta_dir = KDS_UPLOAD_DIR
    meta_dir.mkdir(parents=True, exist_ok=True)
    (meta_dir/"README.md").write_text("# wafer pix2pix checkpoints\n")
    metadata = {
        "title": "Wafer Pix2Pix Checkpoints",
        "id": f"{KAGGLE_USER}/{KDS_SLUG}",
        "licenses": [{"name":"CC0-1.0"}],
        "isPrivate": True
    }
    (meta_dir/"dataset-metadata.json").write_text(json.dumps(metadata))
    run(["kaggle","datasets","create","-p",str(meta_dir),"-q"], check=False)
    return True

def kds_version_upload(note="auto"):
    if not (KAGGLE_JSON.exists() and KAGGLE_USER):
        return False
    KDS_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
    # clean temp
    for p in KDS_UPLOAD_DIR.glob("*"):
        if p.is_file(): p.unlink()
    # pack EXP_DIR
    zip_path = KDS_UPLOAD_DIR/f"{EXPERIMENT}_checkpoints.zip"
    with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
        for f in EXP_DIR.rglob("*"):
            if f.is_file():
                zf.write(f, arcname=str(f.relative_to(CKPT_ROOT)))
    # ensure metadata exists
    meta_json = KDS_UPLOAD_DIR/"dataset-metadata.json"
    if not meta_json.exists():
        meta = {
            "title": "Wafer Pix2Pix Checkpoints",
            "id": f"{KAGGLE_USER}/{KDS_SLUG}",
            "licenses": [{"name":"CC0-1.0"}],
            "isPrivate": True
        }
        meta_json.write_text(json.dumps(meta))
    try:
        run(["kaggle","datasets","version","-p",str(KDS_UPLOAD_DIR),"-m",note,"-r","zip","-q"], check=True)
        print("✅ Upload to Kaggle Dataset completed.")
        return True
    except Exception as e:
        print("⚠️ Failed to upload to Kaggle Dataset:", e)
        return False

def kds_has_version():
    """returns True if dataset exists AND has at least one downloadable file/version"""
    if not (KAGGLE_JSON.exists() and KAGGLE_USER):
        return False
    try:
        status = run(["kaggle","datasets","status",KDS_ID,"-v"], check=False, capture=True)
        # If there's at least one file listed – it's considered ready
        return ("files:" in status.lower() and "error" not in status.lower()) or ("ready" in status.lower())
    except Exception:
        return False

def kds_download_latest(max_tries=6, sleep_sec=20):
    """try to download latest version; if 404 on first run, skip gracefully"""
    if not (KAGGLE_JSON.exists() and KAGGLE_USER):
        return False
    if not kds_has_version():
        print("ℹ️ Dataset exists but likely has no version yet. Continuing without download (first run).")
        return False
    KDS_DL_DIR.mkdir(parents=True, exist_ok=True)
    for _ in range(max_tries):
        try:
            run(["kaggle","datasets","download","-d",KDS_ID,"-p",str(KDS_DL_DIR),"-q"], check=True)
            # unzip all
            for z in KDS_DL_DIR.glob("*.zip"):
                run(["unzip","-o",str(z),"-d",str(KDS_DL_DIR)], check=True)
            # restore into CKPT_ROOT
            CKPT_ROOT.mkdir(parents=True, exist_ok=True)
            for f in KDS_DL_DIR.rglob("*"):
                if f.is_file():
                    dst = CKPT_ROOT / f.relative_to(KDS_DL_DIR)
                    dst.parent.mkdir(parents=True, exist_ok=True)
                    dst.write_bytes(f.read_bytes())
            print("✅ Downloaded checkpoints from latest version.")
            return True
        except Exception as e:
            print("…Waiting for version to be published (or skip if first run). Error:", str(e)[:120])
            time.sleep(sleep_sec)
    print("⚠️ No version downloaded (probably first run). Continuing without resume.")
    return False

# ---------- ensure dataset & try download ----------
EXP_DIR.mkdir(parents=True, exist_ok=True)
if ensure_kds_exists():
    kds_download_latest()

# ---------- resume flags ----------
continue_train = (EXP_DIR/"latest_net_G.pth").exists()
epoch_count_arg = None
iter_txt = EXP_DIR/"iter.txt"
if iter_txt.exists():
    try:
        t = iter_txt.read_text().strip().split()
        if t and t[0].isdigit():
            epoch_count_arg = str(max(1, int(t[0]) + 1))
    except Exception:
        pass

# ---------- build train cmd (NO --display_id) ----------
train_py = str(REPO/"train.py")
cmd = [
    sys.executable, train_py,
    "--dataroot", str(DATAROOT),
    "--name", EXPERIMENT,
    "--model", "pix2pix",
    "--dataset_mode", "aligned",
    "--direction", DIRECTION,
    "--checkpoints_dir", str(CKPT_ROOT),
    "--preprocess", "resize_and_crop",
    "--load_size", str(LOAD_SIZE),
    "--crop_size", str(CROP_SIZE),
    "--input_nc", str(INPUT_NC),
    "--output_nc", str(OUTPUT_NC),
    "--batch_size", str(BATCH_SIZE),
    "--n_epochs", str(N_EPOCHS),
    "--n_epochs_decay", str(N_EPOCHS_DECAY),
    "--lr", str(LR),
    "--beta1", str(BETA1),
    "--gan_mode", "vanilla",
    "--lambda_L1", "100",
    "--save_epoch_freq", str(SAVE_EPOCH_FREQ),
    "--save_latest_freq", str(SAVE_LATEST_FREQ),
    "--print_freq", "100",
]
if continue_train:
    cmd += ["--continue_train"]
    if epoch_count_arg:
        cmd += ["--epoch_count", epoch_count_arg]

print(">>> TRAIN CMD:\n", " ".join(shlex.quote(c) for c in cmd))

# ---------- autosync thread ----------
stop_flag = False
def periodic_kds_uploader():
    if not (KAGGLE_JSON.exists() and KAGGLE_USER):
        return
    last_upload_t = 0
    last_sig = ""
    while not stop_flag:
        try:
            parts = []
            for f in EXP_DIR.rglob("*"):
                if f.is_file():
                    st = f.stat()
                    parts.append(f"{f.relative_to(EXP_DIR)}:{st.st_size}:{int(st.st_mtime)}")
            sig = str(hash("|".join(sorted(parts))))
            now = time.time()
            if (sig != last_sig) and (now - last_upload_t > KDS_POLL_MIN*60):
                note = f"auto sync at {time.strftime('%Y-%m-%d %H:%M')}"
                print(">>> Checkpoints change detected → uploading new dataset version ...")
                if kds_version_upload(note=note):
                    last_sig = sig
                    last_upload_t = now
        except Exception as e:
            print("Uploader thread error:", e)
        time.sleep(30)

uploader_th = threading.Thread(target=periodic_kds_uploader, daemon=True)
uploader_th.start()

# ---------- run training ----------
ret = run(cmd, check=False)
print(">>> train.py exited with code:", ret.returncode)

# final sync
if KAGGLE_JSON.exists() and KAGGLE_USER:
    kds_version_upload(note=f"final sync exit={ret.returncode}")

stop_flag = True
time.sleep(1)
print("DONE.")
