In [10]:
%pip install ../../

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Processing /home/yuyang/SpuCo
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: spuco
  Building wheel for spuco (pyproject.toml) ... [?25ldone
[?25h  Created wheel for spuco: filename=spuco-0.0.1-py3-none-any.whl size=84919 sha256=415f09e0b002b6054133cb295ff19cdaf4555e1a856228ced092612bb89e0c2c
  Stored in directory: /tmp/pip-ephem-wheel-cache-ompsph1s/wheels/99/d4/10/c6136b4f67d7a1fd0d788e21f761b3004ee3bbbfae90fc0ca1
Successfully built spuco
Installing collected packages: spuco
  Attempting uninstall: spuco
    Found existing installation: spuco 0.0.1
    Uninstalling spuco-0.0.1:
      Successfully uninstalled spuco-0.0.1
Successfully installed spuco-0.0.1
Note: you may need to restart the kernel to use

In [1]:
import torch 

device = torch.device("cuda:7")

In [2]:
from spuco.utils import set_seed

set_seed(0)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from wilds import get_dataset
import torchvision.transforms as transforms

# Load the full dataset, and download it if necessary
dataset = get_dataset(dataset="waterbirds", download=True, root_dir="/data")

transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

# Get the training set
train_data = dataset.get_subset(
    "train",
    transform=transform
)

# Get the training set
test_data = dataset.get_subset(
    "test",
    transform=transform
)

In [4]:
from spuco.datasets import WILDSDatasetWrapper

trainset = WILDSDatasetWrapper(dataset=train_data, metadata_spurious_label="background", verbose=True)
testset = WILDSDatasetWrapper(dataset=test_data, metadata_spurious_label="background", verbose=True)

Partitioning data indices into groups: 100%|██████████| 4795/4795 [00:00<00:00, 2156749.35it/s]
Partitioning data indices into groups: 100%|██████████| 5794/5794 [00:00<00:00, 1678996.64it/s]


In [5]:
trainset.group_weights

{(1, 1): 0.22043795620437956,
 (1, 0): 0.01167883211678832,
 (0, 0): 0.7295099061522419,
 (0, 1): 0.0383733055265902}

In [6]:
from spuco.datasets import GroupLabeledDatasetWrapper

invariant_trainset = GroupLabeledDatasetWrapper(trainset, trainset.group_partition)

In [7]:
from spuco.models import model_factory 

model = model_factory("resnet50", trainset[0][0].shape, trainset.num_classes).to(device)

In [8]:
from torch.optim import SGD
from spuco.invariant_train import GroupDRO 

group_dro = GroupDRO(
    model=model,
    num_epochs=300,
    trainset=invariant_trainset,
    batch_size=128,
    optimizer=SGD(model.parameters(), lr=1e-5, weight_decay=1.0, momentum=0.9),
    device=device,
    verbose=True
)
group_dro.train()

Epoch 0: 100%|██████████| 38/38 [00:12<00:00,  3.04batch/s, accuracy=38.983050847457626%, loss=0.71]
Epoch 1: 100%|██████████| 38/38 [00:10<00:00,  3.53batch/s, accuracy=40.67796610169491%, loss=0.699]
Epoch 2: 100%|██████████| 38/38 [00:10<00:00,  3.53batch/s, accuracy=32.20338983050848%, loss=0.724]
Epoch 3: 100%|██████████| 38/38 [00:10<00:00,  3.52batch/s, accuracy=44.067796610169495%, loss=0.708]
Epoch 4: 100%|██████████| 38/38 [00:10<00:00,  3.54batch/s, accuracy=50.847457627118644%, loss=0.699]
Epoch 5: 100%|██████████| 38/38 [00:10<00:00,  3.48batch/s, accuracy=59.32203389830509%, loss=0.69]
Epoch 6: 100%|██████████| 38/38 [00:10<00:00,  3.58batch/s, accuracy=54.23728813559322%, loss=0.694]
Epoch 7: 100%|██████████| 38/38 [00:10<00:00,  3.53batch/s, accuracy=52.54237288135593%, loss=0.683]
Epoch 8: 100%|██████████| 38/38 [00:10<00:00,  3.51batch/s, accuracy=55.932203389830505%, loss=0.685]
Epoch 9: 100%|██████████| 38/38 [00:10<00:00,  3.51batch/s, accuracy=57.6271186440678%, l

In [9]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:  25%|██▌       | 1/4 [00:02<00:07,  2.63s/it]

Group (0, 0) Test Accuracy: 95.21064301552106


Evaluating group-wise accuracy:  50%|█████     | 2/4 [00:05<00:05,  2.58s/it]

Group (0, 1) Test Accuracy: 84.96674057649668


Evaluating group-wise accuracy:  75%|███████▌  | 3/4 [00:06<00:01,  1.95s/it]

Group (1, 0) Test Accuracy: 87.07165109034268


Evaluating group-wise accuracy: 100%|██████████| 4/4 [00:07<00:00,  1.88s/it]

Group (1, 1) Test Accuracy: 91.90031152647975





{(0, 0): 95.21064301552106,
 (0, 1): 84.96674057649668,
 (1, 0): 87.07165109034268,
 (1, 1): 91.90031152647975}