# Notebook to run the model on unseen images

### Install necessary dependencies

In [1]:
# Install required packages
!pip install -q transformers datasets sentencepiece
!pip install -q pytorch-lightning wandb
!pip install -q donut-python

# !huggingface-cli login this shouldh be done from the terminal

## Resize the images
> Image 005294.jpg was wierd

I want to have the images in the correct size and flip them on the correct side

In [1]:
from PIL import Image, ImageOps
import shutil
import os

# Define the paths for the input and output directories
input_dir = "../donut_example/Immagini_Esposito"
output_dir = "img_resized/"
size = (1600,1200)

# Create the output directory if it does not exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Loop through all the image files in the input directory
for filename in os.listdir(input_dir):
    if filename.endswith(".jpg"):
        # Open the image and resize it
        with Image.open(os.path.join(input_dir, filename)) as img:
            
            # Resize the image to a specific size
            img = img.resize(size)
            
            # Automatically rotate the image based on its EXIF orientation metadata
            img = ImageOps.exif_transpose(img)
            
            # Check if the image is in landscape orientation
            if img.width > img.height:
                print(filename)
                
                # Rotate the image 90 degrees clockwise
                img = img.rotate(-90, expand=True)
            
            # Save the cropped and resized image to the output directory
            img.save(os.path.join(output_dir, filename))

In [2]:
from transformers import DonutProcessor, VisionEncoderDecoderModel

# Using the model that I think works the best and generalize which is epoch 9 of the last run (very similar to epoch 10)
processor = DonutProcessor.from_pretrained("Jac-Zac/thesis_test_donut",  revision="ba396d4b3d39a4eaf7c8d4919b384ebcf6f0360f")
model = VisionEncoderDecoderModel.from_pretrained("Jac-Zac/thesis_test_donut",  revision="ba396d4b3d39a4eaf7c8d4919b384ebcf6f0360f")

In [None]:
import re
import os
import json
import torch
from tqdm.auto import tqdm
import numpy as np
import random
from PIL import Image

from donut import JSONParseEvaluator
from datasets import load_dataset

device = "cuda" if torch.cuda.is_available() else "cpu"

model.eval()
model.to(device)

output_list = []
accs = []

image_path = "img_resized"

# Loop through all the image files in the input directory
for filename in os.listdir(images_path):
    if filename.endswith(".jpg"):
        # Load the image
        image = Image.open(os.path.join(images_path, filename))
        # Prepare encoder inputs
        pixel_values = processor(image.convert("RGB"), return_tensors="pt").pixel_values
        pixel_values = pixel_values.to(device)
        # prepare decoder inputs
        task_prompt = "<s_herbarium>"
        decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
        decoder_input_ids = decoder_input_ids.to(device)

        # autoregressively generate sequence
        outputs = model.generate(
                pixel_values,
                decoder_input_ids=decoder_input_ids,
                max_length=model.decoder.config.max_position_embeddings,
    #            early_stopping=True,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                use_cache=True,
                num_beams=1,
                bad_words_ids=[[processor.tokenizer.unk_token_id]],
                return_dict_in_generate=True,
            )

        # turn into JSON
        seq = processor.batch_decode(outputs.sequences)[0]
        seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
        seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
        seq = processor.token2json(seq)

        output_list.append({"filename": filename ,"prediction": seq})
    
# Save output to JSON file
output_file_path = "../output.json"  # Replace with your desired output file path
with open(output_file_path, "w") as f:
    json.dump(output_list, f)