In [1]:
# Required packages, install if not installed (assume PyTorch* and Intel® Extension for PyTorch* is already present)
!echo "Installation in progress..."
# import sys
# !{sys.executable} -m pip install  invisible-watermark > /dev/null
# !conda install -y --quiet --prefix {sys.prefix}  -c conda-forge \
#     accelerate==0.23.0 \
#     validators==0.22.0 \
#     diffusers==0.18.2 \
#     transformers==4.32.1 \
#     tensorboardX \
#     pillow \
#     ipywidgets \
#     ipython > /dev/null && echo "Installation successful" || echo "Installation failed"
import sys
!{sys.executable} -m pip install invisible-watermark --user > /dev/null 2>&1 
!{sys.executable} -m pip install transformers huggingface-hub --user > /dev/null 2>&1
!echo "Installtion complete..."

Installation in progress...
Installtion complete...


In [2]:
from io import BytesIO
import os
import time
import warnings
from pathlib import Path
from typing import List, Dict, Tuple


# Suppress warnings for a cleaner output.
warnings.filterwarnings("ignore")

import random
import requests
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex  # adds xpu namespace to PyTorch, enabling you to use Intel GPUs
import validators
import numpy as np

from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import DPMSolverMultistepScheduler
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler

In [3]:
class Img2ImgModel:
    """
    This class creates a model for transforming images based on given prompts.
    """

    def __init__(
        self,
        model_id_or_path: str,
        device: str = "xpu",
        torch_dtype: torch.dtype = torch.bfloat16,
        optimize: bool = True,
        warmup: bool = False,
        scheduler: bool = True,
    ) -> None:
        """
        Initialize the model with the specified parameters.

        Args:
            model_id_or_path (str): The ID or path of the pre-trained model.
            device (str, optional): The device to run the model on. Defaults to "xpu".
            torch_dtype (torch.dtype, optional): The data type to use for the model. Defaults to torch.float16.
            optimize (bool, optional): Whether to optimize the model. Defaults to True.
        """
        self.device = device
        self.data_type = torch_dtype
        self.scheduler = scheduler
        self.generator = torch.Generator()  # .manual_seed(99)
        self.pipeline = self._load_pipeline(model_id_or_path, torch_dtype)
        if optimize:
            start_time = time.time()
            #print("Optimizing the model...")
            self.optimize_pipeline()
            #print(
            #    "Optimization completed in {:.2f} seconds.".format(
            #        time.time() - start_time
            #    )
            #)
        if warmup:
            self.warmup_model()

    def _load_pipeline(
        self, model_id_or_path: str, torch_dtype: torch.dtype
    ) -> StableDiffusionImg2ImgPipeline:
        """
        Load the pipeline for the model.

        Args:
            model_id_or_path (str): The ID or path of the pre-trained model.
            torch_dtype (torch.dtype): The data type to use for the model.

        Returns:
            StableDiffusionImg2ImgPipeline: The loaded pipeline.
        """
        print("Loading the model...")
        model_path = Path(f"/home/common/data/Big_Data/GenAI/{model_id_or_path}")
        
        if model_path.exists():
            #print(f"Loading the model from {model_path}...")
            load_path = model_path
        else:
            print("Using the default path for models...")
            load_path = model_id_or_path
            
        pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
            load_path,
            torch_dtype=torch_dtype,
            use_safetensors=True,
            variant="fp16",
        )
        if self.scheduler:
            pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
                pipeline.scheduler.config
            )
        if not model_path.exists():
            try:
                print(f"Attempting to save the model to {model_path}...")
                pipeline.save_pretrained(f"{model_path}")
                print("Model saved.")
            except Exception as e:
                print(f"An error occurred while saving the model: {e}. Proceeding without saving.")
        pipeline = pipeline.to(self.device)
        #print("Model loaded.")
        return pipeline

    def _optimize_pipeline(
        self, pipeline: StableDiffusionImg2ImgPipeline
    ) -> StableDiffusionImg2ImgPipeline:
        """
        Optimize the pipeline of the model.

        Args:
            pipeline (StableDiffusionImg2ImgPipeline): The pipeline to optimize.

        Returns:
            StableDiffusionImg2ImgPipeline: The optimized pipeline.
        """
        for attr in dir(pipeline):
            if isinstance(getattr(pipeline, attr), nn.Module):
                setattr(
                    pipeline,
                    attr,
                    ipex.optimize(
                        getattr(pipeline, attr).eval(),
                        dtype=pipeline.text_encoder.dtype,
                        inplace=True,
                    ),
                )
        return pipeline

    def optimize_pipeline(self) -> None:
        """
        Optimize the pipeline of the model.
        """
        self.pipeline = self._optimize_pipeline(self.pipeline)

    def get_image_from_url(self, url: str, path: str) -> Image.Image:
        """
        Get an image from a URL or from a local path if it exists.

        Args:
            url (str): The URL of the image.
            path (str): The local path of the image.

        Returns:
            Image.Image: The loaded image.
        """
        response = requests.get(url)
        if response.status_code != 200:
            raise Exception(
                f"Failed to download image. Status code: {response.status_code}"
            )
        if not response.headers["content-type"].startswith("image"):
            raise Exception(
                f"URL does not point to an image. Content type: {response.headers['content-type']}"
            )
        img = Image.open(BytesIO(response.content)).convert("RGB")
        img.save(path)
        img = img.resize((768, 512))
        return img

    def warmup_model(self):
        """
        Warms up the model by generating a sample image.
        """
        print("Setting up model...")
        start_time = time.time()
        image_url = "https://user-images.githubusercontent.com/786476/256401499-f010e3f8-6f8d-4e9f-9d1f-178d3571e7b9.png"
        try:
            self.generate_images(
                image_url=image_url,
                prompt="A beautiful day",
                num_images=1,
                save_path=".tmp",
            )
        except Exception:
            print("model warmup delayed...")
        #print(
        #    "Model is set up and ready! Warm-up completed in {:.2f} seconds.".format(
        #        time.time() - start_time
        #    )
        #)

    def get_inputs(self, prompt, batch_size=1):
        self.generator = [torch.Generator() for i in range(batch_size)]
        prompts = batch_size * [prompt]
        return {"prompt": prompts, "generator": self.generator}

    def generate_images(
        self,
        prompt: str,
        image_url: str,
        num_images: int = 5,
        num_inference_steps: int = 30,
        strength: float = 0.75,
        guidance_scale: float = 7.5,
        save_path: str = "image_to_image",
        batch_size: int = 1,
    ):
        """
        Generate images based on the provided prompt and variations.

        Args:
            prompt (str): The base prompt for the generation.
            image_url (str): The URL of the seed image.
            variations (List[str]): The list of variations to apply to the prompt.
            num_images (int, optional): The number of images to generate. Defaults to 5.
            num_inference_steps (int, optional): Number of noise removal steps.
            strength (float, optional): The strength of the transformation. Defaults to 0.75.
            guidance_scale (float, optional): The scale of the guidance. Defaults to 7.5.
            save_path (str, optional): The path to save the generated images. Defaults to "image_to_image".

        """
        input_image_path = "input.png"
        if validators.url(image_url):
            init_image = self.get_image_from_url(image_url, input_image_path)
        elif os.path.isfile(image_url):
            init_image = Image.open(image_url).convert("RGB")
        else:
            raise ValueError("The image_input is neither a valid URL nor a local file path.")
        init_images = [init_image for _ in range(batch_size)]

        generated_image_paths = []

        for i in range(0, num_images, batch_size):
            with torch.xpu.amp.autocast(
                enabled=True if self.data_type != torch.float32 else False,
                dtype=self.data_type,
            ):
                if batch_size > 1:
                    inputs = self.get_inputs(batch_size=batch_size, prompt=prompt)
                    images = self.pipeline(
                        **inputs,
                        image=init_images,
                        strength=strength,
                        guidance_scale=guidance_scale,
                        num_inference_steps=num_inference_steps,
                    ).images
                else:
                    images = self.pipeline(
                        prompt=prompt,
                        image=init_images,
                        strength=strength,
                        guidance_scale=guidance_scale,
                        num_inference_steps=num_inference_steps,
                    ).images

                for j in range(len(images)):
                    output_image_path = os.path.join(
                        save_path,
                        f"{'_'.join(prompt.split()[:3])}_{i+j}__{int(time.time() * 1e6)}.png",
                    )
                    images[j].save(output_image_path)
                    generated_image_paths.append(output_image_path)
                    
                return generated_image_paths



In [4]:
import time
import torch
from PIL import Image
from io import BytesIO
import requests
from pathlib import Path

class ComicGenerationModel:
    """
    ComicGenerationModel generates a series of images that tell a story based on an initial text prompt and a reference image.
    It combines text-to-image and image-to-image transformation capabilities.
    """

    def __init__(self, text_model_id, img_model_id, device='xpu', torch_dtype=torch.bfloat16, optimize=True, warmup=True):
        """
        Initialize both Text2ImgModel and Img2ImgModel within the comic generation model.

        Args:
            text_model_id (str): Identifier for the text-to-image model.
            img_model_id (str): Identifier for the image-to-image model.
            device (str, optional): Device to run the models on. Defaults to 'xpu'.
            torch_dtype (torch.dtype, optional): Data type for the models. Defaults to torch.bfloat16.
            optimize (bool, optional): Whether to optimize models on initialization. Defaults to True.
            warmup (bool, optional): Whether to warm up models on initialization. Defaults to True.
        """
        #self.text_to_img_model = Text2ImgModel(text_model_id, device, torch_dtype, optimize, warmup)
        self.img_to_img_model = Img2ImgModel(img_model_id, device, torch_dtype, optimize, warmup)

    def get_image_from_url(self, url: str) -> Image.Image:
        """
        Download an image from a URL.

        Args:
            url (str): URL of the image to download.

        Returns:
            Image.Image: The downloaded image.
        """
        #print(f'getting image for {url}')
        response = requests.get(url)
        if response.status_code != 200:
            raise Exception(f"Failed to download image. Status code: {response.status_code}")
        if not response.headers["content-type"].startswith("image"):
            raise Exception(f"URL does not point to an image. Content type: {response.headers['content-type']}")
        img = Image.open(BytesIO(response.content)).convert("RGB")
        return img

    def generate_comic(self, initial_prompt, reference_image_url, story_prompts, save_path="comic_output"):
        """
        Generate a comic story based on an initial prompt and a series of evolving story prompts.

        Args:
            initial_prompt (str): The initial text prompt for the first image.
            reference_image_url (str): URL of the reference image for the first image.
            story_prompts (list of str): A list of text prompts for subsequent images in the story.
            save_path (str, optional): Directory to save the generated comic images. Defaults to "comic_output".
        """
        if not isinstance(story_prompts, list) or len(story_prompts) != 4:
            raise ValueError("story_prompts must be a list of four strings")

        # Generate the first image from text and reference image
        #print('we are here')
        reference_image = self.get_image_from_url(reference_image_url)
        #first_image = self.text_to_img_model.generate_images(initial_prompt, num_images=1, save_path=save_path)[0]
        #print(f'calling generate_images with {initial_prompt},\n {reference_image_url}, \n{save_path} \n')
        first_image_list = self.img_to_img_model.generate_images(
            prompt=initial_prompt,
            image_url= reference_image_url, #reference_image,  # Use the downloaded image
            num_images=1,
            save_path=save_path
        )
        if not first_image_list:
            raise Exception("No images were generated.")
        
        first_image = first_image_list[0]
        #print(f'generated first image!\n')

        # Generate subsequent images based on previous image and new prompts
        previous_image = first_image
        for i, prompt in enumerate(story_prompts, start=1):
            #print(f'going for {i} prompt with {prompt}')
            subsequent_image = self.img_to_img_model.generate_images(prompt, previous_image, num_images=1, save_path=save_path)[0]
            previous_image = subsequent_image

        print(f"Comic story generated in {save_path}")

# Usage Example
# comic_gen = ComicGenerationModel('text_model_id_here', 'img_model_id_here')
# initial_prompt = "A wizard in a mystical forest"
# reference_image_url = "https://example.com/wizard_reference.jpg"
# story_prompts = ["The wizard encounters a talking tree", "A magical battle begins", "Victory and discovery of a hidden treasure", "The wizard's return to the village"]
# comic_gen.generate_comic(initial_prompt, reference_image_url, story_prompts)


In [5]:
import os
import random
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mp_img
import validators
from IPython.display import clear_output
import ipywidgets as widgets

model_cache = {}

def generate_comic_gui():
    out = widgets.Output()
    output_dir = "comic_output"
    model_ids = [
        "stabilityai/stable-diffusion-2-1",  
        "runwayml/stable-diffusion-v1-5",
        "stabilityai/sdxl-turbo",
    ]    
    
    model_dropdown = widgets.Dropdown(
        options=model_ids,
        value=model_ids[0],
        description="Model:",
    )    
    initial_prompt_text = widgets.Text(
        value="",
        placeholder="Enter the initial prompt",
        description="Initial Prompt:",
        layout=widgets.Layout(width="600px")
    )    
    reference_image_url_text = widgets.Text(
        value="",
        placeholder="Enter the reference image URL",
        description="Image URL:",
        layout=widgets.Layout(width="600px")
    )    
    story_prompts_textarea = widgets.Textarea(
        value="",
        placeholder="Enter story prompts separated by line breaks",
        description="Story Prompts:",
        layout=widgets.Layout(width="600px", height="100px")
    )    

    layout = widgets.Layout(margin="20px")
    button = widgets.Button(description="Generate Comic!", button_style="primary")   
    model_dropdown.layout.width = "70%"
    initial_prompt_text.layout.width = "100%"
    reference_image_url_text.layout.width = "100%"
    story_prompts_textarea.layout.width = "100%"
    button.layout.margin="0 0 0 400px"
    top_row = widgets.HBox([model_dropdown])
    middle_row = widgets.HBox([initial_prompt_text])
    middle_row_2 = widgets.HBox([reference_image_url_text])
    bottom_row = widgets.HBox([story_prompts_textarea])
    left_box = widgets.VBox([top_row, middle_row, middle_row_2, bottom_row])
    user_input_widgets = widgets.HBox([left_box], layout=layout)
    display(user_input_widgets)
    display(button)
    display(out)

    def on_submit(button):
        with out:
            clear_output(wait=True)
            button.button_style = "warning"
            print("\nOnce generated, comic will be saved to `./comic_output` dir, please wait...")
            selected_model_index = model_ids.index(model_dropdown.value)
            model_id = model_ids[selected_model_index]
            model_key = (model_id, "xpu")
            if model_key not in model_cache:
                #print('model not in model_cache')
                model_cache[model_key] = ComicGenerationModel(text_model_id=model_ids[0], img_model_id=model_ids[1], device="xpu")
            initial_prompt = initial_prompt_text.value
            reference_image_url = reference_image_url_text.value
            if not validators.url(reference_image_url):
                print(f'Invalid reference image URL: {reference_image_url}')
                return
            story_prompts = story_prompts_textarea.value.split('\n')
            model = model_cache[model_key]

            if not initial_prompt or not reference_image_url or len(story_prompts) != 4:
                print("Please provide a valid initial prompt, reference image URL, and exactly four story prompts.")
                return

            try:
                start_time = time.time()
                #print('start generate_comic')
                model.generate_comic(
                    initial_prompt,
                    reference_image_url,
                    story_prompts,
                    save_path="./comic_output",
                )
                clear_output(wait=True)
                display_generated_images(output_dir=output_dir)
            except KeyboardInterrupt:
                print("\nUser interrupted comic generation...")
            except Exception as e:
                print(f"An error occurred: {e}")
            finally:
                button.button_style = "primary"
                
    button.on_click(on_submit)

def display_generated_images(output_dir="comic_output"):
    image_files = [f for f in os.listdir(output_dir) if f.endswith((".png", ".jpg"))]    
    num_images = len(image_files)
    num_columns = int(np.ceil(np.sqrt(num_images)))
    num_rows = int(np.ceil(num_images / num_columns))
    fig, axs = plt.subplots(num_rows, num_columns, figsize=(10 * num_columns / num_columns, 10 * num_rows / num_rows))
    if num_images == 1:
        axs = np.array([[axs]])
    elif num_columns == 1 or num_rows == 1:
        axs = np.array([axs])
    for ax, image_file in zip(axs.ravel(), image_files):
        img = mp_img.imread(os.path.join(output_dir, image_file))
        ax.imshow(img)
        ax.axis("off")  # Hide axes
    for ax in axs.ravel()[num_images:]:
        ax.axis("off")
    plt.tight_layout()
    print(f"\nGenerated images...:")
    plt.show()

In [6]:
# Run all cells before running this section and wait a few seconds for UI to load
generate_comic_gui()

HBox(children=(VBox(children=(HBox(children=(Dropdown(description='Model:', layout=Layout(width='70%'), option…

Button(button_style='primary', description='Generate Comic!', layout=Layout(margin='0 0 0 400px'), style=Butto…

Output()