In [None]:
%env CUDA_VISIBLE_DEVICES=0
%env OMP_NUM_THREADS=4

In [None]:
import torch
from types import SimpleNamespace

import numpy as np
import os
import sys
import tqdm

# import from parent directory
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
import models
import utils

from matplotlib import pyplot as plt
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
### TODO: Modify these to your own paths ###
wb_data_path = '/datasets/waterbirds_official'
celeba_data_path = '/datasets/CelebA'

CKPTDIR = '/home/shikai_q/spurious_old/checkpoints'

WB_CKPTS = [
    f"{CKPTDIR}/wb/80_1/final_checkpoint.pt", 
    f"{CKPTDIR}/wb/80_21/20230115_131136_mVno/final_checkpoint.pt",
    f"{CKPTDIR}/wb/80_98/20230115_131150_CThg/final_checkpoint.pt"
]
CelebA_CKPTS = [
    f"{CKPTDIR}/celeba/80_1/20230106_144116_tpaT/final_checkpoint.pt", 
    f"{CKPTDIR}/celeba/80_21/20230106_144132_glNH/final_checkpoint.pt",
    f"{CKPTDIR}/celeba/80_98/20230106_144147_AVRm/final_checkpoint.pt"
    ]
# The remaining code needs zero modification

In [None]:
def extract_embeddings(ckpt, args, num_augs, load_emb=True, model_cls=models.imagenet_resnet50_pretrained):
    trian_group_ratios = np.array(utils.get_train_group_ratios(args))
    train_loader, holdout_loaders = utils.get_data(args, finetune_on_val=args.finetune_on_val)
    n_classes = train_loader.dataset.n_classes

    model = model_cls(n_classes)
    model.load_state_dict(torch.load(ckpt))
    model.cuda(); model.eval();

    classifier = model.fc
    w0 = classifier.weight.data.clone()
    b0 = classifier.bias.data.clone()
    model.fc = torch.nn.Identity()
    
    emb_path = f'{ckpt.replace(".pt", f"_{args.data_transform}{num_augs}_embeddings.pt")}'
    if not load_emb or not os.path.exists(emb_path):
        reweighting_embeddings, reweighting_predictions, reweighting_y, reweighting_groups = get_embeddings_predictions(model, classifier, train_loader, num_augs)
        val_embeddings, val_predictions, val_y, val_groups = get_embeddings_predictions(model, classifier, holdout_loaders['val'], num_augs=1)
        test_embeddings, test_predictions, test_y, test_groups = get_embeddings_predictions(model, classifier, holdout_loaders['test'], num_augs=1)
        emb_dict = dict(
                e=reweighting_embeddings, y=reweighting_y, pred=reweighting_predictions, g=reweighting_groups,
                test_e=test_embeddings, test_pred=test_predictions, test_y=test_y, test_g=test_groups,
                val_e=val_embeddings, val_pred=val_predictions, val_y=val_y, val_g=val_groups,
                w0=w0, b0=b0, trian_group_ratios=trian_group_ratios
            )
        torch.save(
            emb_dict,
            emb_path
        )
        print(f'Saved embeddings to {emb_path}')
    else:
        print(f'Loading embeddings from {emb_path}')
        emb_dict = torch.load(emb_path)
        reweighting_embeddings, reweighting_predictions, reweighting_y, reweighting_groups = emb_dict['e'], emb_dict['pred'], emb_dict['y'], emb_dict['g']
        test_embeddings, test_predictions, test_y, test_groups = emb_dict['test_e'], emb_dict['test_pred'], emb_dict['test_y'], emb_dict['test_g']
        val_embeddings, val_predictions, val_y, val_groups = emb_dict['val_e'], emb_dict['val_pred'], emb_dict['val_y'], emb_dict['val_g']

    group_counts = {}
    for g in reweighting_groups:
            group_counts[g.item()] = group_counts.get(g.item(), 0) + 1
    print('*** Reweighting Counts + Group distribution: ***')
    for g in sorted(group_counts.keys()):
        print(f'Group {g}: {group_counts[g]} ({(group_counts[g] / sum(group_counts.values()) * 100):.1f}%)')

    print('*** Reweighting Acc + Group distribution: ***')        
    print_accs(reweighting_predictions, reweighting_y, reweighting_groups)

    print('*** Test Acc + Group distribution: ***')        
    print_accs(test_predictions, test_y, test_groups)

    print('*** Val Acc + Group distribution: ***')
    print_accs(val_predictions, val_y, val_groups)
    # return reweighting_embeddings.cuda(), reweighting_y.cuda(), reweighting_groups.cuda(), test_embeddings.cuda(), test_y.cuda(), test_predictions.cuda(), test_groups.cuda(), 
    emb_dict['trian_group_ratios'] = trian_group_ratios
    return emb_dict

def get_embeddings_predictions(feature_extractor, classifier, loader, num_augs):
    all_embeddings, all_predictions, all_y_true, all_groups = [], [], [], []
    with torch.no_grad():
        for epoch in range(num_augs):
            for x, y_true, g, *_ in tqdm.tqdm(loader):
                x = x.cuda()

                embeddings = feature_extractor(x)
                predictions = torch.argmax(classifier(embeddings), axis=1)
                all_embeddings.append(embeddings.cpu())
                all_predictions.append(predictions.cpu())
                all_y_true.append(y_true.cpu())
                all_groups.append(g.cpu())
    all_embeddings = torch.cat(all_embeddings, axis=0)
    all_predictions = torch.cat(all_predictions, axis=0)
    all_y_true = torch.cat(all_y_true, axis=0)
    all_groups = torch.cat(all_groups, axis=0)
    return all_embeddings, all_predictions, all_y_true, all_groups

def get_accs(test_predictions, test_y, test_groups):
    return [(test_predictions == test_y)[test_groups == g].float().mean().item() for g in range(test_groups.max() + 1)]

def print_accs(test_predictions, test_y, test_groups):
    acc = (test_predictions == test_y).float().mean()
    group_accs = get_accs(test_predictions, test_y, test_groups)
    group_counts = {}
    for g in test_groups:
            group_counts[g.item()] = group_counts.get(g.item(), 0) + 1
    print(f"Avg: {acc:.3f}")
    for g, acc in enumerate(group_accs):
        if g in group_counts:
            print(f"Group {g}: {acc:.2f} ({(group_counts[g] / sum(group_counts.values()) * 100):.1f}%)")
        else:
            print(f"Group {g}: {acc:.2f} (0.0%)")

In [None]:
def wxe_fn(logits, y, weights):
    ce = torch.nn.functional.cross_entropy(logits, y, reduction='none')
    l = weights * ce
    return l.sum()

def compute_afr_weights(erm_logits, class_label, gamma, balance_classes):
    # erm_logits: (n_samples, n_classes)
    # class_label: (n_samples,)
    # gamma: float
    with torch.no_grad():
        p = erm_logits.softmax(-1)
        y_onehot = torch.zeros_like(erm_logits).scatter_(-1, class_label.unsqueeze(-1), 1)
        p_true = (p * y_onehot).sum(-1)
        weights = (-gamma * p_true).exp()
        n_classes = torch.unique(class_label).numel()
        # class balancing
        if balance_classes:
            class_count = []
            for y in range(n_classes):
                class_count.append((class_label == y).sum())
            for y in range(1, n_classes):
                weights[class_label == y] *= class_count[0] / class_count[y]
        weights /= weights.sum()
    return weights

In [None]:
def train(emb_dict, num_epochs, gamma, reg_coeff, lr=1e-2, balance_classes=False, early_stop='wga', plot=True, silent=False, group_uniform=False, opt='sgd'):
    reweighting_embeddings = emb_dict['e'].cuda()
    reweighting_y = emb_dict['y'].cuda()
    reweighting_groups = emb_dict['g'].cuda()
    test_embeddings = emb_dict['test_e'].cuda()
    test_y = emb_dict['test_y'].cuda()
    test_groups = emb_dict['test_g'].cuda()
    val_embeddings = emb_dict['val_e'].cuda()
    val_y = emb_dict['val_y'].cuda()
    val_groups = emb_dict['val_g'].cuda()
    w0 = emb_dict['w0'].cuda()
    b0 = emb_dict['b0'].cuda()
    trian_group_ratios = emb_dict['trian_group_ratios']
    num_groups = torch.unique(test_groups).numel()

    class Model(torch.nn.Module):
        def __init__(self, w0, b0):
            super().__init__()
            self.w0 = w0
            self.b0 = b0
            self.linear = torch.nn.Linear(w0.shape[1], w0.shape[0], bias=True)
            
        def forward(self, x):
            y_old = x @ self.w0.t() + self.b0
            y_new = self.linear(x)
            return y_old + y_new

    model = Model(w0, b0)
    model.cuda()

    criterion = wxe_fn
    if opt == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0., momentum=0.)
    elif opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        raise NotImplementedError

    mean_accs = []
    test_mean_accs = []
    mean_group_accs = []
    val_mean_group_accs = []
    wgas = []
    test_wgas = []
    val_wgas = []
    losses = []
    val_losses = []
    regs = []
    weights_by_groups = []
    acc_by_groups = []
    val_acc_by_groups = []
    test_accs_by_groups = []
    weighted_acc = []
    val_weighted_acc = []

    initial_logits = reweighting_embeddings @ w0.t() + b0
    weights = compute_afr_weights(initial_logits, reweighting_y, gamma, balance_classes)
    if group_uniform:
        # uniform total weights per group
        group_counts = torch.bincount(reweighting_groups)
        for g in range(len(group_counts)):
            weights[reweighting_groups == g] = 1 / group_counts[g] / len(group_counts)
    for _ in (pbar := tqdm.tqdm(range(num_epochs), disable=silent)):
        optimizer.zero_grad()
        logits = model(reweighting_embeddings)
        loss = criterion(logits, reweighting_y, weights)
        reg = model.linear.weight.pow(2).sum() + model.linear.bias.pow(2).sum()
        loss += reg_coeff * reg

        loss.backward()
        # clip gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
        optimizer.step()
        # additional infos
        with torch.no_grad():
            val_logits = model(val_embeddings)
            val_weights = compute_afr_weights(val_logits, val_y, gamma, balance_classes)
            ce = torch.nn.functional.cross_entropy(logits, reweighting_y, reduction='none')
            accs = get_accs(torch.argmax(logits, -1), reweighting_y, reweighting_groups)
            val_accs = get_accs(torch.argmax(model(val_embeddings), -1), val_y, val_groups)
            test_accs = get_accs(torch.argmax(model(test_embeddings), -1), test_y, test_groups)
        
            weights_by_group = [weights[reweighting_groups == g].sum().item() for g in range(num_groups)]
            weights_by_groups.append(weights_by_group)
            acc_by_groups.append(accs)
            val_acc_by_groups.append(val_accs)
            test_accs_by_groups.append(test_accs)
            mean_acc = (np.array(accs) * np.array(trian_group_ratios)).sum()
            mean_group_accs.append(np.array(accs).mean())
            mean_accs.append(mean_acc)
            val_mean_group_accs.append(np.array(val_accs).mean())
            wga = min(accs)
            test_wga = min(test_accs)
            val_wga = min(val_accs)
            wgas.append(wga)
            val_wgas.append(val_wga)
            test_wgas.append(test_wga)
            # test mean acc is test_accs weighted by trian_group_ratios
            test_mean_acc = (np.array(test_accs) * np.array(trian_group_ratios)).sum()
            test_mean_accs.append(test_mean_acc)
            losses.append(loss.item())
            regs.append(reg.item())
            weighted_acc.append((((torch.argmax(logits, -1) == reweighting_y).float() * weights).sum() / weights.sum()).item())
            val_weighted_acc.append((((torch.argmax(val_logits, -1) == val_y).float() * val_weights).sum() / val_weights.sum()).item())
            val_losses.append(criterion(val_logits, val_y, val_weights).item())
            pbar.set_description(f"Loss: {loss.item():.5f}, WGA: {wga:.3f}, TWGA: {test_wga:.3f}")

    
    if early_stop == 'wga':
        earlystop_epoch = np.argmax(val_wgas)
        early_stop_metric = np.max(val_wgas)
    elif early_stop == 'wga@max_val_wa':
        earlystop_epoch = np.argmax(val_weighted_acc)
        early_stop_metric = val_wgas[earlystop_epoch]
    elif early_stop == 'mga':
        earlystop_epoch = np.argmax(val_mean_group_accs)
        early_stop_metric = np.max(val_mean_group_accs)
    elif early_stop == 'none':
        earlystop_epoch = num_epochs - 1
        early_stop_metric = val_wgas[-1]
    else:
        raise ValueError(f"Unknown early_stop: {early_stop}")
    earlystop_twga = test_wgas[earlystop_epoch]
    earlystop_tmacc = test_mean_accs[earlystop_epoch]
    final_twga = test_wgas[-1]
    final_tmacc = test_mean_accs[-1]
    if plot:
        # print test wga at best reweighting wga epoch
        print(f"Final test WGA: {final_twga:.3f}")
        print(f"Final test mean acc: {final_tmacc:.3f}")
        print(f"Early stopping test WGA: {earlystop_twga:.3f} at epoch {earlystop_epoch}") 
        print(f"Early stopping test mean acc: {earlystop_tmacc:.3f} at epoch {earlystop_epoch}")

        # 3 horizontal subplots
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))
        ax1.plot(losses, label="CWXE")
        ax1.plot(regs, label="Reg")
        ax1.grid()
        ax1.legend(loc="lower right")
        ax1.set_xlabel("Epochs")
        ax1.set_ylabel("Loss")

        ax2.plot(test_wgas, label="Test")
        ax2.plot(wgas, label="Reweighting")
        ax2.plot(val_wgas, label="Validation")
        ax2.grid()
        ax2.legend(loc="lower right")
        ax2.set_xlabel("Epochs")
        ax2.set_ylabel("WGA")

        ax3.plot(test_mean_accs, label="Test")
        ax3.plot(mean_accs, label="Reweighting")
        ax3.grid()
        ax3.legend(loc="lower right")
        ax3.set_xlabel("Epochs")
        ax3.set_ylabel("Mean Acc")

        ax4.plot(weighted_acc, label="Reweighting")
        ax4.plot(val_weighted_acc, label="Validation")
        ax4.grid()
        ax4.set_xlabel("Epochs")
        ax4.set_ylabel("Weighted Acc")
        ax4.legend(loc="lower right")
        plt.show()

        # 3 horizontal subplots
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))

        ax1.plot(weights_by_groups, label=[f'G{g+1}' for g in range(num_groups)])
        ax1.grid()
        ax1.set_xlabel("Epochs")
        ax1.set_ylabel("Weights")
        ax1.legend(loc="lower right")

        for g in range(num_groups):
            ax2.plot([acc[g] for acc in acc_by_groups], label=f"G{g+1}", color=f"C{g}")
            ax2.plot([acc[g] for acc in test_accs_by_groups], color=f"C{g}", linestyle=":")
        ax2.grid()
        ax2.set_xlabel("Epochs")
        ax2.set_ylabel("Acc")
        ax2.legend(loc="lower right")

    return early_stop_metric, earlystop_twga, earlystop_tmacc

In [None]:
def hyper_sampler(sweep_config, random=True):
    if random:
        while True:
            lr = np.random.choice(sweep_config["lr_choices"])
            gamma = np.random.choice(sweep_config["gamma_choices"])
            reg_coeff = np.random.choice(sweep_config["reg_coeff_choices"])
            balance_classes = np.random.choice(sweep_config["balance_classes_choices"])
            yield lr, gamma, reg_coeff, balance_classes
    else:
        for lr in sweep_config["lr_choices"]:
            for gamma in sweep_config["gamma_choices"]:
                for reg_coeff in sweep_config["reg_coeff_choices"]:
                    for balance_classes in sweep_config["balance_classes_choices"]:
                        yield lr, gamma, reg_coeff, balance_classes

def run_training(emb_dict, num_epochs, lr, gamma, reg_coeff, balance_classes, early_stop):
    return train(emb_dict, num_epochs=num_epochs, gamma=gamma, reg_coeff=reg_coeff, lr=lr, balance_classes=balance_classes, early_stop=early_stop, plot=False, silent=True), (lr, gamma, reg_coeff, balance_classes)

In [None]:
def sweep(ckpts, args, num_augs, sweep_config, load_emb, skip_if_done, seed=0):
    best_twgas = []
    best_tmaccs = []
    for ckpt in ckpts:
        emb_dict = extract_embeddings(ckpt, args, num_augs, load_emb)
        # same dir but with {sweep_name}.pt
        sweep_file = os.path.join(os.path.dirname(ckpt), sweep_config['sweep_name'] + '.pt')

        if skip_if_done and os.path.exists(sweep_file):
            twga = torch.load(sweep_file)['best_twga']
            print(f"Skipping {ckpt}: TWGA {twga:.3f}")
            best_twgas.append(twga)
            best_tmaccs.append(torch.load(sweep_file)['best_tmacc'])
            continue

        # break
        # find best hyperparameters
        best_vwga = 0
        best_twga = 0
        best_tmacc = 0
        best_hyperparams = None
        top10_choices = []
        all_results = []
        
        # fix all seeds
        torch.manual_seed(seed)
        np.random.seed(seed)
        import random
        random.seed(seed)
        
        # subsample validation set
        if sweep_config['val_prop'] < 1:
            val_size = int(emb_dict['val_e'].shape[0] * sweep_config['val_prop'])
            val_idx = np.random.choice(emb_dict['val_e'].shape[0], val_size, replace=False)
            emb_dict['val_e'] = emb_dict['val_e'][val_idx]
            emb_dict['val_y'] = emb_dict['val_y'][val_idx]
            emb_dict['val_g'] = emb_dict['val_g'][val_idx]
        print(f'Tuning hypers on {emb_dict["val_e"].shape[0]} val examples')
        sampler = hyper_sampler(sweep_config, random=sweep_config['random'])
        for i in (pbar:= tqdm.tqdm(range(sweep_config['n_trials']))):
            try:
                lr, gamma, reg_coeff, balance_classes = next(sampler)
            except StopIteration:
                break
            (val_wga, test_wga, test_macc), hyperparams = run_training(emb_dict, sweep_config['num_epochs'], lr, gamma, reg_coeff, balance_classes, sweep_config['early_stop'])
            all_results.append([gamma, reg_coeff, val_wga, test_wga])
            if val_wga > best_vwga:
                best_vwga = val_wga
                best_twga = test_wga
                best_tmacc = test_macc
                best_hyperparams = hyperparams
                lr, gamma, reg_coeff, balance_classes = best_hyperparams
                pbar.set_description(f"Test WGA: {test_wga:.3f}, Val WGA: {val_wga:.3f}, lr: {lr:.2g}, gamma: {gamma:.3f}, reg_coeff: {reg_coeff:.3f}, balance_classes: {balance_classes}")
            if len(top10_choices) < 10:
                top10_choices.append((test_wga, val_wga, hyperparams))
                top10_choices = sorted(top10_choices, key=lambda x: x[1], reverse=True)
            else:
                top10_choices = sorted(top10_choices, key=lambda x: x[1], reverse=True)
                if val_wga > top10_choices[-1][1]:
                    top10_choices[-1] = (test_wga, val_wga, hyperparams)

        # save sweep results, indicate data_transform and num aug
        torch.save({
            'best_twga': best_twga,
            'best_tmacc': best_tmacc,
            'best_hyperparams': best_hyperparams,
            'top10_choices': top10_choices,
            'sweep_config': sweep_config,
            'all_results': all_results,
        }, sweep_file)
        best_twgas.append(best_twga)
        best_tmaccs.append(best_tmacc)

    # print mean and std of best WGA
    print(f"Mean best WGA: {np.mean(best_twgas):.3f}, std: {np.std(best_twgas):.3f}")
    # print mean and std of mean acc
    print(f"Mean best mean acc: {np.mean(best_tmaccs):.3f}, std: {np.std(best_tmaccs):.3f}")
    return np.mean(best_twgas), np.std(best_twgas), np.mean(best_tmaccs), np.std(best_tmaccs)

 ## WB ##

In [None]:
load_emb = True
skip_if_done = True
num_augs = 1
val_prop = 1

args = SimpleNamespace(
        data_transform='NoAugWaterbirdsCelebATransform',
        dataset='SpuriousDataset',
        data_dir=wb_data_path,
        num_minority_groups_remove=0,
        test_data_dir=None,
        val_size=-1,
        mixup=False,
        batch_size=32,
        reweight_groups=False,
        reweight_classes=False,
        reweight_spurious=False,
        no_shuffle_train=False,
        finetune_on_val=False,
        dfr=False,
        pass_n=0,
        max_prop=1,
        train_prop=-0.2,
        val_prop=1, # subsampling happens in sweep function
    )
sweep_config = {
    "n_trials": 33 * 4,
    "num_epochs": 500,
    "lr_choices": [1e-2],
    "gamma_choices": np.linspace(4, 20, 33),
    "reg_coeff_choices": [0, 0.1, 0.2, 0.3],
    "balance_classes_choices": [True],
    "early_stop": 'wga',
    "random": False,
    "val_prop": val_prop,
    "sweep_name": 'nominal'
    }
sweep(WB_CKPTS, args, num_augs, sweep_config, load_emb, skip_if_done)

 ## CelebA ##

In [None]:
load_emb = True
skip_if_done = True
num_augs = 1
val_prop = 1

args = SimpleNamespace(
        data_transform='NoAugWaterbirdsCelebATransform',
        dataset='SpuriousDataset',
        data_dir=celeba_data_path,
        num_minority_groups_remove=0,
        test_data_dir=None,
        val_size=-1,
        mixup=False,
        batch_size=32,
        reweight_groups=False,
        reweight_classes=False,
        reweight_spurious=False,
        no_shuffle_train=False,
        finetune_on_val=False,
        dfr=False,
        pass_n=0,
        max_prop=1,
        train_prop=-0.2,
        val_prop=1, # subsampling happens in sweep function
    )
sweep_config = {
    "n_trials": 30,
    "num_epochs": 1000,
    "lr_choices": [2e-2],
    "gamma_choices": np.linspace(1, 3, 10),
    "reg_coeff_choices": np.logspace(-3, -1, 3),
    "balance_classes_choices": [True],
    "early_stop": 'wga',
    "random": False, # if True then random search, else grid search
    "val_prop": val_prop,
    "sweep_name": 'nominal'
    }
sweep(CelebA_CKPTS, args, num_augs, sweep_config, load_emb, skip_if_done)

### Robustness to $\gamma$

In [None]:
# WB
args = SimpleNamespace(
        data_transform='NoAugWaterbirdsCelebATransform',
        dataset='SpuriousDataset',
        data_dir=wb_data_path,
        num_minority_groups_remove=0,
        test_data_dir=None,
        val_size=-1,
        mixup=False,
        batch_size=32,
        reweight_groups=False,
        reweight_classes=False,
        reweight_spurious=False,
        no_shuffle_train=False,
        finetune_on_val=False,
        dfr=False,
        pass_n=0,
        max_prop=1,
        train_prop=-0.2,
        val_prop=1
    )

num_augs = 1

gammas = np.linspace(0, 30, 11)

twgas = []
n_effs = []
for gamma in tqdm.tqdm(gammas):
    twga = []
    n_eff = []
    accs_per_gs = []
    for ckpt in WB_CKPTS:
        emb_dict = extract_embeddings(ckpt, args, num_augs, load_emb=True)
        _, this_twga, _ = train(emb_dict, num_epochs=500, gamma=gamma, reg_coeff=0, lr=1e-2, balance_classes=True, early_stop='wga', plot=False, silent=True)
        e = emb_dict['e'].cuda()
        y = emb_dict['y'].cuda()
        g = emb_dict['g'].cuda()
        w0 = emb_dict['w0'].cuda()
        b0 = emb_dict['b0'].cuda()
        logits = e @ w0.t() + b0
        weights = compute_afr_weights(logits, y, gamma, True)
        n_eff.append(1 / (weights ** 2).sum().item())
        twga.append(this_twga)
    twgas.append(twga)
    n_effs.append(n_eff)
twgas = np.array(twgas)
n_effs = np.array(n_effs)

mean_wgas = twgas.mean(axis=1) * 100
std_wgas = twgas.std(axis=1) * 100
mean_n_effs = n_effs.mean(axis=1)
std_n_effs = n_effs.std(axis=1)

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set1")

plt.figure(dpi=100, figsize=(8, 6))
plt.fill_between(gammas, mean_wgas - std_wgas, mean_wgas + std_wgas, alpha=0.2)
plt.plot(gammas, mean_wgas, label='AFR', marker='o', markersize=10)
plt.xlabel(r'$\gamma$')

erm = 72.6

plt.plot([0, max(gammas)], [erm, erm], '--', label='ERM')
plt.ylim(bottom=60)

plt.ylabel('Test WGA [%]')
plt.legend()
plt.tight_layout()
plt.savefig('../plots/wb_wga_vs_gamma.pdf')


plt.figure(dpi=100, figsize=(8, 6))
plt.fill_between(gammas, mean_n_effs - std_n_effs, mean_n_effs + std_n_effs, alpha=0.2)
plt.plot(gammas, mean_n_effs, marker='o', markersize=10)
plt.xlabel(r'$\gamma$')
plt.ylabel(r'Effective # Samples')
plt.yscale('log')
plt.tight_layout()
# plt.savefig('../plots/wb_effsize_vs_gamma.pdf')

In [None]:
# CelebA
args = SimpleNamespace(
        data_transform='NoAugWaterbirdsCelebATransform',
        dataset='SpuriousDataset',
        data_dir = celeba_data_path,
        num_minority_groups_remove=0,
        test_data_dir=None,
        val_size=-1,
        mixup=False,
        batch_size=32,
        reweight_groups=False,
        reweight_classes=False,
        reweight_spurious=False,
        no_shuffle_train=False,
        finetune_on_val=False,
        dfr=False,
        pass_n=0,
        max_prop=1,
        train_prop=-0.2,
        val_prop=1,
    )
num_augs = 1

gammas = np.linspace(0, 30, 11)

twgas = []
n_effs = []
for gamma in tqdm.tqdm(gammas):
    twga = []
    n_eff = []
    accs_per_gs = []
    for ckpt in CelebA_CKPTS:
        emb_dict = extract_embeddings(ckpt, args, num_augs, load_emb=True)
        _, this_twga, _ = train(emb_dict, num_epochs=500, gamma=gamma, reg_coeff=0, lr=1e-2, balance_classes=True, early_stop='wga', plot=False, silent=True)
        e = emb_dict['e'].cuda()
        y = emb_dict['y'].cuda()
        g = emb_dict['g'].cuda()
        w0 = emb_dict['w0'].cuda()
        b0 = emb_dict['b0'].cuda()
        logits = e @ w0.t() + b0
        weights = compute_afr_weights(logits, y, gamma, True)
        n_eff.append(1 / (weights ** 2).sum().item())
        twga.append(this_twga)
    twgas.append(twga)
    n_effs.append(n_eff)
twgas = np.array(twgas)
n_effs = np.array(n_effs)

mean_wgas = twgas.mean(axis=1) * 100
std_wgas = twgas.std(axis=1) * 100
mean_n_effs = n_effs.mean(axis=1)
std_n_effs = n_effs.std(axis=1)

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set1")

plt.figure(dpi=100, figsize=(8, 6))
plt.fill_between(gammas, mean_wgas - std_wgas, mean_wgas + std_wgas, alpha=0.2)
plt.plot(gammas, mean_wgas, label='AFR', marker='o', markersize=10)
plt.xlabel(r'$\gamma$')

erm = 47.2

plt.plot([0, max(gammas)], [erm, erm], '--', label='ERM')
plt.ylim(bottom=40)
plt.ylabel('Test WGA [%]')
plt.legend()
plt.tight_layout()
plt.savefig('../plots/celeba_wga_vs_gamma.pdf')


plt.figure(dpi=100, figsize=(8, 6))
plt.fill_between(gammas, mean_n_effs - std_n_effs, mean_n_effs + std_n_effs, alpha=0.2)
plt.plot(gammas, mean_n_effs, marker='o', markersize=10)
plt.xlabel(r'$\gamma$')
plt.ylabel(r'Effective # Samples')
plt.yscale('log')
plt.tight_layout()
plt.savefig('../plots/celeba_effsize_vs_gamma.pdf')

### Group Label Efficiency / Down-sampled Validation Set

In [None]:
# We provide an example for Waterbirds
for seed in [0, 21, 42]:
    for num_augs in [1]:
        for val_prop in [0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1]:
            args = SimpleNamespace(
                    data_transform='NoAugWaterbirdsCelebATransform',
                    dataset='SpuriousDataset',
                    data_dir=wb_data_path,
                    num_minority_groups_remove=0,
                    test_data_dir=None,
                    val_size=-1,
                    mixup=False,
                    batch_size=32,
                    reweight_groups=False,
                    reweight_classes=False,
                    reweight_spurious=False,
                    no_shuffle_train=False,
                    finetune_on_val=False,
                    dfr=False,
                    pass_n=0,
                    max_prop=1,
                    train_prop=-0.2,
                    val_prop=1, # subsampling happens in sweep function
                )
            sweep_config = {
                "n_trials": 33 * 4,
                "num_epochs": 500,
                "lr_choices": [1e-2],
                "gamma_choices": np.linspace(4, 20, 33),
                "reg_coeff_choices": [0, 0.1, 0.2, 0.3],
                "balance_classes_choices": [True],
                "early_stop": 'wga',
                "random": False,
                "val_prop": val_prop,
                'sweep_name': f'val_prop_{val_prop}_seed{seed}'
                }
            sweep(WB_CKPTS, args, num_augs, sweep_config, load_emb, skip_if_done, seed=seed)