<a href="https://colab.research.google.com/github/Theskrtnerd/gen-ai-photoshoots/blob/main/gen_ai_photoshoots.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install streamlit python-dotenv google-generativeai torch diffusers transformers accelerate torchvision bitsandbytes datasets pyngrok

In [None]:
#@title Writing download_model.py
%%writefile download_model.py
import os


def download_model(my_bar):
    """
    Download and save the pretrained Stable Diffusion model components.

    This function downloads the tokenizer, text encoder, variational autoencoder (VAE),
    U-Net, and feature extractor components of the Stable Diffusion model and saves them
    to a specified directory. It also updates a progress bar to reflect the download progress.

    Args:
        my_bar: A Streamlit progress bar object used to update the download progress.

    The function performs the following steps:
        1. Checks if the save directory exists; if not, creates it.
        2. Downloads the tokenizer and updates the progress bar to 20%.
        3. Downloads the text encoder and updates the progress bar to 40%.
        4. Downloads the VAE and updates the progress bar to 60%.
        5. Downloads the U-Net and updates the progress bar to 80%.
        6. Downloads the feature extractor and updates the progress bar to 96%.
        7. Saves each downloaded component to the specified directory.
    """
    save_directory = "./stable_diffusion_models"

    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
        my_bar.progress(0, "Downloading the pretrained model...")
        from transformers import CLIPTokenizer, CLIPTextModel, CLIPFeatureExtractor
        from diffusers import AutoencoderKL, UNet2DConditionModel

        # The Stable Diffusion checkpoint we'll fine-tune
        model_id = "CompVis/stable-diffusion-v1-4"
        tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
        my_bar.progress(20, "Downloading the pretrained model...")
        text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
        my_bar.progress(40, "Downloading the pretrained model...")
        vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
        my_bar.progress(60, "Downloading the pretrained model...")
        unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
        my_bar.progress(80, "Downloading the pretrained model...")
        feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
        my_bar.progress(96, "Downloading the pretrained model...")

        # Save each component
        tokenizer.save_pretrained(f"{save_directory}/tokenizer")
        text_encoder.save_pretrained(f"{save_directory}/text_encoder")
        vae.save_pretrained(f"{save_directory}/vae")
        unet.save_pretrained(f"{save_directory}/unet")
        feature_extractor.save_pretrained(f"{save_directory}/feature_extractor")


Overwriting download_model.py


In [None]:
#@title Writing train_model.py
%%writefile train_model.py
from argparse import Namespace
from accelerate import Accelerator, notebook_launcher
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import math
from tqdm.auto import tqdm
from datasets import load_dataset
from download_model import download_model


class DreamBoothDataset(Dataset):
    """
    A custom dataset class for preparing data for training the Stable Diffusion model.

    Args:
        dataset (Dataset): The dataset containing images.
        instance_prompt (str): The prompt describing the instance in the images.
        tokenizer (CLIPTokenizer): The tokenizer for encoding text prompts.
        size (int, optional): The size to which images are resized. Default is 512.

    Methods:
        __len__(): Returns the length of the dataset.
        __getitem__(index): Returns a single data point (image and tokenized prompt).
    """
    def __init__(self, dataset, instance_prompt, tokenizer, size=512):
        self.dataset = dataset
        self.instance_prompt = instance_prompt
        self.tokenizer = tokenizer
        self.size = size
        self.transforms = transforms.Compose(
            [
                transforms.Resize(size),
                transforms.CenterCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        example = {}
        # Assuming self.dataset is a list of dictionaries, each containing an "image" key
        image = self.dataset[index]["image"]
        # Apply transforms to the image
        example["instance_images"] = self.transforms(image)
        # Assuming self.instance_prompt is defined somewhere in your class
        # and tokenizer is a CLIPTokenizer instance

        example["instance_prompt_ids"] = self.tokenizer(
            self.instance_prompt,
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        )["input_ids"]
        return example


def collate_fn(examples):
    """
    Collate function to prepare a batch of data for training.

    Args:
        examples (list): A list of examples where each example is a dictionary
                         containing 'instance_prompt_ids' and 'instance_images'.

    Returns:
        dict: A dictionary containing batched 'input_ids' and 'pixel_values'.
    """
    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]
    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format)
    pixel_values = pixel_values.float()

    tokenizer = CLIPTokenizer.from_pretrained("./stable_diffusion_models/tokenizer")

    # Tokenize prompts
    batch_encoding = tokenizer.pad(
        {"input_ids": input_ids}, padding=True, return_attention_mask=True, return_tensors="pt"
    )

    # Extract input_ids and attention_mask from batch_encoding
    input_ids = batch_encoding["input_ids"]
    attention_mask = batch_encoding["attention_mask"]

    batch = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "pixel_values": pixel_values,
    }
    return batch


def training_function(text_encoder, vae, unet, args, my_bar):
    """
    Function to train the Stable Diffusion model.

    Args:
        text_encoder (CLIPTextModel): The text encoder model.
        vae (AutoencoderKL): The variational autoencoder model.
        unet (UNet2DConditionModel): The U-Net model.
        args (Namespace): A Namespace object containing training arguments and configurations.
        my_bar (streamlit.progress): A Streamlit progress bar to monitor the training progress.

    This function performs the following steps:
        1. Sets up the training environment and configurations.
        2. Loads the data and prepares it for training.
        3. Defines the optimizer and noise scheduler.
        4. Trains the model for the specified number of steps, updating the progress bar.
        5. Saves the trained model pipeline.
    """
    tokenizer = CLIPTokenizer.from_pretrained("./stable_diffusion_models/tokenizer")
    feature_extractor = CLIPFeatureExtractor.from_pretrained("./stable_diffusion_models/feature_extractor")

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
    )

    set_seed(args.seed)

    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
    if args.use_8bit_adam:
        import bitsandbytes as bnb
        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

    optimizer = optimizer_class(
        unet.parameters(),  # Only optimize unet
        lr=args.learning_rate,
    )

    noise_scheduler = DDPMScheduler(
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        num_train_timesteps=1000,
    )

    train_dataloader = DataLoader(
        args.train_dataset,
        batch_size=args.train_batch_size,
        shuffle=True,
        collate_fn=collate_fn,
    )

    unet, optimizer, train_dataloader = accelerator.prepare(
        unet, optimizer, train_dataloader
    )

    # Move text_encode and vae to gpu
    text_encoder.to(accelerator.device)
    vae.to(accelerator.device)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps
    )
    num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Only show the progress bar once on each machine
    progress_bar = tqdm(
        range(args.max_train_steps), disable=not accelerator.is_local_main_process
    )
    progress_bar.set_description("Steps")
    global_step = 0

    for epoch in range(num_train_epochs):
        unet.train()
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(unet):
                # Convert images to latent space
                with torch.no_grad():
                    latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
                    latents = latents * 0.18215

                # Sample noise that we'll add to the latents
                noise = torch.randn(latents.shape).to(latents.device)
                bsz = latents.shape[0]
                # Sample a random timestep for each image
                timesteps = torch.randint(
                    0,
                    noise_scheduler.config.num_train_timesteps,
                    (bsz,),
                    device=latents.device,
                ).long()

                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # Get the text embedding for conditioning
                with torch.no_grad():
                    encoder_hidden_states = text_encoder(batch["input_ids"])[0]

                # Predict the noise residual
                noise_pred = unet(
                    noisy_latents, timesteps, encoder_hidden_states
                ).sample
                loss = (
                    F.mse_loss(noise_pred, noise, reduction="none")
                    .mean([1, 2, 3])
                    .mean()
                )

                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
                optimizer.step()
                optimizer.zero_grad()

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1

            logs = {"loss": loss.detach().item()}
            progress_bar.set_postfix(**logs)
            my_bar.progress(int(epoch*100/num_train_epochs), text="Training in progress...")

            if global_step >= args.max_train_steps:
                break

        accelerator.wait_for_everyone()

    # Create the pipeline using the trained modules and save it
    if accelerator.is_main_process:
        print(f"Loading pipeline and saving to {args.output_dir}...")
        scheduler = PNDMScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            skip_prk_steps=True,
            steps_offset=1,
        )
        pipeline = StableDiffusionPipeline(
            text_encoder=text_encoder,
            vae=vae,
            unet=accelerator.unwrap_model(unet),
            tokenizer=tokenizer,
            scheduler=scheduler,
            safety_checker=StableDiffusionSafetyChecker.from_pretrained(
                "CompVis/stable-diffusion-safety-checker"
            ),
            feature_extractor=feature_extractor,
        )
        pipeline.save_pretrained(args.output_dir)


def train_model(product_name, my_bar):
    """
    Train the Stable Diffusion model with images of the specified product.

    Args:
        product_name (str): The name of the product to be used in the training prompt.
        my_bar (streamlit.progress): A Streamlit progress bar to monitor the training progress.

    This function performs the following steps:
        1. Downloads the pretrained model components.
        2. Loads the dataset containing images of the product.
        3. Creates a DreamBoothDataset with the loaded images and the product name.
        4. Sets up training arguments and configurations.
        5. Initializes the model components.
        6. Launches the training process using the specified number of GPUs.
        7. Saves the trained model pipeline.
    """
    download_model(my_bar)
    dataset = load_dataset("imagefolder", data_dir="training_photos/", split='train')
    instance_prompt = f"a photo of a {product_name}"
    learning_rate = 2e-06
    max_train_steps = 200
    load_directory = "./stable_diffusion_models"
    tokenizer = CLIPTokenizer.from_pretrained(f"{load_directory}/tokenizer")
    train_dataset = DreamBoothDataset(dataset, instance_prompt, tokenizer)
    args = Namespace(
        pretrained_model_name_or_path=load_directory,
        resolution=512,  # Reduce this if you want to save some memory
        train_dataset=train_dataset,
        instance_prompt=instance_prompt,
        learning_rate=learning_rate,
        max_train_steps=max_train_steps,
        train_batch_size=1,
        gradient_accumulation_steps=1,  # Increase this if you want to lower memory usage
        max_grad_norm=1.0,
        gradient_checkpointing=True,  # Set this to True to lower the memory usage
        use_8bit_adam=True,  # Use 8bit optimizer from bitsandbytes
        seed=3434554,
        sample_batch_size=2,
        output_dir="./pipeline-folder",  # Where to save the pipeline
    )
    text_encoder = CLIPTextModel.from_pretrained(f"{load_directory}/text_encoder")
    vae = AutoencoderKL.from_pretrained(f"{load_directory}/vae")
    unet = UNet2DConditionModel.from_pretrained(f"{load_directory}/unet")
    num_of_gpus = 1  # CHANGE THIS TO MATCH THE NUMBER OF GPUS YOU HAVE
    notebook_launcher(
        training_function, args=(text_encoder, vae, unet, args, my_bar), num_processes=num_of_gpus
    )
    my_bar.progress(96, text="Saving the model...")
    with torch.no_grad():
        torch.cuda.empty_cache()


Overwriting train_model.py


In [None]:
#@title Writing generate_image.py
%%writefile generate_image.py

from diffusers import StableDiffusionPipeline
import torch


def generate_image(prompt):
    """
    Generate an image based on the provided prompt using the Stable Diffusion model.

    Args:
        prompt (str): The text prompt describing the desired image.

    Returns:
        PIL.Image: The generated image based on the prompt.

    This function performs the following steps:
        1. Loads the Stable Diffusion pipeline from the pretrained model directory.
        2. Moves the pipeline to the GPU for faster processing.
        3. Sets the guidance scale to influence the strength of the prompt in the generation process.
        4. Generates the image using the provided prompt.
        5. Returns the generated image.
    """
    pipe = StableDiffusionPipeline.from_pretrained(
        "./pipeline-folder",
        torch_dtype=torch.float16,
    ).to("cuda")
    guidance_scale = 9
    image = pipe(prompt, guidance_scale=guidance_scale).images
    return image


Overwriting generate_image.py


In [None]:
#@title Writing gemini_api.py
%%writefile gemini_api.py
import google.generativeai as genai
import streamlit as st
import json


def validate_gemini_api_key(my_api_key):
    """
    Validate the provided Gemini API key by configuring the genai client and
    testing a content generation request.

    Args:
        my_api_key (str): The Gemini API key to be validated.

    Returns:
        bool: True if the API key is valid and the content generation request
              is successful, False otherwise.

    This function performs the following steps:
        1. Configures the genai client with the provided API key.
        2. Attempts to create a GenerativeModel instance and generate content
           to validate the API key.
        3. Returns True if successful; otherwise, displays the error message
           and returns False.
    """
    try:
        genai.configure(api_key=my_api_key)
        model = genai.GenerativeModel('gemini-1.5-flash-latest')
        model.generate_content("What's my name?")
        return True

    except Exception as e:
        st.error(e)
        return False


def generate_reccommendations():
    """
    Generate recommended background places for the user's product using the
    Gemini API.

    Returns:
        list: A list of recommended background places for the product.

    This function performs the following steps:
        1. Retrieves the Gemini API key and product name from the session state.
        2. Configures the genai client with the API key.
        3. Creates a GenerativeModel instance and constructs a prompt to generate
           recommendations.
        4. Sends the prompt to the model and parses the JSON response to extract
           the list of recommended places.
    """
    gemini_api_key = st.session_state.api_key
    product_name = st.session_state.product_name
    genai.configure(api_key=gemini_api_key)
    model = genai.GenerativeModel('gemini-1.5-flash-latest')
    prompt = f'What are some great recommended background places for a {product_name}?'
    prompt += 'Give me the answer as a list of places. There should be 5 relevant places.'
    prompt += 'The answer\'s format should be as follows: \'["on a beach", "in a park", "near a flowing river"]\'.'
    prompt += 'You need to follow this answer\'s format strictly. Don\'t ask more questions.'
    response = model.generate_content(prompt)
    return json.loads(response.text)


Overwriting gemini_api.py


In [None]:
#@title Writing main.py
%%writefile main.py
import streamlit as st
import os
from train_model import train_model
import shutil
from generate_image import generate_image
from gemini_api import validate_gemini_api_key, generate_reccommendations

# Main Application with GUI


def stage_0():
    """
    Display the introductory page of the application.

    This page explains the purpose of the application and provides a button to
    start the creation process.
    """
    st.write("Ready to showcase your product in its best light, but lacking professional photography tools?")
    st.write("Let our AI do the work for you.")
    st.write("Craft exquisite, professional-grade product images effortlessly, and at no cost.")
    st.write("Transform your brand's visual presentation with ease.")
    if st.button("Start Creating"):
        st.session_state.current_stage = 1
        st.rerun()  # Move to next page


def stage_1():
    """
    Display the API key input page.

    This page instructs the user to create a Gemini API key and input it into
    the text box provided. The key is validated, and the user is moved to the
    next stage if the key is valid.
    """
    st.write("To start, go to [this link](https://aistudio.google.com/app/apikey)"
             + " to create your Gemini API key (Don't worry, it's free)")
    api_key = st.text_input("Enter your Google Gemini API Key here:")
    if api_key:
        if not validate_gemini_api_key(api_key):
            st.error("Not the correct API key")
        else:
            st.session_state.current_stage = 2
            st.session_state.api_key = api_key
            st.rerun()  # Move to next page


def stage_2():
    """
    Display the image upload page.

    This page allows the user to upload 5-10 photos of their product. The images
    are stored in a directory for later use in model training.
    """
    st.write("Now upload 5-10 photos of your product, it can be photos taken of your product from different angles")
    files = st.file_uploader(
        "Upload images of your product:",
        type=["jpg", "jpeg"],
        accept_multiple_files=True,
        help="Upload here...")

    if files:
        if os.path.exists("./training_photos/"):
            shutil.rmtree("./training_photos/")  # Create the folder for photos storage
        os.makedirs("./training_photos/")
        for i, file in enumerate(files):
            with open(os.path.join("./training_photos/", f"image_{i}.jpg"), "wb") as f:
                f.write(file.getbuffer())
        st.session_state.current_stage = 3
        st.rerun()  # Move to next page


def stage_3():
    """
    Display the product name input page.

    This page prompts the user to enter the name of their product, which will
    be used in model training and image generation.
    """
    st.write("What's your product's name?")
    st.write("E.g. corggi dog, champions league soccer ball, red scarf with patterns")
    product_name = st.text_input("Enter your product's name here:")
    if product_name:
        st.session_state.current_stage = 4
        st.session_state.product_name = product_name
        st.rerun()  # Move to next page


def stage_4():
    """
    Display the model training page.

    This page initiates the model training process using the uploaded images and
    product name. A progress bar is shown to indicate the training status.
    """
    st.write("Now it's time to train your model.")
    st.write("Please wait a while for us to understand your product.")
    product_name = st.session_state.product_name
    my_bar = st.progress(0, text="Training in progress...")  # Monitor the progress
    try:
        train_model(product_name, my_bar)  # Finetuning the model
        st.session_state.current_stage = 5
        st.rerun()  # Move to next page
    except Exception as e:
        st.error(e)


def stage_5():
    """
    Display the background selection page.

    This page allows the user to select or enter a background for the generated
    product images. Background recommendations are provided for convenience.
    """
    st.write("Everything's finally done, Hooray!")
    st.write("Now's the time to create stunning photos with your product.")
    st.write("Enter a background you want or choose a reccommendation below")
    if "recs" not in st.session_state:
        st.session_state.recs = generate_reccommendations()  # Create some background recommendations
    for rec in st.session_state.recs:
        if st.button(rec):
            st.session_state.current_stage = 6
            st.session_state.background = rec
            st.rerun()  # Move to next page
    text_input = st.text_input("Enter a background: ")
    if text_input:
        st.session_state.current_stage = 6
        st.session_state.background = text_input
        st.rerun()  # Move to next page


def stage_6():
    """
    Display the image generation and results page.

    This page generates and displays an image of the product with the selected
    background. Users can generate another image or choose a different background.
    """
    st.write("Nice Choice!")
    product_name = st.session_state.product_name
    background = st.session_state.background
    prompt = f"a photo of a {product_name} {background}"
    st.write(f"Finally, here's your {product_name} {background}")
    image = generate_image(prompt)
    st.image(image)
    if st.button("Generate another image"):
        st.rerun()  # Rerun stage 6
    if st.button("Choose another background"):
        st.session_state.current_stage = 5
        st.rerun()  # Return to stage 5


if __name__ == "__main__":  # Main page and page control system
    _ = """
    Main function to control the page flow of the application.

    This function checks the current stage stored in the session state and
    displays the appropriate page based on the stage.
    """
    st.title("AI-Powered Product Photoshoot Wizard")

    if "current_stage" not in st.session_state:
        st.session_state.current_stage = 0

    curr_stage = st.session_state.current_stage

    if curr_stage == 0:
        stage_0()

    elif curr_stage == 1:
        stage_1()

    elif curr_stage == 2:
        stage_2()

    elif curr_stage == 3:
        stage_3()

    elif curr_stage == 4:
        stage_4()

    elif curr_stage == 5:
        stage_5()

    elif curr_stage == 6:
        stage_6()


Overwriting main.py


In [None]:
from pyngrok import conf, ngrok

conf.get_default().region = "us"
ngrok_key = "2gssWVZmNGCsLcs8n3NAulvIqYE_5DupK4SWUHnGbWjAe7WAw"
conf.get_default().auth_token = ngrok_key
port = 8501
public_url = ngrok.connect(port).public_url
print(public_url)

https://d56b-34-127-91-106.ngrok-free.app


In [None]:
!streamlit run main.py &


Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8501[0m
[34m  Network URL: [0m[1mhttp://172.28.0.12:8501[0m
[34m  External URL: [0m[1mhttp://34.127.91.106:8501[0m
[0m
[34m  Stopping...[0m
