This notebook uses matrix-based entropy [1, 2, 3] to look at layer-wise entropies of pretrained LLMs. For an introduction to this, check out the sentence_entropies.ipynb notebook.

Authors: Oscar Skean

In [None]:
%load_ext autoreload
%autoreload 2

from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel, GPT2Model, AutoModel
import torch
import numpy as np
import tqdm
from utils import get_model_path, get_dataloader, normalize
import repitl.matrix_itl as itl
import math

device = "cuda:1"

# Layerwise Entropies

The information plane is a probe on the model to analyze the mutual information between a pair of variables (input/output, input/layer representation, output/layer representation, etc.) as some quantity is changed (layer depth, context length, etc.).

In [None]:
from utils import compute_LDA_matrix

def entropy_normalization(entropy, normalization, N, D):
    assert normalization in ['maxEntropy', 'logN', 'logD', 'logNlogD', 'raw']

    if normalization == 'maxEntropy':
        entropy /= min(math.log(N), math.log(D))
    elif normalization == 'logN':
        entropy /= math.log(N)
    elif normalization == 'logD':
        entropy /= math.log(D)
    elif normalization == 'logNlogD':
        entropy /= (math.log(N) * math.log(D))
    elif normalization == 'raw':
        pass

    return entropy

def compute_sentence_entropies_over_layers(model, dataloader, alpha=1, normalization='maxEntropy'):
    batched_layerwise_entropy = []
    
    counter = 0
    max_samples = 1000
    with torch.no_grad():
        for batch in tqdm.tqdm(dataloader, total=max_samples):
            counter += 1
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            N, D = outputs.hidden_states[0].shape[1:]

            hidden_states_list = [normalize(x.squeeze()) for x in outputs.hidden_states]

            # get the covariance matrices for each layer outputs
            layer_cov_list = []
            for layer_states in hidden_states_list:
                layer_states = layer_states.squeeze()
                if N > D:
                   layer_cov = layer_states.T @ layer_states
                else:
                    layer_cov = layer_states @ layer_states.T
                layer_cov /= torch.trace(layer_cov)
                layer_cov = torch.clamp(layer_cov, min=0)
                layer_cov_list.append(layer_cov)

            # compute entropy for each covariance
            layerwise_entropies = [itl.matrixAlphaEntropy(X.double(), alpha=alpha).item() for X in layer_cov_list]
            layerwise_entropies = [entropy_normalization(x, normalization, N, D) for x in layerwise_entropies]
            batched_layerwise_entropy.append(layerwise_entropies)

            if counter > max_samples:
                break
    
    avg_layerwise_entropy = np.array(batched_layerwise_entropy).mean(axis=0)
    return avg_layerwise_entropy

def compute_dataset_entropies_over_layers(model, dataloader, alpha=1, normalization='maxEntropy'):
    counter = 0
    max_samples = 1000

    with torch.no_grad():
        layerwise_samples = []
        for batch in tqdm.tqdm(dataloader, total=max_samples):
            counter += 1
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)

            hidden_states_list = [normalize(x.squeeze()) for x in outputs.hidden_states]

            # get mean hidden state for the sample in each layer
            layer_means = [torch.mean(x, dim=0) for x in hidden_states_list]
            layer_means = torch.stack(layer_means)
            layerwise_samples.append(layer_means)

            if counter >= max_samples:
                break

    Z = torch.stack(layerwise_samples)
    Z = Z.transpose(0, 1)
    L, NUM_SAMPLES, D = Z.shape

    if NUM_SAMPLES > D:
        cov = torch.matmul(Z.transpose(1, 2), Z)  # L x D x D
    else:
        cov = torch.matmul(Z, Z.transpose(1, 2))  # L x NUM_SAMPLES x NUM_SAMPLES

    layerwise_entropies = [itl.matrixAlphaEntropy(LAYER_COV.double(), alpha=alpha).item() for LAYER_COV in cov]
    layerwise_entropies = [entropy_normalization(x, normalization, NUM_SAMPLES, D) for x in layerwise_entropies]

    return layerwise_entropies


def compute_dataset_lidar_over_layers(model, dataloader, alpha=1, normalization='maxEntropy'):        
    counter = 0
    max_samples = 1000

    with torch.no_grad():

        batched_augmented_sample_vectors = []
        for batch in tqdm.tqdm(dataloader, total=max_samples):
            counter += 1

            augmented_sample_vectors = []

            for augmented_sample in batch:
                augmented_sample = {k: v.unsqueeze(0).to(device) for k, v in augmented_sample.items()}
                outputs = model(**augmented_sample)

                hidden_states_list = [normalize(x.squeeze()) for x in outputs.hidden_states] # L x N x D

                # get mean hidden state for the sample in each layer
                layer_means = [torch.mean(x, dim=0) for x in hidden_states_list] # L x D
                layer_means = torch.stack(layer_means)
                augmented_sample_vectors.append(layer_means)

            augmented_sample_vectors = torch.stack(augmented_sample_vectors) # NUM_AUGMENTATIONS x L x D
            batched_augmented_sample_vectors.append(augmented_sample_vectors)

            if counter >= max_samples:
                break


    batched_augmented_sample_vectors = torch.stack(batched_augmented_sample_vectors) # NUM_SAMPLES x NUM_AUGMENTATIONS x L x D
    batched_augmented_sample_vectors = batched_augmented_sample_vectors.transpose(0,1).transpose(0,2) # L x NUM_SAMPLES x NUM_AUGMENTATIONS x D 
    L, NUM_SAMPLES, NUM_AUGMENTATIONS, D = batched_augmented_sample_vectors.shape
    
    layerwise_entropies = [
        itl.matrixAlphaEntropy(
            compute_LDA_matrix(LAYER_EMBEDDINGS.double()),
            alpha=alpha
        ).item() 
        for LAYER_EMBEDDINGS in batched_augmented_sample_vectors
    ]
    layerwise_entropies = [entropy_normalization(x, normalization, NUM_SAMPLES, D) for x in layerwise_entropies]

    return layerwise_entropies

In [None]:
from utils import model_name_to_sizes, get_augmentation_collated_dataloader
import pickle

def calculate_and_save_layerwise_entropies(model_name, experiment_name, granularity='sentence', alpha=1, normalization='maxEntropy'):
    assert experiment_name in ['alpha1', 'lidar']
    assert granularity in ['sentence', 'dataset']

    if experiment_name == 'lidar':
        assert granularity == 'dataset'

    layerwise_entropies_per_model = {}
    for model_size in model_name_to_sizes[model_name]:
        model_path = get_model_path(model_name, model_size)
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModel.from_pretrained(model_path, output_hidden_states=True, torch_dtype=torch.bfloat16).to(device)

        if granularity == 'sentence' and experiment_name == 'alpha1':
            dataloader = get_dataloader(tokenizer, "wikitext", split="train")
            layerwise_entropies_per_model[model_size] = compute_sentence_entropies_over_layers(model, dataloader, alpha=alpha, normalization=normalization)
        elif granularity == 'dataset' and experiment_name == 'alpha1':
            dataloader = get_dataloader(tokenizer, "wikitext", split="train")
            layerwise_entropies_per_model[model_size] = compute_dataset_entropies_over_layers(model, dataloader, alpha=alpha, normalization=normalization)
        elif granularity == 'dataset' and experiment_name == 'lidar':
            dataloader = get_augmentation_collated_dataloader(tokenizer, "wikitext", split="train", num_augmentations_per_sample=16)
            layerwise_entropies_per_model[model_size] = compute_dataset_lidar_over_layers(model, dataloader, alpha=alpha, normalization=normalization)

        del model
    
    with open(f"entropy_results/entropy={experiment_name}_model={model_name}_granularity={granularity}_normalization={normalization}.pkl", "wb") as f:
        pickle.dump(layerwise_entropies_per_model, f)

    return layerwise_entropies_per_model

def load_results_layerwise_entropies(model_name, experiment_name, granularity, normalization):
    with open(f"entropy_results/entropy={experiment_name}_model={model_name}_granularity={granularity}_normalization={normalization}.pkl", "rb") as f:
        layerwise_entropies_per_model = pickle.load(f)
    return layerwise_entropies_per_model

In [None]:
#for model in ['EleutherAI', 'cerebras', 'mamba', 'mamba2']:
#for model in ['mamba', 'mamba2']:
for model in ['EleutherAI', 'cerebras']:
    for experiment in ['lidar', 'alpha1']:
        for granularity in ['sentence', 'dataset']:
            for normalization in ['maxEntropy', 'logN', 'logNlogD', 'raw']:
                try:
                    # check if already computed
                    try:
                        load_results_layerwise_entropies(model, experiment, granularity, normalization)
                        print(f"Already computed {model} {experiment} {granularity} {normalization}")
                        continue
                    except:
                        calculate_and_save_layerwise_entropies(model, experiment, granularity, alpha=1, normalization=normalization)
                except Exception as e:
                    print(f"Error for {model} {experiment} {granularity} {normalization}")
                    print(e)

In [None]:
import matplotlib.pyplot as plt

layerwise_entropies = load_results_layerwise_entropies("mamba", 'alpha1', granularity='sentence', normalization='raw')

fig, ax = plt.subplots(figsize=(10, 6))

for model_name, entropies in layerwise_entropies.items():
    ax.plot(entropies, marker='o', label=model_name)

ax.set_title('Entropies for Mamba Models of different numbers of parameters')
ax.set_xlabel('Layer Index')
ax.set_ylabel('Entropy')
ax.legend()
ax.grid(True)

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt

layerwise_entropies = load_results_layerwise_entropies("mamba", 'alpha1', granularity='sentence', normalization='raw')

fig, ax = plt.subplots(figsize=(10, 6))

for model_name, entropies in layerwise_entropies.items():
    ax.plot(entropies, marker='o', label=model_name)

ax.set_title('Entropies for Mamba Models of different numbers of parameters')
ax.set_xlabel('Layer Index')
ax.set_ylabel('Entropy')
ax.legend()
ax.grid(True)

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt

layerwise_entropies = load_results_layerwise_entropies("EleutherAI", 'alpha1', granularity='dataset', normalization='maxEntropy')

fig, ax = plt.subplots(figsize=(10, 6))

for model_name, entropies in layerwise_entropies.items():
    ax.plot(entropies, marker='o', label=model_name)

ax.set_title('Entropies for Mamba Models of different numbers of parameters')
ax.set_xlabel('Layer Index')
ax.set_ylabel('Entropy')
ax.legend()
ax.grid(True)

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Create a figure with a subplot for each model
num_models = len(layerwise_entropies_per_model)
fig, axs = plt.subplots(num_models, 1, figsize=(8, 3 * num_models), sharex=True)

all_entropies = [entropy for entropies in layerwise_entropies_per_model.values() for entropy in entropies]
y_min, y_max = min(all_entropies), max(all_entropies)

# Flatten axs if there is only one subplot
if num_models == 1:
    axs = [axs]

# Plot each model's data
for ax, (model_name, entropies) in zip(axs, layerwise_entropies_per_model.items()):
    ax.plot(entropies, marker='o')
    ax.set_title(f'Entropies for {model_name}')
    ax.set_ylabel('Entropy')
    ax.set_ylim(y_min*0.95, y_max*1.05)
    ax.grid(True)

# Set common x-axis label
axs[-1].set_xlabel('Layer Index')

# Adjust layout
plt.tight_layout()

# Show the plot
plt.show()

## Data Augmentation

In [None]:
from utils import text_augmentation

input_text = ["The quick brown fox jumps over the lazy dog."]

augmented_text = text_augmentation(input_text, num_augmentations_per_sample=10)
print(f"Original text: {input_text}")
for i, t in enumerate(augmented_text[0].split(',')):
    print(f"Augmented text {i}: {t}")