In [None]:
# Prediction interface for Cog
from cog import BasePredictor, Input, Path
import os
import math
import torch
import subprocess
from PIL import Image, ImageFilter
from typing import List
from dotenv import load_dotenv
from huggingface_hub import login
from diffusers import (
    DDIMScheduler,
    DPMSolverMultistepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    HeunDiscreteScheduler,
    PNDMScheduler,
    FluxPriorReduxPipeline,
    FluxFillPipeline,
)

from script.download_weights import download_weights


MODEL_NAME_FILL = "black-forest-labs/FLUX.1-Fill-dev"
MODEL_NAME_REDUX = "black-forest-labs/FLUX.1-Redux-dev"
MODEL_CACHE = "checkpoints"
# https://github.com/replicate/cog-flux/blob/main/weights.py#L208-L224
MODELS_URL_FILL = "https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev/resolve/main/flux1-fill-dev.safetensors"
MODELS_URL_REDUX = "https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev/resolve/main/flux1-redux-dev.safetensors"

SCHEDULERS = {
    "DDIM": DDIMScheduler,
    "DPMSolverMultistep": DPMSolverMultistepScheduler,
    "HeunDiscrete": HeunDiscreteScheduler,
    "K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler,
    "K_EULER": EulerDiscreteScheduler,
    "PNDM": PNDMScheduler,
}


def login_huggingface():
    load_dotenv()
    login(token=os.environ["HUGGINGFACE_TOKEN"])


class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load the model into memory to make running multiple predictions efficient"""
        print("Downloading weights")
        login_huggingface()
        if not os.path.exists(MODEL_CACHE):
            download_weights()
        print("Loading Flux Prior Redux")
        self.pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-Redux-dev",
            torch_dtype=torch.bfloat16,
            cache_dir=MODEL_CACHE,
        ).to("cuda")
        print("Loading Flux Fill")
        self.pipe = FluxFillPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-Fill-dev",
            torch_dtype=torch.bfloat16,
            cache_dir=MODEL_CACHE,
        ).to("cuda")

    def scale_down_image(self, image_path: Path, max_size: int) -> Image.Image:
        image = Image.open(image_path)
        width, height = image.size
        scaling_factor = min(max_size / width, max_size / height)
        new_width = int(width * scaling_factor)
        new_height = int(height * scaling_factor)
        resized_image = image.resize((new_width, new_height))
        cropped_image = self.crop_center(resized_image)
        return cropped_image

    def crop_center(self, pil_img):
        img_width, img_height = pil_img.size
        crop_width = self.base(img_width)
        crop_height = self.base(img_height)
        return pil_img.crop(
            (
                (img_width - crop_width) // 2,
                (img_height - crop_height) // 2,
                (img_width + crop_width) // 2,
                (img_height + crop_height) // 2,
            )
        )

    def base(self, x):
        return int(8 * math.floor(int(x) / 8))

    def predict(
        self,
        image: Path = Input(description="Input image"),
        mask: Path = Input(
            description="Mask image - make sure it's the same size as the input image"
        ),
        reference_image: Path = Input(
            description="Reference image - image to encode as input for Flux.1 Redux"
        ),
        prompt: str = Input(
            description="Input prompt",
            default="cartoon of a black woman laughing, digital art",
        ),
        scheduler: str = Input(
            description="scheduler",
            choices=list(SCHEDULERS.keys()),
            default="K_EULER",
        ),
        guidance_scale: float = Input(
            description="Guidance scale", ge=0, le=10, default=8.0
        ),
        steps: int = Input(
            description="Number of denoising steps", ge=1, le=80, default=20
        ),
        strength: float = Input(
            description="1.0 corresponds to full destruction of information in image",
            ge=0.01,
            le=1.0,
            default=0.7,
        ),
        seed: int = Input(
            description="Random seed. Leave blank to randomize the seed", default=None
        ),
        num_outputs: int = Input(
            description="Number of images to output. Higher number of outputs may OOM.",
            ge=1,
            le=4,
            default=1,
        ),
        blur_radius: int = Input(
            description="Standard deviation of the Gaussian kernel for the mask. Higher values will blur the mask more.",
            ge=0,
            le=128,
            default=16,
        ),
        prompt_embeds_scale: float = Input(
            description="Strength of prompt embeddings on Flux Redux",
            ge=0.01,
            le=2.0,
            default=1.0,
        ),
        pooled_prompt_embeds_scale: float = Input(
            description="Strength of pooled prompt embeddings on Flux Redux",
            ge=0.01,
            le=2.0,
            default=1.0,
        ),
    ) -> List[Path]:
        """Run a single prediction on the model"""
        # Configure Seed
        if seed is None:
            seed = int.from_bytes(os.urandom(2), "big")
        print(f"Using seed: {seed}")
        generator = torch.Generator("cuda").manual_seed(seed)

        # Configure Scheduler
        self.pipe.scheduler = SCHEDULERS[scheduler].from_config(
            self.pipe.scheduler.config
        )

        # Configure Input Image
        input_image = self.scale_down_image(image, 1024)

        # Configure Mask Image
        pil_mask = Image.open(mask)
        mask_image = pil_mask.resize((input_image.width, input_image.height))
        mask_image = mask_image.filter(ImageFilter.GaussianBlur(blur_radius))

        # Run Flux Prior Redux
        pipe_prior_output = self.pipe_prior_redux(
            image=reference_image,
            prompt=[prompt] * num_outputs if prompt is not None else None,
            prompt_embeds_scale=prompt_embeds_scale,
            pooled_prompt_embeds_scale=pooled_prompt_embeds_scale,
        )
        prompt_embeds = pipe_prior_output["prompt_embeds"]
        pooled_prompt_embeds = pipe_prior_output["pooled_prompt_embeds"]

        # Run Flux Fill
        result = self.pipe(
            prompt_embeds=prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            image=input_image,
            mask_image=mask_image,
            guidance_scale=guidance_scale,
            num_inference_steps=steps,
            strength=strength,
            generator=generator,
            width=input_image.width,
            height=input_image.height,
        )

        # Save Output Images
        output_paths = []
        for i, output in enumerate(result.images):
            output_path = f"/tmp/out-{i}.png"
            output.save(output_path)
            output_paths.append(Path(output_path))

        return output_paths


In [1]:
from cog.types import Path

# override MODEL_CACHE to point to the correct directory
MODEL_CACHE = "../checkpoints"

# Initialize predictor
predictor = Predictor()
predictor.setup()

# Define paths for required images
input_image = Path("images/cartoon-man-laughing.png")
mask_image = Path("images/mask.png")  # You'll need to create a mask image
reference_image = Path("images/cartoon-man-laughing.png")  # Using same image as reference

# Run prediction with default parameters
results = predictor.predict(
    image=input_image,
    mask=mask_image,
    reference_image=reference_image,
    prompt="cartoon of a black woman laughing, digital art",
    num_outputs=1
)

display(results[0])

Downloading weights


model_index.json:   0%|          | 0.00/295 [00:00<?, ?B/s]

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/857M [00:00<?, ?B/s]

image_embedder/config.json:   0%|          | 0.00/128 [00:00<?, ?B/s]

(…)ature_extractor/preprocessor_config.json:   0%|          | 0.00/394 [00:00<?, ?B/s]

image_encoder/config.json:   0%|          | 0.00/419 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/129M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]

model_index.json:   0%|          | 0.00/540 [00:00<?, ?B/s]

Fetching 23 files:   0%|          | 0/23 [00:00<?, ?it/s]

scheduler/scheduler_config.json:   0%|          | 0.00/299 [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/561 [00:00<?, ?B/s]

text_encoder_2/config.json:   0%|          | 0.00/741 [00:00<?, ?B/s]

(…)t_encoder_2/model.safetensors.index.json:   0%|          | 0.00/19.9k [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.53G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/246M [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/588 [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/705 [00:00<?, ?B/s]

tokenizer_2/special_tokens_map.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

tokenizer_2/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

(…)pytorch_model-00001-of-00003.safetensors:   0%|          | 0.00/9.99G [00:00<?, ?B/s]

tokenizer_2/tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

transformer/config.json:   0%|          | 0.00/393 [00:00<?, ?B/s]

(…)pytorch_model-00002-of-00003.safetensors:   0%|          | 0.00/9.95G [00:00<?, ?B/s]

(…)pytorch_model-00003-of-00003.safetensors:   0%|          | 0.00/3.87G [00:00<?, ?B/s]

(…)ion_pytorch_model.safetensors.index.json:   0%|          | 0.00/121k [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/774 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

In [None]:
import os

os.environ["HUGGINGFACE_TOKEN"]