In [3]:
import sys
!{sys.executable} -m pip install openai --user



In [11]:
from io import BytesIO
import os
import random
import time
import warnings
from pathlib import Path
from typing import List, Dict, Tuple
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

import requests
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex  # Used for optimizing PyTorch models
from PIL import Image

import ipywidgets as widgets
from IPython.display import clear_output
from IPython.display import Image as IPImage

from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import openai as OpenAI
from PIL import Image
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import logging
logging.getLogger().setLevel(logging.CRITICAL)

OpenAI.api_key = ""
model_dir = "/home/common/data/Big_Data/GenAI"

In [5]:
class Text2ImgModel:
    """
    Text2ImgModel is a class for generating images based on text prompts using a pretrained model.

    Attributes:
    - device: The device to run the model on. Default to "xpu" - Intel dGPUs.
    - pipeline: The loaded model pipeline.
    - data_type: The data type to use in the model.
    """

    def __init__(
        self,
        model_id_or_path: str,
        device: str = "xpu",
        torch_dtype: torch.dtype = torch.bfloat16,
        optimize: bool = True,
        enable_scheduler: bool = False,
        warmup: bool = False,
    ) -> None:
        """
        The initializer for Text2ImgModel class.

        Parameters:
        - model_id_or_path: The identifier or path of the pretrained model.
        - device: The device to run the model on. Default is "xpu".
        - torch_dtype: The data type to use in the model. Default is torch.bfloat16.
        - optimize: Whether to optimize the model after loading. Default is True.
        """

        self.device = device
        self.pipeline = self._load_pipeline(model_id_or_path, torch_dtype, enable_scheduler)
        self.data_type = torch_dtype
        if optimize:
            start_time = time.time()
            self.optimize_pipeline()
        if warmup:
            self.warmup_model()

    def _load_pipeline(
        self,
        model_id_or_path: str,
        torch_dtype: torch.dtype,
        enable_scheduler: bool,

    ) -> DiffusionPipeline:
        """
        Loads the pretrained model and prepares it for inference.

        Parameters:
        - model_id_or_path: The identifier or path of the pretrained model.
        - torch_dtype: The data type to use in the model.

        Returns:
        - pipeline: The loaded model pipeline.
        """

        print("Creating a new story...")
        model_path = Path(f"{model_dir}/{model_id_or_path}")  
        
        if model_path.exists():
            load_path = model_path
        else:
            load_path = model_id_or_path

        pipeline = DiffusionPipeline.from_pretrained(
            load_path,
            torch_dtype=torch_dtype,
            use_safetensors=True,
            variant="fp16",
        )
        if enable_scheduler:
            pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
                pipeline.scheduler.config
            )
        if not model_path.exists():
            try:
                print(f"Attempting to save the model to {model_path}...")
                pipeline.save_pretrained(f"{model_path}")
                print("Model saved.")
            except Exception as e:
                print(f"An error occurred while saving the model: {e}. Proceeding without saving.")
        pipeline = pipeline.to(self.device)
        return pipeline

    def _optimize_pipeline(self, pipeline: DiffusionPipeline) -> DiffusionPipeline:
        """
        Optimizes the model for inference using ipex.

        Parameters:
        - pipeline: The model pipeline to be optimized.

        Returns:
        - pipeline: The optimized model pipeline.
        """

        for attr in dir(pipeline):
            try:
                if isinstance(getattr(pipeline, attr), nn.Module):
                    setattr(
                        pipeline,
                        attr,
                        ipex.optimize(
                            getattr(pipeline, attr).eval(),
                            dtype=pipeline.text_encoder.dtype,
                            inplace=True,
                        ),
                    )
            except AttributeError:
                pass
        return pipeline

    def warmup_model(self):
        """
        Warms up the model by generating a sample image.
        """
        print("Setting up model...")
        start_time = time.time()
        self.generate_images(
            prompt="A beautiful sunset over the mountains",
            num_images=1,
            save_path=".tmp",
        )
        print(
            "Model is set up and ready! Warm-up completed in {:.2f} seconds.".format(
                time.time() - start_time
            )
        )

    def optimize_pipeline(self) -> None:
        """
        Optimizes the current model pipeline.
        """

        self.pipeline = self._optimize_pipeline(self.pipeline)

    def generate_images(
        self,
        prompt: str,
        num_inference_steps: int = 200,
    ) -> Image.Image:
        
        with torch.xpu.amp.autocast(enabled = True if self.data_type != torch.float32 else False,dtype=self.data_type,):
            image = self.pipeline(prompt = prompt, num_inference_steps = num_inference_steps,).images[0]
            
        return image


In [6]:
def get_completion(prompt, text, model = "gpt-3.5-turbo"):
    messages = [{"role": "user", "content": prompt}]
    response = ""
    try:
        response = OpenAI.chat.completions.create(model = model, messages = messages, temperature = 0,)
    except:
        print("\n\nOpenAI LLM is not responding. Proper story can't be generated...")
        return text
    return response.choices[0].message.content


def get_prompt(input):
    message = f"""
    A part of a story will be provided to you and you have to generate a simple prompt that describes the \
    scenario in that part of the story such that the part of the story can be explained in an image generated by the prompt generated by you.
    
    Here are some rules u have to follow while generating the prompts: 
    1. The prompt be strictly less than 70 words.
    2. Don't include special characters other than comma and hyphen and dot.
    3. You just have to describe the scenario not write the whole story.
    4. Always include "A colored cartoon type sketch of," at the start of every prompt.
    5. The very important one, write in crisp and very simple english, don't use complicated words.
    6. Separate the different traits of the scenario with commas.
    7. If you can't understand the story or text, just write whatever you think the situation could be in the text.
    
    Here are some examples on how to generate the prompt:
    
    Example story paragraph:
    Once upon a time, in a not-so-distant future, there lived a man named Alex. Alex was an adventurous soul who dreamed of exploring the great \
    unknown: outer space. From a young age, he would gaze up at the stars with wonder, imagining what it would be like to journey among them.
    Expected text from you is:
    A colored cartoon type sketch of, a man looking up in the sky at night, sky has stars & moon.
    
    Example story paragraph:
    As the days turned into weeks, Maya forged friendships with the creatures of the jungle. She shared moments of laughter with mischievous monkeys, \
    and learned the ancient wisdom of wise old elephants. Together, they explored hidden caves and winding rivers, each new discovery fueling Maya's sense of wonder.
    Expected text from you is:
    A colored cartoon type sketch of, A girl laughing with monkeys, old elephants, hidden caves, winding rivers.
    
    Example story paragraph:
    In the heart of a bustling metropolis, where skyscrapers kissed the sky and streets hummed with the \
    rhythm of life, there existed a city like no other. Its streets were a labyrinth of winding alleys and bustling boulevards, \
    lined with towering buildings that reached for the clouds.
    Expected prompt generated from you is:
    A colored cartoon type sketch of, a metropolitan city, high skycrapers, streets, sky with clouds.
    
    Example story paragraph:
    what's up
    Since the paragraph is vague to understand, you can assume that a person is saying what's up to another person, for this the expected \
    text generated by you is:
    A colored cartoon type sketch of, two person speaking.
    
    Further rules:
    Please don't generate more than 70 words, this is a must.
    Please note that all the above examples the generated prompts were less than 20 words, you must also generate the prompts strictly less than 70 words.
    
    I just want the prompt from you not the explanation of why you generated that prompt.
    Now the actual story paragraph for which the prompt is to be generated is the text delimited by triple backticks
    Text:
    ```{input}```
    """
    return get_completion(message, input)

def get_story_from_plot(plot):
    message = f"""
    Write a creative story for a class of small childrens based on the plot provided in text delimited by triple backticks in \
    3000 words.
    Text:
    ```{plot}```
    """
    return get_completion(message, plot)

In [7]:
def display_plot(text, image):
    plt.figure(figsize=(5, 5))
    plt.imshow(image)
    plt.axis('off')
    plt.show()
    print(text)

In [8]:
import re

def split_paragraphs(story_text, max_words_per_paragraph=150):
    if story_text is None:
        story_text = "a man"
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', story_text.strip())

  
    current_paragraph = ''
    paragraphs = []

    for sentence in sentences:
        current_paragraph += sentence.strip() + ' '

        if len(current_paragraph.split()) > max_words_per_paragraph:
            sentences_in_paragraph = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', current_paragraph.strip())
            current_paragraph = ' '.join(sentences_in_paragraph[:-1]).strip()

            paragraphs.append(current_paragraph)

            current_paragraph = sentence.strip() + ' '

    if current_paragraph:
        paragraphs.append(current_paragraph)

    return paragraphs

In [9]:
model_cache = {}

def generate_story():
    out = widgets.Output()
    model_ids = [
        "stabilityai/stable-diffusion-2-1",
        "CompVis/stable-diffusion-v1-4",
    ]
    model_dropdown = widgets.Dropdown(options = model_ids, value = model_ids[0], description = "Select Model:",) 
    prompt_text = widgets.Text(value="", placeholder = "Enter the plot", description = "Story plot:", layout = widgets.Layout(width = "600px"))
        
    layout = widgets.Layout(margin = "10px")
    button1 = widgets.Button(description = "Generate Story", button_style = "primary")
    button2 = widgets.Button(description = "Clear Story", button_style = "primary")
    model_dropdown.layout.width = "50%"
    prompt_text.layout.width = "600px"
    button1.layout.margin = "0 0 0 100px"
    button1.layout.width = "150px"
    button2.layout.margin = "0 0 0 300px"
    button2.layout.width = "120px"
    top_row = widgets.HBox([model_dropdown])
    bottom_row = widgets.HBox([prompt_text])
    top_box = widgets.VBox([top_row, bottom_row])
    user_input_widgets = widgets.HBox([top_box], layout = layout)
    bottom_box = widgets.HBox([button1, button2], layout = layout)
    display(user_input_widgets)
    display(bottom_box)
    display(out)

    
    def generate_image(button):
        clear_output(wait = True)
        print("Creating a new story...")
        story = get_story_from_plot(prompt_text.value)
        print(f"Story : {story}")
        partial_stories = split_paragraphs(story)
        for i, parts in enumerate(partial_stories):
            print(f"Part {i} : {parts}")
            with out:
                button.button_style = "warning"
                selected_model_index = model_ids.index(model_dropdown.value)
                model_id = model_ids[selected_model_index]
                model_key = (model_id, "xpu")
                if model_key not in model_cache:
                    model_cache[model_key] = Text2ImgModel(model_id, device = "xpu")
                model = model_cache[model_key]
                prompt = get_prompt(parts)
                if not prompt:
                    prompt = " "  
                try:
                    start_time = time.time()
                    
                    image = model.generate_images(
                        prompt,
                        num_inference_steps = 200,
                    )
                    display_plot(parts, image)
                except KeyboardInterrupt:
                    print("\nUser interrupted image generation...")
                except Exception as e:
                    print(f"An error occurred: {e}")
                finally:
                    button.button_style = "primary"

    def end_story(button):
        with out:
            clear_output(wait = True)
            print("Creating a new story...")
            
    button1.on_click(generate_image)
    button2.on_click(end_story)


In [12]:
generate_story()

HBox(children=(VBox(children=(HBox(children=(Dropdown(description='Select Model:', layout=Layout(width='50%'),…

HBox(children=(Button(button_style='primary', description='Generate Story', layout=Layout(margin='0 0 0 100px'…

Output()