In [1]:
from datasets import load_dataset
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO
from hashlib import md5
from datasets import Dataset
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ds = load_dataset("lmms-lab/RefCOCO")

In [3]:
# Prints the train/val/test split with the features of each dataset
print(ds)

DatasetDict({
    val: Dataset({
        features: ['question_id', 'image', 'question', 'answer', 'segmentation', 'bbox', 'iscrowd', 'file_name'],
        num_rows: 8811
    })
    test: Dataset({
        features: ['question_id', 'image', 'question', 'answer', 'segmentation', 'bbox', 'iscrowd', 'file_name'],
        num_rows: 5000
    })
    testA: Dataset({
        features: ['question_id', 'image', 'question', 'answer', 'segmentation', 'bbox', 'iscrowd', 'file_name'],
        num_rows: 1975
    })
    testB: Dataset({
        features: ['question_id', 'image', 'question', 'answer', 'segmentation', 'bbox', 'iscrowd', 'file_name'],
        num_rows: 1810
    })
})


In [4]:
dataset = ds['val']

In [5]:
location_words = ['next to', 'right of', 'left of', 'under', 'below', 'above', 'against', 'by', 'beside', 'near', 'from']

In [7]:
# Allow only items with unique images and select 1000 of them
def filter_examples(dataset):
    unique_images = set()
    filtered_items = []

    for item in dataset:
        # Hash the image bytes directly
        img_hash = md5(item['image'].tobytes()).hexdigest()
        
        if img_hash not in unique_images:
            # Filter for answers that have location words in them
            item['allanswers'] = item['answer']
            answers = item['answer']
            result = next(
                (s for s in answers if any(re.search(rf'\b{re.escape(word)}\b', s.lower()) for word in location_words) and len(s.split()) > 1),
                None  # Default value if none found 
            )
            if result:
                # Replace the list of answers with the single qualifying answer
                item['answer'] = result  
                unique_images.add(img_hash)
                filtered_items.append(item)
            
        if len(filtered_items) == 1000:
            break
    return filtered_items

In [8]:
unique_dataset = filter_examples(dataset)

In [9]:
print(unique_dataset[0:5])

[{'question_id': '49369', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=480x640 at 0x2A3149E4050>, 'question': 'Please carefully observe the area circled in the image and come up with a caption for the area.', 'answer': 'black cat under sink', 'segmentation': [185.02999877929688, 507.739990234375, 228.86000061035156, 493.6000061035156, 250.38999938964844, 478.989990234375, 266.5299987792969, 478.2200012207031, 333.42999267578125, 486.67999267578125, 357.260009765625, 529.72998046875, 365.7200012207031, 551.260009765625, 404.1700134277344, 568.1799926757812, 423.3900146484375, 597.4000244140625, 421.0799865722656, 623.5399780273438, 408.010009765625, 640.0, 366.489990234375, 640.0, 329.5799865722656, 625.0800170898438, 331.1199951171875, 605.8499755859375, 334.20001220703125, 599.7000122070312, 321.1300048828125, 598.9299926757812, 286.5299987792969, 607.3900146484375, 239.6199951171875, 595.8599853515625, 225.77999877929688, 601.239990234375, 200.41000366210938, 595.8

In [12]:
# Convert items to the correct format for viper and pull images in .jpg format to a separate folder
def convert_matched_items_to_viper(items):
    viper_items = []
    counter = 0
    for item in items:
        viper_item = {
            "query": item["question"],
            "answer": item["answer"],
            "image_name": "refcoco-" + str(counter) + ".jpg",
            "segmentation": item["segmentation"],
            "bbox": item["bbox"],
            "allanswers": item["allanswers"]
        }
        image = item["image"]
        #image.save("refcoco_images/refcoco-" + str(counter) + ".jpg")
        viper_items.append(viper_item)
        counter += 1
    return viper_items

In [13]:
convert_matched_items_to_viper(unique_dataset)

[{'query': 'Please carefully observe the area circled in the image and come up with a caption for the area.',
  'answer': 'black cat under sink',
  'image_name': 'refcoco-0.jpg',
  'segmentation': [185.02999877929688,
   507.739990234375,
   228.86000061035156,
   493.6000061035156,
   250.38999938964844,
   478.989990234375,
   266.5299987792969,
   478.2200012207031,
   333.42999267578125,
   486.67999267578125,
   357.260009765625,
   529.72998046875,
   365.7200012207031,
   551.260009765625,
   404.1700134277344,
   568.1799926757812,
   423.3900146484375,
   597.4000244140625,
   421.0799865722656,
   623.5399780273438,
   408.010009765625,
   640.0,
   366.489990234375,
   640.0,
   329.5799865722656,
   625.0800170898438,
   331.1199951171875,
   605.8499755859375,
   334.20001220703125,
   599.7000122070312,
   321.1300048828125,
   598.9299926757812,
   286.5299987792969,
   607.3900146484375,
   239.6199951171875,
   595.8599853515625,
   225.77999877929688,
   601.239990234

In [14]:
# Turn the formatted dataset into a .csv file.
viper_items = convert_matched_items_to_viper(unique_dataset)

data_dict = {key: [item[key] for item in viper_items] for key in viper_items[0].keys()}
dataset = Dataset.from_dict(data_dict)
dataset.to_csv("refcoco_with_allanswers.csv")

Creating CSV from Arrow format: 100%|██████████| 1/1 [00:01<00:00,  1.54s/ba]


1047051