In [None]:
%env OMP_NUM_THREADS=4
%env CUDA_VISIBLE_DEVICES=0
%load_ext autoreload
%autoreload 2

In [None]:
### TODO: Modify to your own data dir ###
mnli_data_dir = '/data/users/pavel_i/datasets/multinli'

In [None]:
import sys
sys.path.append("../")
sys.path.append("../wilds_exps_utils/")

import torch
import numpy as np
import tqdm
import pickle
import copy
from types import SimpleNamespace
from wilds import get_dataset
from wilds.common.data_loaders import get_eval_loader
from wilds_configs import datasets as dataset_configs
from wilds.datasets.wilds_dataset import WILDSSubset
from wilds_models.initializer import initialize_model
import wilds_transforms as transforms

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline

## Data and model

In [None]:
import argparse
from types import SimpleNamespace
import tqdm
import torch

from transformers import BertConfig, BertForSequenceClassification

import sys
gdro_dir = '../gdro_fork'
sys.path.append(gdro_dir)
from gdro_fork.data.data import prepare_data

In [None]:
NUM_GROUPS = 6

gdro_config = SimpleNamespace(
    dataset='MultiNLI',
    shift_type='confounder',
    root_dir=mnli_data_dir,
    augment_data=False,
    gamma=0.1,
    batch_size=128,
    target_name='gold_label_random',
    confounder_names=['sentence2_has_negation',],
    model='bert',
    fraction=1.,
)
reweighting_data, val_data, test_data = prepare_data(gdro_config, train=True)
loader_kwargs = {'num_workers':4, 'pin_memory':True}

# reweighting_data = val_data
val_loader = val_data.get_loader(
        train=False, reweight_groups=None, **loader_kwargs)
test_loader = test_data.get_loader(
        train=False, reweight_groups=None, **loader_kwargs)

## Extract embeddings

In [None]:
def get_embeddings_predictions(feature_extractor, classifier, loader):
    all_embeddings, all_predictions, all_y_true, all_metadata = [], [], [], []
#     i = 0
    with torch.no_grad():
        for x, y_true, metadata in tqdm.tqdm(loader):
            input_ids = x[:, :, 0].cuda()
            input_masks = x[:, :, 1].cuda()
            segment_ids = x[:, :, 2].cuda()
            embeddings = feature_extractor(
                    input_ids=input_ids,
                    attention_mask=input_masks,
                    token_type_ids=segment_ids).logits
            predictions = torch.argmax(classifier(embeddings), axis=1)
            all_embeddings.append(embeddings.cpu())
            all_predictions.append(predictions.cpu())
            all_y_true.append(y_true.cpu())
            all_metadata.append(metadata)
    all_embeddings = torch.cat(all_embeddings, axis=0)
    all_predictions = torch.cat(all_predictions, axis=0)
    all_y_true = torch.cat(all_y_true, axis=0)
    all_metadata = torch.cat(all_metadata, axis=0)
    return all_embeddings, all_predictions, all_y_true, all_metadata

def save_emb(ckpt_path, seed, save_path):
    reweighting_data, val_data, test_data = prepare_data(gdro_config, train=True)
    loader_kwargs = {'num_workers':4, 'pin_memory':True}

    # reweighting_data = val_data
    val_loader = val_data.get_loader(
            train=False, reweight_groups=None, **loader_kwargs)
    test_loader = test_data.get_loader(
            train=False, reweight_groups=None, **loader_kwargs)
    reweighting_seed = seed
    reweighting_frac = 0.2

    print(f'Dropping reweighting data, seed {reweighting_seed}')

    idx = reweighting_data.dataset.indices.copy()
    rng = np.random.default_rng(reweighting_seed)
    rng.shuffle(idx)
    n_train = int((1 - reweighting_frac) * len(idx))
    reweighting_idx = idx[n_train:]

    print(f'Original dataset size: {len(reweighting_data.dataset.indices)}')
    reweighting_data.dataset = torch.utils.data.dataset.Subset(
        reweighting_data.dataset.dataset,
        indices=reweighting_idx)
    print(f'New dataset size: {len(reweighting_data.dataset.indices)}')

    reweighting_loader = reweighting_data.get_loader(
            train=False, reweight_groups=None, **loader_kwargs)
    model = torch.load(ckpt_path)
    model.cuda()
    model.eval()

    classifier = model.classifier
    model.classifier = torch.nn.Identity(classifier.in_features)

    feature_extractor, classifier = model, classifier
    reweighting_embeddings, reweighting_predictions, reweighting_y, reweighting_metadata = get_embeddings_predictions(
            feature_extractor, classifier, reweighting_loader)
    val_embeddings, val_predictions, val_y, val_metadata = get_embeddings_predictions(
            feature_extractor, classifier, val_loader)
    test_embeddings, test_predictions, test_y, test_metadata = get_embeddings_predictions(
            feature_extractor, classifier, test_loader)
    torch.save(
        dict(
            e=reweighting_embeddings, y=reweighting_y, pred=reweighting_predictions, m=reweighting_metadata,
            test_e=test_embeddings, test_pred=test_predictions, test_y=test_y, test_m=test_metadata,
            val_e=val_embeddings, val_pred=val_predictions, val_y=val_y, val_m=val_metadata,
            w0 = classifier.weight.cpu(),
            b0 = classifier.bias.cpu()
        ),
        save_path
    )

In [None]:
ckpt_dir = '/home/andres_p/gdro_fork/logs/multinli' # TODO: change to your directory of checkpoints
ckpts = ['erm_dfrdrop0', 'erm_dfrdrop_1', 'erm_dfrdrop2'] # TODO: change to your checkpoints
seeds = [0, 1, 2] # TODO: change to your seeds for splitting the train data
for ckpt, seed in zip(ckpts, seeds):
    ckpt_path = f'{ckpt_dir}/{ckpt}/last_model.pth'
    save_path = f'../emb/multinli/{ckpt}.pt'
    save_emb(ckpt_path, seed, save_path)

## AFR ##

In [None]:
### TODO: Modify these to your own paths ###
emb_paths = [
    '../emb/multinli/erm_dfrdrop0.pt',
    '../emb/multinli/erm_dfrdrop1.pt',
    '../emb/multinli/erm_dfrdrop2.pt',
]
### The remaining code needs zero modification ###

In [None]:
import torch
from types import SimpleNamespace

import numpy as np
import os
import sys
import tqdm

# import from parent directory
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
import models
import utils

from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
def extract_embeddings(emb_path):
    trian_group_ratios = np.array([57498, 11158, 67376, 1521, 66630, 1992]).astype(float)
    trian_group_ratios = trian_group_ratios / trian_group_ratios.sum()
    print(f'Loading embeddings from {emb_path}')
    emb_dict = torch.load(emb_path)
    reweighting_embeddings, reweighting_predictions, reweighting_y, reweighting_groups = emb_dict['e'], emb_dict['pred'], emb_dict['y'], emb_dict['m']
    test_embeddings, test_predictions, test_y, test_groups = emb_dict['test_e'], emb_dict['test_pred'], emb_dict['test_y'], emb_dict['test_m']
    val_embeddings, val_predictions, val_y, val_groups = emb_dict['val_e'], emb_dict['val_pred'], emb_dict['val_y'], emb_dict['val_m']

    group_counts = {}
    for g in reweighting_groups:
            group_counts[g.item()] = group_counts.get(g.item(), 0) + 1
    print('*** Reweighting Counts + Group distribution: ***')
    for g in sorted(group_counts.keys()):
        print(f'Group {g}: {group_counts[g]} ({(group_counts[g] / sum(group_counts.values()) * 100):.1f}%)')

    print('*** Reweighting Acc + Group distribution: ***')        
    print_accs(reweighting_predictions, reweighting_y, reweighting_groups)

    print('*** Test Acc + Group distribution: ***')        
    print_accs(test_predictions, test_y, test_groups)

    print('*** Val Acc + Group distribution: ***')
    print_accs(val_predictions, val_y, val_groups)
    emb_dict['trian_group_ratios'] = trian_group_ratios
    emb_dict['g'] = emb_dict['m']
    emb_dict['test_g'] = emb_dict['test_m']
    emb_dict['val_g'] = emb_dict['val_m']
    return emb_dict

def get_accs(test_predictions, test_y, test_groups):
    return [(test_predictions == test_y)[test_groups == g].float().mean().item() for g in torch.unique(test_groups)]

def print_accs(test_predictions, test_y, test_groups):
    acc = (test_predictions == test_y).float().mean()
    group_accs = get_accs(test_predictions, test_y, test_groups)
    group_counts = {}
    for g in test_groups:
            group_counts[g.item()] = group_counts.get(g.item(), 0) + 1
    print(f"Avg: {acc:.3f}")
    for g, acc in enumerate(group_accs):
        print(f"Group {g}: {acc:.2f} ({(group_counts[g] / sum(group_counts.values()) * 100):.1f}%)")

In [12]:
def wxe_fn(logits, y, weights):
    ce = torch.nn.functional.cross_entropy(logits, y, reduction='none')
    l = weights * ce
    return l.sum()

def compute_afr_weights(erm_logits, class_label, gamma, balance_classes):
    # erm_logits: (n_samples, n_classes)
    # class_label: (n_samples,)
    # gamma: float
    with torch.no_grad():
        p = erm_logits.softmax(-1)
        y_onehot = torch.zeros_like(erm_logits).scatter_(-1, class_label.unsqueeze(-1), 1)
        p_true = (p * y_onehot).sum(-1)
        weights = (-gamma * p_true).exp()
        n_classes = torch.unique(class_label).numel()
        # class balancing
        if balance_classes:
            class_count = []
            for y in range(n_classes):
                class_count.append((class_label == y).sum())
            for y in range(1, n_classes):
                weights[class_label == y] *= class_count[0] / class_count[y]
        weights /= weights.sum()
    return weights

In [None]:
def train(emb_dict, num_epochs, gamma, reg_coeff, lr=1e-2, balance_classes=False, early_stop='wga', plot=True, silent=False, group_uniform=False, opt='sgd'):
    reweighting_embeddings = emb_dict['e'].cuda()
    reweighting_y = emb_dict['y'].cuda()
    reweighting_groups = emb_dict['g'].cuda()
    test_embeddings = emb_dict['test_e'].cuda()
    test_y = emb_dict['test_y'].cuda()
    test_groups = emb_dict['test_g'].cuda()
    val_embeddings = emb_dict['val_e'].cuda()
    val_y = emb_dict['val_y'].cuda()
    val_groups = emb_dict['val_g'].cuda()
    w0 = emb_dict['w0'].cuda()
    b0 = emb_dict['b0'].cuda()
    trian_group_ratios = emb_dict['trian_group_ratios']
    num_groups = torch.unique(test_groups).numel()

    class Model(torch.nn.Module):
        def __init__(self, w0, b0):
            super().__init__()
            self.w0 = w0
            self.b0 = b0
            self.linear = torch.nn.Linear(w0.shape[1], w0.shape[0], bias=True)
            
        def forward(self, x):
            y_old = x @ self.w0.t() + self.b0
            y_new = self.linear(x)
            return y_old + y_new

    model = Model(w0, b0)
    model.cuda()

    criterion = wxe_fn
    if opt == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0., momentum=0.)
    elif opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        raise NotImplementedError

    mean_accs = []
    test_mean_accs = []
    mean_group_accs = []
    val_mean_group_accs = []
    wgas = []
    test_wgas = []
    val_wgas = []
    losses = []
    regs = []
    weights_by_groups = []
    acc_by_groups = []
    val_acc_by_groups = []
    test_accs_by_groups = []
    weighted_acc = []
    val_weighted_acc = []

    initial_logits = reweighting_embeddings @ w0.t() + b0
    initial_val_logits = val_embeddings @ w0.t() + b0
    weights = compute_afr_weights(initial_logits, reweighting_y, gamma, balance_classes)
    val_weights = compute_afr_weights(initial_val_logits, val_y, gamma, balance_classes)
    if group_uniform:
        # uniform total weights per group
        group_counts = torch.bincount(reweighting_groups)
        for g in range(len(group_counts)):
            weights[reweighting_groups == g] = 1 / group_counts[g] / len(group_counts)
    for _ in (pbar := tqdm.tqdm(range(num_epochs), disable=silent)):
        optimizer.zero_grad()
        logits = model(reweighting_embeddings)
        loss = criterion(logits, reweighting_y, weights)
        reg = model.linear.weight.pow(2).sum() + model.linear.bias.pow(2).sum()
        loss += reg_coeff * reg
        # additional infos
        with torch.no_grad():
            val_logits = model(val_embeddings)
            ce = torch.nn.functional.cross_entropy(logits, reweighting_y, reduction='none')
            accs = get_accs(torch.argmax(logits, -1), reweighting_y, reweighting_groups)
            val_accs = get_accs(torch.argmax(model(val_embeddings), -1), val_y, val_groups)
            test_accs = get_accs(torch.argmax(model(test_embeddings), -1), test_y, test_groups)

        loss.backward()
        # clip gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
        optimizer.step()

        weights_by_group = [weights[reweighting_groups == g].sum().item() for g in range(num_groups)]
        weights_by_groups.append(weights_by_group)
        acc_by_groups.append(accs)
        val_acc_by_groups.append(val_accs)
        test_accs_by_groups.append(test_accs)
        mean_acc = (np.array(accs) * np.array(trian_group_ratios)).sum()
        mean_group_accs.append(np.array(accs).mean())
        mean_accs.append(mean_acc)
        val_mean_group_accs.append(np.array(val_accs).mean())
        wga = min(accs)
        test_wga = min(test_accs)
        val_wga = min(val_accs)
        wgas.append(wga)
        val_wgas.append(val_wga)
        test_wgas.append(test_wga)
        # test mean acc is test_accs weighted by trian_group_ratios
        test_mean_acc = (np.array(test_accs) * np.array(trian_group_ratios)).sum()
        test_mean_accs.append(test_mean_acc)
        losses.append(loss.item())
        regs.append(reg.item())
        weighted_acc.append((((torch.argmax(logits, -1) == reweighting_y).float() * weights).sum() / weights.sum()).item())
        val_weighted_acc.append((((torch.argmax(val_logits, -1) == val_y).float() * val_weights).sum() / val_weights.sum()).item())
        pbar.set_description(f"Loss: {loss.item():.5f}, WGA: {wga:.3f}, TWGA: {test_wga:.3f}")

    
    if early_stop == 'wga':
        earlystop_epoch = np.argmax(val_wgas)
        early_stop_metric = np.max(val_wgas)
    elif early_stop == 'wga_last':
        # take best wga on validation set
        # if multiple epochs have the same wga, take the last one
        earlystop_epoch = argmax_last(val_wgas)
        early_stop_metric = np.max(val_wgas)
    elif early_stop == 'wga@max_val_wa':
        earlystop_epoch = np.argmax(val_weighted_acc)
        early_stop_metric = val_wgas[earlystop_epoch]
    elif early_stop == 'mga':
        earlystop_epoch = np.argmax(val_mean_group_accs)
        early_stop_metric = np.max(val_mean_group_accs)
    elif early_stop == 'none':
        earlystop_epoch = num_epochs - 1
        early_stop_metric = val_wgas[-1]
    else:
        raise ValueError(f"Unknown early_stop: {early_stop}")
    earlystop_twga = test_wgas[earlystop_epoch]
    earlystop_tmacc = test_mean_accs[earlystop_epoch]
    final_twga = test_wgas[-1]
    final_tmacc = test_mean_accs[-1]
    if plot:
        # print test wga at best reweighting wga epoch
        print(f"Final test WGA: {final_twga:.3f}")
        print(f"Final test mean acc: {final_tmacc:.3f}")
        print(f"Early stopping test WGA: {earlystop_twga:.3f} at epoch {earlystop_epoch}") 
        print(f"Early stopping test mean acc: {earlystop_tmacc:.3f} at epoch {earlystop_epoch}")

        # 3 horizontal subplots
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))
        ax1.plot(losses, label="CWXE")
        ax1.plot(regs, label="Reg")
        ax1.grid()
        ax1.legend(loc="lower right")
        ax1.set_xlabel("Epochs")
        ax1.set_ylabel("Loss")

        ax2.plot(test_wgas, label="Test")
        ax2.plot(wgas, label="Reweighting")
        ax2.plot(val_wgas, label="Validation")
        ax2.grid()
        ax2.legend(loc="lower right")
        ax2.set_xlabel("Epochs")
        ax2.set_ylabel("WGA")

        ax3.plot(test_mean_accs, label="Test")
        ax3.plot(mean_accs, label="Reweighting")
        ax3.grid()
        ax3.legend(loc="lower right")
        ax3.set_xlabel("Epochs")
        ax3.set_ylabel("Mean Acc")

        ax4.plot(weighted_acc, label="Reweighting")
        ax4.plot(val_weighted_acc, label="Validation")
        ax4.grid()
        ax4.set_xlabel("Epochs")
        ax4.set_ylabel("Weighted Acc")
        ax4.legend(loc="lower right")
        plt.show()

        # 3 horizontal subplots
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))

        ax1.plot(weights_by_groups, label=[f'G{g+1}' for g in range(num_groups)])
        ax1.grid()
        ax1.set_xlabel("Epochs")
        ax1.set_ylabel("Weights")
        ax1.legend(loc="lower right")

        for g in range(num_groups):
            ax2.plot([acc[g] for acc in acc_by_groups], label=f"G{g+1}", color=f"C{g}")
            # ax2.plot([acc[g] for acc in val_acc_by_groups], color=f"C{g}", linestyle="--")
            ax2.plot([acc[g] for acc in test_accs_by_groups], color=f"C{g}", linestyle=":")
        ax2.grid()
        ax2.set_xlabel("Epochs")
        ax2.set_ylabel("Acc")
        ax2.legend(loc="lower right")

    return early_stop_metric, earlystop_twga, earlystop_tmacc

In [None]:
def hyper_sampler(sweep_config, random=True):
    if random:
        while True:
            gamma = np.random.choice(sweep_config["gamma_choices"])
            reg_coeff = np.random.choice(sweep_config["reg_coeff_choices"])
            balance_classes = np.random.choice(sweep_config["balance_classes_choices"])
            yield gamma, reg_coeff, balance_classes
    else:
        for gamma in sweep_config["gamma_choices"]:
            for reg_coeff in sweep_config["reg_coeff_choices"]:
                for balance_classes in sweep_config["balance_classes_choices"]:
                    yield gamma, reg_coeff, balance_classes

def run_training(emb_dict, num_epochs, lr, gamma, reg_coeff, balance_classes, early_stop):
    return train(emb_dict, num_epochs=num_epochs, gamma=gamma, reg_coeff=reg_coeff, lr=lr, balance_classes=balance_classes, early_stop=early_stop, plot=False, silent=True), (gamma, reg_coeff, balance_classes)

In [None]:
def sweep(ckpts, sweep_config, skip_if_done, seed=0):
    best_twgas = []
    best_tmaccs = []
    for ckpt in ckpts:
        emb_dict = extract_embeddings(ckpt)
        sweep_file = os.path.join(os.path.dirname(ckpt), f'seed{ckpt[-4]}_' + sweep_config['sweep_name'] + '.pt')

        if skip_if_done and os.path.exists(sweep_file):
            twga = torch.load(sweep_file)['best_twga']
            print(f"Skipping {ckpt}: TWGA {twga:.3f}")
            best_twgas.append(twga)
            best_tmaccs.append(torch.load(sweep_file)['best_tmacc'])
            continue

        # break
        # find best hyperparameters
        best_vwga = 0
        best_twga = 0
        best_tmacc = 0
        best_hyperparams = None
        top10_choices = []
        
        # fix all seed
        torch.manual_seed(seed)
        np.random.seed(seed)
        import random
        random.seed(seed)
        
        # subsample validation set
        if sweep_config['val_prop'] < 1:
            val_size = int(emb_dict['val_e'].shape[0] * sweep_config['val_prop'])
            val_idx = np.random.choice(emb_dict['val_e'].shape[0], val_size, replace=False)
            emb_dict['val_e'] = emb_dict['val_e'][val_idx]
            emb_dict['val_y'] = emb_dict['val_y'][val_idx]
            emb_dict['val_g'] = emb_dict['val_g'][val_idx]
        print(f'Tuning hypers on {emb_dict["val_e"].shape[0]} val examples')
        sampler = hyper_sampler(sweep_config, random=sweep_config['random'])
        for i in (pbar:= tqdm.tqdm(range(sweep_config['n_trials']))):
            try:
                gamma, reg_coeff, balance_classes = next(sampler)
            except StopIteration:
                break
            (val_wga, test_wga, test_macc, *_), hyperparams = run_training(emb_dict, sweep_config['num_epochs'], sweep_config['lr'], gamma, reg_coeff, balance_classes, sweep_config['early_stop'])
            if val_wga > best_vwga:
                best_vwga = val_wga
                best_twga = test_wga
                best_tmacc = test_macc
                best_hyperparams = hyperparams
                gamma, reg_coeff, balance_classes = best_hyperparams
                pbar.set_description(f"Test WGA: {test_wga:.3f}, Val WGA: {val_wga:.3f}, gamma: {gamma:.3f}, reg_coeff: {reg_coeff:.3f}, balance_classes: {balance_classes}")
                if val_wga == 1:
                    break
            if len(top10_choices) < 10:
                top10_choices.append((test_wga, val_wga, hyperparams))
                top10_choices = sorted(top10_choices, key=lambda x: x[1], reverse=True)
            else:
                top10_choices = sorted(top10_choices, key=lambda x: x[1], reverse=True)
                if val_wga > top10_choices[-1][1]:
                    top10_choices[-1] = (test_wga, val_wga, hyperparams)

        # save sweep results, indicate data_transform and num aug
        torch.save({
            'best_twga': best_twga,
            'best_tmacc': best_tmacc,
            'best_hyperparams': best_hyperparams,
            'top10_choices': top10_choices,
            'sweep_config': sweep_config,
        }, sweep_file)
        best_twgas.append(best_twga)
        best_tmaccs.append(best_tmacc)

    # print mean and std of best WGA
    print(f"Mean best WGA: {np.mean(best_twgas):.3f}, std: {np.std(best_twgas):.3f}")
    # print mean and std of mean acc
    print(f"Mean best mean acc: {np.mean(best_tmaccs):.3f}, std: {np.std(best_tmaccs):.3f}")
    return np.mean(best_twgas), np.std(best_twgas), np.mean(best_tmaccs), np.std(best_tmaccs)

In [None]:
val_prop = 1
skip_if_done = True

sweep_config = {
    "n_trials": 260,
    "num_epochs": 200,
    "lr": 1e-2,
    "gamma_choices": np.logspace(2, 5, 10),
    "reg_coeff_choices": np.linspace(0, 50, 26),
    "balance_classes_choices": [True],
    "early_stop": 'wga',
    "random": False, # if True then random search, else grid search
    "val_prop": val_prop,
    "sweep_name": 'nominal',
    }
sweep(emb_paths, sweep_config, skip_if_done)

### Group Label Efficiency / Down-sampled Validation Set

In [None]:
skip_if_done = True

for seed in [0, 21, 42]:
    for val_prop in [0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1]:
        sweep_config = {
            "n_trials": 260,
            "num_epochs": 500,
            "lr": 1e-2,
            "gamma_choices": np.logspace(2, 5, 10),
            "reg_coeff_choices": np.linspace(0, 50, 26),
            "balance_classes_choices": [True],
            "early_stop": 'wga',
            "random": False, # if True then random search, else grid search
            "val_prop": val_prop,
            'sweep_name': f'val_prop_{val_prop}_seed{seed}'
            }
        sweep(emb_paths, sweep_config, skip_if_done, seed=seed)