In [None]:
import torch
import json
import textwrap

#repo = "models/kosmos-2.5"
repo = "models/kosmos-2-patch14-224"

# Open and read the JSON file with
json_path = repo + "/config.json"
with open(json_path, 'r') as file:
    config_data = json.load(file)

if repo == "models/kosmos-2.5":
    from transformers import AutoProcessor, AutoConfig, AutoModelForVision2Seq
    
if repo == "models/kosmos-2-patch14-224":
    from transformers import AutoProcessor, AutoModelForImageTextToText

import fitz  # PyMuPDF is imported as fitz
from PIL import Image

In [None]:
torch.cuda.is_available()

In [None]:
path = "data/Readme.pdf"
doc = fitz.open(path)
page = doc.load_page(0)

print(page.get_text("text"))
print(page.get_text("dict"))
print(page.get_text("html"))
print(page.get_fonts())


pixmap = page.get_pixmap(dpi=500)

image = Image.frombytes("RGB", [pixmap.width, pixmap.height], pixmap.samples)
image

In [4]:
device = "cuda"
dtype = torch.float16

if repo == "models/kosmos-2.5":
    config = AutoConfig.from_pretrained(repo)
    # Load model directly
    model = AutoModelForVision2Seq.from_pretrained(
        repo, device_map=device, torch_dtype=dtype, config=config
    )
    processor = AutoProcessor.from_pretrained(repo)
if repo == "models/kosmos-2-patch14-224":
    # Load model directly
    processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
    model = AutoModelForImageTextToText.from_pretrained("microsoft/kosmos-2-patch14-224")
    
# Move model to the correct device and dtype
model.to(device)
model = model.half() if dtype == torch.float16 else model.float()

In [None]:
inputs = processor(
    images=image,
    text="Extract the exact text from the image to markdown format <md>",
    return_tensor="pt"
)

# Safely retrieve the height and width if they exist
height = inputs.pop("height", None)
width = inputs.pop("width", None)

if height is None or width is None:
    print("Height or width not found in inputs.")
else:
    print(f"Height: {height}, Width: {width}")

# Safely handle 'flattened_patches'
if "flattened_patches" in inputs:
    inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
else: 
    print("'flattened_patches' key not found in inputs.")

import numpy as np

# Handle lists, numpy arrays, and simple types separately
def convert_to_tensor(value):
    if isinstance(value, np.ndarray):
        return torch.tensor(value, dtype=dtype).to(device)
    elif isinstance(value, list):
        return torch.stack([convert_to_tensor(item) if isinstance(item, (np.ndarray, int, float)) else item.to(device) for item in value])
        #return [convert_to_tensor(item) for item in value]
    elif isinstance(value, (int, float)):
        return torch.tensor(value, dtype=dtype).to(device)
    elif value is not None:
        return value.to(device)
    else:
        return None
    
inputs = {k: convert_to_tensor(v) for k, v in inputs.items()}

# Ensure input_ids and attention_mask are in the correct dtype
if "input_ids" in inputs:
    inputs["input_ids"] = inputs["input_ids"].to(torch.long)
if "attention_mask" in inputs:
    inputs["attention_mask"] = inputs["attention_mask"].to(torch.long)
if "image_embeds_position_mask" in inputs:
    inputs["image_embeds_position_mask"] = inputs["image_embeds_position_mask"].to(torch.long)

# Check and fix batch size consistency
batch_size = inputs["pixel_values"].shape[0]
for key, tensor in inputs.items(): 
    if tensor is not None and tensor.shape[0] != batch_size:
        # Add a new dimension and repeat to match the batch size if needed
        if tensor.dim() == 1:
            tensor = tensor.unsqueeze(0).repeat(batch_size, 1)
        elif tensor.dim() == 2:
            tensor = tensor.unsqueeze(0).repeat(batch_size, 1, 1)
        else:
            tensor = tensor.unsqueeze(0).repeat(batch_size, *[1]*(tensor.dim()-1))
        inputs[key] = tensor

# Check shapes and types
for key, tensor in inputs.items():
    if tensor is not None:
        print(f"Key: {key}, Shape: {tensor.shape}, Type: {tensor.dtype}")

In [6]:
# Generate IDs
generated_ids = model.generate(
    **inputs,
    max_new_tokens=4096,
    repetition_penalty=1.05
)

In [None]:
generated_ids

In [None]:
generated_text = processor.batch_decode(
    generated_ids,
    skip_special_tokens=True
)

print("\n".join(textwrap.wrap(generated_text[0], width=80)))