In [1]:
from datasets import load_dataset, load_from_disk, Dataset
from collections import Counter
import pandas as pd
from sklearn.model_selection import train_test_split

data_root = "./data/RLAIF-V-Dataset"
data_file = [f'{data_root}/RLAIF-V-Dataset_{i:03d}.parquet' for i in range(1)]
data = load_dataset('parquet', data_files=data_file)
print(data)


DatasetDict({
    train: Dataset({
        features: ['ds_name', 'image', 'question', 'chosen', 'rejected', 'origin_dataset', 'origin_split', 'idx', 'image_path'],
        num_rows: 6814
    })
})


In [2]:
origin_dataset = data['train']['origin_dataset']
counts = Counter(origin_dataset)
print(counts)
total_len = len(origin_dataset)
propotions = {ds: num/total_len for ds, num in counts.items()}
print('==================================')
for ds, num in propotions.items():
    print(f'{ds}: {num:.2%}')
print('==================================')

Counter({'VQAv2': 1252, 'OK-VQA': 1249, 'LCS-558K': 1241, 'COCO': 1226, 'GQA': 471, 'OCR-VQA': 259, 'TextVQA': 248, 'sharegpt4v-web-celebrity': 224, 'sharegpt4v-textvqa': 218, 'sharegpt4v-web-landmark': 214, 'sharegpt4v-wikiart': 212})
OK-VQA: 18.33%
TextVQA: 3.64%
COCO: 17.99%
LCS-558K: 18.21%
sharegpt4v-wikiart: 3.11%
VQAv2: 18.37%
sharegpt4v-textvqa: 3.20%
sharegpt4v-web-celebrity: 3.29%
sharegpt4v-web-landmark: 3.14%
GQA: 6.91%
OCR-VQA: 3.80%


In [3]:
sample_num = 250*4
df = pd.DataFrame(data['train'])
_, df_sampled0 = train_test_split(df, test_size=sample_num, stratify=df['origin_dataset'], random_state=42)
df_sampled01, df_sampled02 = train_test_split(df_sampled0, test_size=sample_num//2, stratify=df_sampled0['origin_dataset'], random_state=42)
df_sampled1, df_sampled2 = train_test_split(df_sampled01, test_size=sample_num//4, stratify=df_sampled01['origin_dataset'], random_state=42)
df_sampled3, df_sampled4 = train_test_split(df_sampled02, test_size=sample_num//4, stratify=df_sampled02['origin_dataset'], random_state=42)

for df_sampled in (df_sampled1, df_sampled2, df_sampled3, df_sampled4):
    counts = Counter(df_sampled['origin_dataset'])
    total_count = len(df_sampled['origin_dataset'])
    proportions = {ds: count / total_count for ds, count in counts.items()}
    print("=====================================")
    print("Sampled dataset proportions: total count:", total_count)
    for ds, proportion in proportions.items():
        print(f"{ds}: {proportion:.2%}")
    print("=====================================")

Sampled dataset proportions: total count: 250
COCO: 18.00%
VQAv2: 18.40%
LCS-558K: 18.00%
sharegpt4v-textvqa: 3.20%
TextVQA: 3.60%
OK-VQA: 18.40%
GQA: 6.80%
OCR-VQA: 4.00%
sharegpt4v-web-landmark: 3.20%
sharegpt4v-web-celebrity: 3.20%
sharegpt4v-wikiart: 3.20%
Sampled dataset proportions: total count: 250
VQAv2: 18.40%
TextVQA: 3.60%
sharegpt4v-web-landmark: 3.20%
LCS-558K: 18.40%
COCO: 18.00%
OK-VQA: 18.40%
sharegpt4v-wikiart: 3.20%
OCR-VQA: 3.60%
GQA: 6.80%
sharegpt4v-web-celebrity: 3.20%
sharegpt4v-textvqa: 3.20%
Sampled dataset proportions: total count: 250
COCO: 18.00%
OCR-VQA: 3.60%
VQAv2: 18.40%
GQA: 7.20%
LCS-558K: 18.40%
sharegpt4v-web-landmark: 3.20%
OK-VQA: 18.00%
TextVQA: 3.60%
sharegpt4v-textvqa: 3.20%
sharegpt4v-web-celebrity: 3.20%
sharegpt4v-wikiart: 3.20%
Sampled dataset proportions: total count: 250
VQAv2: 18.40%
OK-VQA: 18.40%
COCO: 18.00%
sharegpt4v-web-celebrity: 3.60%
LCS-558K: 18.00%
OCR-VQA: 4.00%
GQA: 6.80%
TextVQA: 3.60%
sharegpt4v-web-landmark: 3.20%
sharegpt

In [4]:
save_paths = {
    'df_sampled1': './data/RLAIF_Sample/subset1',
    'df_sampled2': './data/RLAIF_Sample/subset2',
    'df_sampled3': './data/RLAIF_Sample/subset3',
    'df_sampled4': './data/RLAIF_Sample/subset4',
}
for path, df in zip(save_paths.keys(), [df_sampled1, df_sampled2, df_sampled3, df_sampled4]):
    dataset = Dataset.from_pandas(df)
    dataset.save_to_disk(save_paths[path])

Saving the dataset (0/1 shards):   0%|          | 0/250 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/250 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/250 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/250 [00:00<?, ? examples/s]