In [5]:
# Imports and init remote
import os
import subprocess
from union import task, workflow, FlyteFile, UnionRemote, ImageSpec, Resources, FlyteDirectory, ActorEnvironment, LaunchPlan
from union.remote import HuggingFaceModelInfo
from flytekit.configuration import Config

os.environ["UNION_CONFIG"] = "/Users/pryceturner/.union/config_serving.yaml"

remote = UnionRemote(config=Config.auto(config_file="/Users/pryceturner/.union/config_serving.yaml"))

In [None]:
# Emit HF repo as Artifact
info = HuggingFaceModelInfo(repo="boltz-community/boltz-1")

cache_exec = remote._create_model_from_hf(
    info=info, 
    hf_token_key="HF_TOKEN", 
    union_api_key="UNION_API_KEY",
)

cache_exec = cache_exec.wait(poll_interval=2)
cache_exec.outputs

In [6]:
# Define Image
image = ImageSpec(
    name="boltz",
    packages=[
        "union",
        "flytekit==1.15",
        "union-runtime==0.1.11",
        "fastapi==0.115.11",
        "pydantic==2.10.6",
        "uvicorn==0.34.0",
        "python-multipart==0.0.20",
    ],
    apt_packages=["build-essential"],
    builder="union",
    commands=["uv pip install boltz==0.4.1"]
)

In [7]:
actor = ActorEnvironment(
    name="boltz-actor",
    replica_count=1,
    ttl_seconds=600,
    requests=Resources(
        cpu="2",
        mem="10Gi",
        gpu="1",
    ),
    container_image=image,
)

In [8]:

@actor.task
def simple_predict(input: FlyteFile) -> FlyteDirectory:
    input.download()
    out = "/tmp/boltz_out"
    os.makedirs(out, exist_ok=True)
    subprocess.run(["boltz", "predict", input.path, "--out_dir", out, "--use_msa_server"])
    return FlyteDirectory(path=out)


execution = remote.execute(
    entity=simple_predict, 
    inputs={"input": "inputs/prot_no_msa.yaml"}, 
    wait=True
)
output = execution.outputs
print(output)

[34mImage boltz:6T2f1KBgMSk3wyyuSsFnxA was not found or has expired.[0m
[34m[1m🐳 Submitting a new build...[0m


[33m[1m👍 Build submitted![0m
[1m⏳ Waiting for build to finish at: [36mhttps://serving-mvp.us-west-2.union.ai/console/projects/system/domains/production/executions/a4829gmhkrtj47crwf6c[0m[0m


In [None]:
%%writefile boltz_fastapi.py
from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse
import shutil
import os
from typing import Optional, Dict, Any
from boltz.main import predict  # Ensure you are importing the correct function
from click.testing import CliRunner
import io
from pathlib import Path
import traceback
import tempfile
import subprocess
import asyncio

app = FastAPI()
USE_CPU_ONLY = os.environ.get("USE_CPU_ONLY", "0") == "1"

def package_outputs(output_dir: str) -> bytes:
    import io
    import tarfile

    tar_buffer = io.BytesIO()
    parent_dir = Path(output_dir).parent

    cur_dir = os.getcwd()
    with tarfile.open(fileobj=tar_buffer, mode="w:gz") as tar:
        os.chdir(parent_dir)
        try: 
            tar.add(Path(output_dir).name, arcname=Path(output_dir).name)
        finally: 
            os.chdir(cur_dir)

    return tar_buffer.getvalue()

async def generate_response(process, out_dir, yaml_path):
    try:
        while True:
            try:
                stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=10.0)
                break
            except TimeoutError:
                yield b""  # Yield null character instead of spaces

        if process.returncode != 0:
            raise Exception(stderr.decode())

        print(stdout.decode())

        # Package the output directory
        tar_data = package_outputs(f"{out_dir}/boltz_results_{Path(yaml_path).with_suffix('').name}")
        yield tar_data

    except Exception as e:
        traceback.print_exc()
        yield JSONResponse(status_code=500, content={"error": str(e)}).body

@app.post("/predict/")
async def predict_endpoint(
    yaml_file: UploadFile = File(...),
    msa_dir: Optional[UploadFile] = File(None),
    options: Optional[Dict[str, str]] = Form(None)
):
    yaml_path = f"/tmp/{yaml_file.filename}"
    with open(yaml_path, "wb") as buffer:
        shutil.copyfileobj(yaml_file.file, buffer)

    msa_dir_path = None
    if msa_dir and msa_dir.filename:
        msa_dir_path = f"/tmp/{msa_dir.filename}"
        os.makedirs(msa_dir_path, exist_ok=True)
        with open(msa_dir_path, "wb") as buffer:
            shutil.copyfileobj(msa_dir.file, buffer)

    # Create a temporary directory for the output
    with tempfile.TemporaryDirectory() as out_dir:
        # Call boltz.predict as a CLI tool
        try:
            print(f"Running predictions with options: {options} into directory: {out_dir}")
            # Convert options dictionary to key-value pairs
            options_list = [f"--{key}={value}" for key, value in (options or {}).items()]
            command = ["boltz", "predict", yaml_path, "--out_dir", out_dir, "--use_msa_server"] + (["--accelerator", "cpu"] if USE_CPU_ONLY else []) + options_list
            process = await asyncio.create_subprocess_exec(
                *command,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE
            )

            return StreamingResponse(generate_response(process, out_dir, yaml_path), media_type="application/gzip", headers={"Content-Disposition": f"attachment; filename=boltz_results.tar.gz"})

        except Exception as e:
            traceback.print_exc()
            return JSONResponse(status_code=500, content={"error": str(e)})


In [None]:
from datetime import timedelta

from union import Resources, ImageSpec
from union.app import App, ScalingMetric, Input
from union import Artifact
from flytekit.extras.accelerators import L4, GPUAccelerator

boltz_fastapi = App(
    name="boltz-fastapi-notebook",
    container_image=image,
    limits=Resources(cpu="2", mem="10Gi", gpu="1", ephemeral_storage="50Gi"),
    port=8080,
    include=["./boltz_fastapi.py"],
    args=["uvicorn", "boltz_fastapi:app", "--host", "0.0.0.0", "--port", "8080"],
    env={
        "PYTORCH_ENABLE_MPS_FALLBACK": "1",
        # "CUDA_VISIBLE_DEVICES": "",
    },
    min_replicas=1,
    max_replicas=3,
    scaledown_after=timedelta(minutes=10),
    scaling_metric=ScalingMetric.RequestRate(1),
    accelerator=GPUAccelerator("nvidia-l40s"),
)

In [None]:
from union.remote._app_remote import AppRemote

app_remote = AppRemote(default_project="default", default_domain="development", union_remote=remote)

app_remote.deploy(boltz_fastapi)