In [None]:
import os
import torch
from src.flux.util import load_ae, load_clip, load_t5, load_flow_model
from src.flux.pipeline import Sampler

# M4 Pro Optimizations
os.environ["PYTORCH_MPS_MEMORY_FRACTION"] = "0.95"
os.environ["PYTORCH_MPS_ALLOCATOR_POLICY"] = "expandable_segments"
os.environ["PYTORCH_MPS_PREFER_FAST_ALLOC"] = "1"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["OMP_NUM_THREADS"] = "12"  # 8P + 4E cores
os.environ["MKL_NUM_THREADS"] = "12"

# Use MPS (Metal Performance Shaders) for Apple Silicon
if torch.backends.mps.is_available():
    device = "mps"
    print("✅ Using Metal Performance Shaders (MPS) for Apple Silicon")
elif torch.cuda.is_available():
    device = "cuda"
    print("✅ Using CUDA")
else:
    device = "cpu"
    print("⚠️  Using CPU")

print(f"🚀 Loading FLUX.1 Krea models on {device}...")

# Load models to CPU first for memory efficiency
model = load_flow_model("flux-krea-dev", device="cpu")
ae = load_ae("flux-krea-dev")
clip = load_clip()
t5 = load_t5()

# Move to device with bfloat16 precision
print("📦 Moving models to device...")
ae = ae.to(device=device, dtype=torch.bfloat16)
clip = clip.to(device=device, dtype=torch.bfloat16)
t5 = t5.to(device=device, dtype=torch.bfloat16)
model = model.to(device, dtype=torch.bfloat16)

print("✅ Models loaded and optimized for M4 Pro!")

In [None]:
sampler = Sampler(
    model=model,
    ae=ae,
    clip=clip,
    t5=t5,
    device=device,
    dtype=torch.bfloat16,
)

In [None]:
# Clear MPS cache for optimal performance
if device == "mps":
    torch.mps.empty_cache()
    torch.mps.synchronize()

# Generate image with M4 Pro optimizations
image = sampler(
    prompt="a cute cat",
    width=1024,
    height=1024,
    guidance=4.5,
    num_steps=28,
    seed=42,
)

# Show performance info
print(f"🚀 Generation complete on {device}")
print(f"💾 Memory usage: {torch.cuda.memory_allocated() / 1024**3:.1f} GB" if device == "cuda" else "💾 Using unified memory")

# Display the cute cat
image