# Sampling Subsets with Gumbel-Top $k$ Relaxations

In this part we show how to include a subset sampling component in differentiable models using Gumbel Top $k$ relaxations.
First we show how to build a differentiable subset sampler and then we show one application to differentiable $k$ nearnest neighbor classification.

Formally speaking we are given $N$ elements with weights $w_i$.
We would like to sample $k$ elements from $N$ without replacement.
Stated otherwise, we want a $k$-element subset $S=\{w_{i_1}, w_{i_2},\ldots, w_{i_k}\}$ from $N$ elements.

Given total weight $Z=\sum w_i$, the first element is sampled with probability $\frac{w_{i_1}}{Z}$, the second with probability $\frac{w_{i_2}}{Z-w_{i_1}}$ and so on for $k$ elements.
Multiplying the factors gives the following distribution for $k$ element subsets.

$$ p(S) = \frac{w_{i_1}}{Z}  \frac{w_{i_2}}{Z-w_{i_1}}\cdots \frac{w_{i_k}}{Z-\sum_{j=1}^{k-1} w_{i_j}}.$$

In the introduction we showed how sampling from a categorical distribution could be recast as choosing the argmax of a set of Gumbel random variables.
Relaxing the argmax with a softmax allowed us to approximate sampling from the target categorical distribution. 
A temperature could be used to control the extent of relaxation.
In this case the the categorical probabilities are given by the softmax distrbution
$$p_i = \frac{exp(x_i)}{\sum_j exp(x_j)} = \frac{w_i}{\sum_j w_j}$$

It turns out that by selecting the $k$ largest Gumbel random variables instead of just the largest we can sample subsets according to the sampling without replacement probability given above.
This procedure is closely related to a procedure known by the name of weighted reservoir sampling.

Seen this way, the Gumbel-Argmax trick is a method for sampling subsets of size $k=1$ with probabilities given by $p_i$.
Replacing the argmax by a Top-$k$ procedure for selecting the $k$ largest elements generalizes the Gumbel-Argmax to sample size $k$ subsets with probablity $p(S)$.
In this case we think of the Top-$k$ procedure as returning a $k$-hot vector $y$ where $y_i=1$ if the $i$th element is selected and $y_i=0$ otherwise.
Thus we represent subsets as $k$-hot vectors which also generalizes the representation of categorical samples as 1-hot vectors.

The unrelaxed subset sampling procedure can then be written as follows given non-negative weights $w_i$.

1. Compute keys $\hat{r_i} = -\log(-\log(u_i)) +  \log(w_i)$ for all $i$ and $u_i \in U(0,1)$.
2. Return $k$ largest keys $\hat{r_i}$.

## Top $k$ Relaxation

TODO: ref

We can construct an unrelaxed Top $k$ by iteratively applying the softmax $k$ times and sampling a 1-hot categorical sample at each step.
The $k$ 1-hot categorical samples are then combined into a single $k$-vector.
When the categorical sample gives a particular element, the log probablity for that element is set to $-\infty$ for the future iterations so that the element is never chosen again. We can relax this procedure by replacing samples from the softmax by the probabilties computed by softmax. When the softmax temperature is set to be small, the sampled and the relaxed outputs are close.

In more detail the procedure is as follows.

### Unrelaxed Version
For $i=1\ldots n$ and $j=1\ldots k$, set $ \alpha^1_i = \hat{r_i}$ and $\alpha_i^{j+1} = \alpha_i^{j} + \log(1-a_i^j)$

Here $a^j_i$ is a sample the categorical distribution with probabilities $p(a^j_i = 1) = \frac{\exp(\alpha_i^{j}/\tau)}{\sum_k\exp(\alpha_k^{j}/\tau)}$ and $\tau$ is a temperature.

Note that when $a_i^j$ is a 1-hot categorical sample the $\log(1-a_i^j)$ term in the first equation above sets the next $\alpha_i^{j+1}$ to $-\infty$ if $a_i^j=1$ and leaves it unchanged otherwise.
This ensures that the $i$th element once sampled is not sampled in the next steps.
Finally we add all the $k$ vectors as $\sum_j a^j$ and return the output as the sample.


### Relaxed Version
To relax the above procedure we can replace the categorical sample at step by its expectation.
In this case the update becomes

For $i=1\ldots n$ and $j=1\ldots k$, set $ \alpha^1_i = \hat{r_i}$ and $\alpha_i^{j+1} = \alpha_i^{j} + \log(1-p(a_i^j=1))$

where $p(a^j_i = 1) = \frac{\exp(\alpha_i^{j}/\tau)}{\sum_k\exp(\alpha_k^{j}/\tau)}$ as above.
At low values of $\tau$ the softmax distribution becomes close to deterministic outputs a value that is close to $k$-hot.
The temperature variable is a hyperparameter and ideally should be annealed from larger to smaller values during the course of training.
However, in most applications the temperature is left fixed per trial and tuned using cross validation.
Proper tuning of temperature can have a significant effect on the performance of the model.


Code and paper reference for this tutorial: 
[Reparameterizable Subset Sampling via Continuous Relaxations](https://arxiv.org/abs/1901.10517)
[[Code](https://github.com/ermongroup/subsets)]

In [27]:
import torch
import random
import time
from pathlib import Path

import numpy as np

#from subsets.knn.utils import one_hot
from utils import one_hot
#from subsets.knn.models.preact_resnet import PreActResNet18
#from subsets.knn.models.easy_net import ConvNet
#from subsets.knn.dataset import DataSplit
from dataset import DataSplit
#from subsets.knn.dknn_layer import DKNN, SubsetsDKNN

In [35]:
gpu = torch.device('cuda')

## Subset Sampler Class

The following `SubsetOperator` class implements the relaxed subset sampling procedure described above.


In [28]:
EPSILON = np.finfo(np.float32).tiny

class SubsetOperator(torch.nn.Module):
    def __init__(self, k, tau=1.0, hard=False):
        super(SubsetOperator, self).__init__()
        self.k = k
        self.hard = hard
        self.tau = tau

    def forward(self, scores):
        m = torch.distributions.gumbel.Gumbel(torch.zeros_like(scores), torch.ones_like(scores))
        g = m.sample()
        scores = scores + g

        # continuous top k
        khot = torch.zeros_like(scores)
        onehot_approx = torch.zeros_like(scores)
        for i in range(self.k):
            khot_mask = torch.max(1.0 - onehot_approx, torch.tensor([EPSILON]).cuda())
            scores = scores + torch.log(khot_mask)
            onehot_approx = torch.nn.functional.softmax(scores / self.tau, dim=1)
            khot = khot + onehot_approx

        if self.hard:
            # will do straight through estimation if training
            khot_hard = torch.zeros_like(khot)
            val, ind = torch.topk(khot, self.k, dim=1)
            khot_hard = khot_hard.scatter_(1, ind, 1)
            res = khot_hard - khot.detach() + khot
        else:
            res = khot

        return res


You can try the sampler on some example input and various temperatures. Note that the sum of the vectors elements is always $k$.
At lower temperatures the output should be close to $k$-hot.

In [36]:
sampler = SubsetOperator(k=2, tau=1.0)

x = torch.tensor([[1.,2.,3.,4.]]).to(gpu)
y = sampler(x)
print(y, y.sum())

## Synthetic samples

TODO

## Application: Differentiable k Nearest Neighbors

knn loss
deep features


In [31]:
class SubsetsDKNN(torch.nn.Module):
    def __init__(self, k, tau=1.0, hard=False, num_samples=-1):
        super(SubsetsDKNN, self).__init__()
        self.k = k
        self.subset_sample = SubsetOperator(k=k, tau=tau, hard=hard)
        self.num_samples = num_samples

    # query: batch_size x p
    # neighbors: 10k x p
    def forward(self, query, neighbors, tau=1.0):
        diffs = (query.unsqueeze(1) - neighbors.unsqueeze(0))
        squared_diffs = diffs ** 2
        l2_norms = squared_diffs.sum(2)
        norms = l2_norms  # .sqrt() # M x 10k
        scores = -norms

        top_k = self.subset_sample(scores)
        return top_k

In [None]:

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1)
        self.linear = nn.Linear(800, 500)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.linear(out))
        return out

Define hyperparameters

In [32]:
k = 9
tau = 1.0
NUM_TRAIN_QUERIES = 100
NUM_TEST_QUERIES = 10
NUM_TRAIN_NEIGHBORS = 100
LEARNING_RATE = 10 **-3
NUM_SAMPLES = 5
#resume = args.resume
#method = args.method
NUM_EPOCHS = 100
EMBEDDING_SIZE = 500

In [33]:
dknn_layer = SubsetsDKNN(k, tau)

In [None]:
def dknn_loss(query, neighbors, query_label, neighbor_labels):
    # query: batch_size x p
    # neighbors: 10k x p
    # query_labels: batch_size x [10] one-hot
    # neighbor_labels: n x [10] one-hot

    # num_samples x batch_size x n
    start = time.time()
    top_k_ness = dknn_layer(query, neighbors)
    elapsed = time.time() - start
    correct = (query_label.unsqueeze(1) *
               neighbor_labels.unsqueeze(0)).sum(-1)  # batch_size x n
    correct_in_top_k = (correct.unsqueeze(0) * top_k_ness).sum(-1)
    loss = -correct_in_top_k

    return loss, elapsed

In [None]:
h_phi = ConvNet().to(gpu)

In [None]:
optimizer = torch.optim.SGD(
    h_phi.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4)


In [None]:
split = DataSplit('mnist')

batched_query_train = split.get_train_loader(NUM_TRAIN_QUERIES)
batched_neighbor_train = split.get_train_loader(NUM_TRAIN_NEIGHBORS)


In [None]:
def train(epoch):
    timings = []
    h_phi.train()
    to_average = []
    # train
    for query, candidates in zip(batched_query_train, batched_neighbor_train):
        optimizer.zero_grad()
        cand_x, cand_y = candidates
        query_x, query_y = query

        cand_x = cand_x.to(device=gpu)
        cand_y = cand_y.to(device=gpu)
        query_x = query_x.to(device=gpu)
        query_y = query_y.to(device=gpu)

        neighbor_e = h_phi(cand_x).reshape(NUM_TRAIN_NEIGHBORS, EMBEDDING_SIZE)
        query_e = h_phi(query_x).reshape(NUM_TRAIN_QUERIES, EMBEDDING_SIZE)

        neighbor_y_oh = one_hot(cand_y).reshape(NUM_TRAIN_NEIGHBORS, 10)
        query_y_oh = one_hot(query_y).reshape(NUM_TRAIN_QUERIES, 10)

        losses, timing = dknn_loss(query_e, neighbor_e, query_y_oh, neighbor_y_oh)
        timings.append(timing)
        loss = losses.mean()
        loss.backward()
        optimizer.step()
        to_average.append((-loss).item() / k)

    print('Avg. train correctness of top k:',
          sum(to_average) / len(to_average))
    #print('Avg. train correctness of top k:', sum(
    #    to_average) / len(to_average), file=logfile)
    #print('Avg. time per dkNN step:', np.mean(timings))
    #print('Avg. time per dkNN step:', np.mean(timings), file=logfile)
    #logfile.flush()


In [None]:
def majority(lst):
    return max(set(lst), key=lst.count)


def new_predict(query, neighbors, neighbor_labels):
    '''
    query: p
    neighbors: n x p
    neighbor_labels: n (int)
    '''
    diffs = (query.unsqueeze(1) - neighbors.unsqueeze(0))  # M x n x p
    squared_diffs = diffs ** 2
    norms = squared_diffs.sum(-1)  # M x n
    indices = torch.argsort(norms, dim=-1)
    labels = neighbor_labels.take(indices[:, :k])  # M x k
    prediction = [majority(l.tolist()) for l in labels]
    return torch.Tensor(prediction).to(device=gpu).long()


def acc(query, neighbors, query_label, neighbor_labels):
    prediction = new_predict(query, neighbors, neighbor_labels)
    return (prediction == query_label).float().cpu().numpy()

In [None]:

def test(epoch, val=False):
    h_phi.eval()
    global best_acc
    with torch.no_grad():
        embeddings = []
        labels = []
        for neighbor_x, neighbor_y in batched_neighbor_train:
            neighbor_x = neighbor_x.to(device=gpu)
            neighbor_y = neighbor_y.to(device=gpu)
            embeddings.append(h_phi(neighbor_x))
            labels.append(neighbor_y)
        neighbors_e = torch.stack(embeddings).reshape(-1, EMBEDDING_SIZE)
        labels = torch.stack(labels).reshape(-1)

        results = []
        for queries in batched_query_val if val else batched_query_test:
            query_x, query_y = queries
            query_x = query_x.to(device=gpu)
            query_y = query_y.to(device=gpu)
            query_e = h_phi(query_x)  # batch_size x embedding_size
            results.append(acc(query_e, neighbors_e, query_y, labels))
        total_acc = np.mean(np.array(results))

    split = 'val' if val else 'test'
    print('Avg. %s acc:' % split, total_acc)
    #print('Avg. %s acc:' % split, total_acc, file=logfile)
    #if total_acc > best_acc and val:
    #    print('Saving...')
    #    state = {
    #        'net': h_phi.state_dict(),
    #        'acc': total_acc,
    #        'epoch': epoch,
    #    }
    #    torch.save(state, chkpt_path)
    #    best_acc = total_acc


In [None]:
for t in range(start_epoch, NUM_EPOCHS):
    print('Beginning epoch %d: ' % t)
    #print('Beginning epoch %d: ' % t, e_id, file=logfile)
    logfile.flush()
    train(t)
    test(t, val=True)