In [1]:
import numpy as np

In [3]:
def prep_att_pipe(
        att_tensor: np.ndarray,
        n_first_tokens: int = None,
        skip_first_n_tokens: int = None,
        skip_last_n_tokens: int = None,
        n_context_tokens_start_idx: int = None,
        n_context_tokens_end_idx: int = None,
        window_size: int = 0,
        window_step: int = 4,
        postprocess_fn: callable = None,
        valid_example_th: int = 4,
        **kwargs: dict,
    ) -> np.ndarray:
        """
        Function to prepare the attention tensor for further analysis.
        It removes the prompt tokens and the last offset_size tokens from the context.
        """

        skip_first_n_tokens = skip_first_n_tokens if skip_first_n_tokens is not None else 0
        skip_last_n_tokens = skip_last_n_tokens if skip_last_n_tokens is not None else 0
        n_first_tokens = n_first_tokens if n_first_tokens is not None else att_tensor.shape[-2]

        att_tensor = att_tensor[..., slice(skip_first_n_tokens, n_first_tokens + skip_first_n_tokens - skip_last_n_tokens), slice(n_context_tokens_start_idx, n_context_tokens_end_idx)]
        
        if att_tensor.shape[-2] < valid_example_th:
            return None

        if (window_size) and (att_tensor.shape[-2] > window_size):

            att_tensor = {
                tuple([i, i + window_size]) : postprocess_fn(att_tensor[..., i: i + window_size, :], **kwargs) if kwargs else postprocess_fn(att_tensor[..., i: i + window_size, :])
                for i in range(0, att_tensor.shape[-2], window_step) if i + window_size <= att_tensor.shape[-2]
            }

        else:
            att_tensor = postprocess_fn(att_tensor, **kwargs) if kwargs else postprocess_fn(att_tensor)

        return att_tensor

In [78]:
from scipy.spatial.distance import jensenshannon

In [85]:
random_values_context = np.random.randint(100, 500, size=10)
random_values_gen = np.random.randint(6, 200, size=10)

mock_at = [
    np.random.random(size=(42, 16, rvc, 8)) for rvc, rgg in zip(random_values_context, random_values_gen)
]

In [117]:
def dist_div_heads_agg(att_tensor: np.ndarray) -> np.ndarray:
    """
    Function to calculate the distance between the attention heads.
    """

    n_layers, n_heads, _, n_generated_tokens = att_tensor.shape

    att_tensor = np.concatenate((att_tensor, 1 - np.sum(att_tensor, axis=2, keepdims=True)), axis=2)
    reference_distribution = np.mean(att_tensor, axis=1)

    js_divergence = np.zeros((n_layers, n_heads, n_generated_tokens))

    for i in range(n_heads):
        js_divergence[:, i, :] = jensenshannon(att_tensor[:, i, ...], reference_distribution, axis=1)

    del att_tensor, reference_distribution
    js_divergence = np.mean(js_divergence, axis=-1)

    return js_divergence

In [115]:
for att in mock_at:

    n_layers, n_heads, n_context_tokens, n_generated_tokens = att.shape
    # n_bins = int(2 * (n_context_tokens) ** (1/3))
    
    # quantiles = np.transpose(np.percentile(att, np.linspace(0, 100, n_bins + 1), axis=2), (1, 2, 0, 3))

    # binned_data = np.zeros_like(att, dtype=int)

    # for i in range(n_layers):
    #     for j in range(n_heads):
    #         for k in range(n_generated_tokens):
    #             binned_data[i, j, :, k] = np.digitize(att[i, j, :, k], quantiles[i, j, :, k])

    
    # for i in range(n_heads):
    # unique, counts = np.unique(binned_data[0, 0, :, 0], return_counts=True)
    # print(dict(zip(unique, counts)))
    att_full = np.concatenate((att, 1 - np.sum(att, axis=2, keepdims=True)), axis=2)

    reference_distribution = np.mean(att_full, axis=1)

    print(att_full.shape)
    print(reference_distribution.shape)

    # heads_probs = np.apply_along_axis(lambda x: np.histogram(x, bins=n_bins, density=True)[0], axis=2, arr=binned_data)
    # reference_probs = np.apply_along_axis(lambda x: np.histogram(x, bins=n_bins, density=True)[0], axis=1, arr=reference_distribution)

    # # print(heads_probs[0, :, 0])

    # # print(heads_probs.shape)
    # # print(reference_probs.shape)
    

    js_divergence = np.zeros((n_layers, n_heads, n_generated_tokens))

    for i in range(n_heads):
        js_divergence[:, i, :] = jensenshannon(att_full[:, i, ...], reference_distribution, axis=1)

    js_divergence = np.mean(js_divergence, axis=-1)

    print(js_divergence.shape)
    break

    # print(js_divergence.shape)
    # break

    # for each token and layer, calculate the shanon divergence between the head distribution and the reference distribution

    # np.apply_along_axis(lambda x: jensenshannon(x, reference_distribution), axis=1, arr=binned_data)

    # print(binned_data.shape)

(42, 16, 256, 8)
(42, 256, 8)
(42, 16)


In [68]:
hist = np.apply_along_axis(lambda x: np.histogram(x, bins=n_bins, range=(0, n_bins), density=True)[0], axis=2, arr=binned_data)

In [70]:
hist.shape

(32, 32, 14, 8)

In [None]:
# Example array with shape (channels, height, width)
array = np.random.rand(3, 64, 64)

# Change the order of channels to (height, width, channels)
transposed_array = np.transpose(array, (1, 2, 0))

print(transposed_array.shape)  # Output: (64, 64, 3)