In [1]:
%run Common\ Code.ipynb



2024-08-04 12:57:15.700820: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-04 12:57:15.738181: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [25]:
class TopTokenList:
    def __init__(self, max_size):
        self._members = []
        self._max_size = max_size
        self._worst_score = None

    def maybe_add(self, tokens, activations):
        for token_idx, activation in enumerate(activations):
            if self._worst_score is not None and activation > self._worst_score:
                self._members.append((tokens, token_idx, activation))
                self._members.sort(key = lambda x: x[2], reverse=True)
        
            if len(self) > self._max_size:
                self._drop_worst()

    def __len__(self):
        return len(self._members)

    def _drop_worst(self):        
        self._members = self._members[:self._max_size]

    def get_all(self):
        return self._members[:]

    def get_worst_score(self):
        return self._worst_score

    def get_best_score(self):
        if self._members:
            return self._members[0][2]


from IPython.display import display, HTML, clear_output
from itertools import chain


class TopTokens:
    def __init__(self, num_top_tokens=20):
        self.top_tokens = [TopTokenList(num_top_tokens) for _ in range(100_000)]

    def incorporate(self, tokens_minibatch, sae_activations_minibatch):
        """
        Expects two tensors:
         - tokens: dimension (MB, BS) where MB is (any) minibatch size and BS is (any) block size,
                   representing a minibatch of token strings given to an LLM.
         - sae_activations: dimension (MB, 100_000) where MB is as above.
                            Represents the activation of each neuron of the SAE hidden layer
                            when inspecting an LLM presented with the given tokens.
        Updates the top_tokens for each SAE neuron according to their activations.
        """
        for tokens, sae_activations in zip(tokens_minibatch, sae_activations_minibatch.permute(0, 2, 1)):
            for neuron_idx, top_token_list in enumerate(self.top_tokens):
                top_token_list.maybe_add(tokens, sae_activations[neuron_idx])

    def display(self):       
        clear_output(wait=True)
        lines = ['<h3>Top Tokens</h3>']
        for idx in chain(range(5), range(99_995, 100_000)):            
            if self.top_tokens[idx].get_worst_score() is None:
                lines.append(f'<strong>Neuron {idx}:</strong> awaiting tokens<br>')
            else:
                lines.append(f'<strong>Neuron {idx}</strong>: {len(self.top_tokens[idx])} tokens (range {self.top_tokens[idx].get_worst_score():.2f} - {self.top_tokens[idx].get_best_score():.2f}<br>')
        html_custom = ''.join(lines)
        display(HTML(html_custom))        
    

In [26]:
def all_sae_activations():
    sae = PhiProbeCommons.sae()
    for tokens, activations in PhiProbeCommons.all_activations(split='test'):
        _, sae_activations = sae(activations)
        yield tokens, sae_activations

In [27]:
top_tokens = TopTokens()
for tokens, sae_activations in all_sae_activations():
    print(sae_activations.shape)
    top_tokens.incorporate(tokens, sae_activations)
    top_tokens.display()
    

torch.Size([5, 100, 100000])


KeyboardInterrupt: 