In [5]:
import base64
import io

from datasets import load_dataset
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

import argilla as rg

client = rg.Argilla(
    # api_url="https://[your-owner-name]-[your_space_name].hf.space",
    api_key="argilla.apikey",
    # headers={"Authorization": f"Bearer {HF_TOKEN}"}
)
settings = rg.Settings(
    guidelines="The goal of this task is to classify a given image of a handwritten digit into one of 10 classes representing integer values from 0 to 9, inclusively.",
    fields=[
        rg.ImageField(
            name="image",
            title="An image of a handwritten digit.",
        ),
    ],
    questions=[
        rg.LabelQuestion(
            name="image_label",
            title="What digit do you see on the image?",
            labels=list(map(str, range(10))),
        )
    ],
)

dataset = rg.Dataset(
    name="image_classification_dataset",
    settings=settings,
)
dataset.create()
hf_dataset = load_dataset("ylecun/mnist", split="train[:100]")


def pil_to_data_uri(batch):
    data_uri = []
    for image in batch["image"]:
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        data_uri.append(f"data:image/png;base64,{img_str}")
    batch["image_data_uri"] = data_uri
    return batch


hf_dataset = hf_dataset.map(pil_to_data_uri, batched=True)
hf_dataset = hf_dataset.remove_columns("image")
dataset.records.log(records=hf_dataset, mapping={"image_data_uri": "image"})



Sending records...: 100%|██████████| 1/1 [00:00<00:00,  6.14batch/s]


DatasetRecords(Dataset(id=UUID('053e2008-ac5c-4b48-ae32-6b3c828dfd3d') inserted_at=datetime.datetime(2024, 8, 14, 7, 31, 1, 401743) updated_at=datetime.datetime(2024, 8, 14, 7, 31, 1, 506073) name='image_classification_dataset' status='ready' guidelines='The goal of this task is to classify a given image of a handwritten digit into one of 10 classes representing integer values from 0 to 9, inclusively.' allow_extra_metadata=False distribution=OverlapTaskDistributionModel(strategy='overlap', min_submitted=1) workspace_id=UUID('735cae0d-eb08-45c3-ad79-0a11ad4dd2c2') last_activity_at=datetime.datetime(2024, 8, 14, 7, 31, 1, 506073)))

In [4]:
for dataset in client.datasets.list():
    print(dataset)
    dataset.delete()

Dataset(id=UUID('3453d11b-f50a-4f62-8d38-09ac3e99bc00') inserted_at=datetime.datetime(2024, 8, 13, 13, 3, 43, 20806) updated_at=datetime.datetime(2024, 8, 13, 13, 7, 28, 549728) name='triggers_20240813150340' status='ready' guidelines=None allow_extra_metadata=False distribution=OverlapTaskDistributionModel(strategy='overlap', min_submitted=1) workspace_id=UUID('735cae0d-eb08-45c3-ad79-0a11ad4dd2c2') last_activity_at=datetime.datetime(2024, 8, 13, 13, 7, 28, 549369))
Dataset(id=UUID('a9922a5c-b7df-4f69-818e-663ed6f13785') inserted_at=datetime.datetime(2024, 8, 14, 7, 21, 47, 871032) updated_at=datetime.datetime(2024, 8, 14, 7, 21, 48, 11720) name='image_classification_dataset' status='ready' guidelines='The goal of this task is to classify a given image of a handwritten digit into one of 10 classes representing integer values from 0 to 9, inclusively.' allow_extra_metadata=False distribution=OverlapTaskDistributionModel(strategy='overlap', min_submitted=1) workspace_id=UUID('735cae0d-e