In [None]:
%pip install ../../../

In [None]:
import torch 

device = torch.device("cuda:7")

In [None]:
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
import torchvision.transforms as T

classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_EASY

trainset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.99,
    classes=classes,
    split="train",
)
trainset.initialize()

testset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test"
)
testset.initialize()

In [None]:
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
import torchvision.transforms as T


valset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="val"
)
valset.initialize()

In [None]:
T.ToPILImage()(trainset[1000][0]).resize((28,28))

In [None]:
from spuco.models import model_factory 
model = model_factory("mlp", trainset[0][0].shape, trainset.num_spurious).to(device)
from torch.optim import SGD
from spuco.invariant_train import ERM

erm = ERM(
    model=model,
    num_epochs=1,
    trainset=trainset,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)
erm.train()

In [None]:
from collections import Counter 

Counter(valset.spurious).most_common()

In [None]:
from spuco.group_inference import SSA
from spuco.utils import SpuriousTargetDataset

ssa = SSA(
    spurious_unlabeled_dataset=trainset,
    spurious_labeled_dataset=SpuriousTargetDataset(valset, valset.spurious),
    model=model_factory("lenet", trainset[0][0].shape, trainset.num_spurious).to(device),
    labeled_valset_size=0.1,
    num_iters=1000,
    tau_g_min=0.3,
    device=device,
    verbose=True
)

In [None]:
group_partition = ssa.infer_groups()
for key in sorted(group_partition.keys()):
    print(key, len(group_partition[key]))

In [None]:
from spuco.evaluate import Evaluator 

evaluator = Evaluator(
    testset=trainset,
    group_partition=group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

In [None]:
from torch.optim import SGD
from spuco.invariant_train import GroupBalanceBatchERM, ClassBalanceBatchERM
from spuco.models import model_factory 

model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes).to(device)
group_balance_erm = GroupBalanceBatchERM(
    model=model,
    num_epochs=5,
    trainset=trainset,
    group_partition=group_partition,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-3, weight_decay=5e-4, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)
group_balance_erm.train()

In [None]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

In [None]:
evaluator.worst_group_accuracy

In [None]:
evaluator.average_accuracy

In [None]:
evaluator.evaluate_spurious_task()