In [1]:
import pandas as pd
import numpy as np

df = pd.read_csv('../dataset/train.csv')
df = df.iloc[:1000]

In [2]:
import torch
torch.cuda.empty_cache()

# set cuda_launch_blocking to True in os
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
import re

from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)



VisionEncoderDecoderModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0-1): 2 x DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
           

In [4]:
image_links = df['image_link']
entity_names = df['entity_name']

# get image name from image link
def get_image_name(image_link):
    return image_link.split('/')[-1]


In [5]:
# import cv2
# def get_pixel_values(df, range_of_images):
#     df = df.iloc[range_of_images:range_of_images+100]
#     image_names = df['image_link'].apply(get_image_name)
#     image_paths = [os.path.join('../images', image_name) for image_name in image_names]
#     images = []
#     for image_path in image_paths:
#         if os.path.exists(image_path):
#             # Use PIL to open the image to check its format
#             try:
#                 img = Image.open(image_path)
#                 # Check if the image is a placeholder (black 100x100)
#                 if img.size == (100, 100) and img.getpixel((0, 0)) == (0, 0, 0):
#                     print(f"Skipping placeholder image: {image_path}")
#                     continue
#                 # Convert to RGB format if necessary
#                 if img.mode != 'RGB':
#                     img = img.convert('RGB')
#                 # Convert to numpy array for OpenCV
#                 img = np.array(img)
#                 # Convert color space for OpenCV
#                 img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
#                 images.append(img)
#             except Exception as e:
#                 print(f"Failed to process image: {image_path}, Error: {e}")
#         else:
#             print(f"Image file not found: {image_path}")
#     pixel_values = processor(images=images, return_tensors="pt").pixel_values
#     return pixel_values


# # prepare decoder input ids
# def get_decoder_input_ids(df, range_of_images):
#     df = df.iloc[range_of_images:range_of_images+100]
#     entity_names = df['entity_name'].tolist()
#     task_prompt = "What is the {entity_name}?"
#     prompts = [task_prompt.replace("{entity_name}", entity_name) for entity_name in entity_names]
#     decoder_input_ids = processor.tokenizer(prompts, add_special_tokens=False, return_tensors="pt").input_ids
#     return decoder_input_ids


In [6]:
# from PIL import Image


# for i in range(0, len(df), 20):
#     pixel_values = get_pixel_values(df, range_of_images=i)
#     decoder_input_ids = get_decoder_input_ids(df, range_of_images=i)
    
#     outputs = model.generate(
#         pixel_values.to(device),
#         decoder_input_ids=decoder_input_ids.to(device),
#         max_length=model.decoder.config.max_position_embeddings,
#         early_stopping=True,
#         pad_token_id=processor.tokenizer.pad_token_id,
#         eos_token_id=processor.tokenizer.eos_token_id,
#         use_cache=True,
#         num_beams=1,
#         bad_words_ids=[[processor.tokenizer.unk_token_id]],
#         return_dict_in_generate=True
#     )
    
#     sequences = processor.batch_decode(outputs.sequences)
#     predicted_values = []
    
#     for seq in sequences:
#         sequence = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
#         sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
#         predicted_values.append(processor.token2json(sequence))
    
#     predictions.loc[i:i+len(predicted_values)-1, 'predicted_entity_value'] = predicted_values
    
#     # Clear GPU memory
#     del pixel_values, decoder_input_ids, outputs, sequences, predicted_values
#     torch.cuda.empty_cache()

In [12]:
from PIL import Image
import cv2
# Create a new DataFrame with all columns from df and add a new column 'predicted_entity_value'
predicted_df = df.copy()
predicted_df['predicted_entity_value'] = None


for image_link in image_links:
    image_name = get_image_name(image_link)
    image_path = os.path.join('../images', image_name)
    
    if os.path.exists(image_path):
        try:
            img = Image.open(image_path)
            # Check if the image is a placeholder (black 100x100)
            if img.size == (100, 100) and img.getpixel((0, 0)) == (0, 0, 0):
                print(f"Skipping placeholder image: {image_path}")
                continue
            # Convert to RGB format if necessary
            if img.mode != 'RGB':
                img = img.convert('RGB')
            # Convert to numpy array for OpenCV
            img = np.array(img)
            # Convert color space for OpenCV
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            pixel_values = processor(images=img, return_tensors="pt").pixel_values
            decoder_input_ids = processor.tokenizer([], add_special_tokens=False, return_tensors="pt").input_ids
        except Exception as e:
            print(f"Failed to process image: {image_path}, Error: {e}")
            continue
    else:
        print(f"Image file not found: {image_path}")
        continue
    
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True
    )
    
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
    predicted_values = processor.token2json(sequence)
    
    predicted_df.loc[image_link, 'predicted_entity_value'] = predicted_values
    
    del pixel_values, decoder_input_ids, outputs, sequence, predicted_values
    torch.cuda.empty_cache()
    
    
    

Failed to process image: ../images/61I9XdN6OFL.jpg, Error: list index out of range
Failed to process image: ../images/71gSRbyXmoL.jpg, Error: list index out of range
Failed to process image: ../images/61BZ4zrjZXL.jpg, Error: list index out of range
Failed to process image: ../images/612mrlqiI4L.jpg, Error: list index out of range
Failed to process image: ../images/617Tl40LOXL.jpg, Error: list index out of range
Failed to process image: ../images/61QsBSE7jgL.jpg, Error: list index out of range
Failed to process image: ../images/81xsq6vf2qL.jpg, Error: list index out of range
Failed to process image: ../images/71DiLRHeZdL.jpg, Error: list index out of range
Failed to process image: ../images/91Cma3RzseL.jpg, Error: list index out of range
Failed to process image: ../images/71jBLhmTNlL.jpg, Error: list index out of range
Failed to process image: ../images/81N73b5khVL.jpg, Error: list index out of range
Failed to process image: ../images/61oMj2iXOuL.jpg, Error: list index out of range
Fail