In [1]:
import torchvision
import os
from ffcv.writer import DatasetWriter
from ffcv.fields import RGBImageField, IntField
import torch
import src.pytorch_datasets as pytorch_datasets
from src import ffcv_utils
import yaml
from src.config_parsing import ffcv_read_check_override_config
import pprint
from src.ffcv_utils import get_training_loaders
from src.pytorch_datasets import IndexedDataset

In [2]:
# BETON_ROOT = "/mnt/cfs/home/saachij/betons"
# BETON_ROOT = "/home/gridsan/groups/CorrErrs/betons"
BETON_ROOT = "/mnt/cfs/projects/correlated_errors/betons"

In [3]:
def write_betons(ds_name, train_ds, test_ds, val_ds=None, max_resolution=None, add_spurious=False):
    os.makedirs(os.path.join(BETON_ROOT, ds_name), exist_ok=True)
    ds_pairs = [
        ['train', train_ds],
        ['test', test_ds]
    ]
    if val_ds is not None:
        ds_pairs.append(['val', val_ds])
    
    for split_name, ds in ds_pairs:
        ds = IndexedDataset(ds)
        write_path = os.path.join(BETON_ROOT, ds_name, f"{ds_name}_{split_name}.beton")
        # Pass a type for each data field
        img_field = RGBImageField() if max_resolution is None else RGBImageField(max_resolution=max_resolution)
        if add_spurious:
            writer = DatasetWriter(write_path, {
                # Tune options to optimize dataset size, throughput at train-time
                'image': img_field,
                'label': IntField(),
                'spurious': IntField(),
                'index': IntField(),
            })
        else:
            writer = DatasetWriter(write_path, {
                # Tune options to optimize dataset size, throughput at train-time
                'image': img_field,
                'label': IntField(),
                'index': IntField(),
                })

        # Write dataset
        writer.from_indexed_dataset(ds)
        
def write_celeba_betons(ds_name, train_ds, test_ds, val_ds=None):
    os.makedirs(os.path.join(BETON_ROOT, ds_name), exist_ok=True)
    ds_pairs = [
        ['train', train_ds],
        ['test', test_ds]
    ]
    if val_ds is not None:
        ds_pairs.append(['val', val_ds])
    
    for split_name, ds in ds_pairs:
        ds = IndexedDataset(ds)
        write_path = os.path.join(BETON_ROOT, ds_name, f"{ds_name}_{split_name}.beton")
        # Pass a type for each data field
        writer = DatasetWriter(write_path, {
            # Tune options to optimize dataset size, throughput at train-time
            'image': RGBImageField(max_resolution=75),
            'label': IntField(),
            'spurious': IntField(),
            'index': IntField(),
        })

        # Write dataset
        writer.from_indexed_dataset(ds)
        
def test_dataset(config, pipeline_subset=['image', 'label', 'index']):
    with open(f"dataset_configs/{config}", 'r') as file:
        hparams = yaml.safe_load(file)
    hparams = ffcv_read_check_override_config(hparams)
    print("=========== Current Config ==================")
    pprint.pprint(hparams, indent=4)
    train_loader, val_loader, test_loader = get_training_loaders(hparams, pipeline_subset=pipeline_subset)
    return train_loader, val_loader, test_loader

## CIFAR10

In [4]:
orig_ds_path = "/mnt/nfs/home/saachij/datasets/cifar"
train_ds = torchvision.datasets.CIFAR10(orig_ds_path, train=True)
test_ds = torchvision.datasets.CIFAR10(orig_ds_path, train=False)


In [5]:
write_betons('cifar', train_ds, test_ds, val_ds=None)

100%|██████████| 50000/50000 [00:00<00:00, 99706.32it/s] 
100%|██████████| 10000/10000 [00:00<00:00, 99715.05it/s]


## CIFAR 100

In [6]:
orig_ds_path = "/mnt/nfs/home/saachij/datasets/cifar100"
train_ds = torchvision.datasets.CIFAR100(orig_ds_path, train=True)
test_ds = torchvision.datasets.CIFAR100(orig_ds_path, train=False)

In [7]:
write_betons('cifar100', train_ds, test_ds, val_ds=None)

100%|██████████| 50000/50000 [00:00<00:00, 82954.73it/s] 
100%|██████████| 10000/10000 [00:00<00:00, 33166.62it/s]


## Super CIFAR100

In [8]:
orig_ds_path = "/mnt/nfs/home/saachij/datasets/cifar100"
train_ds = pytorch_datasets.SuperCIFAR100(root=orig_ds_path, train=True)
test_ds = pytorch_datasets.SuperCIFAR100(root=orig_ds_path, train=False)


In [9]:
write_betons('supercifar100', train_ds, test_ds, val_ds=None, add_spurious=True)


100%|██████████| 50000/50000 [00:00<00:00, 99490.16it/s] 
100%|██████████| 10000/10000 [00:00<00:00, 33196.26it/s]


## CelebA

In [8]:
train_ds = pytorch_datasets.SpuriousAttributeCelebA(root="/mnt/nfs/datasets/celeba", split='train') 
val_ds = pytorch_datasets.SpuriousAttributeCelebA(root="/mnt/nfs/datasets/celeba", split='valid') 
test_ds = pytorch_datasets.SpuriousAttributeCelebA(root="/mnt/nfs/datasets/celeba", split='test') 
write_celeba_betons('celeba', train_ds, test_ds, val_ds)

100%|██████████| 162770/162770 [00:28<00:00, 5653.40it/s]
100%|██████████| 19962/19962 [00:05<00:00, 3895.33it/s]
100%|██████████| 19867/19867 [00:04<00:00, 4037.00it/s]


In [3]:
train_ds = pytorch_datasets.SpuriousAttributeCelebAAge(root="/mnt/nfs/datasets/celeba", split='train') 
val_ds = pytorch_datasets.SpuriousAttributeCelebAAge(root="/mnt/nfs/datasets/celeba", split='valid') 
test_ds = pytorch_datasets.SpuriousAttributeCelebAAge(root="/mnt/nfs/datasets/celeba", split='test') 
write_celeba_betons('celeba_age', train_ds, test_ds, val_ds)