In [1]:
%pip install ../../../

Processing /home/sjoshi/spuco
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: spuco
  Building wheel for spuco (pyproject.toml) ... [?25ldone
[?25h  Created wheel for spuco: filename=spuco-0.0.1-py3-none-any.whl size=67756 sha256=5b3f201b13630943160da4d45d2a102c2503079d7036bc6b0958726225fd6822
  Stored in directory: /tmp/pip-ephem-wheel-cache-f0u5ilnm/wheels/ef/5d/43/a265894b1d52121a51705a208277e8d9a9670e95fa1a2e7ae6
Successfully built spuco
Installing collected packages: spuco
  Attempting uninstall: spuco
    Found existing installation: spuco 0.0.1
    Uninstalling spuco-0.0.1:
      Successfully uninstalled spuco-0.0.1
Successfully installed spuco-0.0.1
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch 

device = torch.device("cuda:7")

In [3]:
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
import torchvision.transforms as T

classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_EASY

trainset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.99,
    classes=classes,
    split="train",
)
trainset.initialize()

testset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test"
)
testset.initialize()

In [4]:
from spuco.models import model_factory 
from torch.optim import SGD
from spuco.invariant_train import ERM

model = model_factory("lenet", trainset[0][0].shape, trainset.num_spurious).to(device)
erm = ERM(
    model=model,
    num_epochs=1,
    trainset=trainset,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)
erm.train()

Epoch 0: 100%|██████████| 751/751 [00:03<00:00, 216.22batch/s, accuracy=100.0%, loss=0.00553] 


In [5]:
from spuco.group_inference import CorrectNContrastInference

model = model_factory("lenet", trainset[0][0].shape, trainset.num_spurious).to(device)
cnc_inference = CorrectNContrastInference(
    trainset=trainset, 
    model=model, 
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    num_epochs=1,
    device=device,
    verbose=True
)

In [6]:
group_partition = cnc_inference.infer_groups()
for key in sorted(group_partition.keys()):
    print(key, len(group_partition[key]))

Epoch 0: 100%|██████████| 751/751 [00:02<00:00, 295.84batch/s, accuracy=100.0%, loss=0.0102]  
Getting Trainset Outputs: 100%|██████████| 751/751 [00:00<00:00, 1213.29batch/s]

(0, 0) 10125
(0, 1) 9668
(0, 2) 9002
(0, 3) 9754
(0, 4) 9455





In [7]:
from spuco.evaluate import Evaluator 

evaluator = Evaluator(
    testset=trainset,
    group_partition=group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Group (0, 0) Test Accuracy: 99.06172839506173
Group (0, 1) Test Accuracy: 99.03806371534961
Group (0, 2) Test Accuracy: 99.08909131304155
Group (0, 3) Test Accuracy: 98.92351855648964
Group (0, 4) Test Accuracy: 98.84717080909572


{(0, 0): 99.06172839506173,
 (0, 1): 99.03806371534961,
 (0, 2): 99.08909131304155,
 (0, 3): 98.92351855648964,
 (0, 4): 98.84717080909572}

In [10]:
from torch.optim import SGD
from spuco.invariant_train import CorrectNContrastTrain
from spuco.models import model_factory 
from spuco.utils import GroupLabeledDataset

model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes).to(device)
cnc_train = CorrectNContrastTrain(
    trainset=GroupLabeledDataset(trainset, group_partition),
    model=model,
    batch_size=128,
    optimizer=SGD(model.parameters(), lr=1e-3, momentum=0.9, nesterov=True),
    num_pos=100,
    num_neg=100,
    num_epochs=10,
    lambda_ce=0.05,
    device=device,
    verbose=True
)
cnc_train.train()

Epoch 0: 100%|██████████| 376/376 [00:30<00:00, 12.44batch/s, accuracy=75.0%, loss=0.159]    
Epoch 1: 100%|██████████| 376/376 [00:32<00:00, 11.60batch/s, accuracy=25.0%, loss=0.16]     
Epoch 2: 100%|██████████| 376/376 [00:32<00:00, 11.59batch/s, accuracy=50.0%, loss=0.156]    
Epoch 3: 100%|██████████| 376/376 [00:29<00:00, 12.61batch/s, accuracy=0.0%, loss=0.166]     
Epoch 4: 100%|██████████| 376/376 [00:32<00:00, 11.46batch/s, accuracy=25.0%, loss=0.146]    
Epoch 5: 100%|██████████| 376/376 [00:30<00:00, 12.31batch/s, accuracy=50.0%, loss=0.112]    
Epoch 6: 100%|██████████| 376/376 [00:31<00:00, 12.03batch/s, accuracy=50.0%, loss=0.134]     
Epoch 7: 100%|██████████| 376/376 [00:32<00:00, 11.66batch/s, accuracy=50.0%, loss=0.0834]    
Epoch 8: 100%|██████████| 376/376 [00:30<00:00, 12.53batch/s, accuracy=75.0%, loss=0.0591]    
Epoch 9: 100%|██████████| 376/376 [00:32<00:00, 11.57batch/s, accuracy=75.0%, loss=0.0371]    


In [11]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Group (0, 0) Test Accuracy: 50.59101654846336
Group (0, 1) Test Accuracy: 0.0
Group (0, 2) Test Accuracy: 0.0
Group (0, 3) Test Accuracy: 0.0
Group (0, 4) Test Accuracy: 9.21985815602837
Group (1, 0) Test Accuracy: 0.0
Group (1, 1) Test Accuracy: 99.75550122249389
Group (1, 2) Test Accuracy: 6.617647058823529
Group (1, 3) Test Accuracy: 0.0
Group (1, 4) Test Accuracy: 0.0
Group (2, 0) Test Accuracy: 0.0
Group (2, 1) Test Accuracy: 0.0
Group (2, 2) Test Accuracy: 97.86666666666666
Group (2, 3) Test Accuracy: 0.0
Group (2, 4) Test Accuracy: 0.0
Group (3, 0) Test Accuracy: 0.0
Group (3, 1) Test Accuracy: 0.0
Group (3, 2) Test Accuracy: 0.0
Group (3, 3) Test Accuracy: 100.0
Group (3, 4) Test Accuracy: 0.0
Group (4, 0) Test Accuracy: 72.29219143576826
Group (4, 1) Test Accuracy: 0.0
Group (4, 2) Test Accuracy: 0.0
Group (4, 3) Test Accuracy: 0.0
Group (4, 4) Test Accuracy: 98.73737373737374


{(0, 0): 50.59101654846336,
 (0, 1): 0.0,
 (0, 2): 0.0,
 (0, 3): 0.0,
 (0, 4): 9.21985815602837,
 (1, 0): 0.0,
 (1, 1): 99.75550122249389,
 (1, 2): 6.617647058823529,
 (1, 3): 0.0,
 (1, 4): 0.0,
 (2, 0): 0.0,
 (2, 1): 0.0,
 (2, 2): 97.86666666666666,
 (2, 3): 0.0,
 (2, 4): 0.0,
 (3, 0): 0.0,
 (3, 1): 0.0,
 (3, 2): 0.0,
 (3, 3): 100.0,
 (3, 4): 0.0,
 (4, 0): 72.29219143576826,
 (4, 1): 0.0,
 (4, 2): 0.0,
 (4, 3): 0.0,
 (4, 4): 98.73737373737374}

In [12]:
evaluator.worst_group_accuracy

((0, 1), 0.0)

In [13]:
evaluator.average_accuracy

88.01920851845594

In [14]:
evaluator.evaluate_spurious_task()

84.89