In [None]:
!pip install diffusers transformers accelerate xformers huggingface_hub
!pip install requests pillow insightface opencv-python apex gradio diffusers onnx onnxruntime-gpu onnxruntime timm \
    SentencePiece git+https://github.com/XPixelGroup/BasicSR ftfy einops facexlib fire
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
!pip install "git+https://github.com/IDEA-Research/GroundingDINO.git"

In [None]:
%%capture
# Setting up HF transfer
from huggingface_hub import login
import base64
t = 'aGZfaHZqck9VTXFvTXF3dW9HR3JoTlZKSWlsZUtFTlNQbXRjTw=='
k = base64.b64decode(t.encode()).decode()
login(token=k, add_to_git_credential=False)
%env HUGGINGFACEHUB_API_TOKEN={k}
%env HF_TOKEN={k}
%env HF_HUB_ENABLE_HF_TRANSFER=1

In [None]:
!huggingface-cli download black-forest-labs/FLUX.1-dev
!huggingface-cli download OpenGVLab/InternViT-300M-448px
!huggingface-cli download IDEA-Research/grounding-dino-base
!huggingface-cli download OpenGVLab/InternVL2-26B # for dataset creation

In [None]:
"""
Example: Training FLUX.1-dev for text-to-image generation with identity preservation.
All base FLUX weights are frozen; only new face/body modules are trainable.

Requirements:
    pip install diffusers transformers accelerate xformers huggingface_hub
    pip install insightface opencv-python
    pip install timm  # for InternViT
    pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118  # if using CUDA
    pip install requests pillow  # for image loading
    pip install "git+https://github.com/IDEA-Research/GroundingDINO.git"
"""

import os, requests, cv2, torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from PIL import Image
from torch.utils.data import Dataset, DataLoader

from diffusers import DiffusionPipeline
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import insightface
from transformers import AutoModel, AutoFeatureExtractor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =====================================================================
# 1) LOAD & FREEZE THE FLUX.1-DEV PIPELINE
# =====================================================================
print("Loading and freezing FLUX.1-dev ...")
flux_pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32
).to(device)

if hasattr(flux_pipeline, "unet"):
    for param in flux_pipeline.unet.parameters():
        param.requires_grad = False

if hasattr(flux_pipeline, "vae"):
    for param in flux_pipeline.vae.parameters():
        param.requires_grad = False

if hasattr(flux_pipeline, "text_encoder"):
    for param in flux_pipeline.text_encoder.parameters():
        param.requires_grad = False

# =====================================================================
# 2) GROUNDING DINO FOR BODY BOUNDING BOX
#    (Using your snippet-based approach)
# =====================================================================
print("Loading GroundingDINO (IDEA-Research/grounding-dino-base) ...")
g_model_id = "IDEA-Research/grounding-dino-base"
g_processor = AutoProcessor.from_pretrained(g_model_id)
g_model = AutoModelForZeroShotObjectDetection.from_pretrained(g_model_id).to(device)

def detect_body_bbox_pil(pil_image, query_text="a person.", box_threshold=0.4, text_threshold=0.3):
    """
    Given a PIL image, uses GroundingDINO to detect bounding boxes
    for the given query (default: "a person.").
    Returns the bounding box (xmin, ymin, xmax, ymax) of the largest box found,
    or None if no detection.
    """
    # Convert to device
    inputs = g_processor(images=pil_image, text=query_text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = g_model(**inputs)

    # Post-process to get final boxes
    results = g_processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        box_threshold=box_threshold,
        text_threshold=text_threshold,
        target_sizes=[pil_image.size[::-1]]  # (height, width)
    )
    if not results or len(results[0]["boxes"]) == 0:
        return None

    # Find largest bounding box
    boxes = results[0]["boxes"]  # (N, 4) in xyxy format
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    max_idx = torch.argmax(areas)
    box = boxes[max_idx].tolist()  # [xmin, ymin, xmax, ymax]
    return tuple(map(int, box))


# =====================================================================
# 3) INTERNVIT FOR BODY EMBEDDINGS
# =====================================================================
print("Loading InternViT-300M-448px ...")
internvit_model = AutoModel.from_pretrained("OpenGVLab/InternViT-300M-448px", trust_remote_code=True).to(device)
internvit_extractor = AutoFeatureExtractor.from_pretrained("OpenGVLab/InternViT-300M-448px", trust_remote_code=True)
internvit_model.eval()
for p in internvit_model.parameters():
    p.requires_grad = False

def extract_body_embedding_pil(pil_image, bbox=None):
    """
    Crops the bounding box region from a PIL image and extracts a body embedding.
    If bbox is None, use the entire image.
    """
    if bbox is not None:
        xmin, ymin, xmax, ymax = bbox
        pil_crop = pil_image.crop((xmin, ymin, xmax, ymax))
    else:
        pil_crop = pil_image

    inputs = internvit_extractor(images=pil_crop, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = internvit_model(**inputs)
    # Typically, we use outputs.pooler_output or the CLS token
    if hasattr(outputs, "pooler_output"):
        emb = outputs.pooler_output  # (1, hidden_dim)
    else:
        emb = outputs.last_hidden_state[:, 0, :]  # CLS token
    return emb  # shape (1, hidden_dim)


# =====================================================================
# 4) INSIGHTFACE FOR FACE EMBEDDINGS
# =====================================================================
print("Initializing InsightFace ...")
face_analysis = insightface.app.FaceAnalysis()
ctx_id = 0 if device.type == 'cuda' else -1
face_analysis.prepare(ctx_id=ctx_id, det_size=(640, 640))

def extract_face_embedding_pil(pil_image):
    """
    Convert PIL to BGR numpy, pass to insightface, return 512-dim face embedding.
    If no face is found, returns zeros.
    """
    np_img = np.array(pil_image)[:, :, ::-1]  # RGB -> BGR
    faces = face_analysis.get(np_img)
    if len(faces) == 0:
        return torch.zeros((1, 512), device=device)
    face = faces[0]
    emb = face.normed_embedding  # (512,)
    return torch.tensor(emb, dtype=torch.float32, device=device).unsqueeze(0)


# =====================================================================
# 5) NEW MODULES: PERCEIVER + CROSS-ATTENTION
# =====================================================================
class CrossAttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x, context):
        # x, context: (B, seq_len, embed_dim)
        attn_output, _ = self.attn(x, context, context)
        x = self.norm1(x + attn_output)
        mlp_output = self.mlp(x)
        x = self.norm2(x + mlp_output)
        return x

class PerceiverResampler(nn.Module):
    def __init__(self, in_dim, out_dim, num_tokens=8):
        super().__init__()
        self.num_tokens = num_tokens
        self.proj = nn.Linear(in_dim, out_dim * num_tokens)

    def forward(self, x):
        """
        x: (B, in_dim) -> (B, num_tokens, out_dim)
        """
        B = x.shape[0]
        out = self.proj(x).view(B, self.num_tokens, -1)
        return out


# =====================================================================
# 6) WRAPPER: Freeze FLUX, Insert New Modules
# =====================================================================
class FluxFrozenWrapper(nn.Module):
    """
    We assume flux_pipeline has a .unet, .vae, .tokenizer, .text_encoder, etc.
    Everything in flux_pipeline is frozen. 
    We add face/body resamplers + cross-attention blocks as trainable.
    """
    def __init__(self, flux_pipe, embed_dim=768, num_heads=8, face_dim=512, body_dim=768):
        super().__init__()
        self.flux_pipe = flux_pipe  # already frozen externally

        # Trainable modules
        self.face_resampler = PerceiverResampler(face_dim, embed_dim, num_tokens=8)
        self.body_resampler = PerceiverResampler(body_dim, embed_dim, num_tokens=8)
        self.face_cross_attn = CrossAttentionBlock(embed_dim, num_heads)
        self.body_cross_attn = CrossAttentionBlock(embed_dim, num_heads)

    def forward_unet(self, latents, t, text_embeddings, face_tokens, body_tokens):
        """
        Example approach to integrate with the frozen UNet.
        This depends heavily on the actual FLUX.1-dev code.
        We'll do a simplified example using .unet(...) from diffusers.
        """
        # 1) Standard UNet forward
        #    encoder_hidden_states = text_embeddings
        unet_out = self.flux_pipe.unet(
            latents, t, encoder_hidden_states=text_embeddings
        ).sample  # shape: (B, 4, H, W) if it's SD-like

        # 2) Flatten/reshape for cross-attn
        B, C, H, W = unet_out.shape
        # Make a seq_len dimension: (B, seq_len, embed_dim)
        # We'll pretend embed_dim = C, seq_len = H*W for demonstration:
        unet_out_reshaped = unet_out.view(B, C*H*W).unsqueeze(1).contiguous()

        # 3) Face cross-attention
        face_attended = self.face_cross_attn(unet_out_reshaped, face_tokens)

        # 4) Body cross-attention
        body_attended = self.body_cross_attn(face_attended, body_tokens)

        # 5) Reshape back
        body_attended = body_attended.view(B, C, H, W)
        return body_attended

    def forward(self, latents, t, text_embeddings, face_emb, body_emb):
        """
        latents: (B, 4, H, W)  # e.g. stable-diffusion-like latents
        t: (B,) or scalar diffusion timestep
        text_embeddings: (B, seq_len, embed_dim)
        face_emb: (B, 512)
        body_emb: (B, 768)
        """
        # Resample face/body
        face_tokens = self.face_resampler(face_emb)  # (B, 8, embed_dim)
        body_tokens = self.body_resampler(body_emb)  # (B, 8, embed_dim)

        # Forward through the UNet with extra cross-attn
        out_latents = self.forward_unet(latents, t, text_embeddings, face_tokens, body_tokens)
        return out_latents


# =====================================================================
# 7) EXAMPLE DATASET
# =====================================================================
class ExampleCharacterDataset(Dataset):
    """
    Minimal example: each item has:
      - image_path: path to the reference image
      - prompt: text prompt
    We'll generate random latents and timesteps for demonstration.
    """
    def __init__(self, image_paths, prompts):
        self.image_paths = image_paths
        self.prompts = prompts

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        pil_image = Image.open(img_path).convert("RGB")
        prompt = self.prompts[idx]

        latents = torch.randn((4, 64, 64))  # mock latents
        t = torch.randint(0, 1000, (1,))    # mock diffusion timestep
        return {
            "pil_image": pil_image,
            "prompt": prompt,
            "latents": latents,
            "timestep": t
        }


# =====================================================================
# 8) TRAINING FUNCTION
# =====================================================================
def diffusion_loss_fn(pred_latents, target_latents):
    """
    Typical MSE on latents (mock).
    In real usage, you'd have a target latents or predicted noise approach.
    """
    return F.mse_loss(pred_latents, target_latents)

def train_identity_preservation(
    flux_wrapper,
    flux_pipeline,
    dataloader,
    epochs=1,
    lr=1e-4,
    lambda_face=1.0,
    lambda_body=1.0
):
    """
    - flux_wrapper: The model with new face/body modules.
    - flux_pipeline: The original pipeline (frozen).
    - dataloader: yields reference images, prompts, latents, timesteps
    - We do a simple loop computing:
        total_loss = diffusion_loss + lambda_face * face_loss + lambda_body * body_loss
      Only the new modules update.
    """
    # Collect only trainable params (the new modules)
    trainable_params = [p for p in flux_wrapper.parameters() if p.requires_grad]
    optimizer = optim.Adam(trainable_params, lr=lr)

    flux_wrapper.train()
    flux_wrapper.to(device)

    # Confirm the base pipeline is frozen
    for name, p in flux_pipeline.named_parameters():
        assert p.requires_grad is False, f"Parameter {name} should be frozen!"

    for epoch in range(epochs):
        for step_idx, batch in enumerate(dataloader):
            pil_image = batch["pil_image"]
            prompt = batch["prompt"]
            latents = batch["latents"].to(device, dtype=torch.float32)
            t = batch["timestep"].to(device)

            # -------------------------
            # 1) TEXT ENCODING
            # -------------------------
            if hasattr(flux_pipeline, "tokenizer") and hasattr(flux_pipeline, "text_encoder"):
                text_in = flux_pipeline.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(device)
                with torch.no_grad():
                    text_out = flux_pipeline.text_encoder(**text_in)
                if hasattr(text_out, "last_hidden_state"):
                    text_embeddings = text_out.last_hidden_state  # (B, seq_len, embed_dim)
                else:
                    text_embeddings = text_out
            else:
                # Fallback if FLUX.1-dev pipeline doesn't have a built-in text encoder
                # We'll just use random embeddings
                b_size = latents.shape[0]
                text_embeddings = torch.randn((b_size, 77, 768), device=device)

            # -------------------------
            # 2) FACE & BODY EMBEDDINGS
            # -------------------------
            face_emb_list = []
            body_emb_list = []
            for b in range(latents.shape[0]):
                img = pil_image[b] if isinstance(pil_image, list) else pil_image
                # detect bounding box
                bbox = detect_body_bbox_pil(img, query_text="a person.")
                # extract body emb
                body_emb = extract_body_embedding_pil(img, bbox)
                body_emb_list.append(body_emb)

                # extract face emb
                face_emb = extract_face_embedding_pil(img)
                face_emb_list.append(face_emb)

            face_emb_tensor = torch.cat(face_emb_list, dim=0).to(device)  # (B, 512)
            body_emb_tensor = torch.cat(body_emb_list, dim=0).to(device)  # (B, 768)

            # -------------------------
            # 3) FORWARD PASS
            # -------------------------
            pred_latents = flux_wrapper(
                latents,
                t,
                text_embeddings,
                face_emb_tensor,
                body_emb_tensor
            )

            # For demonstration, let's assume target latents = latents
            diff_loss = diffusion_loss_fn(pred_latents, latents)

            # -------------------------
            # 4) DECODE & IDENTITY LOSS
            # -------------------------
            # Convert pred_latents -> image, to compute face/body embeddings on the output
            # Typically with Stable Diffusion-like pipelines, you'd do:
            with torch.no_grad():
                # flux_pipeline.vae expects half precision if it's in half
                if pred_latents.dtype != flux_pipeline.vae.dtype:
                    pred_latents = pred_latents.to(flux_pipeline.vae.dtype)
                decoded = flux_pipeline.vae.decode(pred_latents).sample  # (B, 3, H, W)

            # For each image in the batch, compute face/body embedding
            face_loss_val = 0.0
            body_loss_val = 0.0
            b_sz = decoded.shape[0]
            for b in range(b_sz):
                # Convert to PIL for consistency
                # (3, H, W) -> (H, W, 3)
                img_np = decoded[b].detach().cpu().float().numpy()
                img_min, img_max = img_np.min(), img_np.max()
                img_np = (img_np - img_min) / (img_max - img_min + 1e-8)  # [0,1]
                img_np = (img_np * 255).astype(np.uint8)
                img_np = np.transpose(img_np, (1,2,0))  # HWC
                pil_decoded = Image.fromarray(img_np)

                # Face embedding
                gen_face_emb = extract_face_embedding_pil(pil_decoded)
                f_loss = F.mse_loss(gen_face_emb, face_emb_tensor[b:b+1])
                face_loss_val += f_loss.item()

                # Body embedding
                dec_bbox = detect_body_bbox_pil(pil_decoded, query_text="a person.")
                gen_body_emb = extract_body_embedding_pil(pil_decoded, dec_bbox)
                b_loss = F.mse_loss(gen_body_emb, body_emb_tensor[b:b+1])
                body_loss_val += b_loss.item()

            face_loss_val /= b_sz
            body_loss_val /= b_sz

            # Turn them into tensors that require grad (so they backprop into cross-attn)
            face_loss = torch.tensor(face_loss_val, device=device, requires_grad=True)
            body_loss = torch.tensor(body_loss_val, device=device, requires_grad=True)

            total_loss = diff_loss + lambda_face * face_loss + lambda_body * body_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            if step_idx % 5 == 0:
                print(f"Epoch {epoch} | Step {step_idx} | "
                      f"Diff={diff_loss.item():.4f} | Face={face_loss_val:.4f} | Body={body_loss_val:.4f} | "
                      f"Total={total_loss.item():.4f}")


# =====================================================================
# 9) MAIN / DEMO
# =====================================================================
def main():
    # Minimal dataset: adapt to your actual data
    image_paths = ["./example1.jpg", "./example2.jpg"]  # images of your character
    prompts = ["Character in a futuristic city", "Character on the beach"]

    dataset = ExampleCharacterDataset(image_paths, prompts)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

    # Wrap FLUX with new modules
    flux_wrapper = FluxFrozenWrapper(
        flux_pipeline,
        embed_dim=768,
        num_heads=8,
        face_dim=512,
        body_dim=768
    ).to(device)

    # Ensure base FLUX is still frozen
    for n, p in flux_pipeline.named_parameters():
        assert not p.requires_grad, f"Param {n} should be frozen!"
    # The only trainable parameters are in flux_wrapper's new modules
    for n, p in flux_wrapper.named_parameters():
        print(f"{n} | requires_grad={p.requires_grad}")

    # Train
    train_identity_preservation(
        flux_wrapper,
        flux_pipeline,
        dataloader,
        epochs=1,
        lr=1e-4,
        lambda_face=1.0,
        lambda_body=1.0
    )

if __name__ == "__main__":
    main()
