In [1]:
import torch
import numpy as np
import tqdm
from wilds import get_dataset
# from wilds.common.data_loaders import get_train_loader
from wilds.common.data_loaders import get_train_loader, get_eval_loader
import transforms
# from transforms import 

from types import SimpleNamespace
from wilds_configs.utils import populate_defaults
from wilds_configs import datasets as dataset_configs
from wilds_configs import model as model_configs

from wilds.datasets.wilds_dataset import WILDSSubset

from wilds_models.initializer import initialize_model
from wilds_algorithms.initializer import infer_d_out
# import configs
# from configs.datasets import dataset_defaults

In [2]:
config = SimpleNamespace(
    algorithm='ERM',
    load_featurizer_only=False,
    pretrained_model_path=None,
    **dataset_configs.dataset_defaults["amazon"],
    )
config.model_kwargs = {}
# config = populate_defaults(config)

In [4]:
dataset = get_dataset(dataset="amazon", download=False,
                      root_dir='/datasets/')
#                       root_dir='/home/pavel/datasets/')

transform = transforms.initialize_transform(
        transform_name='bert',
        config=config,
        dataset=dataset,
        additional_transform_name=None,
        is_training=True)

train_data = dataset.get_subset(
        "train",
        frac=1.,
        transform=transform)

# Get the test set
test_data = dataset.get_subset(
    "test", transform=transform
)
val_data = dataset.get_subset(
    "val", transform=transform
)



In [9]:
# seed = 0
# idx = train_data.indices.copy()
# rng = np.random.default_rng(0)
# rng.shuffle(idx)
# n_train = int((1 - 0.2) * len(idx))
# train_idx = idx[:n_train]
# val_idx = idx[n_train:]

# val_data = WILDSSubset(
#     dataset,
#     indices=val_idx,
#     transform=transform)

In [5]:
model = initialize_model(config=config, d_out=infer_d_out(train_data, config))
# ckpt_dict = torch.load('ckpts/civilcomments_seed_0_epoch_last_model.pth')
ckpt_dict = torch.load('logs/amazon_seed:0_epoch:last_model.pth')
model.load_state_dict({k[len('model.'):]: v for (k, v) in ckpt_dict['algorithm'].items()})

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertClassifier: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertClassifier were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.bias', 'classifier.weight', 'pre_classifier.weight']
You should probably TRAIN t

<All keys matched successfully>

In [6]:
model.cuda();
model.eval();

In [7]:
train_loader = get_train_loader(
        loader="standard",
        dataset=train_data,
        batch_size=16,
        uniform_over_groups=False,
        grouper=None,
        distinct_groups=False,
        n_groups_per_batch=None)

# Prepare the evaluation data loader
test_loader = get_eval_loader("standard", test_data, batch_size=16)

## Eval

In [8]:
all_y_pred, all_y_true, all_metadata = [], [], []
with torch.no_grad():
    for x, y_true, metadata in tqdm.tqdm(test_loader):
        y_pred = model(x.cuda())
        all_y_pred.append(y_pred.cpu())
        all_y_true.append(y_true.cpu())
        all_metadata.append(metadata)
        # break
all_y_pred = torch.cat(all_y_pred, axis=0)
all_y_true = torch.cat(all_y_true, axis=0)
all_metadata = torch.cat(all_metadata, axis=0)

100%|██████████| 6254/6254 [10:10<00:00, 10.25it/s]


In [9]:
print(dataset.eval(torch.argmax(all_y_pred, axis=1), all_y_true, all_metadata)[-1])

Average acc: 0.719
10th percentile acc: 0.533
Worst-group acc: 0.120



## DFR-Val

In [10]:
def get_embeddings(model, loader):
    all_embeddings, all_y_true, all_metadata = [], [], []
    with torch.no_grad():
        for x, y_true, metadata in tqdm.tqdm(loader):
            embeddings = model(x.cuda())
            all_embeddings.append(embeddings.cpu())
            all_y_true.append(y_true.cpu())
            all_metadata.append(metadata)
            # break
    all_embeddings = torch.cat(all_embeddings, axis=0)
    all_y_true = torch.cat(all_y_true, axis=0)
    all_metadata = torch.cat(all_metadata, axis=0)
    return all_embeddings, all_y_true, all_metadata

In [11]:
import sys
sys.path.append('../')
from dfr import dfr_tune, dfr_run, dfr_tune_and_run, dfr_predict

%load_ext autoreload
%autoreload 2

In [12]:
val_loader = get_eval_loader("standard", val_data, batch_size=16)
test_loader = get_eval_loader("standard", test_data, batch_size=16)

In [13]:
model.classifier = torch.nn.Identity(model.classifier.in_features)
model.eval();

In [14]:
val_embeddings, val_y_true, val_metadata = get_embeddings(model, val_loader)
test_embeddings, test_y_true, test_metadata = get_embeddings(model, test_loader)

100%|██████████| 6254/6254 [10:09<00:00, 10.27it/s]
100%|██████████| 6254/6254 [10:12<00:00, 10.21it/s]


### All identity

In [17]:
val_data.metadata_fields

['user', 'product', 'category', 'year', 'y', 'from_source_domain']

In [22]:
torch.bincount(val_y_true)

tensor([ 1413,  2886,  9315, 27908, 58528])

In [16]:
val_metadata

tensor([[ 1573,     0,     0,    18,     4,     0],
        [ 1573,     0,     0,    17,     4,     0],
        [ 2300, 62111,     0,    21,     4,     0],
        ...,
        [ 2011, 62120,    23,    20,     3,     0],
        [ 2491, 62099,    23,    21,     4,     0],
        [ 1492, 62234,    23,    22,     4,     0]])

In [23]:
# val_spurious = val_metadata[:, :8]
val_groups = val_y_true

# test_spurious = test_metadata[:, :8]
test_groups = test_y_true

In [24]:
logreg, scaler = dfr_tune_and_run(val_embeddings, val_y_true, val_groups, verbose=True)

1.0: [0.64637681 0.49566955 0.51955911 0.50244534 0.70095212]
0.7: [0.66956522 0.50566289 0.52280095 0.52574799 0.71920964]
0.3: [0.69855072 0.52298468 0.52863627 0.53984465 0.75756748]
0.1: [0.70724638 0.54230513 0.57380592 0.55221519 0.75551991]
0.07: [0.71304348 0.5236509  0.56559326 0.55336594 0.79445791]
0.03: [0.74492754 0.51165889 0.5612708  0.58119965 0.77169573]
0.01: [0.71304348 0.49900067 0.6001729  0.56300345 0.78036379]
Training model 0/10, group counts: [1413 1413 1413 1413 1413]
Training model 1/10, group counts: [1413 1413 1413 1413 1413]
Training model 2/10, group counts: [1413 1413 1413 1413 1413]
Training model 3/10, group counts: [1413 1413 1413 1413 1413]
Training model 4/10, group counts: [1413 1413 1413 1413 1413]
Training model 5/10, group counts: [1413 1413 1413 1413 1413]
Training model 6/10, group counts: [1413 1413 1413 1413 1413]
Training model 7/10, group counts: [1413 1413 1413 1413 1413]
Training model 8/10, group counts: [1413 1413 1413 1413 1413]
Train

In [25]:
val_preds = torch.from_numpy(dfr_predict(val_embeddings, logreg, scaler))
test_preds = torch.from_numpy(dfr_predict(test_embeddings, logreg, scaler))

In [26]:
print(dataset.eval(test_preds, test_y_true, test_metadata)[-1])

Average acc: 0.682
10th percentile acc: 0.507
Worst-group acc: 0.147



In [27]:
[(test_preds == test_y_true)[test_groups == g].float().mean() for g in range(4)]

[tensor(0.6835), tensor(0.5342), tensor(0.5681), tensor(0.5678)]

### Per-identity

In [28]:
identity_counts = val_metadata[:, :8].sum(axis=0)

In [29]:
identity_ordering = torch.argsort(identity_counts)
identity_counts[identity_ordering]

tensor([1293, 1677, 1993, 2755, 3395, 5182, 5980, 7199])

In [30]:
[dataset.metadata_fields[i] for i in identity_ordering]

['other_religions',
 'LGBTQ',
 'black',
 'muslim',
 'white',
 'christian',
 'male',
 'female']

In [31]:
val_metadata_ordered = val_metadata[:, :8].T[identity_ordering].T

In [32]:
groups = torch.argmax(torch.cat(
        [torch.zeros((len(val_metadata_ordered),))[:, None], val_metadata_ordered], axis=1), axis=1)
groups = groups * 2 + val_y_true
torch.bincount(groups)

tensor([29762,  2576,  1092,   201,  1159,   466,  1257,   584,  1843,   460,
         1805,   624,  3750,   258,  3703,   505,  3285,   478])

In [33]:
logreg, scaler = dfr_tune_and_run(val_embeddings, val_y_true, groups, verbose=True)

1.0: [0.87688172 0.76870229 0.78531073 0.67       0.70630631 0.69294606
 0.64548495 0.70989761 0.72777778 0.69135802 0.74215247 0.72360248
 0.88296761 0.68644068 0.81550218 0.75875486 0.82639715 0.77118644]
0.7: [0.87627688 0.81145038 0.72693032 0.78       0.67747748 0.67219917
 0.71070234 0.67918089 0.73666667 0.74485597 0.76457399 0.72670807
 0.88087774 0.71186441 0.84224891 0.77042802 0.84423306 0.76271186]
0.3: [0.8874328  0.8389313  0.81544256 0.68       0.73693694 0.70124481
 0.71404682 0.73037543 0.72888889 0.72427984 0.70852018 0.77639752
 0.88035528 0.74576271 0.84497817 0.79766537 0.84661118 0.76694915]
0.1: [0.90658602 0.85801527 0.80225989 0.75       0.73513514 0.77593361
 0.68561873 0.78498294 0.73       0.79835391 0.73206278 0.81055901
 0.90125392 0.74576271 0.85480349 0.82879377 0.86444709 0.81779661]
0.07: [0.89603495 0.87022901 0.7740113  0.78       0.72792793 0.76763485
 0.68729097 0.77474403 0.71555556 0.8436214  0.66928251 0.86956522
 0.89341693 0.80508475 0.8329694

In [34]:
val_preds = torch.from_numpy(dfr_predict(val_embeddings, logreg, scaler))
test_preds = torch.from_numpy(dfr_predict(test_embeddings, logreg, scaler))

In [35]:
print(dataset.eval(test_preds, test_y_true, test_metadata)[-1])

Average acc: 0.872
  male                   acc on non_toxic: 0.834 (n =  12092)    acc on toxic: 0.801 (n =   2203) 
  female                 acc on non_toxic: 0.848 (n =  14179)    acc on toxic: 0.800 (n =   2270) 
  LGBTQ                  acc on non_toxic: 0.719 (n =   3210)    acc on toxic: 0.760 (n =   1216) 
  christian              acc on non_toxic: 0.878 (n =  12101)    acc on toxic: 0.781 (n =   1260) 
  muslim                 acc on non_toxic: 0.758 (n =   5355)    acc on toxic: 0.757 (n =   1627) 
  other_religions        acc on non_toxic: 0.823 (n =   2980)    acc on toxic: 0.771 (n =    520) 
  black                  acc on non_toxic: 0.722 (n =   3335)    acc on toxic: 0.774 (n =   1537) 
  white                  acc on non_toxic: 0.717 (n =   5723)    acc on toxic: 0.793 (n =   2246) 
Worst-group acc: 0.717



In [36]:
print(dataset.eval(val_preds, val_y_true, val_metadata)[-1])

Average acc: 0.872
  male                   acc on non_toxic: 0.830 (n =   5076)    acc on toxic: 0.798 (n =    904) 
  female                 acc on non_toxic: 0.841 (n =   6208)    acc on toxic: 0.818 (n =    991) 
  LGBTQ                  acc on non_toxic: 0.746 (n =   1199)    acc on toxic: 0.789 (n =    478) 
  christian              acc on non_toxic: 0.884 (n =   4709)    acc on toxic: 0.763 (n =    473) 
  muslim                 acc on non_toxic: 0.761 (n =   2160)    acc on toxic: 0.797 (n =    595) 
  other_religions        acc on non_toxic: 0.836 (n =   1092)    acc on toxic: 0.781 (n =    201) 
  black                  acc on non_toxic: 0.728 (n =   1349)    acc on toxic: 0.758 (n =    644) 
  white                  acc on non_toxic: 0.728 (n =   2446)    acc on toxic: 0.800 (n =    949) 
Worst-group acc: 0.728



In [37]:
[(val_preds == val_y_true)[groups == g].float().mean() for g in range(8)]

[tensor(0.9103),
 tensor(0.8560),
 tensor(0.8361),
 tensor(0.7811),
 tensor(0.7463),
 tensor(0.7876),
 tensor(0.7367),
 tensor(0.7603)]

## Remove subset

In [181]:
train_data = dataset.get_subset(
        "train",
        frac=1.,
        transform=transform)

In [184]:
train_data.indices.size

269038

In [189]:
train_data.metadata_array[:, :8].float().mean(axis=0)

tensor([0.1108, 0.1347, 0.0313, 0.0994, 0.0519, 0.0243, 0.0368, 0.0621])