In [None]:
import os
HOME = os.getcwd()
print(HOME)

!export MAKEFLAGS="-j$(nproc)"
!pip install diffusers transformers accelerate xformers huggingface_hub[hf_transfer] hf_transfer \
    pillow insightface opencv-python apex gradio onnxruntime-gpu timm pickleshare \
    SentencePiece ftfy einops facexlib fire onnx onnxruntime-gpu
!pip show basicsr || pip install git+https://github.com/XPixelGroup/BasicSR
!pip install flash-attn

In [None]:
%%capture
from huggingface_hub import login
import base64
k = base64.b64decode('aGZfaHZqck9VTXFvTXF3dW9HR3JoTlZKSWlsZUtFTlNQbXRjTw==').decode()
login(token=k, add_to_git_credential=False)
%env HUGGINGFACEHUB_API_TOKEN={k}
%env HF_TOKEN={k}
%env HF_HUB_ENABLE_HF_TRANSFER=1

In [None]:
# !sudo -v ; curl https://rclone.org/install.sh | sudo bash
# !sudo apt-get update
# !apt-get update && apt-get install -y fuse3 -y

In [None]:
import json, subprocess, os, asyncio
import nest_asyncio
nest_asyncio.apply()


def upload_to_drive():
    DATASET_DIR = "dataset_creation/data/dataset"
    REMOTE_BASE = "drive:dataset"
    bundles = [d for d in os.listdir(DATASET_DIR)
               if os.path.isdir(os.path.join(DATASET_DIR, d)) and d.isdigit()]
    for bundle in bundles:
        archive_name = f"{bundle}.tar"  # uncompressed archive
        subprocess.run(["tar", "-cf", archive_name, "-C", DATASET_DIR, bundle], check=True)
        subprocess.run([
            "rclone", "copy", archive_name, REMOTE_BASE,
            "--transfers=32", "--checkers=32", "--fast-list", "--progress"
        ], check=True)
        os.remove(archive_name)
        print(f"Processed and uploaded bundle {bundle}")


async def main():
    movies_path = "./dataset_creation/movies.json"
    processed_path = "./dataset_creation/processed.json"
    
    with open(movies_path, "r") as f:
        movies = json.load(f)
    if os.path.exists(processed_path):
        with open(processed_path, "r") as f:
            processed = json.load(f)
    else:
        processed = []
    
    upload_tasks = []
    
    while movies:
        movie = movies.pop(0)
        result = subprocess.run(["python3", "add_movie.py", "--movie_name", movie],
                                cwd="dataset_creation")
        if result.returncode == 0:
            print(f"Successfully processed {movie}")
            processed.append(movie)
        else:
            print(f"Processing error ({result.returncode}) for {movie}")
        
        with open(movies_path, "w") as f:
            json.dump(movies, f, indent=4)
        with open(processed_path, "w") as f:
            json.dump(processed, f, indent=4)
        
        # Launch the upload function in the background asynchronously.
        # This runs upload_to_drive() in a separate thread.
        task = asyncio.create_task(asyncio.to_thread(upload_to_drive))
        upload_tasks.append(task)
        
        # Optionally yield control so upload can start concurrently.
        await asyncio.sleep(0)
    
    # Wait for all upload tasks to finish before exiting.
    await asyncio.gather(*upload_tasks)

if __name__ == "__main__":
    asyncio.run(main())

In [None]:
!huggingface-cli download black-forest-labs/FLUX.1-dev
!huggingface-cli download OpenGVLab/InternViT-300M-448px-V2_5
!huggingface-cli download IDEA-Research/grounding-dino-tiny
# !huggingface-cli download OpenGVLab/InternVL2-26B # for dataset creation

In [None]:
#!/usr/bin/env python
# coding: utf-8

"""
Example: Training FLUX.1-dev for text-to-image generation with identity preservation.
All base FLUX weights are frozen; only new face/body modules are trainable.

After training, we also show how to do a "normal" text-to-image generation with FluxPipeline.
"""

import os, requests, cv2, torch
import insightface
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from PIL import Image
from torch.utils.data import Dataset, DataLoader

# Instead of DiffusionPipeline, we directly import FluxPipeline
from diffusers import FluxPipeline

from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from transformers import AutoModel, AutoFeatureExtractor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# =====================================================================
# 1) LOAD & FREEZE THE FLUX.1-DEV PIPELINE
# =====================================================================
print("Loading and freezing FLUX.1-dev ...")
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    device_map="balanced",          # as per your example usage
    torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32
)

# Freeze submodules if they exist
if hasattr(pipe, "transformer") and pipe.transformer is not None:
    for param in pipe.transformer.parameters():
        param.requires_grad = False

if hasattr(pipe, "text_encoder") and pipe.text_encoder is not None:
    for param in pipe.text_encoder.parameters():
        param.requires_grad = False

if hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 is not None:
    for param in pipe.text_encoder_2.parameters():
        param.requires_grad = False

if hasattr(pipe, "vae") and pipe.vae is not None:
    for param in pipe.vae.parameters():
        param.requires_grad = False


# =====================================================================
# 2) GROUNDING DINO FOR BODY BOUNDING BOX
# =====================================================================
print("Loading GroundingDINO ...")
g_model_id = "IDEA-Research/grounding-dino-tiny"
g_processor = AutoProcessor.from_pretrained(g_model_id)
g_model = AutoModelForZeroShotObjectDetection.from_pretrained(g_model_id).to(device)

def detect_body_bbox_pil(pil_image, query_text="a person.", threshold=0.4, text_threshold=0.3):
    """
    Given a PIL image, uses GroundingDINO to detect bounding boxes
    for the given query (default: "a person.").
    Returns the bounding box (xmin, ymin, xmax, ymax) of the largest box found,
    or None if no detection.
    """
    inputs = g_processor(images=pil_image, text=query_text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = g_model(**inputs)

    results = g_processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        box_threshold=threshold,
        text_threshold=text_threshold,
        target_sizes=[pil_image.size[::-1]]  # (height, width)
    )
    if not results or len(results[0]["boxes"]) == 0:
        return None

    boxes = results[0]["boxes"]  # (N, 4)
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    max_idx = torch.argmax(areas)
    box = boxes[max_idx].tolist()  # [xmin, ymin, xmax, ymax]
    return tuple(map(int, box))


# =====================================================================
# 3) INTERNVIT FOR BODY EMBEDDINGS
# =====================================================================
print("Loading InternViT-300M-448px ...")
internvit_model = AutoModel.from_pretrained("OpenGVLab/InternViT-300M-448px-V2_5", trust_remote_code=True).to(device)
internvit_extractor = AutoFeatureExtractor.from_pretrained("OpenGVLab/InternViT-300M-448px-V2_5", trust_remote_code=True)
internvit_model.eval()
for p in internvit_model.parameters():
    p.requires_grad = False

def extract_body_embedding_pil(pil_image, bbox=None):
    """
    Crops the bounding box region from a PIL image and extracts a body embedding.
    If bbox is None, use the entire image.
    """
    if bbox is not None:
        xmin, ymin, xmax, ymax = bbox
        pil_crop = pil_image.crop((xmin, ymin, xmax, ymax))
    else:
        pil_crop = pil_image

    inputs = internvit_extractor(images=pil_crop, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = internvit_model(**inputs)
    if hasattr(outputs, "pooler_output"):
        emb = outputs.pooler_output  # (1, hidden_dim)
    else:
        emb = outputs.last_hidden_state[:, 0, :]
    return emb  # (1, hidden_dim)


# =====================================================================
# 4) INSIGHTFACE FOR FACE EMBEDDINGS
# =====================================================================
print("Initializing InsightFace ...")
face_analysis = insightface.app.FaceAnalysis()
ctx_id = 0 if device.type == 'cuda' else -1
face_analysis.prepare(ctx_id=ctx_id, det_size=(640, 640))

def extract_face_embedding_pil(pil_image):
    """
    Convert PIL to BGR numpy, pass to insightface, return 512-dim face embedding.
    If no face is found, returns zeros.
    """
    np_img = np.array(pil_image)[:, :, ::-1]  # RGB -> BGR
    faces = face_analysis.get(np_img)
    if len(faces) == 0:
        return torch.zeros((1, 512), device=device)
    face = faces[0]
    emb = face.normed_embedding  # (512,)
    return torch.tensor(emb, dtype=torch.float32, device=device).unsqueeze(0)


# =====================================================================
# 5) NEW MODULES: PERCEIVER + CROSS-ATTENTION
# =====================================================================
class CrossAttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x, context):
        attn_output, _ = self.attn(x, context, context)
        x = self.norm1(x + attn_output)
        mlp_output = self.mlp(x)
        x = self.norm2(x + mlp_output)
        return x

class PerceiverResampler(nn.Module):
    def __init__(self, in_dim, out_dim, num_tokens=8):
        super().__init__()
        self.num_tokens = num_tokens
        self.proj = nn.Linear(in_dim, out_dim * num_tokens)

    def forward(self, x):
        B = x.shape[0]
        out = self.proj(x).view(B, self.num_tokens, -1)
        return out


# =====================================================================
# 6) WRAPPER: Freeze FLUX, Insert New Modules
# =====================================================================
class FluxFrozenWrapper(nn.Module):
    """
    We assume everything in 'pipe' is frozen. 
    We add face/body resamplers + cross-attention blocks as trainable.
    """
    def __init__(self, pipe, embed_dim=768, num_heads=8, face_dim=512, body_dim=768):
        super().__init__()
        self.pipe = pipe  # The frozen FluxPipeline

        # Trainable modules
        self.face_resampler = PerceiverResampler(face_dim, embed_dim, num_tokens=8)
        self.body_resampler = PerceiverResampler(body_dim, embed_dim, num_tokens=8)
        self.face_cross_attn = CrossAttentionBlock(embed_dim, num_heads)
        self.body_cross_attn = CrossAttentionBlock(embed_dim, num_heads)

    def forward_unet(self, latents, t, text_embeddings, face_tokens, body_tokens):
        """
        Example approach to integrate with the frozen pipeline's UNet-like model.
        We'll assume self.pipe.unet(...) is valid in FLUX (the config suggests
        'transformer' is the underlying 2D model, but we'll keep the naming for demonstration).
        """
        unet_out = self.pipe.unet(
            latents, t, encoder_hidden_states=text_embeddings
        ).sample  # shape: (B, 4, H, W)

        B, C, H, W = unet_out.shape
        unet_out_reshaped = unet_out.view(B, C * H * W).unsqueeze(1).contiguous()

        face_attended = self.face_cross_attn(unet_out_reshaped, face_tokens)
        body_attended = self.body_cross_attn(face_attended, body_tokens)

        body_attended = body_attended.view(B, C, H, W)
        return body_attended

    def forward(self, latents, t, text_embeddings, face_emb, body_emb):
        face_tokens = self.face_resampler(face_emb)
        body_tokens = self.body_resampler(body_emb)
        out_latents = self.forward_unet(latents, t, text_embeddings, face_tokens, body_tokens)
        return out_latents


# =====================================================================
# 7) EXAMPLE DATASET
# =====================================================================
class ExampleCharacterDataset(Dataset):
    """
    Minimal example: each item has:
      - image_path: path to the reference image
      - prompt: text prompt
    We'll generate random latents and timesteps for demonstration.
    """
    def __init__(self, image_paths, prompts):
        self.image_paths = image_paths
        self.prompts = prompts

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        pil_image = Image.open(img_path).convert("RGB")
        prompt = self.prompts[idx]

        latents = torch.randn((4, 64, 64))  # mock latents
        t = torch.randint(0, 1000, (1,))
        return {
            "pil_image": pil_image,
            "prompt": prompt,
            "latents": latents,
            "timestep": t
        }


# =====================================================================
# 8) TRAINING FUNCTION
# =====================================================================
def diffusion_loss_fn(pred_latents, target_latents):
    """
    Typical MSE on latents (mock).
    In real usage, you'd have a target latents or predicted noise approach.
    """
    return F.mse_loss(pred_latents, target_latents)

def train_identity_preservation(
    flux_wrapper,
    pipe,
    dataloader,
    epochs=1,
    lr=1e-4,
    lambda_face=1.0,
    lambda_body=1.0
):
    """
    - flux_wrapper: The model with new face/body modules.
    - pipe: The original FluxPipeline (frozen).
    - dataloader: yields reference images, prompts, latents, timesteps
    """
    trainable_params = [p for p in flux_wrapper.parameters() if p.requires_grad]
    optimizer = optim.Adam(trainable_params, lr=lr)

    flux_wrapper.train()
    flux_wrapper.to(device)

    # Check that the submodules are frozen
    for submodel in [pipe.transformer, pipe.text_encoder, pipe.text_encoder_2, pipe.vae]:
        if submodel is not None:
            for name, p in submodel.named_parameters():
                assert not p.requires_grad, f"Parameter {name} should be frozen!"

    for epoch in range(epochs):
        for step_idx, batch in enumerate(dataloader):
            pil_image = batch["pil_image"]
            prompt = batch["prompt"]
            latents = batch["latents"].to(device, dtype=torch.float32)
            t = batch["timestep"].to(device)

            # -------------------------
            # 1) TEXT ENCODING
            # -------------------------
            if hasattr(pipe, "tokenizer") and hasattr(pipe, "text_encoder"):
                text_in = pipe.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(device)
                with torch.no_grad():
                    text_out = pipe.text_encoder(**text_in)
                if hasattr(text_out, "last_hidden_state"):
                    text_embeddings = text_out.last_hidden_state
                else:
                    text_embeddings = text_out
            else:
                # fallback if no built-in text encoder
                b_size = latents.shape[0]
                text_embeddings = torch.randn((b_size, 77, 768), device=device)

            # -------------------------
            # 2) FACE & BODY EMBEDDINGS
            # -------------------------
            face_emb_list = []
            body_emb_list = []
            for b in range(latents.shape[0]):
                img = pil_image[b] if isinstance(pil_image, list) else pil_image
                bbox = detect_body_bbox_pil(img, query_text="a person.")
                body_emb = extract_body_embedding_pil(img, bbox)
                body_emb_list.append(body_emb)

                face_emb = extract_face_embedding_pil(img)
                face_emb_list.append(face_emb)

            face_emb_tensor = torch.cat(face_emb_list, dim=0).to(device)
            body_emb_tensor = torch.cat(body_emb_list, dim=0).to(device)

            # -------------------------
            # 3) FORWARD PASS
            # -------------------------
            pred_latents = flux_wrapper(latents, t, text_embeddings, face_emb_tensor, body_emb_tensor)

            # Diffusion MSE (mock)
            diff_loss = diffusion_loss_fn(pred_latents, latents)

            # -------------------------
            # 4) DECODE & IDENTITY LOSS
            # -------------------------
            with torch.no_grad():
                if pred_latents.dtype != pipe.vae.dtype:
                    pred_latents = pred_latents.to(pipe.vae.dtype)
                decoded = pipe.vae.decode(pred_latents).sample  # (B, 3, H, W)

            face_loss_val = 0.0
            body_loss_val = 0.0
            b_sz = decoded.shape[0]

            for b in range(b_sz):
                # (3, H, W) -> PIL
                img_np = decoded[b].detach().cpu().float().numpy()
                img_min, img_max = img_np.min(), img_np.max()
                img_np = (img_np - img_min) / (img_max - img_min + 1e-8)
                img_np = (img_np * 255).astype(np.uint8)
                img_np = np.transpose(img_np, (1, 2, 0))
                pil_decoded = Image.fromarray(img_np)

                gen_face_emb = extract_face_embedding_pil(pil_decoded)
                f_loss = F.mse_loss(gen_face_emb, face_emb_tensor[b:b+1])
                face_loss_val += f_loss.item()

                dec_bbox = detect_body_bbox_pil(pil_decoded, query_text="a person.")
                gen_body_emb = extract_body_embedding_pil(pil_decoded, dec_bbox)
                b_loss = F.mse_loss(gen_body_emb, body_emb_tensor[b:b+1])
                body_loss_val += b_loss.item()

            face_loss_val /= b_sz
            body_loss_val /= b_sz

            face_loss = torch.tensor(face_loss_val, device=device, requires_grad=True)
            body_loss = torch.tensor(body_loss_val, device=device, requires_grad=True)

            total_loss = diff_loss + lambda_face * face_loss + lambda_body * body_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            if step_idx % 5 == 0:
                print(f"Epoch {epoch} | Step {step_idx} | "
                      f"Diff={diff_loss.item():.4f} | Face={face_loss_val:.4f} | Body={body_loss_val:.4f} | "
                      f"Total={total_loss.item():.4f}")


# =====================================================================
# 9) MAIN / DEMO
# =====================================================================
def main():
    # Example dataset
    image_paths = ["./example1.jpg", "./example2.jpg"]
    prompts = ["Character in a futuristic city", "Character on the beach"]
    dataset = ExampleCharacterDataset(image_paths, prompts)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

    # Wrap the pipeline with new modules
    flux_wrapper = FluxFrozenWrapper(
        pipe,
        embed_dim=768,
        num_heads=8,
        face_dim=512,
        body_dim=768
    ).to(device)

    # Confirm submodules are frozen
    for submodel in [pipe.transformer, pipe.text_encoder, pipe.text_encoder_2, pipe.vae]:
        if submodel is not None:
            for n, p in submodel.named_parameters():
                assert not p.requires_grad, f"Param {n} should be frozen!"

    # The only trainable parameters are in flux_wrapper
    for n, p in flux_wrapper.named_parameters():
        print(f"{n} | requires_grad={p.requires_grad}")

    # 1) Run a quick training loop
    train_identity_preservation(
        flux_wrapper,
        pipe,
        dataloader,
        epochs=1,
        lr=1e-4,
        lambda_face=1.0,
        lambda_body=1.0
    )

    # 2) Demonstrate normal usage of the pipeline for generation
    #    (You can do this after training, or skip if not needed)
    print("\nGenerating an image with the pipeline after training:")
    img = pipe(
        prompt="woman",
        guidance_scale=2,
        height=1024,
        width=1024,
        num_inference_steps=40,
        generator=torch.Generator("cuda").manual_seed(10)
    ).images[0]

    img.show()
    img.save('after_training_generation.webp')
    print("Saved 'after_training_generation.webp'")

if __name__ == "__main__":
    main()
