# Sensitive Set Transport Invariant Ranking (SenSTIR) demo

The idea of this notebook is to replicate the synthetic experiment shown in figure 1 of [Individually Fair Rankings](https://openreview.net/pdf?id=71zCSP_HuBN).

## Synthetic data generation

In [1]:
import numpy as np
import torch
from torch import nn
from torch.utils.data.sampler import RandomSampler, BatchSampler
from torch.utils.data import IterableDataset

In [2]:
def generate_synthetic_LTR_data(majority_proportion = .8, num_queries = 100, num_docs_per_query = 10, seed=0):
    num_items = num_queries*num_docs_per_query
    X = np.random.uniform(0,3, size = (num_items,2)).astype(np.float32)
    relevance = X[:,0] + X[:,1]
    
    relevance = np.clip(relevance, 0.0,5.0)
    majority_status = np.random.choice([True, False], size=num_items, p=[majority_proportion, 1-majority_proportion])
    X[~majority_status, 1] = 0
    return [{"Q":X[i], "relevances":relevance[i], "majority_status":majority_status[i]} for i in range(num_items)]

In [3]:
class QueryIterableDataset(IterableDataset):
    '''
    iterable dataset that takes a set of items and indifintely samples sets of such items (queries) per iteration
    '''
    def __init__(self, items_dataset, shuffle, query_size):
        self.dataset = items_dataset
        self.query_size = query_size
        self.shuffle = shuffle

    def __iter__(self):
        while True:
            idx = self._infinite_indices()
            query = [self.dataset[i] for i in next(idx)]
            query = torch.utils.data.default_collate(query)
            yield query
    
    def _infinite_indices(self):
        worker_info = torch.utils.data.get_worker_info()
        seed = 0 if worker_info is None else worker_info.id
        g = torch.Generator()
        g.manual_seed(seed)
        while True:
            if self.shuffle:
                idx = (torch.randperm(len(self.dataset))[:self.query_size]).tolist()
                yield idx

In [4]:
num_docs_per_query = 10
num_queries = 100
dataset_train = generate_synthetic_LTR_data(num_queries = num_queries, num_docs_per_query = num_docs_per_query)
dataloader = torch.utils.data.DataLoader(QueryIterableDataset(dataset_train, True, num_docs_per_query), num_workers=2, batch_size=2)
#the data loader gets a batch of queries with relevance (batch x num_items_per_query) and features (batch x num_items_per_query x num_features)
next(iter(dataloader))

{'Q': tensor([[[2.2843e+00, 1.1995e+00],
          [1.1719e+00, 7.5232e-01],
          [2.9315e+00, 2.0143e+00],
          [1.0673e+00, 2.1543e+00],
          [2.9508e-03, 0.0000e+00],
          [1.7279e-01, 1.6969e+00],
          [1.9283e+00, 2.1673e+00],
          [4.0874e-01, 2.7975e+00],
          [1.6565e+00, 2.6349e+00],
          [2.3515e+00, 1.8551e+00]],
 
         [[5.8831e-01, 1.5622e+00],
          [3.9876e-01, 1.7910e+00],
          [2.0933e+00, 1.2158e+00],
          [2.8020e+00, 0.0000e+00],
          [6.3734e-01, 1.5935e-01],
          [1.2637e-01, 2.7408e+00],
          [2.8980e+00, 9.8164e-01],
          [3.0247e-01, 2.9643e+00],
          [2.2433e+00, 1.4692e+00],
          [2.3821e-01, 0.0000e+00]]]),
 'relevances': tensor([[3.4838, 1.9243, 4.9458, 3.2215, 2.1270, 1.8697, 4.0956, 3.2063, 4.2914,
          4.2067],
         [2.1506, 2.1897, 3.3090, 3.7945, 0.7967, 2.8672, 3.8797, 3.2667, 3.7125,
          2.1767]]),
 'majority_status': tensor([[ True,  True,  True,  

## fair distance learning

This is necessary to compute the wasserstein distance on the worst example generation (q' in the paper)

In [5]:
# we perform a logistic regression on the dataset to build a sensitive direction
from sklearn.linear_model import LogisticRegression
all_data = torch.utils.data.default_collate(dataset_train)
x = all_data['Q']
majority_status = all_data['majority_status']
LR = LogisticRegression(C = 100).fit(x, majority_status)
sens_directions = torch.tensor(LR.coef_,dtype=torch.float32).T
print('sensitive directions', sens_directions)


sensitive directions tensor([[-0.5045],
        [50.7707]])


As we can see, the logistic regression finds a high sensitivity on the second dimension, the data generation process artificially produces this high correlation.

### Batched Wasserstein Distance

To audit the model we need to compute a Wasserstein distance between sets of items. This distance can be build by using a Mahalonobis distance as the pairwise cost between items in each set. The sensitive direction we just learned can be used to build this Mahalanobis distance.

In [6]:
from inFairness.distances import SensitiveSubspaceDistance, BatchedWassersteinDistance

In [7]:
distance_q = BatchedWassersteinDistance(SensitiveSubspaceDistance())
distance_q.fit(sens_directions)

## Model and Output distance

We also need a model we would like to train and to define a distance in the output space.

In [18]:
import torch.nn.functional as F
class MultilayerPerceptron(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model1 = MultilayerPerceptron()

In [14]:
from inFairness.distances import SquaredEuclideanDistance
distance_y = SquaredEuclideanDistance()
distance_y.fit(num_dims=num_docs_per_query)

It's worth nothing that in the output space we are measuring distances between sets of scores (each score corresponding to each document in a query). Therefore the dimensionality of the SquaredEuclideanDistance above.

## SenSTIR

In [19]:
from inFairness.fairalgo import SenSTIR
fairalgo1 = SenSTIR(
    network=model1,
    distance_q=distance_q,
    distance_y=distance_y,
    rho=0.01,
    eps=1.0,
    auditor_nsteps=100,
    auditor_lr=0.01,
    monte_carlo_samples_ndcg=20,
)

# Training Loop

In [20]:
from trainer import Trainer
trainer = Trainer(
    dataloader=dataloader,
    model=fairalgo1,
    optimizer=torch.optim.Adam(fairalgo1.parameters(),lr=0.01),
    max_iterations = 100
)

In [21]:
%%time
trainer.train()

CPU times: user 1h 1min 58s, sys: 5.4 s, total: 1h 2min 3s
Wall time: 1min 25s
