<a href="https://colab.research.google.com/github/ariG23498/custom-inference-endpoint/blob/main/flux.2-with-remote-text-encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Installation and Setup

In [None]:
!pip install --upgrade -qq git+https://github.com/huggingface/diffusers
!pip install --upgrade -qq bitsandbytes

In [None]:
from diffusers import Flux2Pipeline, Flux2Transformer2DModel
from diffusers import BitsAndBytesConfig as DiffBitsAndBytesConfig
from huggingface_hub import get_token
import requests
import torch
import io

In [None]:
import diffusers
import torch

print(f"{torch.__version__=}")
print(f"{diffusers.__version__=}")

In [None]:
print(f"Using GPU: {torch.cuda.get_device_name()}")
print(f"Total VRAM: {torch.cuda.get_device_properties().total_memory // 1024**3} GBs")

## Run Inference

In [None]:
repo_id = "black-forest-labs/FLUX.2-dev"

quantized_dit_id = "diffusers/FLUX.2-dev-bnb-4bit"
dit = Flux2Transformer2DModel.from_pretrained(
  quantized_dit_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cpu"
)

pipe = Flux2Pipeline.from_pretrained(
  repo_id,
  text_encoder=None,
  transformer=dit,
  torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()

In [None]:
def remote_text_encoder(prompts: str | list[str]):
    response = requests.post(
        "https://rhknk53jznw37un7.us-east-1.aws.endpoints.huggingface.cloud/predict",
        json={"prompt": prompts},
        headers={
            "Authorization": f"Bearer {get_token()}",
            "Content-Type": "application/json"
        }
    )
    assert response.status_code == 200, f"{response.status_code=}"
    prompt_embeds = torch.load(io.BytesIO(response.content))
    return prompt_embeds.to("cuda")

print("Running remote text encoder ☁️")
prompt = "a photo of a forest with mist swirling around the tree trunks. The word 'FLUX.2 in diffusers' is painted over it in big, red brush strokes with visible texture"
prompt_embeds = remote_text_encoder([prompt])
print("Done ✅")

In [None]:
out = pipe(
  prompt_embeds=prompt_embeds,
  generator=torch.Generator(device="cuda").manual_seed(42),
  num_inference_steps=50, # 28 is a good trade-off
  guidance_scale=4,
  height=512,
  width=512,
)

out.images[0]