**Task:**

Deep neural networks often exploit non-predictive features that are spuriously correlated with class labels, leading to poor performance on groups of examples without such features. Using the SpuCo Package (SpuCo Documentation), we'd like you to implement a simple method to remedy spurious correlations in SpuCoMNIST (use default parameters to initialize the dataset).

The method (George) we'd like you to implement has a 3 step pipeline:
1. Train a model using ERM
2. Cluster inputs based on the output they produce for ERM
3. Retrain using "Group-Balancing" to ensure in each batch each group appears equally.



In [12]:
!pip install spuco



In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader


from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
from spuco.robust_train import ERM, GroupBalanceBatchERM, ClassBalanceBatchERM
from spuco.models import model_factory
from spuco.utils import Trainer
from spuco.evaluate import Evaluator
from spuco.group_inference import Cluster, ClusterAlg

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]


# I chose these params for SpuriousFeatureDifficulty and SpuriousCorrelationStrength
# (Default SpuriousCorrelationStrength is 0.0 per the documentation but this is invalid for a train split)
# There wasn't a default SpuriousFeatureDifficulty in the documentation so I chose MAGNITUDE_LARGE



difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE
spurious_correlation_strength = 0.90



train_dataset = SpuCoMNIST(
    root="data/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=spurious_correlation_strength,
    classes=classes,
    split="train",
    download=True
)


test_dataset = SpuCoMNIST(
    root="data/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test",
    download=True
)
train_dataset.initialize()
test_dataset.initialize()

**1)** Training a LeNet using ERM

In [18]:
# FOR MY MODEL I CHOSE A LeNet

model = model_factory("lenet", train_dataset[0][0].shape, train_dataset.num_classes).to(device)

optimizer = optim.SGD(model.parameters(), lr=1e-3, weight_decay=6e-4, momentum = 0.90, nesterov=True)


# For the model, I used the Trainer class instead of the ERM class
# because Trainer's get_trainset_outputs() method is very convenient to obtain
# the model's outputs on the training dataset, which is necessary for the next step (clustering)

# I did not see an equivalent method for the ERM class on its SpuCo documentation page
# I also did not find a way to access ERM's model member, which would have been helpful to create a helper
# function that has the same functionality as get_trainset_outputs()



erm = Trainer(
    trainset=train_dataset,
    model=model,
    batch_size=64,
    optimizer=optimizer,
    device=device,
    verbose=False
)

erm.train(num_epochs=2)



In [20]:
evaluator = Evaluator(
    testset=test_dataset,
    group_partition=test_dataset.group_partition,
    group_weights=train_dataset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)

evaluator.evaluate()

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:14,  1.64it/s]

Group (0, 0) Accuracy: 100.0


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:01<00:13,  1.68it/s]

Group (0, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:13,  1.66it/s]

Group (0, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:02<00:12,  1.66it/s]

Group (0, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:03<00:12,  1.65it/s]

Group (0, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:03<00:11,  1.65it/s]

Group (1, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:04<00:10,  1.66it/s]

Group (1, 1) Accuracy: 100.0


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:04<00:10,  1.65it/s]

Group (1, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:05<00:09,  1.66it/s]

Group (1, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:06<00:11,  1.33it/s]

Group (1, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:07<00:10,  1.34it/s]

Group (2, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:07<00:09,  1.44it/s]

Group (2, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:08<00:08,  1.50it/s]

Group (2, 2) Accuracy: 100.0


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:09<00:07,  1.53it/s]

Group (2, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:09<00:06,  1.47it/s]

Group (2, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:10<00:07,  1.26it/s]

Group (3, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:11<00:06,  1.18it/s]

Group (3, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:12<00:05,  1.24it/s]

Group (3, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:13<00:04,  1.34it/s]

Group (3, 3) Accuracy: 100.0


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:13<00:03,  1.42it/s]

Group (3, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:14<00:02,  1.49it/s]

Group (4, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:14<00:01,  1.55it/s]

Group (4, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:15<00:01,  1.58it/s]

Group (4, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:16<00:00,  1.59it/s]

Group (4, 3) Accuracy: 0.0


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:16<00:00,  1.49it/s]

Group (4, 4) Accuracy: 100.0





{(0, 0): 100.0,
 (0, 1): 0.0,
 (0, 2): 0.0,
 (0, 3): 0.0,
 (0, 4): 0.0,
 (1, 0): 0.0,
 (1, 1): 100.0,
 (1, 2): 0.0,
 (1, 3): 0.0,
 (1, 4): 0.0,
 (2, 0): 0.0,
 (2, 1): 0.0,
 (2, 2): 100.0,
 (2, 3): 0.0,
 (2, 4): 0.0,
 (3, 0): 0.0,
 (3, 1): 0.0,
 (3, 2): 0.0,
 (3, 3): 100.0,
 (3, 4): 0.0,
 (4, 0): 0.0,
 (4, 1): 0.0,
 (4, 2): 0.0,
 (4, 3): 0.0,
 (4, 4): 100.0}

**2)** Cluster inputs based on ERM outputs

In [22]:
cluster = Cluster(
    Z = erm.get_trainset_outputs(),
    class_labels=train_dataset.labels,
    cluster_alg=ClusterAlg.KMEANS,
    num_clusters=2,
    device=device,
    verbose=True


)




group_partition = cluster.infer_groups()


keys = sorted(group_partition.keys())

for key in keys:
    print(key, len(group_partition[key]))

Clustering class-wise: 100%|██████████| 5/5 [00:00<00:00, 13.43it/s]

(0, 0) 1014
(0, 1) 9119
(1, 0) 8704
(1, 1) 968
(2, 0) 8109
(2, 1) 902
(3, 0) 8772
(3, 1) 975
(4, 0) 8496
(4, 1) 945





In [24]:
evaluator = Evaluator(
    testset=train_dataset,
    group_partition=group_partition,
    group_weights=train_dataset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:  10%|█         | 1/10 [00:00<00:06,  1.32it/s]

Group (0, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  20%|██        | 2/10 [00:03<00:16,  2.05s/it]

Group (0, 1) Accuracy: 100.0


Evaluating group-wise accuracy:  30%|███       | 3/10 [00:06<00:15,  2.19s/it]

Group (1, 0) Accuracy: 100.0


Evaluating group-wise accuracy:  40%|████      | 4/10 [00:06<00:09,  1.61s/it]

Group (1, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  50%|█████     | 5/10 [00:08<00:08,  1.77s/it]

Group (2, 0) Accuracy: 100.0


Evaluating group-wise accuracy:  60%|██████    | 6/10 [00:09<00:05,  1.43s/it]

Group (2, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  70%|███████   | 7/10 [00:11<00:05,  1.71s/it]

Group (3, 0) Accuracy: 100.0


Evaluating group-wise accuracy:  80%|████████  | 8/10 [00:12<00:02,  1.40s/it]

Group (3, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  90%|█████████ | 9/10 [00:15<00:01,  1.72s/it]

Group (4, 0) Accuracy: 100.0


Evaluating group-wise accuracy: 100%|██████████| 10/10 [00:16<00:00,  1.62s/it]

Group (4, 1) Accuracy: 0.0





{(0, 0): 0.0,
 (0, 1): 100.0,
 (1, 0): 100.0,
 (1, 1): 0.0,
 (2, 0): 100.0,
 (2, 1): 0.0,
 (3, 0): 100.0,
 (3, 1): 0.0,
 (4, 0): 100.0,
 (4, 1): 0.0}

**3)** Retrain using GroupBalancing

In [26]:
group_balance_erm = GroupBalanceBatchERM(
    model=model,
    num_epochs=5,
    trainset=train_dataset,
    group_partition=group_partition,
    batch_size=64,
    optimizer=optimizer,
    device=device,
    verbose=False
)
group_balance_erm.train()

In [27]:
evaluator = Evaluator(
    testset=test_dataset,
    group_partition=test_dataset.group_partition,
    group_weights=train_dataset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:15,  1.58it/s]

Group (0, 0) Accuracy: 99.52718676122932


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:01<00:14,  1.57it/s]

Group (0, 1) Accuracy: 94.08983451536643


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:13,  1.57it/s]

Group (0, 2) Accuracy: 97.39952718676123


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:02<00:13,  1.55it/s]

Group (0, 3) Accuracy: 95.74468085106383


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:03<00:12,  1.55it/s]

Group (0, 4) Accuracy: 95.0354609929078


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:03<00:12,  1.58it/s]

Group (1, 0) Accuracy: 95.11002444987776


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:04<00:12,  1.44it/s]

Group (1, 1) Accuracy: 98.77750611246944


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:05<00:13,  1.25it/s]

Group (1, 2) Accuracy: 94.36274509803921


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:06<00:13,  1.16it/s]

Group (1, 3) Accuracy: 96.81372549019608


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:07<00:11,  1.25it/s]

Group (1, 4) Accuracy: 94.11764705882354


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:07<00:10,  1.34it/s]

Group (2, 0) Accuracy: 96.53333333333333


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:08<00:09,  1.41it/s]

Group (2, 1) Accuracy: 92.0


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:09<00:08,  1.46it/s]

Group (2, 2) Accuracy: 97.6


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:09<00:07,  1.50it/s]

Group (2, 3) Accuracy: 88.8


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:10<00:06,  1.48it/s]

Group (2, 4) Accuracy: 89.83957219251337


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:11<00:05,  1.52it/s]

Group (3, 0) Accuracy: 94.9748743718593


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:11<00:05,  1.54it/s]

Group (3, 1) Accuracy: 92.9471032745592


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:12<00:04,  1.55it/s]

Group (3, 2) Accuracy: 94.7103274559194


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:13<00:03,  1.55it/s]

Group (3, 3) Accuracy: 98.2367758186398


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:13<00:03,  1.55it/s]

Group (3, 4) Accuracy: 95.71788413098237


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:14<00:02,  1.54it/s]

Group (4, 0) Accuracy: 95.21410579345088


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:14<00:01,  1.55it/s]

Group (4, 1) Accuracy: 91.43576826196474


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:15<00:01,  1.55it/s]

Group (4, 2) Accuracy: 93.19899244332494


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:16<00:00,  1.57it/s]

Group (4, 3) Accuracy: 95.95959595959596


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:16<00:00,  1.48it/s]

Group (4, 4) Accuracy: 97.47474747474747





{(0, 0): 99.52718676122932,
 (0, 1): 94.08983451536643,
 (0, 2): 97.39952718676123,
 (0, 3): 95.74468085106383,
 (0, 4): 95.0354609929078,
 (1, 0): 95.11002444987776,
 (1, 1): 98.77750611246944,
 (1, 2): 94.36274509803921,
 (1, 3): 96.81372549019608,
 (1, 4): 94.11764705882354,
 (2, 0): 96.53333333333333,
 (2, 1): 92.0,
 (2, 2): 97.6,
 (2, 3): 88.8,
 (2, 4): 89.83957219251337,
 (3, 0): 94.9748743718593,
 (3, 1): 92.9471032745592,
 (3, 2): 94.7103274559194,
 (3, 3): 98.2367758186398,
 (3, 4): 95.71788413098237,
 (4, 0): 95.21410579345088,
 (4, 1): 91.43576826196474,
 (4, 2): 93.19899244332494,
 (4, 3): 95.95959595959596,
 (4, 4): 97.47474747474747}