In [None]:
pip install spuco --upgrade


In [None]:
import os
import torch
import pandas as pd
import torchvision.transforms as transforms
from torch.optim import SGD
from wilds import get_dataset

from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
from spuco.evaluate import Evaluator
from spuco.group_inference import SpareInference
from spuco.robust_train import SpareTrain
from spuco.models import model_factory
from spuco.utils import Trainer, set_seed

In [None]:
params = {
    "gpu": 0,
    "seed": 0,
    "root_dir": "/data",
    "batch_size": 32,
    "num_epochs": 20,
    "lr": 1e-3,
    "weight_decay": 1e-2,
    "momentum": 0.9,
    "pretrained": False,
    "infer_lr": 1e-3,
    "infer_weight_decay": 1e-2,
    "infer_momentum": 0.9,
    "infer_num_epochs": 1,
    "high_sampling_power": 2,
}

In [None]:
device = torch.device(f"cuda:{params['gpu']}" if torch.cuda.is_available() else "cpu")
set_seed(params["seed"])

In [None]:
classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE

trainset = SpuCoMNIST(
    root=f"{params['root_dir']}/mnist/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.995,
    classes=classes,
    split="train"
)
trainset.initialize()

valset = SpuCoMNIST(
    root=f"{params['root_dir']}/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="val",
)
valset.initialize()

testset = SpuCoMNIST(
    root=f"{params['root_dir']}/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test"
)
testset.initialize()


In [None]:
model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes, pretrained=params["pretrained"]).to(device)

trainer = Trainer(
    trainset=trainset,
    model=model,
    batch_size=params["batch_size"],
    optimizer=SGD(model.parameters(), lr=params["infer_lr"], weight_decay=params["infer_weight_decay"], momentum=params["infer_momentum"]),
    device=device,
    verbose=True
)

trainer.train(num_epochs=params["infer_num_epochs"])

In [None]:
logits = trainer.get_trainset_outputs()
predictions = torch.nn.functional.softmax(logits, dim=1)

spare_infer = SpareInference(
    logits=predictions,
    class_labels=trainset.labels,
    device=device,
    max_clusters=5,
    high_sampling_power=params["high_sampling_power"],
    verbose=True
)

group_partition = spare_infer.infer_groups()
print(group_partition)
sampling_powers = spare_infer.sampling_powers

print("Sampling powers:", sampling_powers)
for key in sorted(group_partition.keys()):
    for true_key in sorted(trainset.group_partition.keys()):
        print(f"Inferred group: {key}, true group: {true_key}, size: {len([x for x in trainset.group_partition[true_key] if x in group_partition[key]])}")


In [None]:
valid_evaluator = Evaluator(
    testset=valset,
    group_partition=valset.group_partition,
    group_weights=valset.group_weights,
    batch_size=params["batch_size"],
    model=model,
    device=device,
    verbose=True
)

spare_train = SpareTrain(
    model=model,
    num_epochs=params["num_epochs"],
    trainset=trainset,
    group_partition=group_partition,
    sampling_powers=[20] * 5,
    batch_size=params["batch_size"],
    optimizer=SGD(model.parameters(), lr=params["lr"], weight_decay=params["weight_decay"], momentum=params["momentum"]),
    device=device,
    val_evaluator=valid_evaluator,
    verbose=True
)
spare_train.train()

In [None]:
evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=params["batch_size"],
    model=spare_train.best_model,
    device=device,
    verbose=True
)
evaluator.evaluate()

print("Final Results:")
print(f"Worst Group Accuracy: {evaluator.worst_group_accuracy[1]}")
print(f"Average Accuracy: {evaluator.average_accuracy}")