In [1]:
!pip install transformers torch torchvision einops timm peft sentencepiece flash_attn

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting peft
  Downloading peft-0.13.2-py3-none-any.whl.metadata (13 kB)
Collecting flash_attn
  Downloading flash_attn-2.6.3.tar.gz (2.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.6/2.6 MB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading peft-0.13.2-py3-none-any.whl (320 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m320.7/320.7 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: flash_attn
  Building wheel for flash_attn (setup.py) ... [?25ldone
[?25h  Created wheel for flash_attn: filename=flash_attn-2.6.3-cp310-cp310-linux_x86_64.whl size=187315346 sha256=6ebfbdcbdd1

In [73]:
import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix
import pandas as pd
import re
import json

# Load the model and tokenizer
model_path = 'h2oai/h2ovl-mississippi-800m'
# model_path = 'h2oai/h2ovl-mississippi-2b'

config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
# Disable flash attention
config.vision_config.use_flash_attn = False

model = AutoModel.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    config=config,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).eval().cuda()

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)

# Generation configuration for model
generation_config = dict(max_new_tokens=500, do_sample=False)

In [61]:
def parse_json_response(response):
    """
    Extract JSON object from a response string using regular expressions.
    Assumes the response contains a JSON object that follows the format {"type": ""}.
    """
    try:
        # Use regular expression to find the JSON object in the string
        json_match = re.search(r'\{.*?\}', response)
        if json_match:
            # Parse the matched JSON string
            json_str = json_match.group(0)
            return json.loads(json_str)  # Convert it to a Python dictionary
        else:
            print(f"Could not find valid JSON in response: {response}")
            return None
    except json.JSONDecodeError:
        print(f"Error decoding JSON from response: {response}")
        return None

# Evaluation function
def evaluate_model(model, tokenizer, generation_config, files, prompt):
    # Lists to store actual and predicted labels
    actual_labels = []
    predicted_labels = []
    
    # Iterate through each file and ground truth label in the files list
    for image_file, true_label in files:
        # Use the provided prompt for each file
        response, history = model.chat(tokenizer, image_file, prompt, generation_config, history=None, return_history=True)
        
        # Parse the response to extract the JSON object
        parsed_response = parse_json_response(response)
        
        if parsed_response and "type" in parsed_response:
            predicted_type = parsed_response["type"]
        else:
            predicted_type = ""  # Handle cases where parsing fails or JSON is incomplete
        
        # Append the actual and predicted labels
        actual_labels.append(true_label)
        predicted_labels.append(predicted_type)
    
    # Calculate accuracy
    accuracy = accuracy_score(actual_labels, predicted_labels)
    
    # Calculate confusion matrix
    conf_matrix = confusion_matrix(actual_labels, predicted_labels, labels=["invoice", "news-article", "resume"])
    
    # Display the results
    print(f"Accuracy: {accuracy * 100:.2f}%")
    print("Confusion Matrix:")
    
    # Create a dataframe for better readability of the confusion matrix
    conf_df = pd.DataFrame(conf_matrix, index=["invoice", "news-article", "resume"], columns=["invoice", "news-article", "resume"])
    
    return accuracy, conf_df

In [62]:
# Define the prompt for classification
prompt = """<image>
Extract the type of the image, categorizing it as 'invoice', 'resume', or 'news-article'. Return the result in the following JSON format:
{"type": "" }"""

# List of files to classify
files = [
    ("/kaggle/input/rvl-cdip-small/data/invoice/0000036371.tif", "invoice"),
    ("/kaggle/input/rvl-cdip-small/data/invoice/0000044003.tif", "invoice"),
    ("/kaggle/input/rvl-cdip-small/data/invoice/0000080966.tif", "invoice"),
    ("/kaggle/input/rvl-cdip-small/data/invoice/0000080967.tif", "invoice"),
    ("/kaggle/input/rvl-cdip-small/data/invoice/0000113780.tif", "invoice"),
    ("/kaggle/input/rvl-cdip-small/data/news article/0000002844.tif", "news-article"),
    ("/kaggle/input/rvl-cdip-small/data/news article/0000011128.tif", "news-article"),
    ("/kaggle/input/rvl-cdip-small/data/news article/0000014500.tif", "news-article"),
    ("/kaggle/input/rvl-cdip-small/data/news article/0000081773.tif", "news-article"),
    ("/kaggle/input/rvl-cdip-small/data/resume/0000000869.tif", "resume"),
    ("/kaggle/input/rvl-cdip-small/data/resume/0000037940.tif", "resume"),
    ("/kaggle/input/rvl-cdip-small/data/resume/0000134650.tif", "resume"),
]

# Example usage: evaluate the model on the dataset
accuracy, confusion_df = evaluate_model(model, tokenizer, generation_config, files, prompt)
print(confusion_df)

Could not find valid JSON in response: The image is a news-article. Therefore, the type of the image is 'news-article'.
Could not find valid JSON in response: The image is a resume. Therefore, the type of the image is 'resume'.
Could not find valid JSON in response: The image is a document titled "Research Staff – 1975", which is a type of "resume".
Accuracy: 66.67%
Confusion Matrix:
              invoice  news-article  resume
invoice             5             0       0
news-article        1             2       0
resume              0             0       1
