In [63]:
a = torch.arange(6, dtype=torch.float32).reshape(3,2)
b = torch.cat([a,a + 1])
ab = a.repeat(2,1)
ab_b = ab-b
z = torch.eye(2)
print(torch.einsum("ni,mj,ij->nm",a,b,z))
print(a,b)

tensor([[ 1.,  3.,  5.,  2.,  4.,  6.],
        [ 3., 13., 23.,  8., 18., 28.],
        [ 5., 23., 41., 14., 32., 50.]])
tensor([[0., 1.],
        [2., 3.],
        [4., 5.]]) tensor([[0., 1.],
        [2., 3.],
        [4., 5.],
        [1., 2.],
        [3., 4.],
        [5., 6.]])


In [1]:
from functorch import vmap

In [63]:
sigma = torch.eye(10).to('cuda')
#let's try with the mahalanobis distance now
def mahalanobis_distance(x,y,sigma):
    diff = x-y
    return torch.einsum("i,ij,j",diff,sigma,diff)
d1 = vmap(mahalanobis_distance, in_dims=(None,0,None))
d2 = vmap(d1, in_dims=(0,None,None))
d3 = vmap(d2, in_dims=(0,0,None))


In [53]:
def mahalanobis_vectorized(x,y,sigma):
    x2 = torch.einsum("bnj,ji,bni->bn",x,sigma,x)[:,:,None]
    y2 = torch.einsum("bmj,ji,bmi->bm",y,sigma,y)[:,None,:]
    a = 2*torch.einsum("bnj,ji,bmi->bnm",x,sigma,y)
    return x2 - a + y2

In [64]:
%%timeit
x = torch.randn(1000,10,10).to('cuda')
y = torch.randn(1000,10,10).to('cuda')
mahalanobis_vectorized(x,y,sigma)

1.49 ms ± 4.74 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [65]:
%%timeit
x = torch.randn(1000,10,10).to('cuda')
y = torch.randn(1000,10,10).to('cuda')
d3(x,y,sigma)

2.15 ms ± 11.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Synthetic data generation

In [24]:
import numpy as np
import torch
from torch.utils.data.sampler import RandomSampler, BatchSampler
from torch.utils.data import IterableDataset

np.random.uniform(0,3, size = (4,2)).astypepe(np.float32)

array([[0.25056785, 2.9974692 ],
       [0.90872806, 0.43962702],
       [0.08038563, 0.36455625],
       [2.2616425 , 0.6200977 ]], dtype=float32)

In [25]:
def generate_synthetic_LTR_data(majority_proportion = .8, num_queries = 100, num_docs_per_query = 10, seed=0):
    num_items = num_queries*num_docs_per_query
    X = np.random.uniform(0,3, size = (num_items,2)).astype(np.float32)
    relevance = X[:,0] + X[:,1]

    # i don't know why but the "fair policy" paper clips the values between 0 and 5
    relevance = np.clip(relevance, 0.0,5.0)
    majority_status = np.random.choice([True, False], size=num_items, p=[majority_proportion, 1-majority_proportion])
    X[~majority_status, 1] = 0
    return [{"X":X[i], "relevance":relevance[i], "majority_status":majority_status[i]} for i in range(num_items)]

In [26]:
class QueryIterableDataset(IterableDataset):
    '''
    iterable dataset that takes a set of items and indifintely samples sets of such items (queries) per iteration
    '''
    def __init__(self, items_dataset, shuffle, query_size):
        self.dataset = items_dataset
        self.query_size = query_size
        self.shuffle = shuffle

    def __iter__(self):
        while True:
            idx = self._infinite_indices()
            query = [self.dataset[i] for i in next(idx)]
            query = torch.utils.data.default_collate(query)
            yield query
    
    def _infinite_indices(self):
        worker_info = torch.utils.data.get_worker_info()
        seed = 0 if worker_info is None else worker_info.id
        g = torch.Generator()
        g.manual_seed(seed)
        while True:
            if self.shuffle:
                idx = (torch.randperm(len(self.dataset))[:self.query_size]).tolist()
                yield idx

In [34]:
num_docs_per_query = 10
num_queries = 100
dataset_train = generate_synthetic_LTR_data(num_queries = num_queries, num_docs_per_query = num_docs_per_query)
dataloader = torch.utils.data.DataLoader(QueryIterableDataset(dataset_train, True, num_docs_per_query), num_workers=2, batch_size=2)
#the data loader gets a batch of queries with relevance (batch x num_items_per_query) and features (batch x num_items_per_query x num_features)
next(iter(dataloader))

{'X': tensor([[[2.1182, 0.6758],
          [2.9103, 1.5819],
          [1.7527, 2.1017],
          [1.6029, 2.3631],
          [1.7725, 0.2491],
          [0.8376, 2.6814],
          [1.5226, 0.1447],
          [2.1234, 1.0957],
          [2.0005, 2.3674],
          [2.7640, 0.0000]],
 
         [[1.9271, 0.0000],
          [1.1964, 2.0993],
          [0.2173, 1.3909],
          [1.5868, 1.4924],
          [1.8643, 2.9888],
          [2.9540, 0.1448],
          [2.7450, 2.0327],
          [0.3358, 0.9012],
          [0.7998, 1.8257],
          [0.6313, 0.0131]]]),
 'relevance': tensor([[2.7940, 4.4922, 3.8544, 3.9660, 2.0216, 3.5190, 1.6673, 3.2191, 4.3679,
          4.0509],
         [3.5637, 3.2957, 1.6082, 3.0792, 4.8531, 3.0989, 4.7777, 1.2370, 2.6255,
          0.6444]]),
 'majority_status': tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [False,  True,  True,  True,  True,  True,  True,  True,  True,  True]])}

# fair distance learning

This is necessary to compute the wasserstein distance on the worst example generation (q' in the paper)

In [29]:
# we perform a logistic regression on the dataset to build a sensitive direction
from sklearn.linear_model import LogisticRegression
all_data = torch.utils.data.default_collate(dataset_train)
x = all_data['X']
majority_status = all_data['majority_status']
LR = LogisticRegression(C = 100).fit(x, majority_status)
sens_directions = torch.tensor(LR.coef_,dtype=torch.float32).T
print('sensitive directions', sens_directions)


sensitive directions tensor([[-0.3408],
        [52.3883]])


As we can see, the logistic regression finds a high sensitivity on the second dimension, the data generation process artificially produces this high correlation.

In [30]:
# with this sensitive direction we can build a mahalanobis distance by passing a set of vectors
from inFairness.distances import SensitiveSubspaceDistance

sigma = SensitiveSubspaceDistance().compute_projection_complement(sens_directions)
sigma

tensor([[9.9996e-01, 6.5045e-03],
        [6.5045e-03, 4.2319e-05]])

In [31]:
#vectorized version of the mahalanobis distance
from functorch import vmap
def md(x,y,sigma):
    '''
    computes the mahalanobis distance between 2 vectors of D dimensions:
    
    .. math:: MD = (x - y) \\Sigma (x - y)^{'}
    
    takes a B,N,D and B,M,D batches of items and returns B,N,M matrix of costs.
    '''
    diff = x-y
    return torch.einsum("i,ij,j",diff,sigma,diff)
md1 = vmap(md, in_dims=(None,0,None))
md2 = vmap(md1, in_dims=(0,None,None))
vect_mahalanobis_distance = vmap(md2, in_dims=(0,0,None))

In [33]:
x,y = next(iter(dataloader))['X'], next(iter(dataloader))['X']
vect_mahalanobis_distance(x,y,sigma)

torch.Size([2, 10, 10])

# wasserstein distance

we can use a sinkhorn distance with low blur to approximate the Wasserstein distance. For more documentation go [here](https://www.kernel-operations.io/geomloss/api/pytorch-api.html)

Here we optimize a parameter to be close to x.

In [73]:
%%time
from geomloss import SamplesLoss

wasserstein_distance = SamplesLoss('sinkhorn',cost=lambda x,y: vect_mahalanobis_distance(x,y,sigma), blur=0.005)
wasserstein_distance(x,y)
x_prime = torch.nn.Parameter(torch.rand_like(y))

optimizer = torch.optim.Adam([x_prime], lr=0.001)
print(((x_prime - x)**2).sum())
for i in range(15000):
    optimizer.zero_grad()
    loss = wasserstein_distance(x,x_prime).sum()
    if i%1000 == 0:
        print('loss', loss)
    loss.backward()
    optimizer.step()


tensor(48.5409, grad_fn=<SumBackward0>)
loss tensor(0.3027, grad_fn=<SumBackward0>)
loss tensor(0.3027, grad_fn=<SumBackward0>)
loss tensor(0.3027, grad_fn=<SumBackward0>)
loss tensor(0.3027, grad_fn=<SumBackward0>)
loss tensor(0.3027, grad_fn=<SumBackward0>)
loss tensor(0.3027, grad_fn=<SumBackward0>)
loss tensor(0.3027, grad_fn=<SumBackward0>)
loss tensor(0.3027, grad_fn=<SumBackward0>)
loss tensor(0.3027, grad_fn=<SumBackward0>)
loss tensor(0.3027, grad_fn=<SumBackward0>)
CPU times: user 59min 46s, sys: 4.57 s, total: 59min 51s
Wall time: 1min 19s


In [77]:
# the firt feature of each item is very close to one element in the optimized item
print('x', x[0,:,0])
print('x\'', x_prime[0,:,0])

x tensor([2.6547, 2.2641, 2.6876, 2.9700, 2.7404, 0.6154, 0.1087, 2.9811, 0.8636,
        1.8676])
x' tensor([2.2528, 0.1171, 0.8751, 2.9813, 0.8752, 2.6613, 2.7359, 1.8576, 2.9516,
        0.8711], grad_fn=<SelectBackward0>)


Note that the elements of the first column (per batch) are all very close to some element in the optized set of queries. Given the logistic regression, the projection complement of this basis vector makes the resulting mahalanobis distance to be much more sensitive to differences in the first dimmension than in the second (per item). By using this as the cost function of the Wasserstein distance, only the first dimmension gets 

In [40]:
num_docs_per_query = 10
num_queries = 100
dataset = generate_synthetic_LTR_data(num_queries = num_queries, num_docs_per_query = num_docs_per_query)

sampler = RandomSampler(data_source=dataset)
query_sampler = BatchSampler(
    sampler, num_docs_per_query, drop_last=True
)
batch_sampler = BatchSampler(
    query_sampler, 11, drop_last = True
)

dataloader = build_train_loader(dataset, num_docs_per_query, 11)


In [37]:
len(next(iter(batch_sampler))[0])

10

In [43]:
i = 0
for d in dataloader:
    print(i)
    i +=1
    if i>100000:
        break

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [7]:
num_docs_per_query = 10
num_queries = 100
dataset = generate_synthetic_LTR_data(num_queries = num_queries, num_docs_per_query = num_docs_per_query)
dataloader = build_train_loader(data
# X_queries, relevances, majority_status = generate_synthetic_LTR_data(num_queries = num_queries, num_docs_per_query = num_docs_per_query)
# X_queries_test, relevances_test, majority_status_test = generate_synthetic_LTR_data(num_queries = num_queries, num_docs_per_query = num_docs_per_query)

In [12]:
np.random.choice(100, size=(10,2))

array([[64, 76],
       [33, 87],
       [64, 20],
       [ 6, 48],
       [72, 74],
       [75, 33],
       [65, 62],
       [97, 92],
       [76, 70],
       [68, 30]])