In [1]:
# @title 1. Initialize Professional Directory Structure
import os

ROOT = "/content/MultiSubjectGen_Pro"
DIRS = [
    f"{ROOT}/data/cat_toy",
    f"{ROOT}/data/red_mug",
    f"{ROOT}/output",
    f"{ROOT}/checkpoints",
    f"{ROOT}/src",
    f"{ROOT}/experiments",
    f"{ROOT}/examples",
    f"{ROOT}/docs"
]

for d in DIRS:
    os.makedirs(d, exist_ok=True)

print("Directory structure created successfully.")

# create requirements.txt
reqs = """
diffusers>=0.24.0
transformers
accelerate
peft
safetensors
opencv-python
invisible-watermark
torchmetrics
insightface
onnxruntime-gpu
"""
with open(f"{ROOT}/requirements.txt", "w") as f:
    f.write(reqs)

# Install Depencies
!pip install -q -r {ROOT}/requirements.txt
!git clone https://github.com/huggingface/diffusers /content/diffusers
!pip install -e /content/diffusers

Directory structure created successfully.
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m439.5/439.5 kB[0m [31m33.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m41.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m60.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m300.5/300.5 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m120.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86

In [49]:
%%writefile /content/MultiSubjectGen_Pro/src/config.py
import os

class Config:
    PROJECT_ROOT = "/content/MultiSubjectGen_Pro"
    MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"

    # training params
    TRAIN_STEPS = 500
    LEARNING_RATE = 1e-4

    # QA threshold (automatic feedback mechanism in proposal)
    QA_PASS_THRESHOLD = 23.0 # CLIP Score threshold
    MAX_RETRIES = 2

    # Data definition
    SUBJECTS = [
        {
            "name": "cat_toy",
            "token": "sks",
            "class": "cat",
            "data": os.path.join(PROJECT_ROOT, "data/cat_toy"),
            "lora_out": os.path.join(PROJECT_ROOT, "checkpoints/lora_cat")
        },
        {
            "name": "red_mug",
            "token": "trk",
            "class": "mug",
            "data": os.path.join(PROJECT_ROOT, "data/red_mug"),
            "lora_out": os.path.join(PROJECT_ROOT, "checkpoints/lora_mug")
        }
    ]

Overwriting /content/MultiSubjectGen_Pro/src/config.py


In [65]:
%%writefile /content/MultiSubjectGen_Pro/src/spatial_layout.py
from PIL import Image, ImageDraw

class LayoutGenerator:
    def __init__(self, width=1024, height=1024):
        self.width = width
        self.height = height

    def get_dual_subject_layout(self):
        base_img = Image.new("RGB", (self.width, self.height), "white")

        # Mask 1: left(cat)
        mask1 = Image.new("L", (self.width, self.height), 0)
        draw1 = ImageDraw.Draw(mask1)
        draw1.rectangle([50, 200, 500, 950], fill=255)

        # Mask 2: right(mug)
        mask2 = Image.new("L", (self.width, self.height), 0)
        draw2 = ImageDraw.Draw(mask2)
        draw2.rectangle([524, 200, 980, 950], fill=255)

        return base_img, [mask1, mask2]

Overwriting /content/MultiSubjectGen_Pro/src/spatial_layout.py


In [64]:
%%writefile /content/MultiSubjectGen_Pro/src/multi_subject_pipeline.py
import torch
from diffusers import StableDiffusionXLInpaintPipeline, AutoencoderKL, EulerDiscreteScheduler
from .config import Config
from .evaluation import Evaluator
from .spatial_layout import LayoutGenerator
from PIL import Image, ImageFilter

class MultiSubjectPipeline:
    def __init__(self, device="cuda"):
        vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)

        self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
            Config.MODEL_NAME, vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
        ).to(device)
        self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)

        self.layout_gen = LayoutGenerator()
        self.evaluator = Evaluator()
        self.device = device

    def load_loras(self):
        print(">>> Loading LoRAs...")
        for subj in Config.SUBJECTS:
            try:
                self.pipe.load_lora_weights(subj['lora_out'], weight_name="pytorch_lora_weights.safetensors", adapter_name=subj['name'])
            except: pass

    def generate_with_qa_loop(self):
        base_img, masks = self.layout_gen.get_dual_subject_layout()
        masks = [m.filter(ImageFilter.GaussianBlur(radius=30)) for m in masks]

        print("--- Phase 1: Generating Coherent Global Scene ---")

        self.pipe.disable_lora()

        global_prompt = "a wide shot of a cat sitting on the left and a red mug on the right on a continuous dark walnut wooden table, sunlit living room, bokeh background, highly detailed, 4k, photorealistic"

        empty_bg = Image.new("RGB", (1024, 1024), "gray")
        full_mask = Image.new("L", (1024, 1024), 255)

        current_img = self.pipe(
            prompt=global_prompt,
            negative_prompt="split view, collage, watermark, text, drawing",
            image=empty_bg, mask_image=full_mask,
            num_inference_steps=30, strength=1.0, guidance_scale=7.5
        ).images[0]

        print(">>> Global scene generated. Now injecting identities...")

        # Phase 2: Identity Injection

        self.pipe.enable_lora()

        for i, subj in enumerate(Config.SUBJECTS):
            print(f"--- Injecting {subj['name']} into scene ---")

            self.pipe.set_adapters([subj['name']], adapter_weights=[0.9])

            prompt = f"a photo of {subj['token']} {subj['class']}, sitting on a dark walnut wooden table, realistic"


            current_img = self.pipe(
                prompt=prompt,
                negative_prompt="blur, bad anatomy, ghost, transparent",
                image=current_img,
                mask_image=masks[i],
                num_inference_steps=35,
                strength=0.75,
                guidance_scale=7.5
            ).images[0]

        # Phase 3: QA Loop
        print(">>> Entering QA Loop...")
        for i, subj in enumerate(Config.SUBJECTS):
            prompt = f"a photo of {subj['token']} {subj['class']}"
            score = self.evaluator.compute_clip_score(current_img, prompt)

            # threshold 23.0
            if score < Config.QA_PASS_THRESHOLD:
                print(f"!!! Triggering Refinement for {subj['name']} (Score: {score:.2f})")
                self.pipe.set_adapters([subj['name']], adapter_weights=[1.0])
                current_img = self.pipe(
                    prompt=prompt + ", masterpiece, high fidelity",
                    image=current_img, mask_image=masks[i],
                    num_inference_steps=40, strength=0.65
                ).images[0]
            else:
                print(f"Subject {subj['name']} Passed (Score: {score:.2f})")

        return current_img

Overwriting /content/MultiSubjectGen_Pro/src/multi_subject_pipeline.py


In [5]:
%%writefile /content/MultiSubjectGen_Pro/src/evaluation.py
import torch
import numpy as np
from torchmetrics.functional.multimodal import clip_score
from functools import partial

class Evaluator:
    def __init__(self):
        self.clip_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")

    def compute_clip_score(self, image, prompt):
        image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0)
        # Scale to simulate a real metric score (usually around 20-30 for CLIP)
        score = self.clip_fn(image_tensor, [prompt]).item()
        return score

Writing /content/MultiSubjectGen_Pro/src/evaluation.py


In [6]:
%%writefile /content/MultiSubjectGen_Pro/src/trainer.py
import subprocess
import os
from .config import Config

def run_training():
    script = "/content/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py"
    for subj in Config.SUBJECTS:
        if os.path.exists(subj['lora_out']):
            print(f"Skipping {subj['name']}, checkpoint exists.")
            continue

        print(f"Starting training for {subj['name']}...")
        cmd = [
            "accelerate", "launch", script,
            f"--pretrained_model_name_or_path={Config.MODEL_NAME}",
            f"--instance_data_dir={subj['data']}",
            f"--output_dir={subj['lora_out']}",
            f"--instance_prompt='a photo of {subj['token']} {subj['class']}'",
            "--resolution=1024",
            "--train_batch_size=1",
            "--gradient_accumulation_steps=4",
            f"--learning_rate={Config.LEARNING_RATE}",
            f"--max_train_steps={Config.TRAIN_STEPS}",
            "--mixed_precision=fp16"
        ]
        subprocess.run(" ".join(cmd), shell=True, check=True)

Writing /content/MultiSubjectGen_Pro/src/trainer.py


In [35]:
%%writefile /content/MultiSubjectGen_Pro/experiments/run_baseline_comparison.py
import sys
sys.path.append("/content/MultiSubjectGen_Pro")
from src.multi_subject_pipeline import MultiSubjectPipeline
from diffusers import StableDiffusionXLPipeline
import torch
from src.config import Config

def run_experiment():
    print("=== Experiment: Baseline vs. Ours ===")

    # 1. Baseline: Vanilla SDXL (pure FP32)
    print("Running Baseline (Vanilla SDXL in FP32 Mode)...")

    torch.cuda.empty_cache()

    base_pipe = StableDiffusionXLPipeline.from_pretrained(
        Config.MODEL_NAME,
        use_safetensors=True
    )

    base_pipe.enable_model_cpu_offload()

    prompt = "a photo of sks cat toy and trk red mug on a table"

    # generalizarion
    baseline_img = base_pipe(prompt=prompt, num_inference_steps=30).images[0]
    baseline_img.save(f"{Config.PROJECT_ROOT}/output/baseline_result.png")
    print(" Baseline saved.")

    # delete and clean up
    del base_pipe
    torch.cuda.empty_cache()

    # 2. Ours: Multi-Subject Pipeline
    print("Running Ours (Proposed Method)...")
    our_pipe = MultiSubjectPipeline()
    our_pipe.load_loras()
    our_img = our_pipe.generate_with_qa_loop()
    our_img.save(f"{Config.PROJECT_ROOT}/output/ours_result.png")
    print(" Ours saved.")

    print("Compare 'baseline_result.png' and 'ours_result.png' in the output folder.")

if __name__ == "__main__":
    run_experiment()

Overwriting /content/MultiSubjectGen_Pro/experiments/run_baseline_comparison.py


In [62]:
%%writefile /content/MultiSubjectGen_Pro/examples/run_demo.py
import sys
sys.path.append("/content/MultiSubjectGen_Pro")
from src.trainer import run_training
from experiments.run_baseline_comparison import run_experiment
from experiments.run_ablation import run_ablation
from src.visualization import Visualizer
from PIL import Image
from src.config import Config
from src.spatial_layout import LayoutGenerator

def main():
    print(" Starting Full Project Demo ...")

    # 1. Training check (skip if a model already exists)
    run_training()

    # 2. run Baseline comparation (Vanilla SDXL vs Ours)
    # This will generate baseline_result.png and ours_result.png
    run_experiment()

    # 3. Running ablation experiments (No-QA vs. With-QA)
    # This will generate ablation_no_qa.png and ablation_with_qa.png
    run_ablation()

    # 4. Generate visual analysis charts.
    print("\n Generating Visual Analysis...")
    viz = Visualizer()

    # Load the "Ours" result that was just generated.
    try:
        final_img = Image.open(f"{Config.PROJECT_ROOT}/output/ours_result.png")
        # Reacquire Mask only for drawing.
        layout_gen = LayoutGenerator()
        _, masks = layout_gen.get_dual_subject_layout()

        viz.save_process_grid(None, masks, final_img, filename="final_report_viz.png")
    except Exception as e:
        print(f" Skipped visualization: {e}")

    print("\n Demo Finished! All results are in /content/MultiSubjectGen_Pro/output")

if __name__ == "__main__":
    main()

Overwriting /content/MultiSubjectGen_Pro/examples/run_demo.py


In [19]:
%%writefile /content/MultiSubjectGen_Pro/experiments/run_ablation.py
import sys
sys.path.append("/content/MultiSubjectGen_Pro")
from src.multi_subject_pipeline import MultiSubjectPipeline
from src.config import Config
import torch

def run_ablation():
    print("\n=== Starting Ablation Study ===")
    print("Goal: Prove that the 'QA Loop' actually improves quality.")

    pipe = MultiSubjectPipeline()
    pipe.load_loras()

    # --- Experiment A: Without QA Loop  ---
    print("\n[Ablation A] Running WITHOUT QA Loop (Baseline)...")
    # Temporarily save the original configuration
    original_retries = Config.MAX_RETRIES
    # Forcefully disable redrawing to simulate a situation without QA.
    Config.MAX_RETRIES = 0

    img_no_qa = pipe.generate_with_qa_loop()
    img_no_qa.save(f"{Config.PROJECT_ROOT}/output/ablation_no_qa.png")
    print(">> Saved 'ablation_no_qa.png'")

    # --- Experiment B: With QA Loop  ---
    print("\n[Ablation B] Running WITH QA Loop (Ours)...")
    Config.MAX_RETRIES = original_retries

    img_with_qa = pipe.generate_with_qa_loop()
    img_with_qa.save(f"{Config.PROJECT_ROOT}/output/ablation_with_qa.png")
    print(">> Saved 'ablation_with_qa.png'")

    print("\n Ablation Done! Compare the two images in /output folder.")

if __name__ == "__main__":
    run_ablation()

Overwriting /content/MultiSubjectGen_Pro/experiments/run_ablation.py


In [58]:
%%writefile /content/MultiSubjectGen_Pro/src/visualization.py
import matplotlib.pyplot as plt
from PIL import Image
import os
from .config import Config

class Visualizer:
    def __init__(self):
        self.save_dir = Config.PROJECT_ROOT + "/output"

    def save_process_grid(self, base_img, masks, final_img, filename="process_viz.png"):

        fig, axes = plt.subplots(1, 3, figsize=(18, 6))

        # 1. Mask Preview (Display the two masks overlaid)
        mask_preview = Image.new("RGB", (1024, 1024), "black")
        for m in masks:
            # Make the mask red and semi-transparent overlay.
            colored_mask = Image.new("RGB", (1024, 1024), (255, 50, 50))
            mask_preview = Image.composite(colored_mask, mask_preview, m)

        axes[0].imshow(mask_preview)
        axes[0].set_title("1. Spatial Layout / Masks")
        axes[0].axis("off")

        # 2. Final Result
        axes[1].imshow(final_img)
        axes[1].set_title("2. Final Composition (Ours)")
        axes[1].axis("off")

        # 3. Detail Zoom (For example, zooming in on a cat's face.，showing Identity Retention)
        # Here's a simple demonstration of cropping the center area.
        width, height = final_img.size
        crop_box = (100, 200, 612, 712) # Approximate location of the cat on the left
        axes[2].imshow(final_img.crop(crop_box))
        axes[2].set_title("3. Detail / Identity Check")
        axes[2].axis("off")

        plt.tight_layout()
        save_path = os.path.join(self.save_dir, filename)
        plt.savefig(save_path)
        plt.close()
        print(f" Process visualization saved to {save_path}")

Overwriting /content/MultiSubjectGen_Pro/src/visualization.py


In [66]:
!python /content/MultiSubjectGen_Pro/examples/run_demo.py

2025-11-19 20:07:53.400476: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763582873.422204   54779 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763582873.428865   54779 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1763582873.445350   54779 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763582873.445376   54779 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763582873.445379   54779 computation_placer.cc:177] computation placer alr