In [2]:
%pip install spuco

Collecting spuco
  Obtaining dependency information for spuco from https://files.pythonhosted.org/packages/7e/8b/7a23886da9af9dc54ba3715a885ea2d0bdd4d5c4ade951ff6cd539c13683/spuco-1.0.3-py3-none-any.whl.metadata
  Downloading spuco-1.0.3-py3-none-any.whl.metadata (4.0 kB)
Collecting matplotlib>=3.7.1 (from spuco)
  Obtaining dependency information for matplotlib>=3.7.1 from https://files.pythonhosted.org/packages/af/f3/fb27b3b902fc759bbca3f9d0336c48069c3022e57552c4b0095d997c7ea8/matplotlib-3.8.0-cp311-cp311-macosx_11_0_arm64.whl.metadata
  Using cached matplotlib-3.8.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (5.8 kB)
Collecting numpy>=1.23.5 (from spuco)
  Obtaining dependency information for numpy>=1.23.5 from https://files.pythonhosted.org/packages/35/21/9e150d654da358beb29fe216f339dc17f2b2ac13fff2a89669401a910550/numpy-1.26.0-cp311-cp311-macosx_11_0_arm64.whl.metadata
  Using cached numpy-1.26.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (99 kB)
Collecting torchvision>=0.15.1 (fr

In [1]:
import torch

# Set backend to Metal if available
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

In [2]:
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
import torchvision.transforms as T


# Define classes and difficulty level for the SpuCoMNIST dataset
classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE

# Create and initialize the training and test datasets
trainset = SpuCoMNIST(
    root="~/.pytorch/MNIST_data/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.995,
    classes=classes,
    split="train"
)
trainset.initialize()

testset = SpuCoMNIST(
    root="~/.pytorch/MNIST_data/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test"
)
testset.initialize()

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████████████████████████████| 48004/48004 [00:01<00:00, 41196.55it/s]
100%|██████████████████████████████████| 10000/10000 [00:00<00:00, 40112.16it/s]


In [3]:
from spuco.models import model_factory 
from spuco.utils import Trainer
from torch.optim import SGD

# Create a trainer to train the model using ERM (Empirical Risk Minimization)
model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes).to(device)
trainer = Trainer(
    trainset=trainset,
    model=model,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-3, weight_decay=5e-4, momentum=0.9, nesterov=True),
    device=device,
    verbose=True,
)

# Train the model for one epoch
trainer.train(1)

Epoch 0: 100%|█| 751/751 [00:06<00:00, 112.34batch/s, accuracy=100.0%, loss=0.06


In [4]:
from spuco.evaluate import Evaluator 

# Create an evaluator for the test dataset
evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)

# Evaluate the model's accuracy on different groups within the dataset
evaluator.evaluate()

Evaluating group-wise accuracy:  28%|███▎        | 7/25 [00:00<00:00, 33.59it/s]

Group (0, 0) Accuracy: 93.3806146572104
Group (0, 1) Accuracy: 0.0
Group (0, 2) Accuracy: 0.0
Group (0, 3) Accuracy: 0.0
Group (0, 4) Accuracy: 0.0
Group (1, 0) Accuracy: 17.359413202933986
Group (1, 1) Accuracy: 100.0
Group (1, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  44%|████▊      | 11/25 [00:00<00:00, 32.59it/s]

Group (1, 3) Accuracy: 0.0
Group (1, 4) Accuracy: 0.0
Group (2, 0) Accuracy: 0.0
Group (2, 1) Accuracy: 0.0
Group (2, 2) Accuracy: 100.0
Group (2, 3) Accuracy: 0.0
Group (2, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  80%|████████▊  | 20/25 [00:00<00:00, 35.61it/s]

Group (3, 0) Accuracy: 0.0
Group (3, 1) Accuracy: 0.0
Group (3, 2) Accuracy: 0.0
Group (3, 3) Accuracy: 98.2367758186398
Group (3, 4) Accuracy: 8.816120906801007
Group (4, 0) Accuracy: 0.0
Group (4, 1) Accuracy: 0.0
Group (4, 2) Accuracy: 0.0


Evaluating group-wise accuracy: 100%|███████████| 25/25 [00:00<00:00, 35.44it/s]

Group (4, 3) Accuracy: 3.0303030303030303
Group (4, 4) Accuracy: 97.22222222222223





{(0, 0): 93.3806146572104,
 (0, 1): 0.0,
 (0, 2): 0.0,
 (0, 3): 0.0,
 (0, 4): 0.0,
 (1, 0): 17.359413202933986,
 (1, 1): 100.0,
 (1, 2): 0.0,
 (1, 3): 0.0,
 (1, 4): 0.0,
 (2, 0): 0.0,
 (2, 1): 0.0,
 (2, 2): 100.0,
 (2, 3): 0.0,
 (2, 4): 0.0,
 (3, 0): 0.0,
 (3, 1): 0.0,
 (3, 2): 0.0,
 (3, 3): 98.2367758186398,
 (3, 4): 8.816120906801007,
 (4, 0): 0.0,
 (4, 1): 0.0,
 (4, 2): 0.0,
 (4, 3): 3.0303030303030303,
 (4, 4): 97.22222222222223}

In [5]:
from spuco.group_inference import Cluster, ClusterAlg

# Cluster inputs based on their model outputs
logits = trainer.get_trainset_outputs()
cluster = Cluster(
    Z=logits,
    class_labels=trainset.labels,
    cluster_alg=ClusterAlg.KMEANS,
    num_clusters=2,
    device=device,
    verbose=True
)

# Infer groups from clustering
group_partition = cluster.infer_groups()

Getting Trainset Outputs: 100%|███████████| 751/751 [00:02<00:00, 365.82batch/s]
Clustering class-wise: 100%|██████████████████████| 5/5 [00:00<00:00, 11.06it/s]


In [6]:
# Print sizes of inferred groups
for key in sorted(group_partition.keys()):
    print(key, len(group_partition[key]))

(0, 0) 10091
(0, 1) 42
(1, 0) 9639
(1, 1) 33
(2, 0) 8991
(2, 1) 20
(3, 0) 9721
(3, 1) 26
(4, 0) 9412
(4, 1) 29


In [7]:
from spuco.evaluate import Evaluator 

# Evaluate model performance after group inference
evaluator = Evaluator(
    testset=trainset,
    group_partition=group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:  10%|█▏          | 1/10 [00:00<00:05,  1.57it/s]

Group (0, 0) Accuracy: 91.27935784362303
Group (0, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  30%|███▌        | 3/10 [00:01<00:02,  2.66it/s]

Group (1, 0) Accuracy: 99.85475671750181
Group (1, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  50%|██████      | 5/10 [00:01<00:01,  2.91it/s]

Group (2, 0) Accuracy: 99.71082193304416
Group (2, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  70%|████████▍   | 7/10 [00:02<00:00,  3.12it/s]

Group (3, 0) Accuracy: 99.19761341425779
Group (3, 1) Accuracy: 0.0


Evaluating group-wise accuracy: 100%|███████████| 10/10 [00:02<00:00,  3.36it/s]

Group (4, 0) Accuracy: 95.88822779430514
Group (4, 1) Accuracy: 0.0





{(0, 0): 91.27935784362303,
 (0, 1): 0.0,
 (1, 0): 99.85475671750181,
 (1, 1): 0.0,
 (2, 0): 99.71082193304416,
 (2, 1): 0.0,
 (3, 0): 99.19761341425779,
 (3, 1): 0.0,
 (4, 0): 95.88822779430514,
 (4, 1): 0.0}

In [8]:
from torch.optim import SGD
from spuco.robust_train import GroupBalanceBatchERM, ClassBalanceBatchERM
from spuco.models import model_factory 

# Train model with GroupBalance Batch ERM
model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes).to(device)
group_balance_erm = GroupBalanceBatchERM(
    model=model,
    num_epochs=5,
    trainset=trainset,
    group_partition=group_partition,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-3, weight_decay=5e-4, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)
group_balance_erm.train()

Epoch 0: 100%|█| 751/751 [00:06<00:00, 116.05batch/s, accuracy=100.0%, loss=0.9]
Epoch 1: 100%|█| 751/751 [00:06<00:00, 121.28batch/s, accuracy=100.0%, loss=0.16
Epoch 2: 100%|█| 751/751 [00:06<00:00, 120.10batch/s, accuracy=100.0%, loss=0.02
Epoch 3: 100%|█| 751/751 [00:06<00:00, 119.70batch/s, accuracy=100.0%, loss=0.01
Epoch 4: 100%|█| 751/751 [00:06<00:00, 119.85batch/s, accuracy=100.0%, loss=0.00


In [9]:
from spuco.evaluate import Evaluator

# Evaluate model performance after group balancing
evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:  32%|███▊        | 8/25 [00:00<00:00, 33.80it/s]

Group (0, 0) Accuracy: 99.29078014184397
Group (0, 1) Accuracy: 14.657210401891254
Group (0, 2) Accuracy: 68.5579196217494
Group (0, 3) Accuracy: 80.61465721040189
Group (0, 4) Accuracy: 75.65011820330969
Group (1, 0) Accuracy: 63.32518337408313
Group (1, 1) Accuracy: 98.0440097799511
Group (1, 2) Accuracy: 83.57843137254902


Evaluating group-wise accuracy:  48%|█████▎     | 12/25 [00:00<00:00, 33.91it/s]

Group (1, 3) Accuracy: 67.40196078431373
Group (1, 4) Accuracy: 38.48039215686274
Group (2, 0) Accuracy: 55.733333333333334
Group (2, 1) Accuracy: 63.46666666666667
Group (2, 2) Accuracy: 98.4
Group (2, 3) Accuracy: 3.7333333333333334
Group (2, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  80%|████████▊  | 20/25 [00:00<00:00, 32.51it/s]

Group (3, 0) Accuracy: 69.09547738693468
Group (3, 1) Accuracy: 43.8287153652393
Group (3, 2) Accuracy: 2.770780856423174
Group (3, 3) Accuracy: 98.74055415617129
Group (3, 4) Accuracy: 0.0
Group (4, 0) Accuracy: 79.5969773299748
Group (4, 1) Accuracy: 57.178841309823675


Evaluating group-wise accuracy: 100%|███████████| 25/25 [00:00<00:00, 33.12it/s]

Group (4, 2) Accuracy: 0.0
Group (4, 3) Accuracy: 2.272727272727273
Group (4, 4) Accuracy: 100.0





{(0, 0): 99.29078014184397,
 (0, 1): 14.657210401891254,
 (0, 2): 68.5579196217494,
 (0, 3): 80.61465721040189,
 (0, 4): 75.65011820330969,
 (1, 0): 63.32518337408313,
 (1, 1): 98.0440097799511,
 (1, 2): 83.57843137254902,
 (1, 3): 67.40196078431373,
 (1, 4): 38.48039215686274,
 (2, 0): 55.733333333333334,
 (2, 1): 63.46666666666667,
 (2, 2): 98.4,
 (2, 3): 3.7333333333333334,
 (2, 4): 0.0,
 (3, 0): 69.09547738693468,
 (3, 1): 43.8287153652393,
 (3, 2): 2.770780856423174,
 (3, 3): 98.74055415617129,
 (3, 4): 0.0,
 (4, 0): 79.5969773299748,
 (4, 1): 57.178841309823675,
 (4, 2): 0.0,
 (4, 3): 2.272727272727273,
 (4, 4): 100.0}

In [10]:
# worst group accuracy attained
evaluator.worst_group_accuracy

((2, 4), 0.0)

In [11]:
# average group accuracy attained
evaluator.average_accuracy

98.63601080936486

In [12]:
# overall spurious attribute prediction
evaluator.evaluate_spurious_attribute_prediction()

55.53