# Plot summaries generation
A complete process

## Image captioning

In [1]:
import torch
from transformers import GitForCausalLM, AutoProcessor, AutoConfig

base_checkpoint = 'microsoft/git-large-r-coco'
processor = AutoProcessor.from_pretrained(base_checkpoint)

config = AutoConfig.from_pretrained(base_checkpoint)
model_greedy = GitForCausalLM(config)
model_beam = GitForCausalLM(config)

In [2]:
# Check if running on GPU or CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# Load the pre-trained model
checkpoint_greedy = 'childrensimages-caption-20231108'
checkpoint_beam = 'childrensimages-caption-20231109'

if device == 'cuda':
    model_greedy.load_state_dict(torch.load(checkpoint_greedy))
    model_beam.load_state_dict(torch.load(checkpoint_beam))
elif device == 'cpu':
    model_greedy.load_state_dict(torch.load(checkpoint_greedy,map_location=torch.device('cpu')))
    model_beam.load_state_dict(torch.load(checkpoint_beam,map_location=torch.device('cpu')))

In [4]:
model_greedy.to(device)
model_beam.to(device)

GitForCausalLM(
  (git): GitModel(
    (embeddings): GitEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(1024, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (image_encoder): GitVisionModel(
      (vision_model): GitVisionTransformer(
        (embeddings): GitVisionEmbeddings(
          (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
          (position_embedding): Embedding(257, 1024)
        )
        (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (encoder): GitVisionEncoder(
          (layers): ModuleList(
            (0-23): 24 x GitVisionEncoderLayer(
              (self_attn): GitVisionAttention(
                (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
      

In [5]:
def clean_sentence(sentence):
    return sentence.replace('drawing of ','').replace('picture of ','').\
        replace('image of ','').replace('\'','\\\'').replace('\\\'','\'')

In [6]:
from PIL import Image

# Function to generate caption
def generate_captions(img_path):
    print('Reading image...')
    image = Image.open(img_path)
    new_size = (224, 224)
    # Resize the image
    resized_image = image.resize(new_size)
    
    # prepare image for the model
    inputs = processor(images=resized_image, return_tensors="pt").to(device)
    pixel_values = inputs.pixel_values
    
    # Greedy search model
    print('Generating caption with greedy model...')
    generated_ids_greedy = model_greedy.generate(pixel_values=pixel_values, max_length=50)
    generated_caption_greedy = processor.batch_decode(generated_ids_greedy, skip_special_tokens=True)[0]
    generated_caption_greedy = clean_sentence(generated_caption_greedy)
    
    # Beam search model
    print('Generating caption with beam model...')
    generated_ids_beam = model_beam.generate(pixel_values=pixel_values, max_length=50)
    generated_caption_beam = processor.batch_decode(generated_ids_beam, skip_special_tokens=True)[0]
    generated_caption_beam = clean_sentence(generated_caption_beam)    
    
    return resized_image, [generated_caption_greedy, generated_caption_beam]

## Plot summary generation

In [7]:
import openai
# TODO: Hide key
openai.api_key = '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'

In [8]:
# Function to call ChatGPT with instructions
def call_chatgpt(instructions):
    # Set a context for the ChatGPT API
    messages = [ {"role": "system", "content": "You are an intelligent assistant."} ]
    if instructions:
        messages.append(
            {"role": "user", "content": instructions},
        )
        chat = openai.ChatCompletion.create(
            model="gpt-3.5-turbo", messages=messages
        )
      
    reply = chat.choices[0].message.content
    messages.append({"role": "assistant", "content": reply})
    
    return reply

In [9]:
# Function to format caption to generate a command
def format_caption(idx, instructions):    
    print('Formatting caption {0} as an input command...'.format(idx))
    reply = call_chatgpt(instructions)
    if '{character 1}' in reply or '{character 2}' in reply or '{location}' in reply or '{theme}' in reply:
        reply = call_chatgpt(instructions)
    return reply

In [10]:
# Function to generate plot summary creation command
def generate_plot_input(base_instructions,captions=[]):
    print('Starting the plot input commands generation...')
    scratchplot_inputs = []
    for idx, caption in enumerate(captions):
        instructions = base_instructions + caption
        reply = format_caption(idx+1, instructions)
        scratchplot_inputs.append(reply)
    return scratchplot_inputs

In [11]:
# Function to generate plot summary
def generate_plot_summary(image_path, language='english'):
    
    # Given image path, generate captions (greedy, beam)
    pil_image, captions = generate_captions(image_path)
    
    base_instructions = 'If existing, identify subjects, location and main theme in the given sentence. Subjects will be called \"characters\" as we will generate a children\'s story based on the sentence content. If no subjects, location or main theme are provided, based on the sentence generate them, remember to only generate the ones that are missing in given sentence; these fields cannot be blank and all subjects or characters, locations and main themes should be suitable for children\'s stories. With previous information obtained from given sentence, complete the following command by replacing the curly brackets with characters (1, 2, 3, etc.), location and main theme, modifying the command to match as many characters are identified: \"Write a plot summary of a children\'s story featuring {character 1} and {character 2} in {location} with the main theme {theme}.\" Some sentences will mention a caption or title, do not include the caption in the command, use it only as a reference to obtain the main theme. Example of a sentence: a boy and mom with a dog. Generated command: \"Write a plot summary of a children\'s story featuring a boy and his mom and a dog in an enchanted mansion with the main theme Family and Mistery\". Please show only the generated command and follow the format exactly. Sentence to analyze is at the end of this command.'
           
    # Generate the plot input command
    plot_inputs = generate_plot_input(base_instructions, captions)
    print('Generated inputs:',plot_inputs)
    
    # For each plot input command generate a plot summary
    plot_security = ' As it is a story for children, it does not use discriminatory, offensive, racist, religious language, or any topic that incites violence or hatred.'
    plot_language = ' Write the output in ' + language
    plot_summaries = []
    for idx, pl_input in enumerate(plot_inputs):
        print('Generating plot summary for input command {0}...'.format(idx+1))
        plot_summaries.append(call_chatgpt(pl_input + plot_security + plot_language))
    
    print(plot_summaries)
    
    return pil_image, plot_summaries

## Show UI

In [13]:
import tkinter as tk
from tkinter import filedialog, scrolledtext
from PIL import Image, ImageTk

class LanguageManager:
    def __init__(self):
        self.language = "english"
        self.translations = {
            "english": {
                "selected_file": "Selected File:",
                "browse": "Browse",
                "generate_summary": "Generate Summary",
                "loading": "Loading...",
            },
            "spanish": {
                "selected_file": "Archivo seleccionado:",
                "browse": "Examinar",
                "generate_summary": "Generar Resumen",
                "loading": "Cargando...",
            },
        }

    def set_language(self, language):
        self.language = language

    def get_translation(self, key):
        return self.translations[self.language][key]

language_manager = LanguageManager()

# Language change function
def change_language(lang):
    language_manager.set_language(lang)
    load_data(selected_path)

def open_file_dialog():
    file_path = filedialog.askopenfilename(title=language_manager.get_translation("browse"))
    entry_var.set(file_path)

def close_and_save():
    global selected_path
    selected_path = entry_var.get()
    root.destroy()

def generate_plot_summary_ui():
    global selected_path
    selected_path = entry_var.get()

    # Disable buttons and show loading icon
    browse_button.config(state=tk.DISABLED)
    generate_button.config(state=tk.DISABLED)
    english_button.config(state=tk.DISABLED)
    spanish_button.config(state=tk.DISABLED)
    loading_label.config(text=language_manager.get_translation("loading"), font=("Helvetica", 12), fg="blue")

    # Schedule the loading logic after a delay
    root.after(100, lambda: load_data(selected_path, language_manager.language))

def load_data(selected_path, language):
    language_manager.set_language(language or language_manager.language)

    # Placeholder for the generate_plot_summary function
    pil_image, strings = generate_plot_summary(selected_path, language)

    # Resize the image to 100 x 100 pixels using LANCZOS
    pil_image = pil_image.resize((100, 100), Image.LANCZOS)

    # Display the resized PIL image in the Tkinter window
    tk_image = ImageTk.PhotoImage(pil_image)
    image_label.config(image=tk_image)
    image_label.image = tk_image

    # Display the strings in the Tkinter window
    string1_text.config(state=tk.NORMAL)
    string2_text.config(state=tk.NORMAL)
    string1_text.delete("1.0", tk.END)
    string2_text.delete("1.0", tk.END)
    string1_text.insert(tk.END, strings[0])
    string2_text.insert(tk.END, strings[1])
    string1_text.config(state=tk.DISABLED)
    string2_text.config(state=tk.DISABLED)

    # Enable buttons and hide loading icon after loading is complete
    browse_button.config(state=tk.NORMAL)
    generate_button.config(state=tk.NORMAL)
    english_button.config(state=tk.NORMAL)
    spanish_button.config(state=tk.NORMAL)
    loading_label.config(text="")

## Placeholder for the generate_plot_summary function
#def generate_plot_summary(file_path, language):
#    # Replace this with your actual implementation
#    # This function should return a PIL Image and an array of two strings
#    image = Image.open(file_path)
#    string1 = f"First String ({language}). This is a long piece of text that needs to be wrapped."
#    string2 = f"Second String ({language}). Another long piece of text that should be wrapped as well."
#    return image, [string1, string2]

# Create the main window
root = tk.Tk()
root.title("File Explorer")

# Create elements for the interface
label = tk.Label(root, text=language_manager.get_translation("selected_file"))
entry_var = tk.StringVar()
entry = tk.Entry(root, textvariable=entry_var, width=40)
browse_button = tk.Button(root, text=language_manager.get_translation("browse"), command=open_file_dialog)
generate_button = tk.Button(root, text=language_manager.get_translation("generate_summary"), command=generate_plot_summary_ui)
loading_label = tk.Label(root, text="", font=("Helvetica", 12), fg="blue")
image_label = tk.Label(root)
string1_text = scrolledtext.ScrolledText(root, wrap=tk.WORD, width=40, height=5, state=tk.DISABLED)
string2_text = scrolledtext.ScrolledText(root, wrap=tk.WORD, width=40, height=5, state=tk.DISABLED)

# Language buttons
english_button = tk.Button(root, text="English", command=lambda: language_manager.set_language("english"))
spanish_button = tk.Button(root, text="Español", command=lambda: language_manager.set_language("spanish"))

# Position the language buttons
english_button.grid(row=2, column=1, pady=10, padx=10, sticky=tk.W)
spanish_button.grid(row=2, column=2, pady=10, padx=10, sticky=tk.W)

# Position the elements in the window using grid
label.grid(row=0, column=0, padx=10, pady=10, sticky=tk.W)
entry.grid(row=1, column=0, padx=10, pady=10, sticky=tk.W)
browse_button.grid(row=2, column=0, padx=10, pady=10, sticky=tk.W)
generate_button.grid(row=3, column=0, padx=10, pady=10, sticky=tk.W, ipadx=10)
loading_label.grid(row=4, column=0, padx=10, pady=10, sticky=tk.W)
image_label.grid(row=5, column=0, padx=10, pady=10, rowspan=2, sticky=tk.W)
string1_text.grid(row=7, column=0, padx=10, pady=10, sticky=tk.W)
string2_text.grid(row=8, column=0, padx=10, pady=10, sticky=tk.W)

# Start the event loop
root.mainloop()


Reading image...
Generating caption with greedy model...
Generating caption with beam model...
Starting the plot input commands generation...
Formatting caption 1 as an input command...
Formatting caption 2 as an input command...
Generated inputs: ["Write a plot summary of a children's story featuring planets with the main theme Space.", "Write a plot summary of a children's story featuring a planet and its inhabitants in the vastness of space with the main theme Exploration and Adventure."]
Generating plot summary for input command 1...
Generating plot summary for input command 2...
['Título: El Viaje Espacial de los Planetas\n\nResumen de la trama:\nEl Viaje Espacial de los Planetas es una historia llena de aventuras que se desarrolla en el misterioso y fascinante universo del espacio. Los protagonistas son ocho planetas juguetones llamados Mercurio, Venus, Tierra, Marte, Júpiter, Saturno, Urano y Neptuno.\n\nUn día, los planetas deciden embarcarse en un emocionante viaje espacial pa