# Sampled Softmax

For classification and prediction problems a typical criterion function is cross-entropy with softmax. If the number of output classes is high the computation of this criterion and the corresponding gradients could be quite costly. Sampled Softmax is a heuristic to speed up training in these cases.

## Basics

The softmax function is used in neural networks if we want to interpret the network output as a probability distribution over a set of classes $C$ with $|C|=N_C$.

Softmax maps an $N_C$-dimensional vector $z$, which has unrestricted values, to an $N_C$ dimensional vector $p$ whith non-negative values that sum up to 1 so that they can be interpreted as probabilities. More precisely:

$$
\begin{align}
p_i &= softmax(z, i)\\
    &= \frac{exp(z_i)}{\sum_{k\in C} exp(z_k)}\\
\end{align}
$$

In what follows we assume that the input $z$ to the softmax is computed from some hidden vector $h$ of dimension $N_h$  in a specific way, namely:

$$ z = W h + b $$

where $W$ is a learnable weight matrix of dimension $(N_c, N_h)$ and is $b$ a learnable bias vector.
We restrict ourself to this specific choice of $z$ because it helps in implementing an efficient sampled softmax.

In a typical use-case like for example a recurrent language model, the hidden vector $h$ would be the output of the recurrent layers and $C$ would be the set of words to predict.   

As a training criterion, we use cross-entropy which is a function of the expected (true) class $t\in C$ and the probability predicted for it:

$$cross\_entropy := -log(p_t)$$

## Sampled Softmax from the outside

For the normal softmax the CNTK Python-api provides the function [cross_entropy_with_softmax](https://cntk.ai/pythondocs/cntk.ops.html?highlight=softmax#cntk.ops.cross_entropy_with_softmax). This takes as input the vector $N_C$-dimensional vector $z$. As mentioned for our sampled softmax implementation we assume that this z is computed by $ z = W h + b $. In sampled softmax this has to be part of the whole implementation of the criterion.

Below we show the code for `cross_entropy_with_sampled_softmax_and_embedding`. Let’s look at signature first.

One fundamental difference to the corresponding function in the Python-api (`cross_entropy_with_softmax`) is that in the Python api function the input corresponds to $z$ and must have the same dimension as the target vector, while in cross_entropy_with_full_softmax the input corresponds to our hidden vector $h$ can have any dimension (hidden_dim).
Actually, hidden_dim will be typically much lower than the dimension of the target vector.

We also have some additional parameter `num_samples, sampling_weights, allow_duplicates` that control the random sampling. 
Another difference to the api function is that we return a tripple (z, cross_entropy_on_samples, error_on_samples).

We will come back to the details of the implementation below.


In [8]:
import numpy as np
import os
import cntk as C

# Creates a subgraph computing cross-entropy with sampled softmax.
def cross_entropy_with_sampled_softmax_and_embedding(
    hidden_vector,            # Node providing hidden input
    target_vector,            # Node providing the expected labels (as sparse vectors)
    vocab_dim,                # Vocabulary size
    hidden_dim,               # Dimension of the hidden vector
    num_samples,              # Number of samples to use for sampled softmax
    sampling_weights,         # Node providing weights to be used for the weighted sampling
    allow_duplicates = False, # Boolean flag to control whether to use sampling with replacemement 
                              # (allow_duplicates == True) or without replacement.
    ):
    # define the parameters leanabe parameters
    b = C.Parameter(shape = (vocab_dim, 1), init = C.init_bias_default_or_0)
    W = C.Parameter(shape = (vocab_dim, hidden_dim), init = C.init_default_or_glorot_uniform)

    # Define the node that generates a set of random samples per minibatch
    # Sparse matrix (num_samples * vocab_dim)
    sample_selector = C.random_sample(sampling_weights, num_samples, allow_duplicates)

    # For each of the samples we also need the probablity that it in the sampled set.
    inclusion_probs = C.random_sample_inclusion_frequency(sampling_weights, num_samples, allow_duplicates) # dense row [1 * vocab_size]
    log_prior = C.log(inclusion_probs) # dense row [1 * vocab_dim]

    # Create a submatrix wS of 'weights
    W_sampled = C.times(sample_selector, W) # [num_samples * hidden_dim]
    z_sampled = C.times_transpose(W_sampled, hidden_vector) + C.times(sample_selector, b) - C.times_transpose (sample_selector, log_prior)# [num_samples]

    # Getting the weight vector for the true label. Dimension hidden_dim
    W_target = C.times(target_vector, W) # [1 * hidden_dim]
    z_target = C.times_transpose(W_target, hidden_vector) + C.times(target_vector, b) - C.times_transpose(target_vector, log_prior) # [1]


    z_reduced = C.reduce_log_sum(z_sampled)
    
    # Compute the cross entropy that is used for training.
    # We don't check whether any of the classes in the random samples conincides with the true label, so it might
    # happen that the true class is counted
    # twice in the normalising demnominator of sampled softmax.
    cross_entropy_on_samples = C.log_add_exp(z_target, z_reduced) - z_target

    # For applying the model we also output a node providing the input for the full softmax
    z = C.times_transpose(W, hidden_vector) + b
    z = C.reshape(z, shape = (vocab_dim))

    zSMax = C.reduce_max(z_sampled)
    error_on_samples = C.less(zT, zSMax)
    return (z, cross_entropy_on_samples, error_on_samples)




To give a better idea of what the inputs and outputs are and how this all differs from the normal softmax we give below a corresponding function using normal softmax:

In [9]:
# Creates subgraph computing cross-entropy with (full) softmax.
def cross_entropy_with_softmax_and_embedding(
    hidden_vector,  # Node providing hidden input
    target_vector,  # Node providing the expected labels (as sparse vectors)
    vocab_dim,      # Vocabulary size
    hidden_dim      # Dimension of the hidden vector
    ):
    # Setup bias and weights
    b = C.Parameter(shape = (vocab_dim, 1), init = C.init_bias_default_or_0)
    W = C.Parameter(shape = (vocab_dim, hidden_dim), init = C.init_default_or_glorot_uniform)

    
    z = C.reshape( C.times_transpose(W, hidden_vector) + b, (1,vocab_dim))
    
    # Use cross_entropy_with_softmax
    cross_entropy = C.cross_entropy_with_softmax(z, target_vector)

    zMax = C.reduce_max(z)
    zT = C.times_transpose(z, target_vector)
    error_on_samples = C.less(zT, zMax)

    return (z, cross_entropy, error_on_samples)

As you can see the main differences to the api function `cross_entropy_with_softmax` are:
* We include an embedding.
* We return a tripple (z, cross_entropy, error_on_samples) instead of just returnting the cross entropy.


## A toy example

To explain how to integrate sampled softmax let us look at a toy example. In this toy example we first transform one-hot input vectors via some random projection into a lower dimensional vector $h$. The modeling task is to reverse this mapping using (sampled) softmax. Well, as already said this is a toy example.


In [26]:
from math import log, exp, sqrt
import timeit

# A class with all the parameters with use
class Param:
    # learning parameters
    learning_rate = 0.1
    minibatch_size = 100
    test_set_size = 100
    momentum_time_constant = 5 * minibatch_size

    # Parameters for sampled softmax
    use_sampled_softmax = False
    use_sparse = use_sampled_softmax
    softmax_sample_size = 10

    # Details of data and model
    num_classes = 20
    hidden_dim = 10

    zipf_sampling_weights = np.asarray([ zipf(i) for i in range(vocab_dim)], dtype=np.float32)
    data_sampling_distribution = zipf_sampling_weights/np.sum(zipf_sampling_weights)
    softmax_sampling_weights =  np.power(data_sampling_weights, 0.5)

# Creates random one-hot vectors of dimension 'num_classes'.
# Returns a tuple with a list of one-hot vectors, and list with the indices they encode.
def get_random_one_hot_data(num_vectors):
    indices = np.random.choice(
        range(Param.num_classes),
        size=num_vectors, 
        p=Param.data_sampling_distribution).reshape((1,num_vectors))
    list_of_vectors = C.one_hot(indices, Param.num_classes)
    return (list_of_vectors, indices.flatten())

# Create a network that:
# * Transforms the input one hot-vectors with a constant random embedding
# * Applies a linear decoding with parameters we want to learn
def create_model(labels):
    # random projection matrix
    random_data = np.random.normal(scale = sqrt(1.0/Param.hidden_dim), size=(Param.num_classes, Param.hidden_dim)).astype(np.float32)
    random_matrix = C.constant(shape = (Param.num_classes, Param.hidden_dim), value = random_data)
    
    h = C.times(labels, random_matrix)
    
    # Connect the latent output to (sampled/full) softmax.
    if use_sampled_softmax:
        sampling_weights = np.asarray(Param.softmax_sampling_weights, dtype=np.float32)
        softmax_input, ce, errs = cross_entropy_with_sampled_softmax_and_embedding(
            h, 
            labels,
            Param.num_classes, 
            Param.hidden_dim, 
            Param.softmax_sample_size, 
            Param.sampling_weights, 
            use_sparse = Param.use_sparse)
    else:
        softmax_input, ce, errs = cross_entropy_with_softmax_and_embedding(
            h, 
            labels, 
            Param.num_classes, 
            Param.hidden_dim)

    return softmax_input, ce, errs

def train():
    labels = C.input_variable(shape = Param.num_classes, is_sparse = Param.use_sparse)
    z, cross_entropy, errs = create_model(labels)

    # Setup the trainer
    learning_rate_schedule = C.learning_rate_schedule(Param.learning_rate, C.UnitType.sample)
    momentum_schedule = C.momentum_as_time_constant_schedule(Param.momentum_time_constant)
    learner = C.momentum_sgd(z.parameters, learning_rate_schedule, momentum_schedule, True)
    trainer = C.Trainer(z, cross_entropy, errs, learner)

    # Run training
    minbatch = 0
    average_cross_entropy = compute_average_cross_entropy(z)
    print("minbatch = %d average_cross_entropy = %.3f\tperplexity = %.3f"
            % (minbatch, average_cross_entropy, exp(average_cross_entropy)))

    for minbatch in range(1,2000):
        # Specify the mapping of input variables in the model to actual minibatch data to be trained with
        label_data, indices = get_random_one_hot_data(Param.minibatch_size)
        arguments = ({labels : label_data})

        t_start = timeit.default_timer()
        trainer.train_minibatch(arguments)
        t_end = timeit.default_timer()
        samples_per_second = Param.minibatch_size / (t_end - t_start)
        if minbatch % 200 == 0:
            average_cross_entropy = compute_average_cross_entropy(z)
            print("minbatch = %d average_cross_entropy = %.3f perplexity = %.3f samples/s = %.1f"
                    % (minbatch, average_cross_entropy, exp(average_cross_entropy), samples_per_second))

def compute_average_cross_entropy(softmax_input):
    vectors, indices = get_random_one_hot_data(Param.test_set_size)
    total_cross_entropy = 0.0
    arguments = (vectors)
    z = softmax_input.eval(arguments).reshape(Param.test_set_size, Param.num_classes)

    for i in range(len(indices)):
        log_p = log_softmax(z[i], indices[i])
        total_cross_entropy -= log_p

    return total_cross_entropy / len(indices)

# Computes log(softmax(z,index)) for a one-dimensional numpy array z in an numerically stable way.
def log_softmax(z,    # numpy array
                index # index into the array
            ):
    max_z = np.max(z)
    return z[index] - max_z - log(np.sum(np.exp(z - max_z)))



np.random.seed(1)


train()


minbatch = 0 average_cross_entropy = 2.971	perplexity = 19.510
minbatch = 200 average_cross_entropy = 2.276 perplexity = 9.734 samples/s = 8192.2
minbatch = 400 average_cross_entropy = 1.689 perplexity = 5.414 samples/s = 8423.6
minbatch = 600 average_cross_entropy = 1.332 perplexity = 3.788 samples/s = 8108.1
minbatch = 800 average_cross_entropy = 1.131 perplexity = 3.100 samples/s = 8433.7
minbatch = 1000 average_cross_entropy = 0.905 perplexity = 2.472 samples/s = 8341.9
minbatch = 1200 average_cross_entropy = 0.849 perplexity = 2.337 samples/s = 8378.0
minbatch = 1400 average_cross_entropy = 0.689 perplexity = 1.992 samples/s = 8127.8
minbatch = 1600 average_cross_entropy = 0.639 perplexity = 1.894 samples/s = 8114.0
minbatch = 1800 average_cross_entropy = 0.666 perplexity = 1.946 samples/s = 8285.1


##Importance sampling




In [None]:
def zipf(index):
    return 1.0 / (index + 5.0)

def entropy(p):
    return -np.sum(np.log(p)*p)

print("entropy: "+str(entropy(data_sampling_distribution)))
