In [4]:
import os
import sys
import plotly.express as px
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import circuitsvis as cv

In [5]:
def top_predictions(logits, n_tokens=10):
    # Ensure logits is a numpy array for easy manipulation
    logits = logits.detach().numpy()
    
    # Get the last set of logits if logits represents a sequence of predictions
    next_token_logits = logits[0, -1] if logits.ndim > 1 else logits
    
    # Find the indices of the top n_tokens logits
    top_tokens = np.argsort(next_token_logits)[-n_tokens:][::-1]
    
    # Extract the probabilities (or logits) of these top tokens
    top_tokens_probs = next_token_logits[top_tokens]
    
    return top_tokens.tolist(), top_tokens_probs.tolist()

In [6]:
# Load a model (eg GPT-2 Small)
model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [8]:
 # Load an ICL sequence
sequence = """Circulation revenue has increased by 5% in Finland. // Positive
Panostaja did not disclose the purchase price. // Neutral
Paying off the national debt will be extremely painful. // Negative
The company anticipated its operating profit to improve. //"""


# Run the model and get logits and activations
logits, loss = model(sequence, return_type="both")
print("Model loss:", loss)

# logits, activations = model.run_with_cache(sequence)
# print(model.tokenizer.batch_decode(logits.argmax(dim=-1)[0]))

Model loss: tensor(4.7088, grad_fn=<NegBackward0>)


In [9]:
logits

tensor([[[ 7.5261, 11.1214,  7.8919,  ..., -3.1299, -3.3873,  8.5934],
         [ 5.7127,  4.4277,  5.2648,  ..., -1.0044, -2.2531,  4.2533],
         [ 6.7858,  6.5628,  4.9398,  ..., -0.8247, -2.1811,  6.5063],
         ...,
         [ 8.4301,  8.2613, -0.4150,  ..., -0.0392, -1.8792,  8.2315],
         [ 3.2394,  5.5793,  4.1027,  ..., -3.4902, -3.4566, 11.1141],
         [ 3.9121,  6.0483,  3.1660,  ..., -4.0318, -5.2399,  7.6598]]],
       grad_fn=<AddBackward0>)

In [10]:
top_tokens, top_probs = top_predictions(logits)
top_tokens, top_probs

([25627, 33733, 36183, 4633, 8500, 3967, 13496, 9576, 13535, 198],
 [16.663612365722656,
  15.992452621459961,
  15.790984153747559,
  12.728263854980469,
  12.549846649169922,
  12.500576972961426,
  11.822123527526855,
  10.926145553588867,
  10.884611129760742,
  10.776315689086914])

In [11]:
model.tokenizer.batch_decode(top_tokens)

[' Neutral',
 ' Positive',
 ' Negative',
 ' negative',
 ' neutral',
 ' positive',
 ' Neg',
 ' Very',
 ' Strong',
 '\n']