# Notebook to run the model on unseen images

### Install necessary dependencies

In [None]:
# Install required packages
!pip install -q transformers datasets sentencepiece
!pip install -q pytorch-lightning wandb
!pip install -q donut-python

# !huggingface-cli login this shouldh be done from the terminal

## Resize the images

I want to have the images in the correct size and flip them on the correct side

In [None]:
from PIL import Image, ExifTags
import shutil
import os

# Define the paths for the input and output directories
input_dir = "img"
output_dir = "img_resized/"

# Create the output directory if it does not exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Loop through all the image files in the input directory
for filename in os.listdir(input_dir):
    if filename.endswith(".jpg"):
        # Open the image and resize it
        with Image.open(os.path.join(input_dir, filename)) as img:
            resized_img = img.resize((1600, 1200))
            
            # Check if the image has orientation metadata and rotate it if necessary
            for orientation in ExifTags.TAGS.keys():
                if ExifTags.TAGS[orientation] == 'Orientation':
                    if hasattr(img, '_getexif'):
                        exif = dict(img._getexif().items())
                        if exif[orientation] == 3:
                            resized_img = resized_img.rotate(180, expand=True)
                        elif exif[orientation] == 6:
                            resized_img = resized_img.rotate(270, expand=True)
                        elif exif[orientation] == 8:
                            resized_img = resized_img.rotate(90, expand=True)
                    break
            
            # Save the resized image to the output directory with correct orientation metadata
            resized_img.save(os.path.join(output_dir, filename), exif=img.info.get('exif'))

In [4]:
from transformers import DonutProcessor, VisionEncoderDecoderModel

# Using the model that I think works the best and generalize which is epoch 9 of the last run (very similar to epoch 10)
processor = DonutProcessor.from_pretrained("Jac-Zac/thesis_test_donut",  revision="8c5467cb66685e801ec6ff8de7e7fdd247274ed0")
model = VisionEncoderDecoderModel.from_pretrained("Jac-Zac/thesis_test_donut",  revision="8c5467cb66685e801ec6ff8de7e7fdd247274ed0")

In [8]:
import re
import json
import torch
from tqdm.auto import tqdm
import numpy as np
import random
from PIL import Image

from donut import JSONParseEvaluator
from datasets import load_dataset

device = "cuda" if torch.cuda.is_available() else "cpu"

model.eval()
model.to(device)

output_list = []
accs = []

image_path = "img_resized"

# Load the dataset as a dataset
dataset = load_dataset(image_path,)

# Iterate over all the images
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
    # prepare encoder inputs
    pixel_values = processor(sample["image"].convert("RGB"), return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    # prepare decoder inputs
    task_prompt = "<s_herbarium>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
    decoder_input_ids = decoder_input_ids.to(device)
    
    # autoregressively generate sequence
    outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=model.decoder.config.max_position_embeddings,
#            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )

    # turn into JSON
    seq = processor.batch_decode(outputs.sequences)[0]
    seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
    seq = processor.token2json(seq)

    output_list.append({"sample_id": idx, "prediction": seq})
    
# Save output to JSON file
output_file_path = "../output.json"  # Replace with your desired output file path
with open(output_file_path, "w") as f:
    json.dump(output_list, f)

Resolving data files:   0%|          | 0/1553 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/138 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/139 [00:00<?, ?it/s]

Found cached dataset imagefolder (/Users/jaczac/.cache/huggingface/datasets/imagefolder/img_resized-7f5590504a871c24/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)


  0%|          | 0/137 [00:00<?, ?it/s]