# Maquinito 17 - Safari Sonoro

# Carga de librerías, importación de datos, preparación del sistema

In [None]:
from IPython.display import clear_output, HTML
!pip install --quiet --upgrade git+https://github.com/huggingface/diffusers.git git+https://github.com/huggingface/transformers.git accelerate
clear_output()

In [None]:
import matplotlib.pyplot as plt
import requests
import seaborn as sns
import torch

from diffusers import AudioLDM2Pipeline, DPMSolverMultistepScheduler
from google.colab import files
from IPython.display import Audio
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoProcessor, Kosmos2ForConditionalGeneration, pipeline

%config InlineBackend.figure_format = 'retina'

In [None]:
model_text_name = "microsoft/kosmos-2-patch14-224"
model_audio_name = "cvssp/audioldm2"

model_text = Kosmos2ForConditionalGeneration.from_pretrained(model_text_name)
processor_text = AutoProcessor.from_pretrained(model_text_name)
clear_output()

pipe_audio = AudioLDM2Pipeline.from_pretrained(model_audio_name, torch_dtype=torch.float16)
pipe_audio.scheduler = DPMSolverMultistepScheduler.from_config(pipe_audio.scheduler.config)
pipe_audio.to("cuda")
clear_output()

In [None]:
def image_with_boxes(original_image, entities):

  image_with_boxes = original_image.copy()
  draw = ImageDraw.Draw(image_with_boxes)

  # Image dimensions
  width, height = image.size

  colors = ['#425bde', '#bd92ea', '#ffdbff', '#f590bc', '#de425b']
  line_thickness = 5
  font = ImageFont.truetype("LiberationSans-Regular.ttf", 24)


  for index, (entity, boxes) in enumerate(entities):
      color = colors[index % len(colors)]
      for box in boxes:

          left, top, right, bottom = box
          absolute_left = left * width
          absolute_top = top * height
          absolute_right = right * width
          absolute_bottom = bottom * height
          draw.rectangle([absolute_left, absolute_top, absolute_right, absolute_bottom], outline=color, width=line_thickness)

          _,_,text_width, text_height = draw.textbbox((0,0), entity, font=font)
          background_top_left_x = absolute_left
          background_top_left_y = absolute_top
          background_bottom_right_x = absolute_left + text_width
          background_bottom_right_y = absolute_top + text_height

          draw.rectangle([background_top_left_x, background_top_left_y, background_bottom_right_x + line_thickness*2, background_bottom_right_y+ line_thickness*2],
                               fill=color)
          draw.text((absolute_left +line_thickness, absolute_top + line_thickness), entity, font=font, fill='white')


  sns.set()
  fig, axs = plt.subplots(1, 2, figsize=(12, 5))

  axs[0].imshow(image)
  axs[0].axis('off')
  axs[0].set_title('Imagen original')

  axs[1].imshow(image_with_boxes)
  axs[1].axis('off')
  axs[1].set_title(caption)

  plt.tight_layout()
  plt.show()

# Aquí empieza lo interesante

In [None]:
uploaded = files.upload()
image_filename = next(iter(uploaded))
image = Image.open(image_filename)

prompt = "<grounding>"

inputs = processor_text(text=prompt, images=image, return_tensors="pt")

generated_ids = model_text.generate(
    pixel_values=inputs["pixel_values"],
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    image_embeds=None,
    image_embeds_position_mask=inputs["image_embeds_position_mask"],
    use_cache=True,
    max_new_tokens=128,
)

generated_text = processor_text.batch_decode(generated_ids, skip_special_tokens=True)[0]
processed_text = processor_text.post_process_generation(generated_text, cleanup_and_extract=False)
caption, entities = processor_text.post_process_generation(generated_text)
entities = [(entity[0], entity[2]) for entity in entities]
html = f"<h1>{caption}</h1>"
display(HTML(html))
image_with_boxes(image, entities)
generator = torch.Generator("cuda").manual_seed(0)

negative_prompt = "Low quality, average quality."

audio = pipe_audio(caption, audio_length_in_s=10.24, negative_prompt=negative_prompt, generator=generator, num_inference_steps=40).audios[0]

Audio(audio, rate=16000)