In [None]:
%pip install ../../

In [None]:
import torch 
from spuco.utils import set_seed
from wilds import get_dataset
import torchvision.transforms as transforms
from spuco.datasets import WILDSDatasetWrapper
from spuco.datasets import GroupLabeledDatasetWrapper
import numpy as np

from spuco.models import model_factory 
from spuco.evaluate import Evaluator

import pickle
from spuco.invariant_train import ERM 
from torch.optim import SGD

In [None]:
seed = 0
set_seed(seed)
device = torch.device("cuda:0")

In [None]:
# Load the full dataset, and download it if necessary
dataset = get_dataset(dataset="waterbirds", download=True, root_dir='/home/data')

target_resolution = (224, 224)
transform_train = transforms.Compose([
            transforms.RandomResizedCrop(
                target_resolution,
                scale=(0.7, 1.0),
                ratio=(0.75, 1.3333333333333333),
                interpolation=2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

scale = 256.0 / 224.0
transform_test = transforms.Compose([
        transforms.Resize((int(target_resolution[0]*scale), int(target_resolution[1]*scale))),
        transforms.CenterCrop(target_resolution),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

# Get the training set
train_data = dataset.get_subset(
    "train",
    transform=transform_train
)

# Get the test set
test_data = dataset.get_subset(
    "test",
    transform=transform_test
)

# Get the val set
val_data = dataset.get_subset(
    "val",
    transform=transform_test
)

In [None]:
trainset = WILDSDatasetWrapper(dataset=train_data, metadata_spurious_label="background", verbose=True)
testset = WILDSDatasetWrapper(dataset=test_data, metadata_spurious_label="background", verbose=True)
valset = WILDSDatasetWrapper(dataset=val_data, metadata_spurious_label="background", verbose=True)

In [None]:
model = model_factory("resnet50", trainset[0][0].shape, 2, pretrained=True).to(device)

val_evaluator = Evaluator(
    testset=valset,
    group_partition=valset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=False
    )

erm = ERM(
    model=model,
    num_epochs=300,
    trainset=trainset,
    val_evaluator=val_evaluator,
    batch_size=128,
    optimizer=SGD(model.parameters(), lr=1e-3, weight_decay=1e-3, momentum=0.9),
    device=device,
    verbose=True
)
erm.train()

In [None]:
evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=erm.best_model,
    device=device,
    verbose=False
    )
evaluator.evaluate()

In [None]:
torch.save(model.state_dict(), 'path-to-save-the-model')