# Generate the subsets of the dataset for calculating the LDS
To calculate the LDS (Linear Datamodeling Score), we need a large number of diffusion models each trained over different subsets of the whole dataset (in our case, CIFAR-10).

DTRAK uses 64 subsets, each covering 50% of the orignal dataset, then fine tunes 9 stable diffusion models on each subset. Unfortunately this is infeasible for our hardware, so we instead use 32 subsets of 50% coverage, and train a single diffusion model on each.

In [1]:
from utils.config import CIFAR_10_Config, Project_Config
DATASET_NAME = "cifar-10"
dataset_config = CIFAR_10_Config()
project_config = Project_Config()

SAVE_ORIGINAL_IMAGES_TO_DISK = True

  from .autonotebook import tqdm as notebook_tqdm
  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")


In [2]:
#dataset_config.dataset.save_to_disk("./datasets/cifar-10/all")

In [3]:
NUM_SUBSETS = 32

NUM_CLASSES = 10
ROWS_PER_CLASS = int(dataset_config.dataset.num_rows/ NUM_CLASSES)

SUBSET_SIZE_ALPHA = 0.5
SUBSET_SIZE_TOTAL = int(dataset_config.dataset.num_rows * SUBSET_SIZE_ALPHA)
ROWS_PER_CLASS_SUBSET = int(ROWS_PER_CLASS*SUBSET_SIZE_ALPHA)

assert(SUBSET_SIZE_TOTAL == ROWS_PER_CLASS_SUBSET * NUM_CLASSES)
print(f"Rows per class per subset {ROWS_PER_CLASS_SUBSET}")

Rows per class per subset 2500


In [4]:
dataset_config.dataset = dataset_config.dataset.sort(column_names=dataset_config.caption_column)

from pathlib import Path

base_path = ("." + project_config.folder_symbol +
             "datasets" + project_config.folder_symbol +
             DATASET_NAME + project_config.folder_symbol)

base_path_images = (base_path + project_config.folder_symbol +
                    "train" + project_config.folder_symbol)

Path(base_path_images).mkdir(parents=True, exist_ok=True)

for class_caption in dataset_config.class_captions:
    Path(base_path_images+class_caption+project_config.folder_symbol).mkdir(parents=True, exist_ok=True)

if SAVE_ORIGINAL_IMAGES_TO_DISK:
    i = 0
    for item in dataset_config.dataset:
        item[dataset_config.image_column].save(f"{base_path_images}{item[dataset_config.caption_column]}{project_config.folder_symbol}{i}.png")
        i = i+1
        if i >= ROWS_PER_CLASS:
            i=0


In [5]:
import numpy as np
rng = np.random.default_rng(42)


#https://huggingface.co/docs/hub/en/datasets-manual-configuration
"""
---
configs:
- config_name: subset_0
  data_files:
  - split: train
    path: 
    - "data/*.csv"
---
"""

file_lines = []
file_lines.append("---")
file_lines.append("configs:")

for subset in range(NUM_SUBSETS):
    file_lines.append(f"  - config_name: subset_{subset}")
    file_lines.append("    drop_labels: false")
    file_lines.append("    data_files:")
    file_lines.append("      - split: train")
    file_lines.append("        path:")
    for class_caption in dataset_config.class_captions:
        class_i = rng.choice(ROWS_PER_CLASS,ROWS_PER_CLASS_SUBSET, replace=False)
        for i in class_i:
            file_lines.append(f"          - \"train{project_config.folder_symbol}{class_caption}{project_config.folder_symbol}{i}.png\"")


file_lines.append("---")

for i in range(len(file_lines)):
    file_lines[i] = file_lines[i] + "\n"

print(file_lines[0])

f = open(base_path+"README.md", "w")
f.writelines(file_lines)
f.close()
    

---



In [6]:
import datasets
test_ds = datasets.load_dataset(base_path, name="subset_0", split="train")

Downloading data: 100%|██████████| 25000/25000 [00:00<00:00, 60452.37files/s] 
Generating train split: 25000 examples [00:00, 26216.73 examples/s]


In [7]:
test_ds.features['label']

ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], id=None)

In [8]:
test_ds[0]

{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,
 'label': 0}

In [9]:
#This would make way more sense, but sadly the .arrow format that .save_to_disk produces 
#   is not supported by train_text_to_image
"""
for subset in range(NUM_SUBSETS):
    subset_indices = np.zeros(SUBSET_SIZE_TOTAL, np.int32)

    for i in range(NUM_CLASSES):
        class_i = rng.choice(ROWS_PER_CLASS,ROWS_PER_CLASS_SUBSET, replace=False)
        index_start = i*ROWS_PER_CLASS_SUBSET
        index_end = (i+1)*ROWS_PER_CLASS_SUBSET
        subset_indices[index_start:index_end] = class_i + (i*ROWS_PER_CLASS)

    for i in range(NUM_CLASSES):
        assert(subset_indices[ROWS_PER_CLASS_SUBSET*i]>(ROWS_PER_CLASS*i))

    dataset_subset = dataset_config.dataset.select(subset_indices)
    for class_caption in dataset_config.class_captions:
        assert(dataset_subset[dataset_config.caption_column].count(class_caption) == ROWS_PER_CLASS_SUBSET)

    dataset_subset.save_to_disk(f"./datasets/cifar-10/subset-{subset}")
"""

'\nfor subset in range(NUM_SUBSETS):\n    subset_indices = np.zeros(SUBSET_SIZE_TOTAL, np.int32)\n\n    for i in range(NUM_CLASSES):\n        class_i = rng.choice(ROWS_PER_CLASS,ROWS_PER_CLASS_SUBSET, replace=False)\n        index_start = i*ROWS_PER_CLASS_SUBSET\n        index_end = (i+1)*ROWS_PER_CLASS_SUBSET\n        subset_indices[index_start:index_end] = class_i + (i*ROWS_PER_CLASS)\n\n    for i in range(NUM_CLASSES):\n        assert(subset_indices[ROWS_PER_CLASS_SUBSET*i]>(ROWS_PER_CLASS*i))\n\n    dataset_subset = dataset_config.dataset.select(subset_indices)\n    for class_caption in dataset_config.class_captions:\n        assert(dataset_subset[dataset_config.caption_column].count(class_caption) == ROWS_PER_CLASS_SUBSET)\n\n    dataset_subset.save_to_disk(f"./datasets/cifar-10/subset-{subset}")\n'