In [5]:
import src.ds_utils as ds_utils
import os
import torch
import src.pytorch_datasets as pytorch_datasets
import numpy as np

# CIFAR

In [2]:
ds_name = 'cifar'
config = f"dataset_configs/{ds_name}.yaml"
hparams, train_labels = ds_utils.get_all_beton_labels(config, 'train', "/mnt/cfs/projects/correlated_errors/betons")

unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5, unlabeled_split_amt=2)
path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')
torch.save(unlabeled_split, path)

only_val_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5)
path = os.path.join('index_files', f'{ds_name}.pt')
torch.save(only_val_split, path)


Using default os_cache: False
Using default quasi_random: True
Using default val_aug: None
Using default loss_vec_file: None
Using default indices_file: None
Using default val_beton: None
Using default unlabeled_beton: None
Using default loss_upweight: 5
Using default bce: False


100%|██████████| 98/98 [00:07<00:00, 13.63it/s]


In [18]:
# CIFAR 25%
# VAL: 1/5, UNLABELED: 3/5, TRAIN: 1/5
config = f"dataset_configs/cifar.yaml"

ds_name = 'cifar_0.25'
hparams, train_labels = ds_utils.get_all_beton_labels(config, 'train', "/mnt/cfs/projects/correlated_errors/betons")

unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5, unlabeled_split_amt=4)
unlabeled_split = {
    'val_indices': unlabeled_split['val_indices'],
    'train_indices': unlabeled_split['unlabeled_indices'],
    'unlabeled_indices': unlabeled_split['train_indices'],
}
path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')
torch.save(unlabeled_split, path)



Using default os_cache: False
Using default quasi_random: False
Using default val_aug: None
Using default indices_file: None
Using default val_beton: None
Using default unlabeled_beton: None
Using default loss_upweight: 5
Using default bce: False


100%|██████████| 98/98 [00:00<00:00, 286.81it/s]


# CIFAR-100

In [3]:
ds_name = 'cifar100'
config = f"dataset_configs/{ds_name}.yaml"
hparams, train_labels = ds_utils.get_all_beton_labels(config, 'train', "/mnt/cfs/projects/correlated_errors/betons")

unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5, unlabeled_split_amt=2)
path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')
torch.save(unlabeled_split, path)

only_val_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5)
path = os.path.join('index_files', f'{ds_name}.pt')
torch.save(only_val_split, path)


Using default os_cache: False
Using default quasi_random: True
Using default val_aug: None
Using default loss_vec_file: None
Using default indices_file: None
Using default val_beton: None
Using default unlabeled_beton: None
Using default loss_upweight: 5
Using default bce: False


100%|██████████| 98/98 [00:01<00:00, 95.10it/s]


## ImageNet

In [13]:
ds_name = 'imagenet'
config = f"dataset_configs/{ds_name}.yaml"
hparams, train_labels = ds_utils.get_all_beton_labels(config, 'train', "/mnt/cfs/projects/correlated_errors/betons")

unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5, unlabeled_split_amt=2)
path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')
torch.save(unlabeled_split, path)

only_val_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5)
path = os.path.join('index_files', f'{ds_name}.pt')
torch.save(only_val_split, path)


Using default os_cache: False
Using default quasi_random: True
Using default val_aug: None
Using default loss_vec_file: None
Using default val_beton: None
Using default unlabeled_beton: None
Using default loss_upweight: 5
Using default bce: False


100%|██████████| 1252/1252 [12:43<00:00,  1.64it/s]


# Super CIFAR-100

In [2]:
ds_name = 'supercifar100'
config = f"dataset_configs/{ds_name}.yaml"
hparams, train_labels = ds_utils.get_all_beton_labels(config, 'train', "/mnt/cfs/projects/correlated_errors/betons")

unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5, unlabeled_split_amt=2)
path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')
torch.save(unlabeled_split, path)

only_val_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5)
path = os.path.join('index_files', f'{ds_name}.pt')
torch.save(only_val_split, path)


Using default os_cache: False
Using default quasi_random: True
Using default val_aug: None
Using default indices_file: None
Using default val_beton: None
Using default unlabeled_beton: None
Using default loss_upweight: 5
Using default bce: False


100%|██████████| 98/98 [00:09<00:00, 10.38it/s]


# Spurious CIFAR 100

In [12]:
import numpy as np
ds = pytorch_datasets.SuperCIFAR100(root="/mnt/nfs/home/saachij/datasets/cifar100", train=True)
config = f"dataset_configs/supercifar100.yaml"
hparams, train_labels = ds_utils.get_all_beton_labels(config, 'train', "/mnt/cfs/projects/correlated_errors/betons")

subclass_targets = np.array(ds.subclass_targets)
targets = np.array(ds.targets)
classes_to_drop = []
# for c in range(20):
#     classes_to_drop.append(np.unique(np.array(subclass_targets[targets == c]))[0])
classes_to_drop = [4, 73, 54, 10, 51, 40, 84, 18, 3, 12, 33, 38, 64, 45, 2, 44, 80, 96, 13, 81]

def split_spurious_cifar(orig_indices):
    new_train_indices = []
    for c in range(100):
        mask = subclass_targets[orig_indices] == c
        if c in classes_to_drop:
            new_train_indices.append(orig_indices[mask][::4])
        else:
            new_train_indices.append(orig_indices[mask])
    new_train_indices = torch.cat(new_train_indices)
    return new_train_indices

unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5, unlabeled_split_amt=2)
unlabeled_split['train_indices'] = split_spurious_cifar(unlabeled_split['train_indices'])
unlabeled_split['unlabeled_indices'] = unlabeled_split['unlabeled_indices']
unlabeled_split['classes_to_drop'] = classes_to_drop

only_val_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5)
subsampled = split_spurious_cifar(only_val_split['train_indices'])
only_val_split['train_indices'] = subsampled
only_val_split['classes_to_drop'] = classes_to_drop

ds_name = 'spurious_cifar100'
path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')
torch.save(unlabeled_split, path)

path = os.path.join('index_files', f'{ds_name}.pt')
torch.save(only_val_split, path)




Using default os_cache: False
Using default quasi_random: True
Using default val_aug: None
Using default indices_file: None
Using default val_beton: None
Using default unlabeled_beton: None
Using default loss_upweight: 5
Using default bce: False


100%|██████████| 98/98 [00:00<00:00, 287.49it/s]


In [9]:
np.arange(10)[::0]

ValueError: slice step cannot be zero

# CelebA

In [36]:
config = "dataset_configs/celeba.yaml"

hparams, train_labels, train_spuriouses = ds_utils.get_all_beton_labels(config, 'train', "/mnt/cfs/projects/correlated_errors/betons", include_spurious=True)
hparams, val_labels, val_spuriouses = ds_utils.get_all_beton_labels(config, 'val', "/mnt/cfs/projects/correlated_errors/betons", include_spurious=True)

# spuriouses is 1 if blond, 2 if black hair, 0 if neither
# train_labels is 0 if female, 1 if male

Using default os_cache: False
Using default quasi_random: True
Using default val_aug: None
Using default indices_file: None
Using default unlabeled_beton: None
Using default loss_upweight: 5
Using default bce: False


100%|██████████| 318/318 [00:22<00:00, 14.35it/s]


Using default os_cache: False
Using default quasi_random: True
Using default val_aug: None
Using default indices_file: None
Using default unlabeled_beton: None
Using default loss_upweight: 5
Using default bce: False


100%|██████████| 39/39 [00:03<00:00,  9.84it/s]


In [37]:

def get_celeba_split(labels, spuriouses, majority_multiplier=7):
    all_indices = np.arange(len(labels))

    female = labels == 0
    male = labels == 1
    blond = spuriouses == 1
    black = spuriouses == 2

    female_blond = all_indices[female & blond]
    female_black = all_indices[female & black]
    male_blond = all_indices[male & blond]
    male_black = all_indices[male & black]
    print(len(female_blond), len(female_black), len(male_black), len(male_blond))

    smallest_minority = len(male_blond)
    smallest_majority = len(male_black)
    minority = min(smallest_minority, int(smallest_majority/majority_multiplier))
    
    majority = minority*majority_multiplier
    train_indices = [female_blond[:majority], female_black[:minority], male_black[:majority], male_blond[:minority]]
    print([len(u) for u in train_indices])
    return np.concatenate(train_indices)

In [38]:
# val
mult = 5

ds_name = f"celeba_1_{mult}"

val_indices = get_celeba_split(val_labels, val_spuriouses, majority_multiplier=1)

# with unlabeled
unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], unlabeled_split_amt=2)
orig_train_indices = unlabeled_split['train_indices']
train_indices = orig_train_indices[get_celeba_split(
    train_labels[orig_train_indices], 
    train_spuriouses[orig_train_indices],
    majority_multiplier=mult,
)]

orig_unlabeled_indices = unlabeled_split['unlabeled_indices']
unlabeled_indices = orig_unlabeled_indices[get_celeba_split(
    train_labels[orig_unlabeled_indices], 
    train_spuriouses[orig_unlabeled_indices],
    majority_multiplier=mult,
)]
path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')
torch.save({
    'val_indices': val_indices,
    'train_indices': train_indices,
    'unlabeled_indices': unlabeled_indices
}, path)

# without unlabeled
path = os.path.join('index_files', f'{ds_name}.pt')
torch.save({
    'val_indices': val_indices,
    'train_indices': get_celeba_split(train_labels, train_spuriouses, majority_multiplier=mult),
}, path)

592 680 3463 2464
[592, 680, 2464, 2464]
2035 2538 17003 10109
[2035, 2538, 17000, 3400]
2002 2554 16809 10121
[2002, 2554, 16805, 3361]
4037 5092 33812 20230
[4037, 5092, 33810, 6762]


# CelebA Age

In [2]:
config = "dataset_configs/celeba.yaml"

hparams, train_labels, train_spuriouses = ds_utils.get_all_beton_labels(config, 'train', "/mnt/cfs/projects/correlated_errors/betons", include_spurious=True)
hparams, val_labels, val_spuriouses = ds_utils.get_all_beton_labels(config, 'val', "/mnt/cfs/projects/correlated_errors/betons", include_spurious=True)


Using default os_cache: False
Using default quasi_random: True
Using default val_aug: None
Using default indices_file: None
Using default unlabeled_beton: None
Using default loss_upweight: 5
Using default bce: False


100%|██████████| 318/318 [00:25<00:00, 12.71it/s]


Using default os_cache: False
Using default quasi_random: True
Using default val_aug: None
Using default indices_file: None
Using default unlabeled_beton: None
Using default loss_upweight: 5
Using default bce: False


100%|██████████| 39/39 [00:03<00:00, 10.48it/s]


In [12]:
# spurious: 
# old is male
# young is female

def get_celeba_split(labels, spuriouses, majority_multiplier=7):
    all_indices = np.arange(len(labels))

    old = labels == 0
    young = labels == 1
    female = spuriouses == 0
    male = spuriouses == 1

    old_male = all_indices[old & male]
    old_female = all_indices[old & female]
    young_male = all_indices[young & male]
    young_female = all_indices[young & female]

    print("OLD MALE", len(old_male), "OLD FEMALE", len(old_female), "YOUNG MALE", len(young_male), "YOUNG FEMALE", len(young_female))

    smallest_minority = len(old_female)
    smallest_majority = len(old_male)
    minority = min(smallest_minority, int(smallest_majority/majority_multiplier))
    
    majority = minority*majority_multiplier
    train_indices = [old_female[:minority], old_male[:majority], young_female[:majority], young_male[:minority]]
    print([len(u) for u in train_indices])
    return np.concatenate(train_indices)

In [13]:
# val
mult = 4

ds_name = f"celeba_age_1_{mult}"

val_indices = get_celeba_split(val_labels, val_spuriouses, majority_multiplier=1)

# with unlabeled
unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], unlabeled_split_amt=2)
orig_train_indices = unlabeled_split['train_indices']
train_indices = orig_train_indices[get_celeba_split(
    train_labels[orig_train_indices], 
    train_spuriouses[orig_train_indices],
    majority_multiplier=mult,
)]

orig_unlabeled_indices = unlabeled_split['unlabeled_indices']
unlabeled_indices = orig_unlabeled_indices[get_celeba_split(
    train_labels[orig_unlabeled_indices], 
    train_spuriouses[orig_unlabeled_indices],
    majority_multiplier=mult,
)]
path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')
torch.save({
    'val_indices': val_indices,
    'train_indices': train_indices,
    'unlabeled_indices': unlabeled_indices
}, path)

# without unlabeled
path = os.path.join('index_files', f'{ds_name}.pt')
torch.save({
    'val_indices': val_indices,
    'train_indices': get_celeba_split(train_labels, train_spuriouses, majority_multiplier=mult),
}, path)

OLD MALE 3240 OLD FEMALE 1795 YOUNG MALE 5218 YOUNG FEMALE 9614
[1795, 1795, 1795, 1795]
OLD MALE 12447 OLD FEMALE 5544 YOUNG MALE 21756 YOUNG FEMALE 41638
[3111, 12444, 12444, 3111]
OLD MALE 12368 OLD FEMALE 5623 YOUNG MALE 21690 YOUNG FEMALE 41704
[3092, 12368, 12368, 3092]
OLD MALE 24815 OLD FEMALE 11167 YOUNG MALE 43446 YOUNG FEMALE 83342
[6203, 24812, 24812, 6203]
