In [None]:
# Grant read and write permissions to the folder
sudo mkdir -p /data/images/stable-diffusion/uploads

sudo chown -R $USER:$USER /data/images/stable-diffusion/uploads

sudo chmod -R 755 /data/images/stable-diffusion/uploads


In [None]:
#!/usr/bin/env python3
import psutil
import json

def get_status():

    cpu = psutil.cpu_percent(interval=0.5)
    mem = psutil.virtual_memory().percent
    return {"cpu": cpu, "mem": mem}

if __name__ == "__main__":
    status = get_status()
    print(json.dumps(status))


In [None]:
# The code above is used to monitor the Memory and CPU usage of server
chmod +x monitor.py

In [None]:
# For instructions on how to download the model and launch Docker, please refer to the example code in All on One Device.ipynb.

In [None]:
# If you want the server to preload the model before the task starts, use the following code
nano img2img_server.py

In [None]:
#!/usr/bin/env python3
"""
img2img_server.py

"""

import os, time, glob, json, traceback
import torch, numpy as np
from omegaconf import OmegaConf
from PIL import Image
from einops import repeat
from torch import autocast, no_grad
from contextlib import nullcontext
from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler

UPLOAD_DIR         = "/data/images/stable-diffusion/uploads"
INIT_IMG           = os.path.join(UPLOAD_DIR, "init_img.png")
PROMPT_FILE        = os.path.join(UPLOAD_DIR, "prompt.txt")
LATENT_FILE        = os.path.join(UPLOAD_DIR, "init_latent.pt")
COND_FILE          = os.path.join(UPLOAD_DIR, "encoded_condition.pt")
META_FILE          = os.path.join(UPLOAD_DIR, "intermediate_meta.json")
DENOISED_FILE      = os.path.join(UPLOAD_DIR, "denoised_latent.pt")
FLAG_RUN_DENOISE   = os.path.join(UPLOAD_DIR, "run_denoise.flag")
FLAG_RUN_DECODE    = os.path.join(UPLOAD_DIR, "run_decode.flag")
FLAG_ENCODE_DONE   = os.path.join(UPLOAD_DIR, "encode_done.flag")
FLAG_DENOISE_DONE  = os.path.join(UPLOAD_DIR, "denoise_done.flag")

OUT_DIR            = "/data/images/stable-diffusion/test"
CONFIG_PATH        = "/opt/stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
CKPT_PATH          = "/data/models/stable-diffusion/sd-v1-5.ckpt"

DEFAULT_DDIM_STEPS = 50
DEFAULT_DDIM_ETA   = 0.0
DEFAULT_STRENGTH   = 0.30
DEFAULT_SCALE      = 5.0

AMP_DEVICE         = "cuda" if torch.cuda.is_available() else "cpu"

def log(m): print(m, flush=True)
def to_uint8_img(x):
    return (x.mul(255).add_(0.5).clamp(0,255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8))
def save_flag(path):
    open(path, 'w').close()
    try: os.chmod(path, 0o666)
    except: pass
    log(f"[flag] {os.path.basename(path)}")

def ensure_dirs():
    os.makedirs(UPLOAD_DIR, exist_ok=True)
    os.makedirs(OUT_DIR, exist_ok=True)
    for p in [UPLOAD_DIR, OUT_DIR]:
        try: os.chmod(p, 0o777)
        except: pass

def load_image(path, device, batch=1):
    img = Image.open(path).convert("RGB")
    w, h = img.size; w -= w%32; h -= h%32
    img = img.resize((w, h), resample=Image.LANCZOS)
    arr = np.array(img).astype(np.float32)/255.0
    x = torch.from_numpy(arr[None].transpose(0,3,1,2)).to(device) * 2 - 1
    return repeat(x, '1 ... -> b ...', b=batch), (w, h)

def to_device(x, device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, (list,tuple)): return [to_device(t, device) for t in x]
    if isinstance(x, dict): return {k:to_device(v, device) for k,v in x.items()}
    return x

log("[server] init…")
ensure_dirs()
seed_everything(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg    = OmegaConf.load(CONFIG_PATH)
pl_sd  = torch.load(CKPT_PATH, map_location="cpu")
sd     = pl_sd.get("state_dict", pl_sd)

model  = instantiate_from_config(cfg.model).to(device).eval()
missing, unexpected = model.load_state_dict(sd, strict=False)
sampler = DDIMSampler(model)
sf_local = float(getattr(model, "scale_factor", 0.18215))
log(f"[server] model on {device}, scale_factor(local)={sf_local}, missing={len(missing)}, unexpected={len(unexpected)}")

def do_encode():
    log("[encode] start")
    try:
        x, (w,h) = load_image(INIT_IMG, device, batch=1)
        with no_grad():
            z = model.get_first_stage_encoding(model.encode_first_stage(x))
        if torch.cuda.is_available(): torch.cuda.synchronize()
        torch.save(z.cpu(), LATENT_FILE); os.chmod(LATENT_FILE, 0o666)

        with open(PROMPT_FILE) as f:
            prompt = f.read().strip()
        with no_grad():
            c = model.get_learned_conditioning([prompt])
        if isinstance(c, torch.Tensor): torch.save(c.cpu(), COND_FILE)
        elif isinstance(c, (list,tuple)): torch.save([t.cpu() if isinstance(t, torch.Tensor) else t for t in c], COND_FILE)
        elif isinstance(c, dict): torch.save({k:(v.cpu() if isinstance(v, torch.Tensor) else v) for k,v in c.items()}, COND_FILE)
        else: torch.save(c, COND_FILE)
        os.chmod(COND_FILE, 0o666)

        meta = {
            "width": w, "height": h, "n_samples": 1,
            "steps": DEFAULT_DDIM_STEPS, "eta": DEFAULT_DDIM_ETA,
            "strength": DEFAULT_STRENGTH, "scale": DEFAULT_SCALE,
            "scale_factor": sf_local,
            "ckpt": os.path.basename(CKPT_PATH),
            "config": os.path.basename(CONFIG_PATH),
            "seed": 42
        }
        with open(META_FILE, "w") as f: json.dump(meta, f, indent=2)
        os.chmod(META_FILE, 0o666)

        with no_grad(): recon = model.decode_first_stage(z)
        recon = torch.clamp((recon+1)/2, 0, 1)
        Image.fromarray(to_uint8_img(recon[0])).save(os.path.join(OUT_DIR,"recon_encode.png"))

        save_flag(FLAG_ENCODE_DONE)
        log("[encode] done")
    except Exception:
        log("[encode] ERROR:"); traceback.print_exc()

def do_denoise():
    log("[denoise] start")
    try:
        z = torch.load(LATENT_FILE, map_location="cpu"); z = to_device(z, device)
        c = torch.load(COND_FILE, map_location="cpu"); c = to_device(c, device)
        with open(META_FILE) as f: meta = json.load(f)

        sf_send = float(meta.get("scale_factor", sf_local))
        if abs(sf_send - sf_local) > 1e-6:
            scale = sf_send / sf_local
            log(f"[denoise] rescale latent by {scale:.6f} (send={sf_send} local={sf_local})")
            z = z * scale

        bsz = z.shape[0] if isinstance(z, torch.Tensor) else 1
        if isinstance(c, torch.Tensor) and c.shape[0] != bsz:
            log(f"[denoise] repeat cond from {c.shape[0]} -> {bsz}")
            c = c.repeat(bsz, 1, 1)

        steps    = int(meta.get("steps", DEFAULT_DDIM_STEPS))
        eta      = float(meta.get("eta", DEFAULT_DDIM_ETA))
        strength = float(meta.get("strength", DEFAULT_STRENGTH))
        scale    = float(meta.get("scale", DEFAULT_SCALE))
        seed     = int(meta.get("seed", 42))
        torch.manual_seed(seed)

        sampler.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=False)
        t_enc = max(0, int(round(strength * steps)))
        t_vec = torch.tensor([t_enc]*bsz).to(device)

        uc = None if scale == 1.0 else to_device(model.get_learned_conditioning([""]*bsz), device)
        amp = autocast("cuda") if AMP_DEVICE=="cuda" else nullcontext()
        with no_grad(), amp, model.ema_scope():
            z_noisy = sampler.stochastic_encode(z, t_vec)
            out = sampler.decode(z_noisy, c, t_enc,
                                 unconditional_guidance_scale=scale,
                                 unconditional_conditioning=uc)
        if torch.cuda.is_available(): torch.cuda.synchronize()
        torch.save(out.cpu(), DENOISED_FILE); os.chmod(DENOISED_FILE, 0o666)

        with no_grad(): prev = model.decode_first_stage(out)
        prev = torch.clamp((prev+1)/2, 0, 1)
        Image.fromarray(to_uint8_img(prev[0])).save(os.path.join(OUT_DIR,"preview_denoise.png"))

        save_flag(FLAG_DENOISE_DONE)
        log("[denoise] done")
    except Exception:
        log("[denoise] ERROR:"); traceback.print_exc()

def do_decode():
    log("[decode] start")
    try:
        s = torch.load(DENOISED_FILE, map_location="cpu"); s = to_device(s, device)
        with no_grad(): imgs = model.decode_first_stage(s)
        if torch.cuda.is_available(): torch.cuda.synchronize()
        imgs = torch.clamp((imgs+1)/2, 0, 1)
        for img in imgs:
            Image.fromarray(to_uint8_img(img)).save(os.path.join(OUT_DIR, f"{time.time_ns()}.png"))
        log("[decode] done")

        for fp in [INIT_IMG, PROMPT_FILE, LATENT_FILE, COND_FILE, META_FILE, DENOISED_FILE,
                   FLAG_ENCODE_DONE, FLAG_DENOISE_DONE, FLAG_RUN_DENOISE, FLAG_RUN_DECODE]:
            if os.path.exists(fp):
                try: os.remove(fp)
                except: pass
    except Exception:
        log("[decode] ERROR:"); traceback.print_exc()

def daemon_loop():
    log("[Daemon] watching…")
    while True:
        try:
            try:
                log(f"[ls] uploads={os.listdir(UPLOAD_DIR)} out={os.listdir(OUT_DIR)}")
            except: pass

            if os.path.exists(INIT_IMG) and os.path.exists(PROMPT_FILE) and not os.path.exists(LATENT_FILE):
                for old in glob.glob(os.path.join(OUT_DIR, '*.png')):
                    if old.endswith("recon_encode.png") or old.endswith("preview_denoise.png"):
                        continue
                    try: os.remove(old)
                    except: pass
                do_encode()

            if os.path.exists(FLAG_RUN_DENOISE) and os.path.exists(LATENT_FILE) and os.path.exists(COND_FILE):
                try: os.remove(FLAG_RUN_DENOISE)
                except: pass
                do_denoise()

            if os.path.exists(FLAG_RUN_DECODE) and os.path.exists(DENOISED_FILE):
                try: os.remove(FLAG_RUN_DECODE)
                except: pass
                do_decode()

            time.sleep(1)
        except Exception:
            log("[Daemon] loop error:"); traceback.print_exc(); time.sleep(3)

if __name__=='__main__':
    for fp in [LATENT_FILE, COND_FILE, META_FILE, DENOISED_FILE,
               FLAG_ENCODE_DONE, FLAG_DENOISE_DONE, FLAG_RUN_DENOISE, FLAG_RUN_DECODE]:
        if os.path.exists(fp):
            try: os.remove(fp)
            except: pass
    daemon_loop()


In [None]:
# Preload the model
pkill -f img2img_server.py
nohup python3 -u img2img_server.py > /var/log/img2img_server.log 2>&1 &
tail -f /var/log/img2img_server.log

In [None]:
# The following code will allow the client-side model to start loading after being called, and execute the corresponding steps after model loading and receiving the flag
nano img2img_client.py

In [None]:
#!/usr/bin/env python3
"""
img2img_client.py

"""

import argparse, os, time, json, threading
import torch, numpy as np
from omegaconf import OmegaConf
from PIL import Image
from einops import repeat
from torchvision.utils import make_grid
from torch import autocast
from contextlib import nullcontext
from pytorch_lightning import seed_everything

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler

def to_uint8_img(x):
    return (x.mul(255).add_(0.5).clamp(0,255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8))

def _load_all(opt):
    seed_everything(opt.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    cfg = OmegaConf.load(opt.config)
    pl_sd = torch.load(opt.ckpt, map_location="cpu")
    sd = pl_sd.get("state_dict", pl_sd)
    model = instantiate_from_config(cfg.model); model.load_state_dict(sd, strict=False)
    model = model.to(device).eval()
    sampler = DDIMSampler(model)
    return model, sampler, device

class AsyncLoader:
    def __init__(self, opt):
        self._e = threading.Event(); self._res=None
        threading.Thread(target=self._run, args=(opt,), daemon=True).start()
    def _run(self, opt):
        self._res = _load_all(opt); self._e.set()
    def get(self):
        self._e.wait(); return self._res

def _load_image(path, device, n):
    img = Image.open(path).convert("RGB")
    w, h = img.size; w -= w%32; h -= h%32
    img = img.resize((w,h), resample=Image.LANCZOS)
    arr = np.array(img).astype(np.float32)/255.0
    x = torch.from_numpy(arr[None].transpose(0,3,1,2)).to(device)*2-1
    return repeat(x,'1 ... -> b ...', b=n), (w,h)

def _to_device(x, device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, (list,tuple)): return [ _to_device(t, device) for t in x ]
    if isinstance(x, dict): return {k:_to_device(v, device) for k,v in x.items()}
    return x

def encode_stage(opt, loader):
    model, sampler, device = loader.get()
    x, (w,h) = _load_image(opt.init_img, device, opt.n_samples)
    with torch.no_grad(): z = model.get_first_stage_encoding(model.encode_first_stage(x))
    torch.save(z.cpu(), opt.latent_out)

    with torch.no_grad(): c = model.get_learned_conditioning([opt.prompt]*opt.n_samples)
    if isinstance(c, torch.Tensor): torch.save(c.cpu(), opt.cond_out)
    elif isinstance(c, (list,tuple)): torch.save([t.cpu() if isinstance(t, torch.Tensor) else t for t in c], opt.cond_out)
    elif isinstance(c, dict): torch.save({k:(v.cpu() if isinstance(v, torch.Tensor) else v) for k,v in c.items()}, opt.cond_out)
    else: torch.save(c, opt.cond_out)

    sf_local = float(getattr(model, "scale_factor", 0.18215))
    meta = {
        "width": w, "height": h, "n_samples": opt.n_samples,
        "steps": opt.ddim_steps, "eta": opt.ddim_eta,
        "strength": opt.strength, "scale": opt.scale,
        "scale_factor": sf_local,
        "ckpt": os.path.basename(opt.ckpt),
        "config": os.path.basename(opt.config),
        "seed": opt.seed
    }
    with open(opt.meta_out, "w") as f: json.dump(meta, f, indent=2)

    with torch.no_grad(): recon = model.decode_first_stage(z)
    recon = torch.clamp((recon+1)/2, 0, 1)
    os.makedirs(opt.outdir, exist_ok=True)
    Image.fromarray(to_uint8_img(recon[0])).save(os.path.join(opt.outdir, "recon_encode.png"))
    print(f"[encode] latent:{tuple(z.shape)} sf:{sf_local}")

def denoise_stage(opt, loader):
    model, sampler, device = loader.get()
    z  = torch.load(opt.latent_out, map_location="cpu"); z = _to_device(z, device)
    cc = torch.load(opt.cond_out,   map_location="cpu"); c = _to_device(cc, device)
    with open(opt.meta_out) as f: meta = json.load(f)

    sf_send = float(meta.get("scale_factor", float(getattr(model,"scale_factor",0.18215))))
    sf_local= float(getattr(model, "scale_factor", 0.18215))
    if abs(sf_send - sf_local) > 1e-6:
        z = z * (sf_send/sf_local)
        print(f"[denoise] rescale latent by {(sf_send/sf_local):.6f}")

    bsz = z.shape[0] if isinstance(z, torch.Tensor) else 1
    if isinstance(c, torch.Tensor) and c.shape[0] != bsz:
        c = c.repeat(bsz, 1, 1)

    steps    = int(meta.get("steps", opt.ddim_steps))
    eta      = float(meta.get("eta", opt.ddim_eta))
    strength = float(meta.get("strength", opt.strength))
    scale    = float(meta.get("scale", opt.scale))
    seed     = int(meta.get("seed", opt.seed))
    torch.manual_seed(seed)

    sampler.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=False)
    t_enc = max(0, int(round(strength * steps)))
    t_vec = torch.tensor([t_enc]*bsz).to(device)

    uc = None if scale==1.0 else _to_device(model.get_learned_conditioning([""]*bsz), device)
    amp = autocast("cuda") if (opt.precision=="autocast" and device.type=="cuda") else nullcontext()
    with torch.no_grad(), amp, model.ema_scope():
        z_noisy = sampler.stochastic_encode(z, t_vec)
        s = sampler.decode(z_noisy, c, t_enc,
                           unconditional_guidance_scale=scale,
                           unconditional_conditioning=uc)
    torch.save(s.cpu(), opt.denoised_out)

    with torch.no_grad(): prev = model.decode_first_stage(s)
    prev = torch.clamp((prev+1)/2, 0, 1)
    Image.fromarray(to_uint8_img(prev[0])).save(os.path.join(opt.outdir, "preview_denoise.png"))
    print(f"[denoise] steps={steps} strength={strength} t_enc={t_enc}")

def decode_stage(opt, loader):
    model, sampler, device = loader.get()
    s = torch.load(opt.denoised_out, map_location="cpu")
    if not isinstance(s, torch.Tensor):
        raise TypeError("denoised_latent.pt must be Tensor")
    s = s.to(device)
    with torch.no_grad(): imgs = model.decode_first_stage(s)
    imgs = torch.clamp((imgs+1)/2, 0, 1)
    os.makedirs(opt.outdir, exist_ok=True)
    Image.fromarray(to_uint8_img(imgs[0])).save(os.path.join(opt.outdir, "final.png"))
    print(f"[decode] saved {os.path.join(opt.outdir,'final.png')}")

if __name__ == '__main__':
    import time as _time
    p = argparse.ArgumentParser()
    p.add_argument("--prompt", required=True)
    p.add_argument("--init-img", required=True)
    p.add_argument("--outdir", default="outputs/img2img-samples")
    p.add_argument("--latent-out",   default="init_latent.pt")
    p.add_argument("--cond-out",     default="encoded_condition.pt")
    p.add_argument("--meta-out",     default="intermediate_meta.json")
    p.add_argument("--denoised-out", default="denoised_latent.pt")
    p.add_argument("--ddim_steps",   type=int,   default=50)
    p.add_argument("--ddim_eta",     type=float, default=0.0)
    p.add_argument("--n_samples",    type=int,   default=1)
    p.add_argument("--strength",     type=float, default=0.30)
    p.add_argument("--scale",        type=float, default=5.0)
    p.add_argument("--precision",    choices=["full","autocast"], default="autocast")
    p.add_argument("--config",       default="configs/stable-diffusion/v1-inference.yaml")
    p.add_argument("--ckpt",         default="models/ldm/stable-diffusion-v1/model.ckpt")
    p.add_argument("--seed",         type=int,   default=42)
    opt = p.parse_args()

    loader = AsyncLoader(opt)
    print("[client_local] watching flags…")
    for name, func in [("encode", encode_stage), ("denoise", denoise_stage), ("decode", decode_stage)]:
        flag = f"{name}.flag"
        while True:
            if os.path.exists(flag):
                print(f"=== {name} ==="); func(opt, loader); os.remove(flag); break
            _time.sleep(0.5)
    print("✅ done")


In [None]:
# After both the client and server have enabled stable diffusion through Docker, modified the above code
# Please remember that the network configuration needs to be modified according to the actual situation.
# By default, the CPU and MEM usage are used to determine which device to execute the task.
# However, you can modify some code to force certain steps to be executed on a certain device for testing.
pip install paramiko
nano client.py

In [None]:
#!/usr/bin/env python3
"""
client.py

"""

import argparse, time, os, io, csv, json, subprocess, paramiko, threading, traceback
import torch, numpy as np
from omegaconf import OmegaConf
from PIL import Image
from einops import repeat
from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from contextlib import nullcontext
from torch import autocast

CSV_DIR = "/data/images/stable-diffusion"

# Network Configuration
EDGE_SERVER_IP   = "100.110.165.80"
EDGE_SERVER_USER = "host"
EDGE_SERVER_PWD  = "123"
MONITOR_SCRIPT   = "/home/host/monitor.py"

REMOTE_BASE      = "/data/images/stable-diffusion"
REMOTE_UPLOAD    = os.path.join(REMOTE_BASE, "uploads")
REMOTE_OUT       = os.path.join(REMOTE_BASE, "test")

FLAG_ENCODE_DONE  = os.path.join(REMOTE_UPLOAD, "encode_done.flag")
FLAG_DENOISE_DONE = os.path.join(REMOTE_UPLOAD, "denoise_done.flag")
FLAG_RUN_DENOISE  = os.path.join(REMOTE_UPLOAD, "run_denoise.flag")
FLAG_RUN_DECODE   = os.path.join(REMOTE_UPLOAD, "run_decode.flag")

LATENT_LOCAL  = "init_latent.pt"
COND_LOCAL    = "encoded_condition.pt"
META_LOCAL    = "intermediate_meta.json"
DENOISED_LOCAL= "denoised_latent.pt"

def log(m): print(m, flush=True)

class StageTimer:
    def __init__(self):
        self.metrics = {k: 0.0 for k in ("encode","denoise","decode")}
        self.cur = None; self.t0  = None
    def _now(self): return time.perf_counter()
    def start(self, stage: str): self.cur = stage; self.t0 = self._now()
    def switch(self, next_stage: str):
        now = self._now()
        if self.cur is not None: self.metrics[self.cur] += now - self.t0
        self.cur = next_stage; self.t0 = now
    def stop(self):
        now = self._now()
        if self.cur is not None: self.metrics[self.cur] += now - self.t0

def _init_local_model(cfg_path, ckpt_path, seed):
    seed_everything(seed)
    cfg = OmegaConf.load(cfg_path)
    pl  = torch.load(ckpt_path, map_location="cpu")
    sd  = pl.get("state_dict", pl)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model  = instantiate_from_config(cfg.model); model.load_state_dict(sd, strict=False)
    model  = model.to(device).eval()
    sampler= DDIMSampler(model)
    return model, sampler, device

class AsyncModelLoader:
    def __init__(self, cfg, ckpt, seed):
        self._e = threading.Event(); self._res=None
        threading.Thread(target=self._load, args=(cfg,ckpt,seed), daemon=True).start()
    def _load(self, cfg, ckpt, seed):
        try: self._res = _init_local_model(cfg, ckpt, seed); log("[loader] model ready")
        except Exception: log("[loader] model failed"); traceback.print_exc()
        finally: self._e.set()
    def get(self):
        self._e.wait()
        if self._res is None: raise RuntimeError("model load failed")
        return self._res

def make_ssh():
    ssh = paramiko.SSHClient()
    ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    ssh.connect(EDGE_SERVER_IP, username=EDGE_SERVER_USER, password=EDGE_SERVER_PWD)
    return ssh

def server_ok(ssh):
    try:
        _, stdout, _ = ssh.exec_command(f"python3 {MONITOR_SCRIPT}")
        stat = json.loads(stdout.read().decode() or "{}")
        cpu = stat.get("cpu", 100); mem = stat.get("mem", 100)
        ok = (cpu < 80) and (mem < 80)
        log(f"[server_ok] cpu={cpu} mem={mem} -> {ok}")
        return ok
    except Exception:
        log("[server_ok] monitor failed -> False"); traceback.print_exc(); return False

def r_exists(sftp, path):
    try: sftp.stat(path); return True
    except Exception: return False

def r_list(sftp, path):
    try: return sftp.listdir(path)
    except Exception: return []

def wait_remote(ssh, path, poll, timeout, name):
    t0 = time.time(); snap=0
    while True:
        with ssh.open_sftp() as sftp:
            if r_exists(sftp, path): log(f"[remote] found {name}"); return True
            now=time.time()
            if now - snap >= 5:
                log(f"[remote] waiting {name}… uploads={r_list(sftp, REMOTE_UPLOAD)} out={r_list(sftp, REMOTE_OUT)}")
                snap=now
        if time.time()-t0 > timeout: log(f"[remote] timeout {name}"); return False
        time.sleep(poll)

def r_put(ssh, l, r):
    with ssh.open_sftp() as sftp:
        try: sftp.mkdir(os.path.dirname(r))
        except: pass
        sftp.put(l, r)
def r_put_flag(ssh, r):
    with ssh.open_sftp() as sftp:
        try: sftp.mkdir(os.path.dirname(r))
        except: pass
        sftp.putfo(io.BytesIO(b""), r)
def r_get(ssh, r, l):
    with ssh.open_sftp() as sftp: sftp.get(r, l)

def to_uint8_img(x):
    return (x.mul(255).add_(0.5).clamp(0,255).permute(1,2,0).cpu().numpy().astype(np.uint8))

def local_encode(args, loader):
    log("[local] ENCODE")
    model, sampler, device = loader.get()
    img = Image.open(args.init_img).convert("RGB")
    w,h = img.size; w-=w%32; h-=h%32
    img = img.resize((w,h), resample=Image.LANCZOS)
    arr = np.array(img, dtype=np.float32)/255.0
    x = torch.from_numpy(arr.transpose(2,0,1)).unsqueeze(0).to(device)*2-1
    with torch.no_grad(): z = model.get_first_stage_encoding(model.encode_first_stage(x))
    torch.save(z.cpu(), LATENT_LOCAL)

    with torch.no_grad(): c = model.get_learned_conditioning([args.prompt])
    if isinstance(c, torch.Tensor): torch.save(c.cpu(), COND_LOCAL)
    elif isinstance(c, (list,tuple)): torch.save([t.cpu() if isinstance(t, torch.Tensor) else t for t in c], COND_LOCAL)
    elif isinstance(c, dict): torch.save({k:(v.cpu() if isinstance(v, torch.Tensor) else v) for k,v in c.items()}, COND_LOCAL)
    else: torch.save(c, COND_LOCAL)

    sf_local = float(getattr(model,"scale_factor",0.18215))
    meta = {
        "width": w, "height": h, "n_samples": 1,
        "steps": args.ddim_steps, "eta": args.ddim_eta,
        "strength": args.strength, "scale": args.scale,
        "scale_factor": sf_local, "ckpt": os.path.basename(args.ckpt),
        "config": os.path.basename(args.config), "seed": args.seed
    }
    with open(META_LOCAL,"w") as f: json.dump(meta, f, indent=2)

    with torch.no_grad(): recon = model.decode_first_stage(z)
    Image.fromarray(to_uint8_img(torch.clamp((recon+1)/2,0,1)[0])).save("recon_encode.png")

def local_denoise(args, loader):
    log("[local] DENOISE")
    model, sampler, device = loader.get()
    z = torch.load(LATENT_LOCAL, map_location="cpu").to(device)
    with open(META_LOCAL) as f: meta = json.load(f)
    sf_send  = float(meta.get("scale_factor", float(getattr(model,"scale_factor",0.18215))))
    sf_local = float(getattr(model,"scale_factor",0.18215))
    if abs(sf_send-sf_local)>1e-6:
        z = z * (sf_send/sf_local); log(f"[local] rescale latent {(sf_send/sf_local):.6f}")

    steps    = int(meta.get("steps", args.ddim_steps))
    eta      = float(meta.get("eta", args.ddim_eta))
    strength = float(meta.get("strength", args.strength))
    scale    = float(meta.get("scale", args.scale))
    seed     = int(meta.get("seed", args.seed))
    torch.manual_seed(seed)

    sampler.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=False)
    t_enc = max(0, int(round(strength*steps)))
    uc = None if scale==1.0 else model.get_learned_conditioning([""])
    uc = uc.to(device) if isinstance(uc, torch.Tensor) else uc
    c  = torch.load(COND_LOCAL, map_location="cpu")
    if isinstance(c, torch.Tensor): c = c.to(device)
    t_vec = torch.tensor([t_enc]).to(device)
    amp = autocast("cuda") if torch.cuda.is_available() else nullcontext()
    with torch.no_grad(), amp, model.ema_scope():
        z_noisy = sampler.stochastic_encode(z, t_vec)
        s = sampler.decode(z_noisy, c, t_enc,
                           unconditional_guidance_scale=scale,
                           unconditional_conditioning=uc)
    torch.save(s.cpu(), DENOISED_LOCAL)
    with torch.no_grad(): prev = model.decode_first_stage(s)
    Image.fromarray(to_uint8_img(torch.clamp((prev+1)/2,0,1)[0])).save("preview_denoise.png")

def local_decode(args, loader):
    log("[local] DECODE")
    model, sampler, device = loader.get()
    s = torch.load(DENOISED_LOCAL, map_location="cpu").to(device)
    with torch.no_grad(): imgs = model.decode_first_stage(s)
    imgs = torch.clamp((imgs+1)/2, 0, 1)
    Image.fromarray(to_uint8_img(imgs[0])).save(args.output)
    log(f"[local] saved {args.output}")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--prompt", required=True)
    ap.add_argument("--init-img", required=True)
    ap.add_argument("--output", required=True)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--strength", type=float, default=0.30)
    ap.add_argument("--ddim_steps", type=int, default=50)
    ap.add_argument("--ddim_eta", type=float, default=0.0)
    ap.add_argument("--scale", type=float, default=5.0)
    ap.add_argument("--config", default="/opt/stable-diffusion/configs/stable-diffusion/v1-inference.yaml")
    ap.add_argument("--ckpt",   default="/data/models/stable-diffusion/sd-v1-5.ckpt")
    ap.add_argument("--poll-interval", type=float, default=1.0)
    ap.add_argument("--remote_timeout", type=float, default=120.0)
    ap.add_argument("--mode", choices=["auto","local","server"], default="auto")
    ap.add_argument("--no-fallback", action="store_true")
    args = ap.parse_args()

    loader = AsyncModelLoader(args.config, args.ckpt, args.seed)
    ssh = make_ssh()

    def want_server(stage):
        if args.mode=="server": return True
        if args.mode=="local":  return False
        return server_ok(ssh)

    timer = StageTimer()
    timer.start("encode")

    if want_server("ENCODE"):
        log("[branch] ENCODE on SERVER")
        subprocess.run(f"cp {args.init_img} init_img.png", shell=True, check=True)
        with open("prompt.txt","w") as f: f.write(args.prompt)
        with ssh.open_sftp() as sftp:
            try: sftp.mkdir(REMOTE_UPLOAD)
            except: pass

        with ssh.open_sftp() as sftp:
            sftp.put("init_img.png", os.path.join(REMOTE_UPLOAD, "init_img.png"))
            sftp.put("prompt.txt",   os.path.join(REMOTE_UPLOAD, "prompt.txt"))
        ok = wait_remote(ssh, FLAG_ENCODE_DONE, args.poll_interval, args.remote_timeout, "encode_done.flag")
        if not ok:
            if args.no_fallback: raise TimeoutError("ENCODE remote timeout")
            log("[fallback] ENCODE -> LOCAL")
            local_encode(args, loader)
    else:
        log("[branch] ENCODE on LOCAL")
        local_encode(args, loader)

    if torch.cuda.is_available(): torch.cuda.synchronize()
    timer.switch("denoise")

    if want_server("DENOISE"):
        log("[branch] DENOISE on SERVER")

        with ssh.open_sftp() as sftp:
            has_server_latent = os.path.basename(LATENT_LOCAL) in sftp.listdir(REMOTE_UPLOAD)
        if not has_server_latent and os.path.exists(LATENT_LOCAL):
            r_put(ssh, LATENT_LOCAL, os.path.join(REMOTE_UPLOAD, "init_latent.pt"))
            r_put(ssh, COND_LOCAL,   os.path.join(REMOTE_UPLOAD, "encoded_condition.pt"))
            r_put(ssh, META_LOCAL,   os.path.join(REMOTE_UPLOAD, "intermediate_meta.json"))
        r_put_flag(ssh, FLAG_RUN_DENOISE)
        ok = wait_remote(ssh, FLAG_DENOISE_DONE, args.poll_interval, args.remote_timeout, "denoise_done.flag")
        if not ok:
            if args.no_fallback: raise TimeoutError("DENOISE remote timeout")
            log("[fallback] DENOISE -> LOCAL")
            local_denoise(args, loader)
    else:
        log("[branch] DENOISE on LOCAL")
        local_denoise(args, loader)

    if torch.cuda.is_available(): torch.cuda.synchronize()
    timer.switch("decode")

    if want_server("DECODE"):
        log("[branch] DECODE on SERVER")
        r_put_flag(ssh, FLAG_RUN_DECODE)
        t0=time.time(); name=None
        while True:
            with ssh.open_sftp() as sftp:
                outs = sftp.listdir(REMOTE_OUT)
                pngs = sorted([f for f in outs if f.endswith(".png") and f not in ("recon_encode.png","preview_denoise.png")])
                log(f"[remote] out={outs}")
            if pngs: name=pngs[-1]; break
            if time.time()-t0>args.remote_timeout:
                if args.no_fallback: raise TimeoutError("DECODE remote timeout")
                log("[fallback] DECODE -> LOCAL"); break
            time.sleep(args.poll_interval)
        if name:
            r_get(ssh, os.path.join(REMOTE_OUT, name), args.output)
        else:
            local_decode(args, loader)
    else:
        log("[branch] DECODE on LOCAL")
        local_decode(args, loader)

    if torch.cuda.is_available(): torch.cuda.synchronize()
    timer.stop()

    os.makedirs(CSV_DIR, exist_ok=True)
    existing = [f for f in os.listdir(CSV_DIR) if f.endswith(".csv") and f[:-4].isdigit()]
    next_idx = (max(map(int,[f[:-4] for f in existing]))+1) if existing else 0
    with open(os.path.join(CSV_DIR,f"{next_idx}.csv"),"w",newline="") as f:
        w=csv.writer(f); w.writerow(["stage","delay"])
        w.writerow(["encode",f"{timer.metrics['encode']:.6f}"])
        w.writerow(["denoise",f"{timer.metrics['denoise']:.6f}"])
        w.writerow(["decode",f"{timer.metrics['decode']:.6f}"])
    log("[+] latency csv written")
    ssh.close()

if __name__ == "__main__":
    main()


In [None]:
python3 client.py \
  --prompt "high quality, sharp focus, crisp street details, natural lighting, balanced contrast, accurate colors, clean edges, realistic textures, photorealistic" \
  --init-img /data/images/stable-diffusion/samples/input.png \
  --output /data/images/stable-diffusion/test/output.png \
  --n_samples 1 --seed $(python3 -c 'import secrets; print(1 + secrets.randbelow(2147483646))') --strength 0.3 --ddim_steps 50
