## Softmax is not enough recreation

- This is a simple recreation of the main results in the paper **"softmax is not enough (for sharp out-of-distribution)", Veličković et al. 2024.**

- This notebook intends to only show the main results of the **toy model max retrieval task, comparing "vanilla" (contant softmax temperature of 1) and adaptive temperature softmax.** 

- This is supportive material for my **Medium article.**

In [2]:
# Import the necessary libraries, we will be using JAX and Flax as per the original paper
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import jax.random as random
import optax
from typing import Callable

In [3]:
#Set intial random seed
seed=0

## Data Generation

Here we will be generating data for our task, as described in the original paper in section "A.2 Data generation". The main points are as follows:

- $n$ describes the number of items in the set to classify (this will be increased during inference to demonstrate out-of-distribution degradation).
- For each item from $1 \to n$, we will make a priority value $\rho$ ("Rho"), for which we want our model to pick the maxiumum value out of the $n$ items.
- To make this a classification task (so we can use cross-entropy loss), we append a class to each item, $\kappa$ ("Kappa"), which is sampled uniformly from $C$ classes. ($\kappa \sim \mathcal{U}\{1,...,C\}$). We will keep $C=10$ for the whole experiment as per the paper.
- The class is appended to the priority value $\rho$ as a one-hot encoded vector.
- We then generate a query vector, for this task the query vector is irrelivent so we sample this as a uniformly sampled variable between 0 and 1 ($q \sim \mathcal{U}(0,1)$)
- I am not sure how the original paper handled variable length sequences (differenet values of $n$) but I have used masking for this notebook, so I will append values of -1 to the x array to a maximum sequence length ($n$) of 4096

In [4]:
#Jax requires a random key to be passed to the random number generator
key = jax.random.PRNGKey(seed)


def generate_max_retrieval_data(n, n_classes, batch_size, rng, feats=4096):
    #Split the random key into 4 keys for generating 4 different random numbers
    rng, rho_key, kappa_key, q_key = jax.random.split(rng, 4)

    # Generate priority values
    rho = random.uniform(rho_key, (batch_size, n, 1))
    
    # Generate class labels
    kappa = random.randint(kappa_key, (batch_size, n), 0, n_classes)
    
    # One-hot encode class labels
    kappa_onehot = jax.nn.one_hot(kappa, n_classes)
    
    # Concatenate priority values and one-hot encoded class labels
    x = jnp.concatenate([rho, kappa_onehot], axis=-1)
    x = jnp.concatenate([x, jnp.full((batch_size,feats-n, n_classes+1), -1)], axis=1) 
    
    # Generate query vector
    q = random.uniform(q_key, (batch_size, 1, 1))


    # Determine the target class of the maximal item
    max_ind = jnp.argmax(rho, axis=1)[:,0]
    targets = jax.nn.one_hot(kappa[jnp.arange(batch_size), max_ind], n_classes)
    
    return x, q, targets


Here we are checking correct shapes of the data. Features should be of shape (batch_size, max_seq_length, n_classes + 1), queries should be (batch_size, 1, 1), the targets (largest class) should be (batch_size, n_classes)

In [5]:

n = 5   # Number of items in the set
C = 10  # Number of classes
num_examples = 10  # Number of examples to generate

features, queries, targets = generate_max_retrieval_data(n, C, num_examples, key)

print("Features shape:", features.shape)  # Should be (num_examples, feats, C+1)
print("Queries shape:", queries.shape)    # Should be (num_examples, 1, 1)
print("Targets shape:", targets.shape)    # Should be (num_examples, C)

Features shape: (10, 4096, 11)
Queries shape: (10, 1, 1)
Targets shape: (10, 10)


In [6]:
ind=5
for i in range(ind):
    print('*'*10)
    for x in [features, queries, targets]:
        print('________')
        print(x[i])

**********
________
[[ 0.16824102  0.          0.         ...  0.          0.
   1.        ]
 [ 0.7228373   0.          0.         ...  0.          0.
   0.        ]
 [ 0.04095995  0.          0.         ...  0.          0.
   0.        ]
 ...
 [-1.         -1.         -1.         ... -1.         -1.
  -1.        ]
 [-1.         -1.         -1.         ... -1.         -1.
  -1.        ]
 [-1.         -1.         -1.         ... -1.         -1.
  -1.        ]]
________
[[0.21033502]]
________
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
**********
________
[[ 0.82459986  0.          0.         ...  0.          0.
   1.        ]
 [ 0.9941201   0.          0.         ...  0.          0.
   0.        ]
 [ 0.12593412  1.          0.         ...  0.          0.
   0.        ]
 ...
 [-1.         -1.         -1.         ... -1.         -1.
  -1.        ]
 [-1.         -1.         -1.         ... -1.         -1.
  -1.        ]
 [-1.         -1.         -1.         ... -1.         -1.
  -1.        ]]
_______

## Masking

To deal with variable sequence lengths, we will be using masking in the attention head. The jax "MultiHeadDotProductAttention" has the option to pass an attention masks where a 1 value means keep value and 0 means mask value. The mask should be of shape (batch_size, n_heads, query_length, key_length), likely (128, 1, 1, 4096) during training.

In [8]:
from flax.typing import (
  Dtype
)
Array = jax.Array

@jax.jit
def make_attention_mask(query_input: Array,
                        key_input: Array,
                        dtype: Dtype = jnp.float32):
  """Mask-making helper for attention weights.

  """

#   query_mask = jnp.ones_like(query_input)
#   key_mask = jnp.where(key_input == -1, 0, 1)
#   query_mask = jnp.expand_dims(query_mask, axis=-2)
#   key_mask = jnp.expand_dims(key_mask, axis=-2)

#   mask = (jnp.einsum('...qhd,...khd->...hqk', query_mask, key_mask) != 0).astype(dtype)

  return jnp.repeat(jnp.where(key_input[0,:,0] == -1, 0, 1)[None, None, None, :], key_input.shape[0], axis=0) # This is less computationally expensive than the einsum above but less intuitive

In [9]:
print(queries.shape, features.shape)
mask = make_attention_mask(queries, features)
print(mask.shape)

(10, 1, 1) (10, 4096, 11)
(10, 1, 1, 4096)


## Model

First we will import custom functions from custom_flax_functions.py, these were copied from jax sourcecode but I have made a couple small alterations to allow for a custom softmax function and to return the attention weights from the "MultiHeadDotProductAttention". The model is mainly copied from the paper with a few small modifications for updating the softmax function and retrieing intermediate weights.

In [10]:
from custom_flax_funcs import MultiHeadDotProductAttention, dot_product_attention_with_adaptive_temp

In [34]:
import jax.numpy as jnp
from flax import linen as nn 
from typing import Callable
from functools import partial
class Model(nn.Module): 
    n_classes: int = 10
    n_feats: int = 128
    activation: Callable = nn.gelu
    softmax_fnc: Callable = jax.nn.softmax
    


    @nn.compact
    def __call__(self, x, q):
        attention_fn: Callable = partial(dot_product_attention_with_adaptive_temp, softmax_fnc=self.softmax_fnc)
        
        mask = make_attention_mask(q, x)
        x = nn.Dense(features=self.n_feats)(x)
        x = self.activation(x)
        x = nn.Dense(features=self.n_feats)(x)
        x = self.activation(x)
        q = nn.Dense(features=self.n_feats)(q)
        q = self.activation(q)
        q = nn.Dense(features=self.n_feats)(q)
        q = self.activation(q)

        # Save pre-attention head logits
        self.sow('intermediates', 'pre_attention_logits', x)
        
        x, attention_weights = MultiHeadDotProductAttention(
            num_heads=1,
            qkv_features=self.n_feats,
            attention_fn=attention_fn)(
            inputs_q=q,
            inputs_kv=x,
            mask=mask)
        
        self.sow('intermediates', 'attention_weights', attention_weights)
        x = nn.Dense(features=self.n_feats)(jnp.squeeze(x, -2)) 
        x = self.activation(x)
        x = nn.Dense(features=self.n_classes)(x)
        
        return x

## Training

Below is the loss function and training loop, the training was mostly recreated using section "A.4 Experimental hyperparameters" section in the paper, with less training steps, no adaptive lr and smaller max number of items ($n$). This is to save my computational resources and my overall effort. The classification accuracy later will reflect this.

The main point here is that the number of items in the task during training is sample uniformly from $n \sim \mathcal{U}\{5,...,16\}$ -- notice this is much smaller than the max sequence length of 4096 and is to demonstrate the out of distribution degradation during inference. Also, the model is trained on the "vanilla" softmax.

In [12]:
@jax.jit
def cross_entropy_loss(logits, targets):
    # Compute cross-entropy loss
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    loss = -jnp.mean(jnp.sum(log_probs * targets, axis=-1))

    return loss

In [12]:
from tqdm import tqdm

def compute_loss(params, apply_fn, batch, l2_reg=0.001):
    x, q, targets = batch
    logits, intermediates = apply_fn({'params': params}, x, q, mutable=['intermediates'])  
    
    # Compute cross-entropy loss
    loss = cross_entropy_loss(logits, targets)
    
    # Add L2 regularization
    l2_loss = l2_reg * sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
    return loss + l2_loss

# Update Step
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        return compute_loss(params, state.apply_fn, batch)
    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

# Training Configuration
def train_model():
    n_classes, n_feats, batch_size, num_steps = 10, 128, 128, 5000
    learning_rate, l2_reg = 0.001, 0.001

    # Initialize model and optimizer
    model = Model(n_classes=n_classes, n_feats=n_feats)
    rng = jax.random.PRNGKey(seed)
    x_sample, q_sample, _ = generate_max_retrieval_data(20, n_classes, batch_size, rng)
    params = model.init(rng, x_sample, q_sample)['params']

    # Create optimizer and training state
    tx = optax.adam(learning_rate=learning_rate)
    state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
    
    for step in tqdm(range(num_steps)):
        rng, key = jax.random.split(rng)

        #Remake the batch every 10 steps, trying to save computation
        if step % 10 == 0:
            n = jax.random.randint(key, shape=(), minval=5, maxval=17)
            batch = generate_max_retrieval_data(n, n_classes, batch_size, key)
        else:
            indices = jax.random.permutation(key, batch[0].shape[0])  # Shuffle indices based on the batch size when not remaking the batch
            batch = (batch[0][indices], batch[1][indices], batch[2][indices])

        # Perform training step
        state = train_step(state, batch)

        # Log progress 
        if step % 1000 == 0:
            loss = compute_loss(state.params, state.apply_fn, batch)  # Use state.apply_fn
            print(f"Step {step}: Loss = {loss:.4f}")

    
    return state

# Run training
trained_state = train_model()

  0%|          | 1/5000 [00:08<11:19:28,  8.16s/it]

Step 0: Loss = 3.4464


 20%|██        | 1001/5000 [52:28<6:03:03,  5.45s/it]

Step 1000: Loss = 0.4179


 40%|████      | 2001/5000 [1:46:33<4:29:33,  5.39s/it]

Step 2000: Loss = 0.3890


 60%|██████    | 3001/5000 [2:35:41<2:43:36,  4.91s/it]

Step 3000: Loss = 0.2659


 80%|████████  | 4001/5000 [3:22:36<1:23:03,  4.99s/it]

Step 4000: Loss = 0.2616


100%|██████████| 5000/5000 [4:15:10<00:00,  3.06s/it]  


## Testing Adaptive Temperature During Inference

Now for the fun part, we will use the paper's adaptive temperature softmax in place of the "vanilla" jax.nn.softmax function in the "MultiHeadDotProductAttention" layer...

Note: we are using the paper's originally fit function mapping logit entropy to optimal temperature because I couldn't be asked to fit my own.

In [13]:
#First define the adaptive temperature softmax function from the paper
def adaptive_temperature_softmax(logits): 
   original_probs = jax.nn.softmax(logits)
   poly_fit = jnp.array([-0.037, 0.481, -2.3, 4.917, -1.791])  # see Figure 5
   entropy = jnp.sum(-original_probs * jnp.log(original_probs + 1e-9), axis=-1, keepdims=True) # compute the Shannon entropy 
   beta = jnp.where( # beta = 1 / theta
    entropy > 0.5,  # don’t overcorrect low-entropy heads
    jnp.maximum(jnp.polyval(poly_fit, entropy), 1.0),  # never increase entropy
    1.0)
   return jax.nn.softmax(logits * beta)

In [14]:
n_list = [2**i for i in range(1, 13)] # list of number of items we want to test during inference testing, here we use powers of 2 [2, 4, 8, 16, ..., 4096]

def gen_inf_data(n_list, n_classes, batch_size, rng):
    x_list, q_list, target_list = [], [], []
    for n in n_list:
        x, q, targets = generate_max_retrieval_data(n, n_classes, batch_size, rng)
        x_list.append(x)
        q_list.append(q)
        target_list.append(targets)
    return jnp.concatenate(x_list, axis=0), jnp.concatenate(q_list, axis=0), jnp.concatenate(target_list, axis=0)

In [15]:
rng = jax.random.PRNGKey(10)
rng, key = jax.random.split(rng)
x, q, targets = gen_inf_data(n_list, 10, 2, rng)

In [16]:
print(x.shape, q.shape, targets.shape)

(24, 4096, 11) (24, 1, 1) (24, 10)


In [17]:
def compute_accuracy(logits, targets):
    """
    Compute classification accuracy.
    
    Parameters:
        logits (jnp.ndarray): Predicted logits (batch_size, n_classes).
        targets (jnp.ndarray): One-hot encoded true labels (batch_size, n_classes).
    
    Returns:
        float: Accuracy as a percentage.
    """
    predictions = jnp.argmax(logits, axis=-1)
    true_labels = jnp.argmax(targets, axis=-1)
    accuracy = jnp.mean(predictions == true_labels)
    return accuracy

In [18]:
from flax.serialization import to_bytes, from_bytes

def load_model(filepath):
    """
    Load JAX model parameters from a file.
    
    Parameters:
        filepath (str): Path to the file where parameters are saved.
    
    Returns:
        dict: Loaded model parameters.
    """
    with open(filepath, 'rb') as f:
        params = from_bytes(None, f.read())
    return params


In [19]:
# Load parameters for inference or further training
loaded_params = load_model('model_params.pkl')

## Inference

Here we will test the accuracy of "vanilla" softmax model vs adaptive softmax model. The performance for both models will be mean averaged over 10 seeds, as per the paper.

In [190]:
seeds = jnp.arange(11,22)
print(seeds)

[11 12 13 14 15 16 17 18 19 20 21]


In [194]:
adaptive_model = Model(n_classes=10, n_feats=128, softmax_fnc=adaptive_temperature_softmax)
vanilla_model = Model(n_classes=10, n_feats=128, softmax_fnc=jax.nn.softmax)

params = loaded_params   #trained_state.params

att_weights_ad = []
att_weights_van = []

for n in n_list:
    van_acc, ad_acc = [], []
    loss_adaptive, loss_vanilla = [], []
    for seed in seeds:
        rng = jax.random.PRNGKey(seed)
        rng, key = jax.random.split(rng)
        x, q, targets = generate_max_retrieval_data(n, 10, 32, key)
        ordered = jnp.argsort(x[:, :, 0].squeeze(), axis=-1)

        logits_adaptive, intermediate_adaptive = adaptive_model.apply({'params': params}, x, q, mutable=['intermediates'])
        logits_vanilla, intermediate_vanilla = vanilla_model.apply({'params': params}, x, q, mutable=['intermediates'])

        loss_adaptive.append(cross_entropy_loss(logits_adaptive, targets))
        loss_vanilla.append(cross_entropy_loss(logits_vanilla, targets))

        
        
        van_acc.append(compute_accuracy(logits_vanilla, targets))
        ad_acc.append(compute_accuracy(logits_adaptive, targets))

        if seed == 12:

            # Reorder the attention weights based on the ordered indices
            reordered_adaptive = jnp.take_along_axis(intermediate_adaptive['intermediates']['attention_weights'][0][:,0,0,:], ordered, axis=-1)
            reordered_vanilla = jnp.take_along_axis(intermediate_vanilla['intermediates']['attention_weights'][0][:,0,0,:], ordered, axis=-1)

            att_weights_ad.append(reordered_adaptive)
            att_weights_van.append(reordered_vanilla)
    
    print('*'*10)
    print(f"Number of items: {n}")

    print(f"Adaptive Model - Loss: {jnp.mean(jnp.array(loss_adaptive)):.4f}, Accuracy: {jnp.mean(jnp.array(ad_acc)) * 100:.2f}%")
    print(f"Vanilla Model - Loss: {jnp.mean(jnp.array(loss_vanilla)):.4f}, Accuracy: {jnp.mean(jnp.array(van_acc)) * 100:.2f}%")

att_weights_ad = jnp.array(att_weights_ad)
att_weights_van = jnp.array(att_weights_van)



**********
Number of items: 2
Adaptive Model - Loss: 0.0228, Accuracy: 100.00%
Vanilla Model - Loss: 0.0228, Accuracy: 100.00%
**********
Number of items: 4
Adaptive Model - Loss: 0.0384, Accuracy: 98.86%
Vanilla Model - Loss: 0.0384, Accuracy: 98.86%
**********
Number of items: 8
Adaptive Model - Loss: 0.1245, Accuracy: 96.59%
Vanilla Model - Loss: 0.1248, Accuracy: 96.59%
**********
Number of items: 16
Adaptive Model - Loss: 0.2131, Accuracy: 95.74%
Vanilla Model - Loss: 0.2111, Accuracy: 95.74%
**********
Number of items: 32
Adaptive Model - Loss: 0.3944, Accuracy: 90.62%
Vanilla Model - Loss: 0.3856, Accuracy: 90.34%
**********
Number of items: 64
Adaptive Model - Loss: 0.8070, Accuracy: 80.11%
Vanilla Model - Loss: 0.7644, Accuracy: 78.98%
**********
Number of items: 128
Adaptive Model - Loss: 1.0664, Accuracy: 76.14%
Vanilla Model - Loss: 0.9561, Accuracy: 72.73%
**********
Number of items: 256
Adaptive Model - Loss: 2.0465, Accuracy: 59.94%
Vanilla Model - Loss: 1.5774, Accuracy

Cool! The adaptive temperature softmax improves the accuracy of our model as we go out-of-distribution, just like the paper! 

## Visualise the Attention Weights

In [21]:
import treescope

Here we will use treescope to visulaise the attention weights for one of our seed runs and compare them as the number of items increases. Hopefully we will see sharper attention for the adaptive temperature softmax!

### Adaptive Temperature Softmax Attention Visualisation

In [22]:
treescope.render_array(att_weights_ad[:,:,-16:],axis_labels={0: "log2(num_items)", 1: "batch", 2: "top_16_items"})

### "Vanilla" Softmax Attention Visualisation

In [23]:
treescope.render_array(att_weights_van[:,:,-16:], axis_labels={0: "log2(num_items)", 1: "batch", 2: "top_16_items"})

Cool! We can see clearly that the adaptive temperature softmax attentions are sharper than the "vanilla" softmax, especially for the out-of-distribution tasks -- jsut like the paper!

Now we just need to think of ways to extend the paper, try other architecures without softmax? Use sigmoid attention? Other methods? Maybe you can think of some...

In [24]:


def save_model(params, filepath):
    """
    Save JAX model parameters to a file.
    
    Parameters:
        params (dict): Model parameters to save.
        filepath (str): Path to the file where parameters will be saved.
    """
    with open(filepath, 'wb') as f:
        f.write(to_bytes(params))

# Save parameters after training
save_model(trained_state.params, 'model_params.pkl')