In [1]:
%pip install ../../

Processing /home/hyang/SpuCo
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: spuco
  Building wheel for spuco (pyproject.toml) ... [?25ldone
[?25h  Created wheel for spuco: filename=spuco-0.0.1-py3-none-any.whl size=80990 sha256=baa23365dd11eadbad1f313099fbb066330dd64031c45b99ba3e52f9ba1e681a
  Stored in directory: /tmp/pip-ephem-wheel-cache-ze7p6be3/wheels/16/2e/00/5bdefcfd7f850d6f19880ecdb4dd3f325f4b906bf004cd82e3
Successfully built spuco
Installing collected packages: spuco
  Attempting uninstall: spuco
    Found existing installation: spuco 0.0.1
    Uninstalling spuco-0.0.1:
      Successfully uninstalled spuco-0.0.1
Successfully installed spuco-0.0.1
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch 

device = torch.device("cuda:7")

In [3]:
from spuco.utils import set_seed

set_seed(0)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from wilds import get_dataset
import torchvision.transforms as transforms

# Load the full dataset, and download it if necessary
dataset = get_dataset(dataset="waterbirds", download=True)

transform = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

# Get the training set
train_data = dataset.get_subset(
    "train",
    transform=transform
)

# Get the training set
test_data = dataset.get_subset(
    "test",
    transform=transform
)

In [5]:
from spuco.utils import WILDSDatasetWrapper

trainset = WILDSDatasetWrapper(dataset=train_data, metadata_spurious_label="background", verbose=True)
testset = WILDSDatasetWrapper(dataset=test_data, metadata_spurious_label="background", verbose=True)

Partitioning data indices into groups: 100%|██████████| 4795/4795 [00:00<00:00, 2341018.24it/s]
Partitioning data indices into groups: 100%|██████████| 5794/5794 [00:00<00:00, 2388148.33it/s]


In [6]:
trainset.group_weights

{(1, 1): 0.22043795620437956,
 (1, 0): 0.01167883211678832,
 (0, 0): 0.7295099061522419,
 (0, 1): 0.0383733055265902}

In [7]:
from spuco.models import model_factory 

model = model_factory("resnet50", trainset[0][0].shape, trainset.num_classes).to(device)

In [8]:
from torch.optim import SGD
from spuco.invariant_train import ERM 

erm = ERM(
    model=model,
    num_epochs=300,
    trainset=trainset,
    batch_size=128,
    optimizer=SGD(model.parameters(), lr=1e-3, momentum=0.9, nesterov=True, weight_decay = 1e-4),
    device=device,
    verbose=True
)
erm.train()

Epoch 0: 100%|██████████| 38/38 [00:34<00:00,  1.11batch/s, accuracy=91.52542372881356%, loss=0.232]
Epoch 1: 100%|██████████| 38/38 [00:27<00:00,  1.40batch/s, accuracy=94.91525423728814%, loss=0.168]
Epoch 2: 100%|██████████| 38/38 [00:27<00:00,  1.38batch/s, accuracy=96.61016949152543%, loss=0.0825]
Epoch 3: 100%|██████████| 38/38 [00:27<00:00,  1.40batch/s, accuracy=100.0%, loss=0.0411]   
Epoch 4: 100%|██████████| 38/38 [00:26<00:00,  1.41batch/s, accuracy=96.61016949152543%, loss=0.0697]
Epoch 5: 100%|██████████| 38/38 [00:26<00:00,  1.43batch/s, accuracy=98.30508474576271%, loss=0.043]
Epoch 6: 100%|██████████| 38/38 [00:26<00:00,  1.45batch/s, accuracy=100.0%, loss=0.0291]   
Epoch 7: 100%|██████████| 38/38 [00:25<00:00,  1.46batch/s, accuracy=100.0%, loss=0.0265]   
Epoch 8: 100%|██████████| 38/38 [00:25<00:00,  1.46batch/s, accuracy=100.0%, loss=0.0199]   
Epoch 9: 100%|██████████| 38/38 [00:26<00:00,  1.46batch/s, accuracy=98.30508474576271%, loss=0.0283]
Epoch 10: 100%|████

In [None]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:  25%|██▌       | 1/4 [00:02<00:07,  2.66s/it]

Group (0, 0) Test Accuracy: 99.33481152993348


Evaluating group-wise accuracy:  50%|█████     | 2/4 [00:05<00:05,  2.69s/it]

Group (0, 1) Test Accuracy: 71.84035476718404


Evaluating group-wise accuracy:  75%|███████▌  | 3/4 [00:06<00:02,  2.01s/it]

Group (1, 0) Test Accuracy: 55.45171339563863


Evaluating group-wise accuracy: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it]

Group (1, 1) Test Accuracy: 94.0809968847352





{(0, 0): 99.33481152993348,
 (0, 1): 71.84035476718404,
 (1, 0): 55.45171339563863,
 (1, 1): 94.0809968847352}

In [None]:
evaluator.average_accuracy

96.60911484174974