In [None]:
%env OMP_NUM_THREADS=4
%env CUDA_VISIBLE_DEVICES=1
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../")
sys.path.append("../wilds_exps_utils/")

import torch
import numpy as np
import tqdm
import pickle
import copy
from types import SimpleNamespace
from wilds import get_dataset
from wilds.common.data_loaders import get_eval_loader
from wilds_configs import datasets as dataset_configs
from wilds.datasets.wilds_dataset import WILDSSubset
from wilds_models.initializer import initialize_model
import wilds_transforms as transforms

# from wilds_algorithms.initializer import infer_d_out

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline

## Data and model

In [None]:
import argparse
from types import SimpleNamespace
import tqdm
import torch

from transformers import BertConfig, BertForSequenceClassification

import sys
gdro_dir = '/home/pavel_i/projects/ssl_robustness/group_DRO'
sys.path.append(gdro_dir)
from gdro_data.data import prepare_data

In [None]:
NUM_GROUPS = 6

gdro_config = SimpleNamespace(
    dataset='MultiNLI',
    shift_type='confounder',
    root_dir='/data/users/pavel_i/datasets/multinli',
    augment_data=False,
    gamma=0.1,
    batch_size=128,
    target_name='gold_label_random',
    confounder_names=['sentence2_has_negation',],
    model='bert',
    fraction=1.,
)
reweighting_data, val_data, test_data = prepare_data(gdro_config, train=True)
loader_kwargs = {'num_workers':4, 'pin_memory':True}

# reweighting_data = val_data
val_loader = val_data.get_loader(
        train=False, reweight_groups=None, **loader_kwargs)
test_loader = test_data.get_loader(
        train=False, reweight_groups=None, **loader_kwargs)

## Extract embeddings

In [None]:
def get_embeddings_predictions(feature_extractor, classifier, loader):
    all_embeddings, all_predictions, all_y_true, all_metadata = [], [], [], []
#     i = 0
    with torch.no_grad():
        for x, y_true, metadata in tqdm.tqdm(loader):
            input_ids = x[:, :, 0].cuda()
            input_masks = x[:, :, 1].cuda()
            segment_ids = x[:, :, 2].cuda()
            embeddings = feature_extractor(
                    input_ids=input_ids,
                    attention_mask=input_masks,
                    token_type_ids=segment_ids).logits
            predictions = torch.argmax(classifier(embeddings), axis=1)
            all_embeddings.append(embeddings.cpu())
            all_predictions.append(predictions.cpu())
            all_y_true.append(y_true.cpu())
            all_metadata.append(metadata)
#             i += 1
#             if i > 20:
#                 break
    all_embeddings = torch.cat(all_embeddings, axis=0)
    all_predictions = torch.cat(all_predictions, axis=0)
    all_y_true = torch.cat(all_y_true, axis=0)
    all_metadata = torch.cat(all_metadata, axis=0)
    return all_embeddings, all_predictions, all_y_true, all_metadata

def save_emb(ckpt_path, seed, save_path):
    reweighting_data, val_data, test_data = prepare_data(gdro_config, train=True)
    loader_kwargs = {'num_workers':4, 'pin_memory':True}

    # reweighting_data = val_data
    val_loader = val_data.get_loader(
            train=False, reweight_groups=None, **loader_kwargs)
    test_loader = test_data.get_loader(
            train=False, reweight_groups=None, **loader_kwargs)
    dfr_reweighting_seed = seed
    dfr_reweighting_frac = 0.2

    print(f'Dropping DFR reweighting data, seed {dfr_reweighting_seed}')

    idx = reweighting_data.dataset.indices.copy()
    rng = np.random.default_rng(dfr_reweighting_seed)
    rng.shuffle(idx)
    n_train = int((1 - dfr_reweighting_frac) * len(idx))
    reweighting_idx = idx[n_train:]

    print(f'Original dataset size: {len(reweighting_data.dataset.indices)}')
    reweighting_data.dataset = torch.utils.data.dataset.Subset(
        reweighting_data.dataset.dataset,
        indices=reweighting_idx)
    print(f'New dataset size: {len(reweighting_data.dataset.indices)}')

    reweighting_loader = reweighting_data.get_loader(
            train=False, reweight_groups=None, **loader_kwargs)
    model = torch.load(ckpt_path)
    model.cuda()
    model.eval()

    classifier = model.classifier
    model.classifier = torch.nn.Identity(classifier.in_features)

    feature_extractor, classifier = model, classifier
    reweighting_embeddings, reweighting_predictions, reweighting_y, reweighting_metadata = get_embeddings_predictions(
            feature_extractor, classifier, reweighting_loader)
    val_embeddings, val_predictions, val_y, val_metadata = get_embeddings_predictions(
            feature_extractor, classifier, val_loader)
    test_embeddings, test_predictions, test_y, test_metadata = get_embeddings_predictions(
            feature_extractor, classifier, test_loader)
    torch.save(
        dict(
            e=reweighting_embeddings, y=reweighting_y, pred=reweighting_predictions, m=reweighting_metadata,
            test_e=test_embeddings, test_pred=test_predictions, test_y=test_y, test_m=test_metadata,
            val_e=val_embeddings, val_pred=val_predictions, val_y=val_y, val_m=val_metadata,
            w0 = classifier.weight.cpu(),
            b0 = classifier.bias.cpu()
        ),
        save_path
    )

In [None]:
ckpt_dir = '/home/shikai_q/multinli_ckpts/multinli' # TODO: change to your directory of checkpoints
ckpts = ['erm_2', 'erm_1', 'erm_0', 'erm_dfrdrop0', 'erm_dfrdrop_1', 'erm_dfrdrop2'] # TODO: change to your checkpoints
seeds = [0, 0, 0, 0, 1, 2]
for ckpt, seed in zip(ckpts, seeds):
    ckpt_path = f'{ckpt_dir}/{ckpt}/last_model.pth'
    save_path = f'emb/multinli/{ckpt}.pt'
    save_emb(ckpt_path, seed, save_path)

### Load from checkpoint

In [None]:
ckpt = torch.load("emb/multinli/erm_0.pt")
reweighting_embeddings = ckpt["e"]
reweighting_y = ckpt["y"]
reweighting_predictions = ckpt["pred"]
reweighting_metadata = ckpt["m"]
test_embeddings = ckpt["test_e"]
test_predictions = ckpt["test_pred"]
test_y = ckpt["test_y"]
test_metadata = ckpt["test_m"]
val_embeddings = ckpt["val_e"]
val_predictions = ckpt["val_pred"]
val_y = ckpt["val_y"]
val_metadata = ckpt["val_m"]