In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import requests
import torch 
import glob

model = AutoModelForCausalLM.from_pretrained(
    "anananan116/TinyVLM",
    trust_remote_code = True,
    torch_dtype=torch.float16,
    )

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
model.to(device).eval()


tokenizer = AutoTokenizer.from_pretrained("anananan116/TinyVLM")

# `<IMGPLH>` is the image placeholder which will be replaced by image embeddings. 
# the number of `<IMGPLH>` should be equal to the number of input images



In [None]:
#Scan for images in the input folder
images_path = glob.glob("assets/test_images/*.jpg") + glob.glob("assets/test_images/*.png") + glob.glob("assets/test_images/*.jpeg") + glob.glob("assets/test_images/*.webp") + glob.glob("assets/test_images/*.avif")

In [None]:
images = []
for one_image in images_path:
    image = Image.open(one_image)
    images.append(image)
prompt = "Here's an image:<IMGPLH>Describe this image."
inputs = model.prepare_input_ids_for_generation([prompt] * len(images), images, tokenizer)

In [None]:
with torch.no_grad():
    outputs = model.generate(
        input_ids=inputs['input_ids'].to(device), 
        attention_mask=inputs['attention_mask'].to(device), 
        encoded_image = inputs["encoded_image"], 
        max_new_tokens=128, 
        do_sample=True,
    )

output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [None]:
output_text = [text.split("assistant\n\n")[1] for text in output_text]

In [None]:
output_text

In [None]:
import matplotlib.pyplot as plt
import math
import textwrap

In [None]:
def plot_images_with_captions(images, captions, wrap_width=42):
    # Calculate the number of rows needed
    num_images = len(images)
    num_columns = 4
    num_rows = math.ceil(num_images / num_columns)
    
    # Create a figure with specified size
    fig, axes = plt.subplots(num_rows, num_columns, figsize=(15, num_rows * 4))
    axes = axes.flatten()

    for idx, (image, caption) in enumerate(zip(images, captions)):
        axes[idx].imshow(image)
        # Use textwrap to wrap the caption
        wrapped_caption = "\n".join(textwrap.wrap(caption, wrap_width))
        axes[idx].set_title(wrapped_caption, fontsize=12)
        axes[idx].axis('off')

    # Turn off any extra axes
    for ax in axes[num_images:]:
        ax.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
plot_images_with_captions(images, output_text)