In [None]:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, TextStreamer, AutoTokenizer,AutoModelForVision2Seq
from PIL import Image
import requests
import torch
from datasets import load_dataset
from huggingface_hub import login
from tqdm import tqdm
import os
import google.generativeai as genai
import time
import re

In [None]:
login("YOUR_API_KEY")

In [None]:
ds = load_dataset("Ayush-Singh/llms-are-blind-captions_florence2-large")
ds

In [None]:
model_id = "BUAADreamer/PaliGemma-3B-Chat-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, device_map="cuda")

In [None]:
def get_accuracy(l1,l2):
    count = 0
    for x,y in zip(l1,l2):
        if x==y:
            count+=1
    return count/len(l1)*100

def generate(model,processor,tokenizer,dataset,prompt=None):
    results = []
    for sample in tqdm(dataset):
        image = sample['image']
        if prompt==None:
            prompt = sample['prompt']
        length = len(prompt)
        #print(prompt)
        pixel_values = processor(images=image, return_tensors="pt").to(model.device)["pixel_values"]

        messages = prompt
        
        input_ids = tokenizer.encode(messages, return_tensors="pt")
        image_token_id = tokenizer.convert_tokens_to_ids("<image>")
        image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
        input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)

        generation = model.generate(input_ids, pixel_values=pixel_values, max_new_tokens=10)
        generation = generation[0]
        decoded = processor.decode(generation, skip_special_tokens=True)[length:]
        #print(decoded)
        results.append(decoded)
        
    return results

### Geometry - Circled Letters

In [None]:
Sample_Circled_Letter = ds['Circled_Letter'].select(range(200))
Sample_Circled_Letter

In [None]:
result_cl = generate(model,processor,tokenizer,Sample_Circled_Letter)

In [None]:
print(result_cl)

In [None]:
result_cl = [item.strip() for item in result_cl]

In [None]:
print(result_cl)

In [None]:
count = 0
for x,y in zip(result_cl,Sample_Circled_Letter['groundtruth']):
    if (x==y):
        count+=1
print("Accuracy is = ", count/len(result_cl)*100)

### **Geometry** - Touching Circles

In [None]:
Sample_Touching_Circles = ds['Touching_Circles'].select(range(200))
Sample_Touching_Circles

In [None]:
result_tc = generate(model,processor,tokenizer,Sample_Touching_Circles)

In [None]:
print(result_tc)

In [None]:
count = 0
for x,y in zip(result_tc,Sample_Touching_Circles['groundtruth']):
    if (x=='yes' and y=='Yes') or (x=='no' and y=='No'):
        count+=1
print("Accuracy is = ", count/len(result_tc)*100)

### **Geometry** - Are Line Intersecting?

In [None]:
Sample_Line_Intersecting = ds['Line_Plot_Intersections'].select(range(200))
Sample_Line_Intersecting

In [None]:
prompt = "Are the two line plots intersecting? Answer with Yes/No. "
result_lp = generate(model,processor,tokenizer,Sample_Line_Intersecting,prompt)

In [None]:
print(result_lp)

In [None]:
count = 0
for x,y in zip(result_lp,Sample_Line_Intersecting['groundtruth']):
    if (x=="yes" and y!=0 or x=="no" and y==0):
        count+=1
print("Accuracy is = ", count/len(result_lp)*100)

### **Counting** - Olympic Rings

In [None]:
Sample_Olympic_Ring = ds['Olympic_Counting_Circles'].select(range(200))
Sample_Olympic_Ring

In [None]:
prompt = "Count the total number of circles in the image? "
result_or = generate(model,processor,tokenizer,Sample_Olympic_Ring)

In [None]:
print(result_or)

In [None]:
print("Accuracy is = ", get_accuracy(result_or,Sample_Olympic_Ring['groundtruth']))

### **Counting** - Number of Line Intersections

In [None]:
Sample_Line_Point_Int = ds['Line_Plot_Intersections'].select(range(200))
Sample_Line_Point_Int

In [None]:
prompt = "Count the number of points of intersection of red and blue lines?"
result_lpi = generate(model,processor,tokenizer,Sample_Line_Point_Int,prompt)

In [None]:
print(result_lpi)

In [None]:
print(Sample_Line_Point_Int['groundtruth'])

In [None]:
print("Accuracy is = ", get_accuracy(result_lpi,Sample_Line_Point_Int['groundtruth']))

### **Counting** - RowsxColumns

In [None]:
Sample_Rows = ds['Counting_Grid_Blank_Grids'].select(range(200))
Sample_Rows

In [None]:
prompt = "Count number of rows and number columns in grid "
result_cr = generate(model,processor,tokenizer,Sample_Rows,prompt)

In [None]:
print(result_cr[:20])

In [None]:
def extract_numbers(s):
    numbers = re.findall(r'\d+', s)
    return f'{numbers[0]},{numbers[1]}' if len(numbers) >= 2 else None

extracted_numbers = [extract_numbers(s) for s in result_cr if extract_numbers(s)]

In [None]:
extracted_numbers[:10]

In [None]:
print("Accuracy is = ", get_accuracy(extracted_numbers,Sample_Rows['groundtruth']))

### **Counting** - Subway Connections

In [None]:
Sample_Subway = ds['Subway_Connections'].select(range(10))
Sample_Subway

In [None]:
result_sc = generate(model,processor,tokenizer,Sample_Subway)

In [None]:
print(result_sc)

### **Counting** - Nested Squares

In [None]:
Sample_Nested = ds['Nested_Squares'].select(range(200))
Sample_Nested

In [None]:
prompt = "Count total number of squares in the image."
result_ns = generate(model,processor,tokenizer,Sample_Nested,prompt)

In [None]:
print(result_ns)

In [None]:
print("Accuracy is = ",get_accuracy(result_ns, Sample_Nested['groundtruth']))