In [None]:
# Install required packages
!pip install -q transformers datasets sentencepiece
!pip install -q pytorch-lightning wandb
!pip install -q donut-python

# MPS acceleration is available on MacOS 12.3+
!pip3 install -q --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu

# !huggingface-cli login this shouldh be done from the terminal

In [None]:
from transformers import DonutProcessor, VisionEncoderDecoderModel

processor = DonutProcessor.from_pretrained("Jac-Zac/thesis_donut")
model = VisionEncoderDecoderModel.from_pretrained("Jac-Zac/thesis_donut")

In [None]:
print(sample)

In [None]:
import re
import csv
import os
import json
import torch
from tqdm.auto import tqdm
import numpy as np
import wandb
from torchvision.transforms import ToPILImage

from donut import JSONParseEvaluator
from datasets import load_dataset

import torch

# device = "cuda" if torch.cuda.is_available() else "cpu"

# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")
else:
    device = torch.device("mps")


# Modify the table columns to include the desired keys
table = wandb.Table(columns=["Image", "Prediction_Nome_verbatim", "Prediction_Date", "Prediction_Elevation", "Prediction_Locality", 
                             "Ground_Truth_Nome_verbatim", "Ground_Truth_Date", "Ground_Truth_Elevation", "Ground_Truth_Locality", 
                             "Name_Edit_Distance", "Scores"])
model.eval()
model.to(device)

output_list, accs = [], []

image_path = "img_resized"

dataset = load_dataset(image_path, split="validation")

api_key = "api_key"
wandb.login(key=api_key)
wandb.init(project="Donut", name="validation_set")

for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
    # Do some execption handling for wierd files
    try:
        # Load the image
        image = sample["image"].convert("RGB")
        
        # Check if the image is truncated
        image.load()
    except OSError as e:
        if "image file is truncated" in str(e):
            print(f"Warning: Skipping truncated image")
            continue
        else:
            raise
                
    # prepare encoder inputs
    pixel_values = processor(sample["image"].convert("RGB"), return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    # prepare decoder inputs
    task_prompt = "<s_herbarium>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
    decoder_input_ids = decoder_input_ids.to(device)
        
    # autoregressively generate sequence
    outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=model.decoder.config.max_position_embeddings,
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )

    # turn into JSON
    seq = processor.batch_decode(outputs.sequences)[0]
    seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
    seq = processor.token2json(seq)
    
    ground_truth = json.loads(sample["ground_truth"])
    evaluator = JSONParseEvaluator()
    score = evaluator.cal_acc(seq, ground_truth)

    accs.append(score)
    output_list.append(seq)
    
    # Avoid unexpected preditction errors:
    try:
        if score <= 0.2:
            image = sample["image"].convert("RGB").resize((900, 1200))
            pred = seq
            gt = json.loads(sample["ground_truth"])
 
            # Merge Day, Month, and Year for prediction and ground truth
            pred_date = f"{pred.get('Day', '')}/{pred.get('Month', '')}/{pred.get('Year', '')}"
            gt_date = f"{gt.get('Day', '')}/{gt.get('Month', '')}/{gt.get('Year', '')}"

            # Compute edit distance for Nome_verbatim
            name_edit_dist = evaluator.cal_acc(pred.get('Nome_verbatim', ''), gt.get('Nome_verbatim', ''))

            # Convert the image to a wandb.Image object
            image_wandb = wandb.Image(image)

            # Add data to the table in the desired format
            table.add_data(
                image_wandb,
                pred.get('Nome_verbatim', ''),
                pred_date,
                pred.get('Elevation', ''),
                pred.get('Locality', ''),
                gt.get('Nome_verbatim', ''),
                gt_date,
                gt.get('Elevation', ''),
                gt.get('Locality', ''),
                name_edit_dist,
                score
            )
    except TypeError as e:
        # Extract the file name from the file path
        print(f"Warning: Skipping sample number: {idx} due to error: {e}")
        
wandb.log({"worst_predictions": table})

scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)}
print(scores, f"length : {len(accs)}")
print("Median accuracy:", np.median(accs))

wandb.finish()

In [None]:
print("Mean accuracy:", np.mean(accs))

In [None]:
print("Mean accuracy:", np.mean(accs))

In [None]:
print("Mean accuracy:", np.median(accs))

In [None]:
mean_without_worst = np.mean(np.sort(accs)[10:])
print("Mean accuracy (excluding worst 10):", mean_without_worst)

In [None]:
# get indices of worst 5 predictions
worst_idxs = np.argsort(accs)[:100].tolist()

# prepare decoder inputs
task_prompt = "<s_herbarium>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids = decoder_input_ids.to(device)

for idx in worst_idxs:
    sample = dataset[idx]

    # prepare encoder inputs
    pixel_values = processor(sample["image"].convert("RGB"), return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)

    # autoregressively generate sequence
    outputs = model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids,
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # turn into JSON
    seq = processor.batch_decode(outputs.sequences)[0]
    seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
    seq = processor.token2json(seq)
    
    print(f"Ground Truth: {sample['ground_truth']}\n")
    print(f"Prediction: {seq}\n")
    print(f"Score: {accs[idx]}\n")
    display(sample["image"])