In [1]:
from datasets import load_dataset, concatenate_datasets
import random
from PIL import Image
import requests
from io import BytesIO

random.seed(1315)

dataset_id = "allenai/pixmo-count"

In [None]:
# load dataset

ds_train = load_dataset(dataset_id, split="train")
ds_test = load_dataset(dataset_id, split="test") # only contains counts from 2 to 10

group_size = 56

ds_train = ds_train.filter(lambda x: x["count"]==1)
ds = ds_train.select(random.sample(range(len(ds_train)), group_size+10))
ds = concatenate_datasets([ds, ds_test.filter(lambda x: x["count"]<=5)])

# ds_i = [ds_train.filter(lambda x: x["count"]==i) for i in range(5)]
# ds_i = [d.select(random.sample(range(len(d)), group_size+10)) for d in ds_i]
# ds = concatenate_datasets(ds_i)

In [None]:
print(len(ds))

In [None]:
# construct input/output

counts = [0,0,0,0,0]

def generate_iopair(data):
    if counts[data["count"]-1] >= group_size:
        return None

    try:
        response = requests.get(data["image_url"], timeout=3)

        response.raise_for_status()

        image_input = Image.open(BytesIO(response.content))
    except:
        return None

    counts[data["count"]-1] += 1
    print(counts)

    return {
        "text": data["label"],
        "image": image_input,
        "label": data["count"]
    }

dataset = []
for item in ds:
    res = generate_iopair(item)
    if res is not None:
        dataset.append(res)

In [5]:
grouped_dataset = {label: [] for label in range(1, 6)}
for item in dataset:
    grouped_dataset[item["label"]].append(item)

interleaved_dataset = []
max_len = max(len(grouped_dataset[label]) for label in grouped_dataset)

for i in range(max_len):
    for label in range(1, 6):
        if i < len(grouped_dataset[label]):
            interleaved_dataset.append(grouped_dataset[label][i])

dataset = interleaved_dataset

In [None]:
import os
import json

output_dir = "dataset_images"
os.makedirs(output_dir, exist_ok=True)

metadata = []

for i, item in enumerate(dataset):
    image_filename = f"{i}.jpg"
    image_path = os.path.join(output_dir, image_filename)

    item["image"].convert("RGB").save(image_path)

    entry = {
        "label": item["text"],
        "count": item["label"],
        "index": i
    }
    metadata.append(entry)

with open("dataset.json", "w") as f:
    json.dump(metadata, f, indent=4)

print(f"Saved {len(dataset)} images to '{output_dir}/' and metadata to 'dataset.json'.")

In [None]:
import shutil
from google.colab import files

# Zip the dataset_images directory
shutil.make_archive('dataset_images', 'zip', 'dataset_images')

# Download the zip file
files.download('dataset_images.zip')
files.download('dataset.json')