In [None]:
!pip install spuco

In [None]:
import torch
import torchvision.transforms as T
from spuco.datasets import SpuCoMNIST, GroupLabeledDatasetWrapper, SpuriousFeatureDifficulty
from spuco.robust_train import ERM, GroupBalanceBatchERM
from spuco.models import model_factory
from spuco.utils import Trainer
from torch.optim import Adam
from spuco.group_inference import Cluster, ClusterAlg

#### Import & Augment Data

In [None]:
DIFFICULTY = SpuriousFeatureDifficulty.MAGNITUDE_LARGE
SPUCO_STRENGTH = 0.995
CLASSES = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]

train_dataset = SpuCoMNIST(
    root = "./data/mnist",
    spurious_feature_difficulty = DIFFICULTY,
    spurious_correlation_strength = SPUCO_STRENGTH,
    classes = CLASSES,
    split = "train"
)
train_dataset.initialize()

test_dataset = SpuCoMNIST(
    root = "./data/mnist",
    spurious_feature_difficulty = DIFFICULTY,
    spurious_correlation_strength = SPUCO_STRENGTH,
    classes = CLASSES,
    split = "test"
)
test_dataset.initialize()

100%|██████████| 48004/48004 [00:07<00:00, 6494.87it/s]
100%|██████████| 10000/10000 [00:02<00:00, 4963.85it/s]


#### Create & Train ERM Model

In [None]:
erm_model = model_factory(
    arch = "lenet",
    input_shape = train_dataset[0][0].shape,
    num_classes = len(CLASSES)
)

erm_trainer = ERM(
    model = erm_model,
    trainset = train_dataset,
    optimizer = torch.optim.Adam(params = erm_model.parameters()),
    batch_size = 64,
    num_epochs = 1,
    verbose = True
)



In [None]:
erm_trainer.train()

Epoch 0: 100%|██████████| 751/751 [00:24<00:00, 30.69batch/s, accuracy=100.0%, loss=0.00277]


In [None]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset = test_dataset,
    group_partition = test_dataset.group_partition,
    group_weights = train_dataset.group_weights,
    batch_size=64,
    model=erm_model,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:01<00:33,  1.41s/it]

Group (0, 0) Accuracy: 100.0


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:03<00:40,  1.74s/it]

Group (0, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:05<00:40,  1.84s/it]

Group (0, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:06<00:33,  1.61s/it]

Group (0, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:08<00:30,  1.54s/it]

Group (0, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:09<00:26,  1.42s/it]

Group (1, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:10<00:23,  1.28s/it]

Group (1, 1) Accuracy: 100.0


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:10<00:19,  1.12s/it]

Group (1, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:11<00:16,  1.01s/it]

Group (1, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:12<00:13,  1.07it/s]

Group (1, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:13<00:12,  1.13it/s]

Group (2, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:14<00:11,  1.17it/s]

Group (2, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:14<00:09,  1.20it/s]

Group (2, 2) Accuracy: 100.0


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:15<00:08,  1.23it/s]

Group (2, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:16<00:08,  1.17it/s]

Group (2, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:17<00:08,  1.03it/s]

Group (3, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:19<00:08,  1.06s/it]

Group (3, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:20<00:07,  1.03s/it]

Group (3, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:20<00:05,  1.04it/s]

Group (3, 3) Accuracy: 100.0


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:21<00:04,  1.10it/s]

Group (3, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:22<00:03,  1.16it/s]

Group (4, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:23<00:02,  1.20it/s]

Group (4, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:23<00:01,  1.23it/s]

Group (4, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:24<00:00,  1.25it/s]

Group (4, 3) Accuracy: 0.0


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:25<00:00,  1.02s/it]

Group (4, 4) Accuracy: 100.0





{(0, 0): 100.0,
 (0, 1): 0.0,
 (0, 2): 0.0,
 (0, 3): 0.0,
 (0, 4): 0.0,
 (1, 0): 0.0,
 (1, 1): 100.0,
 (1, 2): 0.0,
 (1, 3): 0.0,
 (1, 4): 0.0,
 (2, 0): 0.0,
 (2, 1): 0.0,
 (2, 2): 100.0,
 (2, 3): 0.0,
 (2, 4): 0.0,
 (3, 0): 0.0,
 (3, 1): 0.0,
 (3, 2): 0.0,
 (3, 3): 100.0,
 (3, 4): 0.0,
 (4, 0): 0.0,
 (4, 1): 0.0,
 (4, 2): 0.0,
 (4, 3): 0.0,
 (4, 4): 100.0}

#### Cluster Logits

In [None]:
def get_logits(model, dataset, batch_size = 100):
    output_logits = []
    for l_idx in range(0, len(dataset), batch_size):
        r_idx = min(l_idx + batch_size, len(dataset))
        batch = torch.stack([dataset[i][0] for i in range(l_idx, r_idx)])
        output_logits += [model(batch)]
    return torch.cat(output_logits)

In [None]:
logits = get_logits(erm_model, train_dataset)

In [None]:
num_clusters = 2
cluster = Cluster(
    Z = logits,
    class_labels = train_dataset.labels,
    cluster_alg = ClusterAlg.KMEANS,
    num_clusters = num_clusters,
    verbose = True
)

group_partitions = cluster.infer_groups()

Clustering class-wise: 100%|██████████| 5/5 [00:00<00:00,  9.41it/s]


In [None]:
group_partitions.keys()

dict_keys([(2, 0), (2, 1), (0, 0), (0, 1), (4, 0), (4, 1), (1, 0), (1, 1), (3, 0), (3, 1)])

#### Group Balance Batch ERM

In [None]:
balanced_model = model_factory(
    arch = "lenet",
    input_shape = train_dataset[0][0].shape,
    num_classes = len(CLASSES)
)

balanced_trainer = GroupBalanceBatchERM(
    model = balanced_model,
    trainset = train_dataset,
    group_partition = group_partitions,
    batch_size = 64,
    optimizer = torch.optim.Adam(params = balanced_model.parameters()),
    num_epochs = 1,
    verbose = True
)



In [None]:
balanced_trainer.train()

Epoch 0: 100%|██████████| 751/751 [00:24<00:00, 30.41batch/s, accuracy=100.0%, loss=0.0665]


In [None]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset = test_dataset,
    group_partition = test_dataset.group_partition,
    group_weights = train_dataset.group_weights,
    batch_size=64,
    model=balanced_model,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:20,  1.16it/s]

Group (0, 0) Accuracy: 100.0


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:01<00:18,  1.24it/s]

Group (0, 1) Accuracy: 89.83451536643027


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:02<00:20,  1.07it/s]

Group (0, 2) Accuracy: 90.0709219858156


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:03<00:21,  1.01s/it]

Group (0, 3) Accuracy: 80.61465721040189


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:05<00:21,  1.08s/it]

Group (0, 4) Accuracy: 78.25059101654847


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:05<00:19,  1.04s/it]

Group (1, 0) Accuracy: 72.37163814180929


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:06<00:17,  1.05it/s]

Group (1, 1) Accuracy: 98.2885085574572


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:07<00:15,  1.12it/s]

Group (1, 2) Accuracy: 79.65686274509804


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:08<00:13,  1.17it/s]

Group (1, 3) Accuracy: 60.78431372549019


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:09<00:12,  1.19it/s]

Group (1, 4) Accuracy: 59.55882352941177


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:09<00:11,  1.22it/s]

Group (2, 0) Accuracy: 76.26666666666667


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:10<00:10,  1.25it/s]

Group (2, 1) Accuracy: 81.86666666666666


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:11<00:09,  1.26it/s]

Group (2, 2) Accuracy: 95.2


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:12<00:08,  1.26it/s]

Group (2, 3) Accuracy: 71.73333333333333


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:12<00:07,  1.27it/s]

Group (2, 4) Accuracy: 53.475935828877006


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:13<00:07,  1.28it/s]

Group (3, 0) Accuracy: 63.31658291457286


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:14<00:06,  1.24it/s]

Group (3, 1) Accuracy: 54.659949622166245


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:15<00:05,  1.26it/s]

Group (3, 2) Accuracy: 89.92443324937028


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:16<00:05,  1.15it/s]

Group (3, 3) Accuracy: 98.99244332493703


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:17<00:04,  1.05it/s]

Group (3, 4) Accuracy: 81.10831234256926


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:18<00:04,  1.04s/it]

Group (4, 0) Accuracy: 92.9471032745592


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:19<00:03,  1.01s/it]

Group (4, 1) Accuracy: 72.79596977329975


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:20<00:01,  1.06it/s]

Group (4, 2) Accuracy: 46.85138539042821


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:21<00:00,  1.13it/s]

Group (4, 3) Accuracy: 58.333333333333336


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:22<00:00,  1.13it/s]

Group (4, 4) Accuracy: 98.23232323232324





{(0, 0): 100.0,
 (0, 1): 89.83451536643027,
 (0, 2): 90.0709219858156,
 (0, 3): 80.61465721040189,
 (0, 4): 78.25059101654847,
 (1, 0): 72.37163814180929,
 (1, 1): 98.2885085574572,
 (1, 2): 79.65686274509804,
 (1, 3): 60.78431372549019,
 (1, 4): 59.55882352941177,
 (2, 0): 76.26666666666667,
 (2, 1): 81.86666666666666,
 (2, 2): 95.2,
 (2, 3): 71.73333333333333,
 (2, 4): 53.475935828877006,
 (3, 0): 63.31658291457286,
 (3, 1): 54.659949622166245,
 (3, 2): 89.92443324937028,
 (3, 3): 98.99244332493703,
 (3, 4): 81.10831234256926,
 (4, 0): 92.9471032745592,
 (4, 1): 72.79596977329975,
 (4, 2): 46.85138539042821,
 (4, 3): 58.333333333333336,
 (4, 4): 98.23232323232324}