In [None]:
from copy import deepcopy
from pprint import pprint

from baal import ModelWrapper
from active_fairness import utils
from active_fairness.metrics import FairnessMetric
from active_fairness.utils import get_datasets

'''
REPAIR resampling of datasets minimizing representation bias
Returns a weight in [0, 1] for each example,

Code from: https://github.com/JerryYLi/Dataset-REPAIR
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import sklearn.metrics as skm
from tqdm import tqdm

def repair(loader, attribute, feat_dim, epochs, lr, lr_w):
    # class counts
    labels = torch.tensor([data[1]['target'] for data in loader.dataset]).long().cuda()
    n_cls = int(labels.max()) + 1
    cls_idx = torch.stack([labels == c for c in range(n_cls)]).float().cuda()

    # create models
    model = nn.Linear(feat_dim, n_cls).cuda()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    weight_param = nn.Parameter(torch.zeros(len(loader.dataset)).cuda())
    optimizer_w = optim.SGD([weight_param], lr=lr_w)

    # training
    with tqdm(range(1, epochs + 1)) as pbar:
        for _ in pbar:
            losses = []
            corrects = 0
            for x, target, idx in loader:
                y = y['target'].cuda()
                sensible = y[attribute].cuda()

                # class probabilities
                w = torch.sigmoid(weight_param)
                z = w[idx] / w.mean()
                cls_w = cls_idx @ w
                q = cls_w / cls_w.sum()

                # linear classifier
                out = model(sensible)
                loss_vec = F.cross_entropy(out, y, reduction='none')
                loss = (loss_vec * z).mean()
                losses.append(loss.item())
                corrects += out.max(1)[1].eq(y).sum().item()
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

                # class weights
                optimizer_w.zero_grad()
                entropy = -(q[y].log() * z).mean()
                loss_w = 1 - loss / entropy
                loss_w.backward()
                optimizer_w.step()

            loss = sum(losses) / len(losses)
            acc = 100 * corrects / len(loader.dataset)
            pbar.set_postfix(loss='%.3f' % loss, acc='%.2f%%' % acc)

    # class probabilities & bias
    with torch.no_grad():
        w = torch.sigmoid(weight_param)
        cls_w = cls_idx @ w
        q = cls_w / cls_w.sum()
        rnd_loss = -(q * q.log()).sum().item()
        bias = 1 - loss / rnd_loss

    print('Accuracy = {:.2f}%, Loss = {:.3f}, Rnd Loss = {:.3f}, Bias = {:.3f}'.format(acc, loss, rnd_loss, bias))
    return w, q, cls_idx, cls_w, bias

In [None]:
from torch.utils.data import Dataset,  DataLoader, Subset

class IndexedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return (*self.dataset[idx], idx)

def get_keep_idx(w, cls_idx, threshold = 0.5, keep_ratio=0.5, mode='threshold'):
    # strategy 1: fixed threshold
    if mode == 'threshold':
        keep_idx = (w > threshold).nonzero().cpu().squeeze()

    # strategy 2: top k% examples
    elif mode == 'rank':
        keep_examples = round(keep_ratio * len(w))
        keep_idx = w.sort(descending=True)[1][:keep_examples].cpu()

    # strategy 3: top k% examples each class
    elif mode == 'cls_rank':
        keep_idx_list = []
        for c in range(10):
            c_idx = cls_idx[c].nonzero().squeeze()
            keep_examples = round(keep_ratio * len(c_idx))
            sort_idx = w[c_idx].sort(descending=True)[1]
            keep_idx_list.append(c_idx[sort_idx][:keep_examples])
        keep_idx = torch.cat(keep_idx_list).cpu()

    # strategy 4: sampling according to weights
    elif mode == 'sample':
        keep_idx = torch.bernoulli(w).nonzero().cpu().squeeze()

    # strategy 5: random uniform sampling
    elif mode == 'uniform':
        keep_examples = round(keep_ratio * len(w))
        keep_idx = torch.randperm(len(w))[:keep_examples]

    return keep_idx

In [None]:
class MyCrit(nn.Module):
    def __init__(self, crit):
        super().__init__()
        self.crit = crit

    def forward(self, input, target):
        return self.crit(input, target['target'])

In [None]:
attribute = 'color'
dataset_path = '/datasets/fairface_like_dataset_50000.pkl'
active_set, val_set, test_set = get_datasets(dataset_path, 10, attribute, 'char')
train_dataset = active_set._dataset
num_classes = len(test_set._all_target)
num_group = len(test_set._all_attribute)

criterion = MyCrit(nn.CrossEntropyLoss())
model = utils.vgg16(pretrained=True, num_classes=num_classes)

model.cuda()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9,
                      weight_decay=5e-4)
model = ModelWrapper(model, criterion)
initial_weights = deepcopy(model.state_dict())

model.add_metric('fair_recall',
                     lambda: FairnessMetric(skm.recall_score, 'recall', average='micro',
                                            attribute=attribute))
model.add_metric('fair_accuracy',
                 lambda: FairnessMetric(skm.accuracy_score, 'accuracy',
                                        attribute=attribute))
model.add_metric('fair_precision',
                 lambda: FairnessMetric(skm.precision_score, 'precision', average='micro',
                                        attribute=attribute))
model.add_metric('fair_f1',
                 lambda: FairnessMetric(skm.f1_score, 'f1', average='micro',
                                        attribute=attribute))

In [None]:
# Raw dataset

model.load_state_dict(initial_weights)
model.train_on_dataset(train_dataset, optimizer, 64, epoch=10, use_cuda=True, workers=12)

# Validation!
model.test_on_dataset(test_set, batch_size=32, use_cuda=True, workers=6, average_predictions=1)
fair_logs = {}
for met in ['fair_recall', 'fair_accuracy', 'fair_precision', 'fair_f1']:
    fair_test = model.metrics[f'test_{met}'].value
    fair_test = {'test_' + k: v for k, v in fair_test.items()}
    fair_train = model.metrics[f'train_{met}'].value
    fair_train = {'train_' + k: v for k, v in fair_train.items()}
    fair_logs.update(fair_test)
    fair_logs.update(fair_train)

metrics = model.metrics
# Send logs
train_loss = metrics['train_loss'].value
val_loss = metrics['test_loss'].value
fair_logs.update({"test_loss": val_loss, "train_loss": train_loss})

pprint(fair_logs)

In [None]:
# Resampling!

repair_dataset = IndexedDataset(train_dataset)
train_loader = DataLoader(repair_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
w, q, cls_idx, cls_w, bias = repair(train_loader, attribute, 1, epochs=200, lr=1e-3, lr_w=10)


# perform resampling
sampling = 'threshold' # One of ['threshold', 'rank', 'cls_rank', 'sample', 'uniform']
print('Resampling strategy:', sampling)
keep_idx = get_keep_idx(w, cls_idx, mode=sampling)
keep_idx_train = keep_idx
print('Keep examples: {}/{} ({:.2%})'.format(len(keep_idx), len(w), len(keep_idx) / len(w)))

In [None]:
# Train with resampling
resampled_train_set = Subset(train_dataset, keep_idx_train)
model.load_state_dict(initial_weights)
model.train_on_dataset(resampled_train_set, optimizer, 64,
                       epoch=10, use_cuda=True, workers=12)

model.test_on_dataset(test_set, batch_size=32, use_cuda=True, workers=6,
                      average_predictions=1)
fair_logs = {}
for met in ['fair_recall', 'fair_accuracy', 'fair_precision', 'fair_f1']:
    fair_test = model.metrics[f'test_{met}'].value
    fair_test = {'test_' + k: v for k, v in fair_test.items()}
    fair_train = model.metrics[f'train_{met}'].value
    fair_train = {'train_' + k: v for k, v in fair_train.items()}
    fair_logs.update(fair_test)
    fair_logs.update(fair_train)

metrics = model.metrics
# Send logs
train_loss = metrics['train_loss'].value
val_loss = metrics['test_loss'].value
fair_logs.update({"test_loss": val_loss, "train_loss": train_loss})

pprint(fair_logs)