In [1]:
%env CUDA_VISIBLE_DEVICES=0
%env OMP_NUM_THREADS=4

env: CUDA_VISIBLE_DEVICES=0
env: OMP_NUM_THREADS=4


In [3]:
import torch
from types import SimpleNamespace

import numpy as np
import os
import sys
import tqdm

# import from parent directory
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
import models
import utils

from matplotlib import pyplot as plt
%matplotlib inline

In [4]:
def extract_embeddings(emb_path):
    trian_group_ratios = np.array([57498, 11158, 67376, 1521, 66630, 1992]).astype(float)
    trian_group_ratios = trian_group_ratios / trian_group_ratios.sum()
    print(f'Loading embeddings from {emb_path}')
    emb_dict = torch.load(emb_path)
    reweighting_embeddings, reweighting_predictions, reweighting_y, reweighting_groups = emb_dict['e'], emb_dict['pred'], emb_dict['y'], emb_dict['m']
    test_embeddings, test_predictions, test_y, test_groups = emb_dict['test_e'], emb_dict['test_pred'], emb_dict['test_y'], emb_dict['test_m']
    val_embeddings, val_predictions, val_y, val_groups = emb_dict['val_e'], emb_dict['val_pred'], emb_dict['val_y'], emb_dict['val_m']

    group_counts = {}
    for g in reweighting_groups:
            group_counts[g.item()] = group_counts.get(g.item(), 0) + 1
    print('*** Reweighting Counts + Group distribution: ***')
    for g in sorted(group_counts.keys()):
        print(f'Group {g}: {group_counts[g]} ({(group_counts[g] / sum(group_counts.values()) * 100):.1f}%)')

    print('*** Reweighting Acc + Group distribution: ***')        
    print_accs(reweighting_predictions, reweighting_y, reweighting_groups)

    print('*** Test Acc + Group distribution: ***')        
    print_accs(test_predictions, test_y, test_groups)

    print('*** Val Acc + Group distribution: ***')
    print_accs(val_predictions, val_y, val_groups)
    emb_dict['trian_group_ratios'] = trian_group_ratios
    emb_dict['g'] = emb_dict['m']
    emb_dict['test_g'] = emb_dict['test_m']
    emb_dict['val_g'] = emb_dict['val_m']
    return emb_dict

def get_accs(test_predictions, test_y, test_groups):
    return [(test_predictions == test_y)[test_groups == g].float().mean().item() for g in torch.unique(test_groups)]

def print_accs(test_predictions, test_y, test_groups):
    acc = (test_predictions == test_y).float().mean()
    group_accs = get_accs(test_predictions, test_y, test_groups)
    group_counts = {}
    for g in test_groups:
            group_counts[g.item()] = group_counts.get(g.item(), 0) + 1
    print(f"Avg: {acc:.3f}")
    for g, acc in enumerate(group_accs):
        print(f"Group {g}: {acc:.2f} ({(group_counts[g] / sum(group_counts.values()) * 100):.1f}%)")

In [5]:
def wxe_fn(logits, y, weights):
    ce = torch.nn.functional.cross_entropy(logits, y, reduction='none')
    l = weights * ce
    return l.sum()

def compute_afr_weights(erm_logits, class_label, gamma, balance_classes):
    # erm_logits: (n_samples, n_classes)
    # class_label: (n_samples,)
    # gamma: float
    with torch.no_grad():
        p = erm_logits.softmax(-1)
        y_onehot = torch.zeros_like(erm_logits).scatter_(-1, class_label.unsqueeze(-1), 1)
        p_true = (p * y_onehot).sum(-1)
        weights = (-gamma * p_true).exp()
        n_classes = torch.unique(class_label).numel()
        # class balancing
        if balance_classes:
            class_count = []
            for y in range(n_classes):
                class_count.append((class_label == y).sum())
            for y in range(1, n_classes):
                weights[class_label == y] *= class_count[0] / class_count[y]
        weights /= weights.sum()
    return weights

def argmax_last(x):
    # x: (n,)
    # take argmax, if there are multiple, take the last one
    x = torch.as_tensor(x)
    max_idx = (x == x.max()).nonzero().squeeze(-1)
    return max_idx[-1]

In [6]:
def train(emb_dict, num_epochs, gamma, reg_coeff, lr=1e-2, balance_classes=False, early_stop='wga', plot=True, silent=False, group_uniform=False, opt='sgd'):
    reweighting_embeddings = emb_dict['e'].cuda()
    reweighting_y = emb_dict['y'].cuda()
    reweighting_groups = emb_dict['g'].cuda()
    test_embeddings = emb_dict['test_e'].cuda()
    test_y = emb_dict['test_y'].cuda()
    test_groups = emb_dict['test_g'].cuda()
    val_embeddings = emb_dict['val_e'].cuda()
    val_y = emb_dict['val_y'].cuda()
    val_groups = emb_dict['val_g'].cuda()
    w0 = emb_dict['w0'].cuda()
    b0 = emb_dict['b0'].cuda()
    trian_group_ratios = emb_dict['trian_group_ratios']
    num_groups = torch.unique(test_groups).numel()

    class Model(torch.nn.Module):
        def __init__(self, w0, b0):
            super().__init__()
            self.w0 = w0
            self.b0 = b0
            self.linear = torch.nn.Linear(w0.shape[1], w0.shape[0], bias=True)
            
        def forward(self, x):
            y_old = x @ self.w0.t() + self.b0
            y_new = self.linear(x)
            return y_old + y_new

    model = Model(w0, b0)
    model.cuda()

    criterion = wxe_fn
    if opt == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0., momentum=0.)
    elif opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        raise NotImplementedError

    mean_accs = []
    test_mean_accs = []
    mean_group_accs = []
    val_mean_group_accs = []
    wgas = []
    test_wgas = []
    val_wgas = []
    losses = []
    regs = []
    weights_by_groups = []
    acc_by_groups = []
    val_acc_by_groups = []
    test_accs_by_groups = []
    weighted_acc = []
    val_weighted_acc = []

    initial_logits = reweighting_embeddings @ w0.t() + b0
    initial_val_logits = val_embeddings @ w0.t() + b0
    weights = compute_afr_weights(initial_logits, reweighting_y, gamma, balance_classes)
    val_weights = compute_afr_weights(initial_val_logits, val_y, gamma, balance_classes)
    if group_uniform:
        # uniform total weights per group
        group_counts = torch.bincount(reweighting_groups)
        for g in range(len(group_counts)):
            weights[reweighting_groups == g] = 1 / group_counts[g] / len(group_counts)
    for _ in (pbar := tqdm.tqdm(range(num_epochs), disable=silent)):
        optimizer.zero_grad()
        logits = model(reweighting_embeddings)
        loss = criterion(logits, reweighting_y, weights)
        reg = model.linear.weight.pow(2).sum() + model.linear.bias.pow(2).sum()
        loss += reg_coeff * reg
        # additional infos
        with torch.no_grad():
            val_logits = model(val_embeddings)
            ce = torch.nn.functional.cross_entropy(logits, reweighting_y, reduction='none')
            accs = get_accs(torch.argmax(logits, -1), reweighting_y, reweighting_groups)
            val_accs = get_accs(torch.argmax(model(val_embeddings), -1), val_y, val_groups)
            test_accs = get_accs(torch.argmax(model(test_embeddings), -1), test_y, test_groups)

        loss.backward()
        # clip gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
        optimizer.step()

        weights_by_group = [weights[reweighting_groups == g].sum().item() for g in range(num_groups)]
        weights_by_groups.append(weights_by_group)
        acc_by_groups.append(accs)
        val_acc_by_groups.append(val_accs)
        test_accs_by_groups.append(test_accs)
        mean_acc = (np.array(accs) * np.array(trian_group_ratios)).sum()
        mean_group_accs.append(np.array(accs).mean())
        mean_accs.append(mean_acc)
        val_mean_group_accs.append(np.array(val_accs).mean())
        wga = min(accs)
        test_wga = min(test_accs)
        val_wga = min(val_accs)
        wgas.append(wga)
        val_wgas.append(val_wga)
        test_wgas.append(test_wga)
        # test mean acc is test_accs weighted by trian_group_ratios
        test_mean_acc = (np.array(test_accs) * np.array(trian_group_ratios)).sum()
        test_mean_accs.append(test_mean_acc)
        losses.append(loss.item())
        regs.append(reg.item())
        weighted_acc.append((((torch.argmax(logits, -1) == reweighting_y).float() * weights).sum() / weights.sum()).item())
        val_weighted_acc.append((((torch.argmax(val_logits, -1) == val_y).float() * val_weights).sum() / val_weights.sum()).item())
        pbar.set_description(f"Loss: {loss.item():.5f}, WGA: {wga:.3f}, TWGA: {test_wga:.3f}")

    
    if early_stop == 'wga':
        earlystop_epoch = np.argmax(val_wgas)
        early_stop_metric = np.max(val_wgas)
    elif early_stop == 'wga_last':
        # take best wga on validation set
        # if multiple epochs have the same wga, take the last one
        earlystop_epoch = argmax_last(val_wgas)
        early_stop_metric = np.max(val_wgas)
    elif early_stop == 'wga@max_val_wa':
        earlystop_epoch = np.argmax(val_weighted_acc)
        early_stop_metric = val_wgas[earlystop_epoch]
    elif early_stop == 'mga':
        earlystop_epoch = np.argmax(val_mean_group_accs)
        early_stop_metric = np.max(val_mean_group_accs)
    elif early_stop == 'none':
        earlystop_epoch = num_epochs - 1
        early_stop_metric = val_wgas[-1]
    else:
        raise ValueError(f"Unknown early_stop: {early_stop}")
    earlystop_twga = test_wgas[earlystop_epoch]
    earlystop_tmacc = test_mean_accs[earlystop_epoch]
    final_twga = test_wgas[-1]
    final_tmacc = test_mean_accs[-1]
    if plot:
        # print test wga at best reweighting wga epoch
        print(f"Final test WGA: {final_twga:.3f}")
        print(f"Final test mean acc: {final_tmacc:.3f}")
        print(f"Early stopping test WGA: {earlystop_twga:.3f} at epoch {earlystop_epoch}") 
        print(f"Early stopping test mean acc: {earlystop_tmacc:.3f} at epoch {earlystop_epoch}")

        # 3 horizontal subplots
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))
        ax1.plot(losses, label="CWXE")
        ax1.plot(regs, label="Reg")
        ax1.grid()
        ax1.legend(loc="lower right")
        ax1.set_xlabel("Epochs")
        ax1.set_ylabel("Loss")

        ax2.plot(test_wgas, label="Test")
        ax2.plot(wgas, label="Reweighting")
        ax2.plot(val_wgas, label="Validation")
        ax2.grid()
        ax2.legend(loc="lower right")
        ax2.set_xlabel("Epochs")
        ax2.set_ylabel("WGA")

        ax3.plot(test_mean_accs, label="Test")
        ax3.plot(mean_accs, label="Reweighting")
        ax3.grid()
        ax3.legend(loc="lower right")
        ax3.set_xlabel("Epochs")
        ax3.set_ylabel("Mean Acc")

        ax4.plot(weighted_acc, label="Reweighting")
        ax4.plot(val_weighted_acc, label="Validation")
        ax4.grid()
        ax4.set_xlabel("Epochs")
        ax4.set_ylabel("Weighted Acc")
        ax4.legend(loc="lower right")
        plt.show()

        # 3 horizontal subplots
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))

        ax1.plot(weights_by_groups, label=[f'G{g+1}' for g in range(num_groups)])
        ax1.grid()
        ax1.set_xlabel("Epochs")
        ax1.set_ylabel("Weights")
        ax1.legend(loc="lower right")

        for g in range(num_groups):
            ax2.plot([acc[g] for acc in acc_by_groups], label=f"G{g+1}", color=f"C{g}")
            # ax2.plot([acc[g] for acc in val_acc_by_groups], color=f"C{g}", linestyle="--")
            ax2.plot([acc[g] for acc in test_accs_by_groups], color=f"C{g}", linestyle=":")
        ax2.grid()
        ax2.set_xlabel("Epochs")
        ax2.set_ylabel("Acc")
        ax2.legend(loc="lower right")

    return early_stop_metric, earlystop_twga, earlystop_tmacc

In [7]:
def hyper_sampler(sweep_config, random=True):
    if random:
        while True:
            gamma = np.random.choice(sweep_config["gamma_choices"])
            reg_coeff = np.random.choice(sweep_config["reg_coeff_choices"])
            balance_classes = np.random.choice(sweep_config["balance_classes_choices"])
            yield gamma, reg_coeff, balance_classes
    else:
        for gamma in sweep_config["gamma_choices"]:
            for reg_coeff in sweep_config["reg_coeff_choices"]:
                for balance_classes in sweep_config["balance_classes_choices"]:
                    yield gamma, reg_coeff, balance_classes

def run_training(emb_dict, num_epochs, lr, gamma, reg_coeff, balance_classes, early_stop):
    return train(emb_dict, num_epochs=num_epochs, gamma=gamma, reg_coeff=reg_coeff, lr=lr, balance_classes=balance_classes, early_stop=early_stop, plot=False, silent=True), (gamma, reg_coeff, balance_classes)

In [8]:
def sweep(ckpts, sweep_config, skip_if_done, seed=0):
    best_twgas = []
    best_tmaccs = []
    for ckpt in ckpts:
        emb_dict = extract_embeddings(ckpt)
        sweep_file = os.path.join(os.path.dirname(ckpt), f'seed{ckpt[-4]}_' + sweep_config['sweep_name'] + '.pt')

        if skip_if_done and os.path.exists(sweep_file):
            twga = torch.load(sweep_file)['best_twga']
            print(f"Skipping {ckpt}: TWGA {twga:.3f}")
            best_twgas.append(twga)
            best_tmaccs.append(torch.load(sweep_file)['best_tmacc'])
            continue

        # break
        # find best hyperparameters
        best_vwga = 0
        best_twga = 0
        best_tmacc = 0
        best_hyperparams = None
        top10_choices = []
        
        # fix all seed
        torch.manual_seed(seed)
        np.random.seed(seed)
        import random
        random.seed(seed)
        
        # subsample validation set
        if sweep_config['val_prop'] < 1:
            val_size = int(emb_dict['val_e'].shape[0] * sweep_config['val_prop'])
            val_idx = np.random.choice(emb_dict['val_e'].shape[0], val_size, replace=False)
            emb_dict['val_e'] = emb_dict['val_e'][val_idx]
            emb_dict['val_y'] = emb_dict['val_y'][val_idx]
            emb_dict['val_g'] = emb_dict['val_g'][val_idx]
        print(f'Tuning hypers on {emb_dict["val_e"].shape[0]} val examples')
        sampler = hyper_sampler(sweep_config, random=sweep_config['random'])
        for i in (pbar:= tqdm.tqdm(range(sweep_config['n_trials']))):
            try:
                gamma, reg_coeff, balance_classes = next(sampler)
            except StopIteration:
                break
            (val_wga, test_wga, test_macc, *_), hyperparams = run_training(emb_dict, sweep_config['num_epochs'], sweep_config['lr'], gamma, reg_coeff, balance_classes, sweep_config['early_stop'])
            if val_wga > best_vwga:
                best_vwga = val_wga
                best_twga = test_wga
                best_tmacc = test_macc
                best_hyperparams = hyperparams
                gamma, reg_coeff, balance_classes = best_hyperparams
                pbar.set_description(f"Test WGA: {test_wga:.3f}, Val WGA: {val_wga:.3f}, gamma: {gamma:.3f}, reg_coeff: {reg_coeff:.3f}, balance_classes: {balance_classes}")
                if val_wga == 1:
                    break
            if len(top10_choices) < 10:
                top10_choices.append((test_wga, val_wga, hyperparams))
                top10_choices = sorted(top10_choices, key=lambda x: x[1], reverse=True)
            else:
                top10_choices = sorted(top10_choices, key=lambda x: x[1], reverse=True)
                if val_wga > top10_choices[-1][1]:
                    top10_choices[-1] = (test_wga, val_wga, hyperparams)

        # save sweep results, indicate data_transform and num aug
        torch.save({
            'best_twga': best_twga,
            'best_tmacc': best_tmacc,
            'best_hyperparams': best_hyperparams,
            'top10_choices': top10_choices,
            'sweep_config': sweep_config,
        }, sweep_file)
        best_twgas.append(best_twga)
        best_tmaccs.append(best_tmacc)

    # print mean and std of best WGA
    print(f"Mean best WGA: {np.mean(best_twgas):.3f}, std: {np.std(best_twgas):.3f}")
    # print mean and std of mean acc
    print(f"Mean best mean acc: {np.mean(best_tmaccs):.3f}, std: {np.std(best_tmaccs):.3f}")
    return np.mean(best_twgas), np.std(best_twgas), np.mean(best_tmaccs), np.std(best_tmaccs)

## Nominal ##

In [9]:
ckpts = [
    '/home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed0.pt',
    '/home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed1.pt',
    '/home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed2.pt',
    ]

skip_if_done = True

for val_prop in [1]:
    sweep_config = {
        "n_trials": 260,
        "num_epochs": 200,
        "lr": 1e-2,
        "gamma_choices": np.logspace(2, 5, 10),
        "reg_coeff_choices": np.linspace(0, 50, 26),
        "balance_classes_choices": [True],
        "early_stop": 'wga',
        "random": False, # if True then random search, else grid search
        "val_prop": val_prop,
        "sweep_name": 'nominal',
        }
    sweep(ckpts, sweep_config, skip_if_done)

Loading embeddings from /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed0.pt
*** Reweighting Counts + Group distribution: ***
Group 0: 11459 (27.8%)
Group 1: 2149 (5.2%)
Group 2: 13475 (32.7%)
Group 3: 291 (0.7%)
Group 4: 13440 (32.6%)
Group 5: 421 (1.0%)
*** Reweighting Acc + Group distribution: ***
Avg: 0.824
Group 0: 0.83 (27.8%)
Group 1: 0.96 (5.2%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.6%)
Group 5: 0.63 (1.0%)
*** Test Acc + Group distribution: ***
Avg: 0.822
Group 0: 0.82 (28.0%)
Group 1: 0.95 (5.4%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.65 (0.9%)
*** Val Acc + Group distribution: ***
Avg: 0.823
Group 0: 0.82 (27.7%)
Group 1: 0.96 (5.6%)
Group 2: 0.84 (32.7%)
Group 3: 0.77 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.65 (1.0%)
Skipping /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed0.pt: TWGA 0.742
Loading embeddings from /home/shikai_q/spurious/checkpoints/multinli/multin

## Small validation ##

In [11]:
ckpts = [
    '/home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed0.pt',
    '/home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed1.pt',
    '/home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed2.pt',
    ]

skip_if_done = True

for seed in [0, 21, 42]:
    for val_prop in [0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1]:
        sweep_config = {
            "n_trials": 260,
            "num_epochs": 500,
            "lr": 1e-2,
            "gamma_choices": np.logspace(2, 5, 10),
            "reg_coeff_choices": np.linspace(0, 50, 26),
            "balance_classes_choices": [True],
            "early_stop": 'wga',
            "random": False, # if True then random search, else grid search
            "val_prop": val_prop,
            'sweep_name': f'val_prop_{val_prop}_seed{seed}'
            }
        sweep(ckpts, sweep_config, skip_if_done, seed=seed)

Loading embeddings from /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed0.pt
*** Reweighting Counts + Group distribution: ***
Group 0: 11459 (27.8%)
Group 1: 2149 (5.2%)
Group 2: 13475 (32.7%)
Group 3: 291 (0.7%)
Group 4: 13440 (32.6%)
Group 5: 421 (1.0%)
*** Reweighting Acc + Group distribution: ***
Avg: 0.824
Group 0: 0.83 (27.8%)
Group 1: 0.96 (5.2%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.6%)
Group 5: 0.63 (1.0%)
*** Test Acc + Group distribution: ***
Avg: 0.822
Group 0: 0.82 (28.0%)
Group 1: 0.95 (5.4%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.65 (0.9%)
*** Val Acc + Group distribution: ***
Avg: 0.823
Group 0: 0.82 (27.7%)
Group 1: 0.96 (5.6%)
Group 2: 0.84 (32.7%)
Group 3: 0.77 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.65 (1.0%)
Tuning hypers on 412 val examples


Test WGA: 0.664, Val WGA: 0.750, gamma: 100.000, reg_coeff: 0.000, balance_classes: True: 100%|██████████| 260/260 [24:26<00:00,  5.64s/it]


Loading embeddings from /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed1.pt
*** Reweighting Counts + Group distribution: ***
Group 0: 9137 (27.7%)
Group 1: 1814 (5.5%)
Group 2: 10832 (32.8%)
Group 3: 282 (0.9%)
Group 4: 10603 (32.1%)
Group 5: 320 (1.0%)
*** Reweighting Acc + Group distribution: ***
Avg: 0.925
Group 0: 0.92 (27.7%)
Group 1: 0.98 (5.5%)
Group 2: 0.94 (32.8%)
Group 3: 0.91 (0.9%)
Group 4: 0.91 (32.1%)
Group 5: 0.83 (1.0%)
*** Test Acc + Group distribution: ***
Avg: 0.821
Group 0: 0.82 (28.0%)
Group 1: 0.96 (5.4%)
Group 2: 0.84 (32.7%)
Group 3: 0.76 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.64 (0.9%)
*** Val Acc + Group distribution: ***
Avg: 0.822
Group 0: 0.82 (27.7%)
Group 1: 0.96 (5.6%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.66 (1.0%)
Tuning hypers on 412 val examples


Test WGA: 0.634, Val WGA: 0.750, gamma: 100.000, reg_coeff: 0.000, balance_classes: True: 100%|██████████| 260/260 [23:12<00:00,  5.35s/it]


Loading embeddings from /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed2.pt
*** Reweighting Counts + Group distribution: ***
Group 0: 11618 (28.2%)
Group 1: 2202 (5.3%)
Group 2: 13514 (32.8%)
Group 3: 286 (0.7%)
Group 4: 13195 (32.0%)
Group 5: 420 (1.0%)
*** Reweighting Acc + Group distribution: ***
Avg: 0.827
Group 0: 0.81 (28.2%)
Group 1: 0.96 (5.3%)
Group 2: 0.84 (32.8%)
Group 3: 0.77 (0.7%)
Group 4: 0.81 (32.0%)
Group 5: 0.67 (1.0%)
*** Test Acc + Group distribution: ***
Avg: 0.823
Group 0: 0.81 (28.0%)
Group 1: 0.96 (5.4%)
Group 2: 0.84 (32.7%)
Group 3: 0.77 (0.7%)
Group 4: 0.80 (32.3%)
Group 5: 0.64 (0.9%)
*** Val Acc + Group distribution: ***
Avg: 0.824
Group 0: 0.81 (27.7%)
Group 1: 0.96 (5.6%)
Group 2: 0.84 (32.7%)
Group 3: 0.76 (0.7%)
Group 4: 0.80 (32.3%)
Group 5: 0.65 (1.0%)
Tuning hypers on 412 val examples


Test WGA: 0.650, Val WGA: 0.500, gamma: 100.000, reg_coeff: 0.000, balance_classes: True: 100%|██████████| 260/260 [24:01<00:00,  5.54s/it]


Mean best WGA: 0.649, std: 0.012
Mean best mean acc: 0.822, std: 0.001
Loading embeddings from /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed0.pt
*** Reweighting Counts + Group distribution: ***
Group 0: 11459 (27.8%)
Group 1: 2149 (5.2%)
Group 2: 13475 (32.7%)
Group 3: 291 (0.7%)
Group 4: 13440 (32.6%)
Group 5: 421 (1.0%)
*** Reweighting Acc + Group distribution: ***
Avg: 0.824
Group 0: 0.83 (27.8%)
Group 1: 0.96 (5.2%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.6%)
Group 5: 0.63 (1.0%)
*** Test Acc + Group distribution: ***
Avg: 0.822
Group 0: 0.82 (28.0%)
Group 1: 0.95 (5.4%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.65 (0.9%)
*** Val Acc + Group distribution: ***
Avg: 0.823
Group 0: 0.82 (27.7%)
Group 1: 0.96 (5.6%)
Group 2: 0.84 (32.7%)
Group 3: 0.77 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.65 (1.0%)
Skipping /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed0.pt: TWGA 0.654
Load

Test WGA: 0.683, Val WGA: 0.821, gamma: 100.000, reg_coeff: 10.000, balance_classes: True: 100%|██████████| 260/260 [24:06<00:00,  5.57s/it]


Loading embeddings from /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed1.pt
*** Reweighting Counts + Group distribution: ***
Group 0: 9137 (27.7%)
Group 1: 1814 (5.5%)
Group 2: 10832 (32.8%)
Group 3: 282 (0.9%)
Group 4: 10603 (32.1%)
Group 5: 320 (1.0%)
*** Reweighting Acc + Group distribution: ***
Avg: 0.925
Group 0: 0.92 (27.7%)
Group 1: 0.98 (5.5%)
Group 2: 0.94 (32.8%)
Group 3: 0.91 (0.9%)
Group 4: 0.91 (32.1%)
Group 5: 0.83 (1.0%)
*** Test Acc + Group distribution: ***
Avg: 0.821
Group 0: 0.82 (28.0%)
Group 1: 0.96 (5.4%)
Group 2: 0.84 (32.7%)
Group 3: 0.76 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.64 (0.9%)
*** Val Acc + Group distribution: ***
Avg: 0.822
Group 0: 0.82 (27.7%)
Group 1: 0.96 (5.6%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.66 (1.0%)
Tuning hypers on 412 val examples


Test WGA: 0.688, Val WGA: 0.823, gamma: 10000.000, reg_coeff: 8.000, balance_classes: True: 100%|██████████| 260/260 [23:14<00:00,  5.36s/it]


Loading embeddings from /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed2.pt
*** Reweighting Counts + Group distribution: ***
Group 0: 11618 (28.2%)
Group 1: 2202 (5.3%)
Group 2: 13514 (32.8%)
Group 3: 286 (0.7%)
Group 4: 13195 (32.0%)
Group 5: 420 (1.0%)
*** Reweighting Acc + Group distribution: ***
Avg: 0.827
Group 0: 0.81 (28.2%)
Group 1: 0.96 (5.3%)
Group 2: 0.84 (32.8%)
Group 3: 0.77 (0.7%)
Group 4: 0.81 (32.0%)
Group 5: 0.67 (1.0%)
*** Test Acc + Group distribution: ***
Avg: 0.823
Group 0: 0.81 (28.0%)
Group 1: 0.96 (5.4%)
Group 2: 0.84 (32.7%)
Group 3: 0.77 (0.7%)
Group 4: 0.80 (32.3%)
Group 5: 0.64 (0.9%)
*** Val Acc + Group distribution: ***
Avg: 0.824
Group 0: 0.81 (27.7%)
Group 1: 0.96 (5.6%)
Group 2: 0.84 (32.7%)
Group 3: 0.76 (0.7%)
Group 4: 0.80 (32.3%)
Group 5: 0.65 (1.0%)
Tuning hypers on 412 val examples


Test WGA: 0.666, Val WGA: 0.829, gamma: 1000.000, reg_coeff: 2.000, balance_classes: True: 100%|██████████| 260/260 [23:31<00:00,  5.43s/it]


Mean best WGA: 0.679, std: 0.010
Mean best mean acc: 0.818, std: 0.004
Loading embeddings from /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed0.pt
*** Reweighting Counts + Group distribution: ***
Group 0: 11459 (27.8%)
Group 1: 2149 (5.2%)
Group 2: 13475 (32.7%)
Group 3: 291 (0.7%)
Group 4: 13440 (32.6%)
Group 5: 421 (1.0%)
*** Reweighting Acc + Group distribution: ***
Avg: 0.824
Group 0: 0.83 (27.8%)
Group 1: 0.96 (5.2%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.6%)
Group 5: 0.63 (1.0%)
*** Test Acc + Group distribution: ***
Avg: 0.822
Group 0: 0.82 (28.0%)
Group 1: 0.95 (5.4%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.65 (0.9%)
*** Val Acc + Group distribution: ***
Avg: 0.823
Group 0: 0.82 (27.7%)
Group 1: 0.96 (5.6%)
Group 2: 0.84 (32.7%)
Group 3: 0.77 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.65 (1.0%)
Skipping /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed0.pt: TWGA 0.687
Load

Test WGA: 0.619, Val WGA: 0.667, gamma: 100.000, reg_coeff: 12.000, balance_classes: True: 100%|██████████| 260/260 [23:53<00:00,  5.51s/it]


Loading embeddings from /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed1.pt
*** Reweighting Counts + Group distribution: ***
Group 0: 9137 (27.7%)
Group 1: 1814 (5.5%)
Group 2: 10832 (32.8%)
Group 3: 282 (0.9%)
Group 4: 10603 (32.1%)
Group 5: 320 (1.0%)
*** Reweighting Acc + Group distribution: ***
Avg: 0.925
Group 0: 0.92 (27.7%)
Group 1: 0.98 (5.5%)
Group 2: 0.94 (32.8%)
Group 3: 0.91 (0.9%)
Group 4: 0.91 (32.1%)
Group 5: 0.83 (1.0%)
*** Test Acc + Group distribution: ***
Avg: 0.821
Group 0: 0.82 (28.0%)
Group 1: 0.96 (5.4%)
Group 2: 0.84 (32.7%)
Group 3: 0.76 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.64 (0.9%)
*** Val Acc + Group distribution: ***
Avg: 0.822
Group 0: 0.82 (27.7%)
Group 1: 0.96 (5.6%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.66 (1.0%)
Tuning hypers on 412 val examples


Test WGA: 0.692, Val WGA: 0.667, gamma: 100.000, reg_coeff: 0.000, balance_classes: True: 100%|██████████| 260/260 [23:10<00:00,  5.35s/it]


Loading embeddings from /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed2.pt
*** Reweighting Counts + Group distribution: ***
Group 0: 11618 (28.2%)
Group 1: 2202 (5.3%)
Group 2: 13514 (32.8%)
Group 3: 286 (0.7%)
Group 4: 13195 (32.0%)
Group 5: 420 (1.0%)
*** Reweighting Acc + Group distribution: ***
Avg: 0.827
Group 0: 0.81 (28.2%)
Group 1: 0.96 (5.3%)
Group 2: 0.84 (32.8%)
Group 3: 0.77 (0.7%)
Group 4: 0.81 (32.0%)
Group 5: 0.67 (1.0%)
*** Test Acc + Group distribution: ***
Avg: 0.823
Group 0: 0.81 (28.0%)
Group 1: 0.96 (5.4%)
Group 2: 0.84 (32.7%)
Group 3: 0.77 (0.7%)
Group 4: 0.80 (32.3%)
Group 5: 0.64 (0.9%)
*** Val Acc + Group distribution: ***
Avg: 0.824
Group 0: 0.81 (27.7%)
Group 1: 0.96 (5.6%)
Group 2: 0.84 (32.7%)
Group 3: 0.76 (0.7%)
Group 4: 0.80 (32.3%)
Group 5: 0.65 (1.0%)
Tuning hypers on 412 val examples


Test WGA: 0.611, Val WGA: 0.667, gamma: 100.000, reg_coeff: 8.000, balance_classes: True: 100%|██████████| 260/260 [24:06<00:00,  5.56s/it]


Mean best WGA: 0.641, std: 0.036
Mean best mean acc: 0.821, std: 0.002
Loading embeddings from /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed0.pt
*** Reweighting Counts + Group distribution: ***
Group 0: 11459 (27.8%)
Group 1: 2149 (5.2%)
Group 2: 13475 (32.7%)
Group 3: 291 (0.7%)
Group 4: 13440 (32.6%)
Group 5: 421 (1.0%)
*** Reweighting Acc + Group distribution: ***
Avg: 0.824
Group 0: 0.83 (27.8%)
Group 1: 0.96 (5.2%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.6%)
Group 5: 0.63 (1.0%)
*** Test Acc + Group distribution: ***
Avg: 0.822
Group 0: 0.82 (28.0%)
Group 1: 0.95 (5.4%)
Group 2: 0.84 (32.7%)
Group 3: 0.75 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.65 (0.9%)
*** Val Acc + Group distribution: ***
Avg: 0.823
Group 0: 0.82 (27.7%)
Group 1: 0.96 (5.6%)
Group 2: 0.84 (32.7%)
Group 3: 0.77 (0.7%)
Group 4: 0.79 (32.3%)
Group 5: 0.65 (1.0%)
Skipping /home/shikai_q/spurious/checkpoints/multinli/multinli-embeddings-drop-seed0.pt: TWGA 0.619
Load