In [None]:
import json
import pickle
import numpy as np
import sklearn.metrics
import scipy.special
MALE = 1
FEMALE = 0
NO_DOMAIN = 2
sigmoid = scipy.special.expit

In [None]:
class optimize_potentials_given_known_domain():
    """Given a set of network outputs on a test set, updates potentials to reduce bias.
    
    Args:
      input_potentials: A float64 numpy array with shape (test_set_size, class_count).
        Contains the network outputs on each test example, for a single class prediction
        with known domain.
      gt_domain: An int32 numpy array with shape (test_set_size,). The ground truth
        domain. Used to do the optimization.
      gt_class: An int32 numpy array with shape (test_set_size,). The ground truth
        class label. Used only to compute accuracy.
      training_set_frequencies: A float32 numpy array with shape (test_set_size, class_count).
        The relative frequencies of the training set classes given each example's known domain.
        
    Returns:
      output_potentials: A float64 numpy array with shape (test_set_size, class_count).
        Contains the optimized network potentials that are the result of the optimization.
      output_predictions: An int32 numpy array with shape (test_set_size,). Contains
        the final network predictions (i.e. just an argmax over the potentials).
    """
    def __init__(self, gt_labels, gt_domain, lr, margin, apply_prior_shift,
                 inputs_are_activations, method_name, target_domain_ratios,
                 domain_labels, inference_thresholds, training_set_targets, verbosity=2,
                 total_epochs=100):
        self.test_set_size = gt_labels.shape[0]
        self.class_count = gt_labels.shape[1] // 2
        self.gt_labels = gt_labels
        self.gt_domain = gt_domain
        self.lr = lr
        self.margin = margin
        self.apply_prior_shift = apply_prior_shift
        self.inputs_are_activations = inputs_are_activations
        self.method_name = method_name
        self.target_domain_ratios = target_domain_ratios
        self.domain_labels = domain_labels
        self.verbosity = verbosity
        self.inference_thresholds = inference_thresholds
        self.training_set_targets = training_set_targets
        gt_labels_bool = gt_labels.astype(np.bool)
        self.gt_class = (gt_labels_bool[:, :self.class_count] | gt_labels_bool[:, self.class_count:]).astype(np.int32)
        self.total_epochs = total_epochs
    
    def multiclass_probabilities(self, potentials):
        """Returns the probability of each class from the network activation."""
        if self.inputs_are_activations:
            return sigmoid(potentials)
        return potentials

    def compute_sample_weights(self):
        n_m = np.sum(self.gt_labels[:, :self.class_count], axis=0).astype(np.float64)
        n_w = np.sum(self.gt_labels[:, self.class_count:], axis=0).astype(np.float64)
        male_class_weights = (n_m + n_w) / (2.0*n_m)
        female_class_weights = (n_m + n_w) / (2.0*n_w)
        sample_weights = np.zeros_like(self.gt_labels, dtype=np.float64)
        sample_weights[:, :self.class_count] = np.tile(np.reshape(male_class_weights,
                                                                  (1, self.class_count)),
                                                       (self.gt_labels.shape[0], 1))
        sample_weights[:, self.class_count:] = np.tile(np.reshape(female_class_weights,
                                                                  (1, self.class_count)),
                                                       (self.gt_labels.shape[0], 1))
        sample_weights[self.gt_labels == 0] = 1.0
        collapsed = np.ones((self.test_set_size, self.class_count), dtype=np.float64)
        collapsed[self.gt_domain == MALE, :] = sample_weights[self.gt_domain == MALE, :self.class_count]
        collapsed[self.gt_domain == FEMALE, :] = sample_weights[self.gt_domain == FEMALE, self.class_count:]
        sample_weights = collapsed
        return sample_weights
    
    def compute_mAP(self, potentials):
        probs = self.multiclass_probabilities(potentials)
        sample_weights = self.compute_sample_weights()
        APs = [sklearn.metrics.average_precision_score(self.gt_class[:, i], probs[:, i],
                                                       sample_weight=sample_weights[:, i])
               for i in range(self.class_count)]
        return 100.0 * np.mean(APs)
    
    def compute_accuracy(self, potentials):
        decisions = self.multiclass_inference(potentials)
        return 100.0 * np.mean(decisions == self.gt_class)
    
    def compute_bias(self, decisions):
        if decisions.shape[1] == self.class_count:
            domain_decisions = np.zeros((self.test_set_size, 2*self.class_count), dtype=np.float64)
            for i in range(decisions.shape[0]):
                g = int(self.gt_domain[i])
                if g == NO_DOMAIN:
                    continue
                domain_decisions[i, g*self.class_count:(g+1)*self.class_count] = decisions[i, :]
            decisions = domain_decisions
        class_domain_counts = np.sum(decisions, axis=0)
        class_counts = class_domain_counts[:self.class_count] + class_domain_counts[self.class_count:]
        class_counts = np.tile(class_counts, 2)
        domain_weights = np.divide(class_domain_counts, class_counts,
                                  out=np.zeros_like(class_counts), where=(class_counts != 0.0))
        return domain_weights

    def compute_bias_amplification(self, potentials):
        decisions = self.multiclass_inference(potentials).astype(np.float64)
        test_bias = self.compute_bias(decisions)
        train_bias = self.compute_bias(self.training_set_targets)
        amplified_bias = np.abs(test_bias - train_bias)
        mAB = np.mean(amplified_bias)
        return mAB
    
    def multiclass_count_domain_incidence_from_gt(self, predictions):
        male_gt_rows = predictions[self.gt_domain == MALE, :]
        female_gt_rows = predictions[self.gt_domain == FEMALE, :]
        male_gt_count = np.sum(male_gt_rows, axis=0)
        female_gt_count = np.sum(female_gt_rows, axis=0)
        count_per_class = np.stack([male_gt_count, female_gt_count], axis=1).astype(np.float64)
        return count_per_class
    
    def multiclass_inference(self, potentials):
        """Converts the potentials into decisions."""
        probs = self.multiclass_probabilities(potentials)
        thresholds = np.tile(self.inference_thresholds, (self.test_set_size, 1))
        decisions = (probs > thresholds).astype(np.int32)
        return decisions

    def generate_constraints(self):
        constraints = np.zeros((self.class_count, 2, 2))
        constraints[:, 0, 0] = self.target_domain_ratios - 1 - self.margin
        constraints[:, 0, 1] = self.target_domain_ratios - self.margin
        constraints[:, 1, 0] = 1 - (self.margin + self.target_domain_ratios)
        constraints[:, 1, 1] = -(self.margin + self.target_domain_ratios)
        return constraints

    def optimize(self, input_potentials):
        if self.verbosity >= 1:
            initial_mAP = self.compute_mAP(input_potentials)
            initial_bias = np.mean(np.abs(0.5 - self.compute_bias(input_potentials)))
            initial_bias_amplification = self.compute_bias_amplification(input_potentials)
            name_in = ('%s, before optimization' % self.method_name).ljust(85)
            print('%s mAP. %0.2f%%. Bias %0.3f' % (name_in, initial_mAP, initial_bias))
            print('\t bias amplification: %0.4f. ' % initial_bias_amplification)
        if self.verbosity >= 2:
            initial_mAP = self.compute_mAP(input_potentials)
            print('Pre optimization mAP: %0.2f%%' % initial_mAP)
        lambdas = np.zeros((self.class_count, 2), dtype=np.float64)
        current_potentials = input_potentials.copy()
        constraints = self.generate_constraints()
        initial_predictions = self.multiclass_inference(input_potentials)
        for epoch in range(self.total_epochs):
            violated_constraint_count = 0
            error = np.zeros((self.class_count, 2), dtype=np.float64)

            predictions = self.multiclass_inference(current_potentials)
            count_per_class = self.multiclass_count_domain_incidence_from_gt(predictions)
            count_per_class = np.reshape(count_per_class, [self.class_count, 1, 2])
            constraint_delta = np.sum(constraints * count_per_class, axis=2)
            lambdas += self.lr * constraint_delta
            error += constraint_delta
            count_per_class = np.reshape(count_per_class, [self.class_count, 2])

            lambdas = np.maximum(lambdas, 0)
            violated_constraint_count = np.count_nonzero(error > 0)
            current_potentials = input_potentials.copy()

            prediction_mask = predictions.astype(np.float64)
            for example_idx in range(self.test_set_size):
                domain_idx = int(self.gt_domain[example_idx])
                if domain_idx == NO_DOMAIN:
                    continue # This example has no domain present, it can't affect the constraints.
                current_potentials[example_idx, :] -= prediction_mask[example_idx, :] * lambdas[:, 0] * constraints[:, 0, domain_idx]
                current_potentials[example_idx, :] -= prediction_mask[example_idx, :] * lambdas[:, 1] * constraints[:, 1, domain_idx]

            if (epoch % 10 == 0 or epoch == self.total_epochs-1) and self.verbosity >= 2:
                print('Finished %i-th Epoch.' % epoch)
                mean_bias = np.mean(np.abs(0.5 - self.compute_bias(current_potentials)))
                print('\tMean Bias: %0.4f' % mean_bias)
                constraint_count = len(constraints)
                print('\tConstraint Satisfaction: %i/%i' % (constraint_count-violated_constraint_count, constraint_count))
                current_mAP = self.compute_mAP(current_potentials)
                current_class_acc = self.compute_accuracy(current_potentials)
                total_flipped_predictions = np.count_nonzero(self.multiclass_inference(current_potentials) != initial_predictions)
                print('\tTotal Flipped Predictions: %i' % total_flipped_predictions)
                print('\tCurrent mAP: %0.2f%%' % current_mAP)
                print('\tCurrent Class Acc: %0.2f%%' % current_class_acc)

            if violated_constraint_count == 0:
                break
        if self.verbosity >= 1:
            final_mAP = self.compute_mAP(current_potentials)
            final_bias = np.mean(np.abs(0.5 - self.compute_bias(current_potentials)))
            final_bias_amplification = self.compute_bias_amplification(current_potentials)
            name_in = ('%s, after optimization' % self.method_name).ljust(85)
            print('%s mAP. %0.2f%%. Bias %0.3f' % (name_in, final_mAP, final_bias))
            print('\t bias amplification: %0.4f. ' % final_bias_amplification)
            print('mAP change %f, bias change %f, bias amplication change %f' 
                  % (final_mAP-initial_mAP, final_bias-initial_bias, final_bias_amplification-initial_bias_amplification))
        return current_potentials, self.multiclass_inference(current_potentials)

def compute_thresh_on_dev(dev_potentials, dev_targets, reduction_method='sum'):
    thresholds = [0.1 * (i+1) for i in range(9)]
        
    class_count = dev_potentials.shape[1] // 2
    output_threshes = np.zeros((class_count,), dtype=np.float64)
    dev_potentials = sigmoid(dev_potentials)
        
    if reduction_method == 'sum':
        dev_potentials = (dev_potentials[:, :class_count] + dev_potentials[:, class_count:]) / 2.0  
    elif reduction_method == 'condition':
        gt_domain = compute_gt_domain_from_labels(dev_targets)
        selected_outputs = []
        for i in range(dev_potentials.shape[0]):
            cur_domain = gt_domain[i]
            if cur_domain == MALE:
                selected_outputs.append(dev_potentials[i, :class_count])
            elif cur_domain == FEMALE:
                selected_outputs.append(dev_potentials[i, class_count:])
            elif cur_domain == NO_DOMAIN:
                # We can't condition because there is no domain to condition on. So just average.
                current_outputs = dev_potentials[i, :class_count] + dev_potentials[i, class_count:]
                selected_outputs.append(current_outputs / 2.0)
        dev_potentials = np.stack(selected_outputs, axis=0)

    dev_targets = dev_targets.astype(np.bool) 
    dev_targets = dev_targets[:, :class_count] | dev_targets[:, class_count:]
    probs = dev_potentials
    for ci in range(class_count):
        output_threshes[ci] = max_fscore(dev_targets[:, ci], probs[:, ci])[1]
    return output_threshes

def max_fscore(targets,scores):
    # sorting the scores and the targets
    ssn = zip(scores,range(len(scores)))
    ssn = sorted(ssn,reverse=True)
    ts = [targets[ssn[i][1]] for i in range(len(ssn))]

    num_pos = np.sum(ts)
    true_pos = np.cumsum(ts)
    # f-score at each threshold
    # 2/(1/recall+1/precision) = 2/(npos/tp+n/tp)=2*tp/(npos+n)
    f = 2*true_pos/(num_pos + range(len(ts))+1)
    ii=np.argmax(f)

    ff = f[ii]
    thr = ssn[ii][0]
    fcheck = f_score(targets,scores,thr)
    # check:
    if not ff == fcheck:
        print('Check failed')
        assert(False)
    # return the f-score and the corresponding threshold (>= )
    return ff,thr

def f_score(targets,scores,thr):
    num_pos = np.sum(targets)
    true_pos = np.sum(np.logical_and(scores >= thr,targets == 1))
    num_pred = np.sum([scores >= thr])
    return 2*true_pos/(num_pos+num_pred)

def compute_gt_domain_from_labels(labels):
    test_set_size, twice_class_count = labels.shape
    class_count = twice_class_count // 2
    gt_domain = np.zeros((test_set_size,), dtype=np.float64)
    for i in range(test_set_size):
        has_male_output = np.any(labels[i, :class_count])
        has_female_output = np.any(labels[i, class_count:])
        if has_male_output and not has_female_output:
            gt_domain[i] = MALE
        if has_female_output and not has_male_output:
            gt_domain[i] = FEMALE
        if not has_female_output and not has_male_output:
            gt_domain[i] = NO_DOMAIN
        assert not (has_male_output and has_female_output)
    return gt_domain

def run(hparams, data):
    optimize_probabilities = hparams['optimize_probabilities']
    reduction_method = hparams['reduction']
    apply_prior_shift = hparams['prior_shift']

    expected_test_set_size = data['targets'].shape[0]
    expected_class_count = data['targets'].shape[1] // 2
    gt_labels = data['targets'].astype(np.int32)
    gt_domain = compute_gt_domain_from_labels(gt_labels)
    twon_activations = data['outputs']
    
    inference_thresholds = compute_thresh_on_dev(data['dev_outputs'], data['dev_targets'],
                                                     reduction_method=reduction_method)

    gender_count = np.sum(gt_labels, axis=0)
    target_domain_ratios = gender_count[:expected_class_count] / (
        gender_count[:expected_class_count] + gender_count[expected_class_count:])

    domain_labels = ['Male', 'Female']
    train_targets = data['train_targets']
        
    selected_outputs = []
    twon_activations = sigmoid(twon_activations)
        
    if reduction_method == 'sum':
        outputs = (twon_activations[:, :expected_class_count] + twon_activations[:, expected_class_count:]) / 2.0

    if reduction_method == 'condition':
        for i in range(twon_activations.shape[0]):
            cur_domain = gt_domain[i]
            if cur_domain == MALE:
                selected_outputs.append(twon_activations[i, :expected_class_count])
            elif cur_domain == FEMALE:
                selected_outputs.append(twon_activations[i, expected_class_count:])
            elif cur_domain == NO_DOMAIN:
                current_outputs = (twon_activations[i, :expected_class_count] +
                                   twon_activations[i, expected_class_count:])
                selected_outputs.append(current_outputs / 2.0)
            else:
                assert False
        outputs = np.stack(selected_outputs, axis=0)

    assert outputs.shape == (expected_test_set_size, expected_class_count)
    assert gt_domain.shape == (expected_test_set_size,)
    margin = 0.05
    lr = hparams['lr']
    input_potentials = outputs
    
    optimization_str = 'optimize on probabilities' if hparams['optimize_probabilities'] else 'optimize on outputs'
    if hparams['optimize_probabilities'] and hparams['reduction'] == 'sum':
        reduction_str = 'sum probabilities'
    elif not hparams['optimize_probabilities'] and hparams['reduction'] =='sum':
        reduction_str = 'sum outputs'
    elif not hparams['optimize_probabilities'] and hparams['reduction'] == 'condition':
        reduction_str = 'condition on d0'
    else:
        assert False
    prior_shift_str = 'prior shift' if hparams['prior_shift'] else 'no prior shift'
    method_str = '%s, %s, %s' % (reduction_str, optimization_str, prior_shift_str)
    
    optimizer = optimize_potentials_given_known_domain(gt_labels=gt_labels,
                                                       gt_domain=gt_domain,
                                                       lr=lr,
                                                       margin=margin,
                                                       apply_prior_shift=apply_prior_shift,
                                                       inputs_are_activations=(not optimize_probabilities),
                                                       method_name=method_str,
                                                       target_domain_ratios=target_domain_ratios,
                                                       domain_labels=domain_labels,
                                                       inference_thresholds=inference_thresholds,
                                                       training_set_targets=train_targets,
                                                       verbosity=hparams['verbosity'],
                                                       total_epochs=hparams['total_epochs'])
    
    output_potentials, output_classes = optimizer.optimize(input_potentials)
    return output_potentials, output_classes


In [None]:
with open('../data/celeba/labels_dict', 'rb') as f:
    celeba_labels_dict = pickle.load(f)
with open('../data/celeba/train_key_list', 'rb') as f:
    train_key_list = pickle.load(f)
with open('../data/celeba/dev_key_list', 'rb') as f:
    dev_key_list = pickle.load(f)
with open('../data/celeba/test_key_list', 'rb') as f:
    test_key_list = pickle.load(f)
with open('../data/celeba/subclass_idx', 'rb') as f:
    subclass_idx = pickle.load(f)

In [None]:
train_target_array = np.array([celeba_labels_dict[key] for key in train_key_list])
dev_target_array = np.array([celeba_labels_dict[key] for key in dev_key_list])
test_target_array = np.array([celeba_labels_dict[key] for key in test_key_list])

In [None]:
train_targets = np.hstack((train_target_array[:, subclass_idx]*train_target_array[:, -1:], 
                           train_target_array[:, subclass_idx]*(1-train_target_array[:, -1:])))
dev_targets = np.hstack((dev_target_array[:, subclass_idx] * dev_target_array[:, -1:], 
                         dev_target_array[:, subclass_idx] * (1-dev_target_array[:, -1:])))
test_targets = np.hstack((test_target_array[:, subclass_idx] * test_target_array[:, -1:], 
                          test_target_array[:, subclass_idx] * (1-test_target_array[:, -1:])))

In [None]:
# Change this to corresponding result path
dev_result_path = '../record/celeba_domain_discriminative/celeba_domain_discriminative_e1/dev_result.pkl'
test_result_path = '../record/celeba_domain_discriminative/celeba_domain_discriminative_e1/test_result.pkl'

with open(dev_result_path, 'rb') as f:
    dev = pickle.load(f)
with open(test_result_path, 'rb') as f:
    test = pickle.load(f)

In [None]:
dev_outputs = dev['output'][:, subclass_idx + [item+39 for item in subclass_idx]]
test_outputs = test['output'][:, subclass_idx + [item+39 for item in subclass_idx]]

In [None]:
data = {'targets': test_targets, 'outputs': test_outputs, 
        'dev_targets': dev_targets, 'dev_outputs': dev_outputs,
        'train_targets': train_targets,
        }

In [None]:
hparams = {'optimize_probabilities':True, 'reduction':'sum', 'prior_shift':False, 
           'lr': 1e-5, 'total_epochs': 300, 'verbosity': 1}

In [None]:
_, _ = run(hparams, data)