<a href="https://colab.research.google.com/github/Vishisht16/AI-Photoshoot-Studio/blob/main/AI_Photoshoot_Studio.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AI Photoshoot Studio

> Please follow the instructions provided in this notebook to set up your own Studio through Colab and test out the code. No technical knowledge required!


> #### Initial Cloud Hardware Setup

> Go to the connect option at the top right, below your profile and click on the dropdown that will take you to 'Change runtime type'. Please go ahead and select 'T4 GPU' as your hardware accelerator and connect.

### STEP 1: Go to the cell below and press Ctrl/⌘ + Enter.

In [None]:
!pip install -qq --upgrade numpy diffusers transformers accelerate mediapipe opencv-python-headless controlnet_aux torch gradio

Note that the above step may take up to 2-3 minutes to be completed.

### STEP 2: Once the cell has finished running, click on 'Runtime' on the ribbon at the top and click 'Restart Session and run all'.

#### The cells below contain the code required to create the whole application from beginning to end. You can check it out or choose to ignore the building process completely.



In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Creating folders for better abstraction
!mkdir core app

In [None]:
%%writefile core/utils.py

import numpy as np
import mediapipe as mp
from PIL import Image, ImageDraw

# Initialize MediaPipe Face Detection
mp_face_detection = mp.solutions.face_detection
face_detector = mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)

def detect_face(image: Image.Image) -> tuple | None:
    """
    Detects the first face in a given PIL Image and returns its bounding box.
    """
    try:
        np_image = np.array(image.convert("RGB"))
        results = face_detector.process(np_image)

        if not results.detections:
            print("Utils: No face detected.")
            return None

        detection = results.detections[0]
        bbox_data = detection.location_data.relative_bounding_box
        img_h, img_w, _ = np_image.shape
        x = int(bbox_data.xmin * img_w)
        y = int(bbox_data.ymin * img_h)
        w = int(bbox_data.width * img_w)
        h = int(bbox_data.height * img_h)

        print(f"Utils: Face detected at [x={x}, y={y}, w={w}, h={h}]")
        return (x, y, w, h)
    except Exception as e:
        print(f"Utils: An error occurred during face detection: {e}")
        return None

def create_inpainting_mask(image: Image.Image, bbox: tuple) -> Image.Image:
    """
    Creates a black and white mask for inpainting from an image and a bounding box.
    """
    try:
        mask = Image.new("RGB", image.size, "black")
        draw = ImageDraw.Draw(mask)
        x, y, w, h = bbox
        shape = (x, y, x + w, y + h)
        draw.rectangle(shape, fill="white")
        print("Utils: Inpainting mask created successfully.")
        return mask
    except Exception as e:
        print(f"Utils: An error occurred during mask creation: {e}")
        return None

In [None]:
%%writefile core/pipeline_manager.py

import torch
from diffusers import (
    StableDiffusionPipeline, # <- ADDED: The standard pipeline
    StableDiffusionControlNetPipeline,
    StableDiffusionControlNetInpaintPipeline,
    ControlNetModel,
    DDIMScheduler,
    PNDMScheduler,
    DDPMScheduler,
)
from controlnet_aux import OpenposeDetector
from PIL import Image
from core.utils import detect_face, create_inpainting_mask

class SynthesisPipeline:
    def __init__(self):
        print("PipelineManager: Initializing...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
        print(f"PipelineManager: Using device: {self.device} with dtype: {self.torch_dtype}")

        BASE_MODEL = "runwayml/stable-diffusion-v1-5"
        CONTROLNET_MODEL = "lllyasviel/sd-controlnet-openpose"

        print("PipelineManager: Loading models...")
        # --- ControlNet specific models ---
        self.openpose = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
        self.controlnet = ControlNetModel.from_pretrained(
            CONTROLNET_MODEL, torch_dtype=self.torch_dtype
        )

        # --- Load ALL THREE pipelines ---
        # 1. Standard Text-to-Image Pipeline
        self.base_pipe = StableDiffusionPipeline.from_pretrained(
            BASE_MODEL, torch_dtype=self.torch_dtype, safety_checker=None
        ).to(self.device)

        # 2. ControlNet Pipeline
        self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
            BASE_MODEL, controlnet=self.controlnet, torch_dtype=self.torch_dtype, safety_checker=None
        ).to(self.device)

        # 3. ControlNet Inpainting Pipeline
        self.inpaint_pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
            BASE_MODEL, controlnet=self.controlnet, torch_dtype=self.torch_dtype, safety_checker=None
        ).to(self.device)

        self.schedulers = {
            "PNDM": PNDMScheduler.from_config(self.pipe.scheduler.config),
            "DDIM": DDIMScheduler.from_config(self.pipe.scheduler.config),
            "DDPM": DDPMScheduler.from_config(self.pipe.scheduler.config),
        }
        print("PipelineManager: All models loaded and ready!")

    def generate_image(
        self,
        prompt: str,
        negative_prompt: str,
        control_image: Image.Image | None, # <- UPDATED: Now optional
        preserve_face: bool,
        scheduler_name: str,
        num_steps: int = 25,
        guidance_scale: float = 7.5,
    ) -> tuple[Image.Image, Image.Image | None]: # <- UPDATED: Returns a tuple
        print("\n--- New Generation Request ---")

        scheduler = self.schedulers.get(scheduler_name, self.pipe.scheduler)
        generator = torch.Generator(device=self.device).manual_seed(-1)

        # THE NEW LOGIC: Check if a control image was provided
        if control_image is not None:
            # --- CONTROLNET MODE ---
            print("Control image provided. Entering ControlNet Mode.")

            # Resize control image for consistency
            control_image = control_image.resize((512, 512))
            pose_image = self.openpose(control_image)

            # Set scheduler for the relevant pipelines
            self.pipe.scheduler = scheduler
            self.inpaint_pipe.scheduler = scheduler

            if preserve_face:
                face_bbox = detect_face(control_image)
                if face_bbox:
                    mask_image = create_inpainting_mask(control_image, face_bbox)
                    print("Using Inpainting Pipeline...")
                    result_image = self.inpaint_pipe(
                        prompt=prompt, negative_prompt=negative_prompt, image=control_image,
                        mask_image=mask_image, control_image=pose_image, num_inference_steps=num_steps,
                        guidance_scale=guidance_scale, generator=generator
                    ).images[0]
                else: # Fallback if no face found
                    result_image = self.pipe(
                        prompt=prompt, negative_prompt=negative_prompt, image=pose_image,
                        num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator
                    ).images[0]
            else: # Standard ControlNet
                result_image = self.pipe(
                    prompt=prompt, negative_prompt=negative_prompt, image=pose_image,
                    num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator
                ).images[0]

            print("--- Generation Complete (ControlNet Mode) ---")
            return result_image, pose_image # Return both images

        else:
            # --- STANDARD TEXT-TO-IMAGE MODE ---
            print("No control image provided. Entering Standard T2I Mode.")

            # Set scheduler for the base pipeline
            self.base_pipe.scheduler = scheduler

            result_image = self.base_pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=num_steps,
                guidance_scale=guidance_scale,
                generator=generator
            ).images[0]

            print("--- Generation Complete (Standard Mode) ---")
            return result_image, None

In [None]:
%%writefile core/__init__.py

In [None]:
%%writefile app/__init__.py

In [None]:
%%writefile app/app_gradio.py

import sys
import gradio as gr

sys.path.append('.')
from core.pipeline_manager import SynthesisPipeline
from app import ui_components

print("App: Initializing the Synthesis Pipeline...")
try:
    synthesis_studio = SynthesisPipeline()
    print("App: Synthesis Pipeline initialized successfully.")
except Exception:
    import traceback
    print("FATAL: Failed to initialize SynthesisPipeline.")
    traceback.print_exc()
    synthesis_studio = None

# UPDATED: The logic for handling 'None' is now in the backend. This function is simpler.
def run_inference(pos_prompt, neg_prompt, control_image, preserve_face, scheduler, steps, guidance, progress=gr.Progress(track_tqdm=True)):
    if synthesis_studio is None:
        raise gr.Error("Models could not be loaded. Application is not functional.")

    # REMOVED: The check for control_image is GONE!

    progress(0.1, desc="Starting generation...")

    # The backend now returns a tuple (generated_image, pose_image)
    # pose_image will be None if no control_image was provided.
    generated_image, pose_image = synthesis_studio.generate_image(
        prompt=pos_prompt, negative_prompt=neg_prompt,
        control_image=control_image, preserve_face=preserve_face,
        scheduler_name=scheduler, num_steps=int(steps), guidance_scale=float(guidance),
    )

    progress(1.0, desc="Done!")
    return generated_image, pose_image

def build_app():
    with gr.Blocks(css=ui_components.CUSTOM_CSS, theme=gr.themes.Soft(primary_hue="orange", secondary_hue="yellow")) as interface:
        if synthesis_studio is None:
            gr.Markdown("# ❌ Application Error\nThe AI models failed to load. Please check console logs for details.")
            return interface

        scheduler_list = list(synthesis_studio.schedulers.keys())
        c = ui_components.create_ui_layout(scheduler_list)

        def autofill_negative_prompt(is_checked, current_text):
            common_negs = ui_components.COMMON_NEGATIVE_PROMPTS
            if is_checked:
                return f"{current_text}, {common_negs}" if common_negs not in current_text else current_text
            else:
                return current_text.replace(f", {common_negs}", "").replace(common_negs, "").strip().rstrip(',')

        def clear_negative_prompt():
            return "", False

        def handle_image_upload(image):
            return gr.update(visible=image is not None)

        def reset_settings():
            return scheduler_list[0], 25, 7.5

        c["autofill_neg_checkbox"].change(fn=autofill_negative_prompt, inputs=[c["autofill_neg_checkbox"], c["negative_prompt"]], outputs=c["negative_prompt"])
        c["clear_neg_button"].click(fn=clear_negative_prompt, inputs=None, outputs=[c["negative_prompt"], c["autofill_neg_checkbox"]])
        c["control_image_input"].upload(fn=handle_image_upload, inputs=c["control_image_input"], outputs=c["preserve_face_checkbox"])
        c["control_image_input"].clear(fn=lambda: gr.update(visible=False), inputs=None, outputs=c["preserve_face_checkbox"])
        c["reset_button"].click(fn=reset_settings, inputs=None, outputs=[c["scheduler_dropdown"], c["steps_slider"], c["guidance_slider"]])
        c["generate_button"].click(
            fn=run_inference,
            inputs=[c["positive_prompt"], c["negative_prompt"], c["control_image_input"], c["preserve_face_checkbox"], c["scheduler_dropdown"], c["steps_slider"], c["guidance_slider"]],
            outputs=[c["output_image"], c["pose_image"]]
        )

    return interface

app_instance = build_app()

In [None]:
%%writefile app/ui_components.py


# Disclaimer: The code in this cell was AI-generated.
import gradio as gr
import base64
from pathlib import Path

def encode_image_to_base64(image_path):
    """Image ko read karke use Base64 string mein convert karta hai."""
    try:
        with open(image_path, "rb") as image_file:
            encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
        return f"data:image/png;base64,{encoded_string}"
    except FileNotFoundError:
        print(f"WARNING: Background image not found at {image_path}. Using a plain background.")
        return "none"

background_image_path = Path("background.png")
base64_background = encode_image_to_base64(background_image_path)

COMMON_NEGATIVE_PROMPTS = "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face, blurry, draft, grainy"

# Your custom CSS remains untouched
CUSTOM_CSS = f"""
@import url('https://fonts.googleapis.com/css2?family=Josefin+Sans:wght@400;600;700&display=swap');
body {{ background-image: url('{base64_background}') !important; background-size: cover !important; background-position: center !important; background-attachment: fixed !important; }}
.gradio-container {{ background: none !important; }}
* {{ font-family: 'Josefin Sans', sans-serif !important; color: #EAEAEA !important; }}
.gradio-group, .gradio-tabs, .gradio-accordion, .gradio-html {{ background-color: rgba(20, 20, 30, 0.4) !important; backdrop-filter: blur(12px) !important; -webkit-backdrop-filter: blur(12px) !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; border-radius: 12px !important; box-shadow: 0 4px 30px rgba(0, 0, 0, 0.1) !important; }}
#main_header {{ background: none !important; border: none !important; box-shadow: none !important; }}
textarea, input[type="text"], input[type="number"] {{ background-color: rgba(0, 0, 10, 0.5) !important; border: 1px solid rgba(255, 255, 255, 0.2) !important; }}
#main_header h1 {{ font-weight: 700; font-size: 2.8em; color: #FFFFFF !important; text-shadow: 0px 2px 4px rgba(0,0,0,0.4); }}
#main_header p {{ font-size: 1.2em; color: #B0C4DE !important; }}
h3 {{ font-weight: 600 !important; color: #FFFFFF !important; text-transform: uppercase !important; letter-spacing: 1px !important; border-bottom: 1px solid rgba(255, 255, 255, 0.2); padding-bottom: 5px; }}
#github_icon {{ display: flex; justify-content: flex-end; align-items: center; height: 100%; background: none !important; }}
#github_icon a {{ color: #EAEAEA !important; font-size: 1.8em; transition: color 0.3s, transform 0.3s; }}
#github_icon a:hover {{ color: #FFFFFF !important; transform: scale(1.1); }}
"""

# <<< CHANGE 1: UPDATED GUIDE MARKDOWN >>>
# The guide now explains both Standard and ControlNet modes.
GUIDE_MARKDOWN = """
## How to Use This Studio 📖

This studio can generate images in two ways:
1.  **Standard Mode:** Simply write a prompt and click Generate.
2.  **ControlNet Mode:** Upload a "Control Image" to guide the structure of the generated image.

#### **What is a "Control Image"? (Optional Feature)**
The "Control Image" is your blueprint. The AI doesn't copy the *style* of this image, but it copies the **structure** (like a person's pose).

-   **What it does:** The final generated image will have a person in the *exact same pose* as the person in your control image.
-   **When to use it:** Use this when you want to control the composition or pose of a person in your image.

#### **How to Use "Attempt to preserve face?"**
This feature **only works if you have uploaded a Control Image**.

-   **Result:** It will try to make the face in the generated image look like the face from your uploaded image.

#### **3. Example Prompts**
-   **Standard Prompt:** `A majestic lion in the savannah, golden hour, photorealistic, 8k`
-   **ControlNet Prompt:** `A photorealistic portrait of a royal queen in a magnificent golden dress, cinematic lighting, highly detailed`
---
"""

ADVANCED_GUIDE_MARKDOWN = """
### **Advanced Settings Explained**
-   **Scheduler:** This is the algorithm used for denoising the image. Different schedulers can produce slightly different styles and textures.
    -   **`DDIM` / `PNDM`:** Good all-rounders, fast and reliable.
    -   **`DDPM`:** Often produces high-quality results but can be slower.
-   **Inference Steps:** How many steps the AI takes to generate the image.
    -   **More steps (e.g., 40-50):** Can add more detail but takes longer.
    -   **Fewer steps (e.g., 20-25):** Faster generation, great for testing ideas.
-   **Guidance Scale (CFG):** How strictly the AI should follow your positive prompt.
    -   **Higher value (e.g., 10-15):** The AI will stick very closely to your prompt, but might be less creative.
    -   **Lower value (e.g., 7-9):** A good balance of creativity and prompt adherence.
"""

def create_ui_layout(scheduler_choices):
    """
    Creates and returns a dictionary of all the Gradio UI components.
    """
    with gr.Row(elem_id="main_header"):
        with gr.Column(scale=10):
            gr.HTML("<h1>AI Photoshoot Studio 🎨</h1><p>Turn your ideas into stunning visuals with precise control.</p>")
        with gr.Column(scale=1, min_width=50, elem_id="github_icon"):
            gr.HTML('''<a href="https://github.com/Vishisht16/AI-Photoshoot-Studio" target="_blank" title="View on GitHub" style="display: inline-block; padding: 8px; background-color: #f0f0f0; border-radius: 6px; text-decoration: none;"><svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" viewBox="0 0 24 24"><path fill="#000000" d="M12 2A10 10 0 0 0 2 12c0 4.42 2.87 8.17 6.84 9.5c.5.09.68-.22.68-.48v-1.7c-2.78.6-3.37-1.34-3.37-1.34c-.46-1.16-1.11-1.47-1.11-1.47c-.91-.62.07-.6.07-.6c1 .07 1.53 1.03 1.53 1.03c.89 1.53 2.34 1.09 2.91.83c.09-.65.35-1.09.63-1.34c-2.22-.25-4.55-1.11-4.55-4.95c0-1.1.39-1.99 1.03-2.69a3.6 3.6 0 0 1 .1-2.64s.84-.27 2.75 1.02a9.58 9.58 0 0 1 5 0c1.91-1.29 2.75-1.02 2.75-1.02c.53 1.28.018 2.37.1 2.64c.64.7 1.03 1.6 1.03 2.69c0 3.85-2.34 4.7-4.57 4.94c.36.31.68.92.68 1.85v2.72c0 .27.18.58.69.48A10 10 0 0 0 22 12A10 10 0 0 0 12 2"/></svg></a>''')

    with gr.Row(equal_height=False):
        with gr.Column(scale=1):
            with gr.Group():
                gr.Markdown("### 1. Describe Your Image")
                # <<< CHANGE 2: UPDATED PLACEHOLDER TEXT >>>
                positive_prompt = gr.Textbox(label="Positive Prompt", placeholder="e.g., A majestic lion in the savannah, golden hour...", lines=3)
                negative_prompt = gr.Textbox(label="Negative Prompt", lines=2, value=COMMON_NEGATIVE_PROMPTS)
                with gr.Row():
                    autofill_neg_checkbox = gr.Checkbox(label="Auto-fill negatives", value=True)
                    clear_neg_button = gr.Button("Clear")

            with gr.Group():
                # <<< CHANGE 3: HEADING UPDATED TO "OPTIONAL" >>>
                gr.Markdown("### 2. (Optional) Provide a Control Image")
                control_image_input = gr.Image(label="Upload Image to Control Pose", type="pil", image_mode="RGB", height=300)
                preserve_face_checkbox = gr.Checkbox(label="Attempt to preserve face?", value=False, visible=False)

            with gr.Group():
                gr.Markdown("### 3. Adjust Settings")
                with gr.Accordion("Advanced Settings", open=True):
                    scheduler_dropdown = gr.Dropdown(label="Scheduler", choices=scheduler_choices, value=scheduler_choices[0])
                    steps_slider = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=25)
                    guidance_slider = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=20.0, step=0.5, value=7.5)
                    reset_button = gr.Button("Reset to Defaults")

            generate_button = gr.Button("Generate Image", variant="primary", size="lg")

        with gr.Column(scale=2):
            with gr.Tabs():
                with gr.TabItem("Generated Image"):
                    output_image = gr.Image(label="Final Output", height=440)
                with gr.TabItem("Detected Pose"):
                    # <<< CHANGE 4: LABEL UPDATED FOR CLARITY >>>
                    pose_image = gr.Image(label="Detected Pose (Only if Control Image is used)", height=440)

            with gr.Accordion("Click me to toggle User Manual!", open=False):
                gr.Markdown(GUIDE_MARKDOWN)
                with gr.Accordion("Click me to know more about Advanced Settings!", open=False):
                    gr.Markdown(ADVANCED_GUIDE_MARKDOWN)

    return { "positive_prompt": positive_prompt, "negative_prompt": negative_prompt, "autofill_neg_checkbox": autofill_neg_checkbox, "clear_neg_button": clear_neg_button, "control_image_input": control_image_input, "preserve_face_checkbox": preserve_face_checkbox, "scheduler_dropdown": scheduler_dropdown, "steps_slider": steps_slider, "guidance_slider": guidance_slider, "generate_button": generate_button, "reset_button": reset_button, "output_image": output_image, "pose_image": pose_image }

In [None]:
%%writefile app.py

from app.app_gradio import app_instance

if __name__ == "__main__":
    print("=======================================")
    print("🚀 Launching AI Photoshoot Studio...")
    print("=======================================")
    app_instance.launch(
        server_name="0.0.0.0",
        share=True,
        debug=True
    )

#### The following cell creates a public URL within 3-4 minutes, that you can browse to access and test the AI Photoshoot Studio for as long as the Colab notebook is running.

In [None]:
!python app.py

Once finished, please go to 'Runtime' and click on 'Disconnect and delete runtime' to save valuable resources.
Thank you for the visit!