In [1]:
import os
import sys
import plotly.express as px
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import circuitsvis as cv

In [2]:
def top_predictions(logits, n_tokens=10):
    # Ensure logits is a numpy array for easy manipulation
    logits = logits.detach().numpy()
    
    # Get the last set of logits if logits represents a sequence of predictions
    next_token_logits = logits[0, -1] if logits.ndim > 1 else logits
    
    # Find the indices of the top n_tokens logits
    top_tokens = np.argsort(next_token_logits)[-n_tokens:][::-1]
    
    # Extract the probabilities (or logits) of these top tokens
    top_tokens_probs = next_token_logits[top_tokens]
    
    return top_tokens.tolist(), top_tokens_probs.tolist()

In [3]:
# Load a model (eg GPT-2 Small)
model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
 # Load an ICL sequence
sequence = """Google decreased revenue by 50% this quarter. // Sell
Apple increased revenue by 30% this quarter. // Buy
OpenAI increased revenue by 80% this quarter. // Buy
Tesla increased revenue by 20% this week. // Buy
Ford decreased revenue by 40% this year. //"""


# Run the model and get logits and activations
logits, loss = model(sequence, return_type="both")
print("Model loss:", loss)

# logits, activations = model.run_with_cache(sequence)
# print(model.tokenizer.batch_decode(logits.argmax(dim=-1)[0]))

Model loss: tensor(2.8293, grad_fn=<NegBackward0>)


In [5]:
top_tokens, top_probs = top_predictions(logits)
top_tokens, top_probs

([11763, 25688, 9842, 12324, 198, 2822, 3677, 2094, 1514, 17329],
 [18.217296600341797,
  15.746993064880371,
  13.62832260131836,
  13.515268325805664,
  13.00149154663086,
  12.990470886230469,
  12.637158393859863,
  12.620811462402344,
  12.61329460144043,
  12.570591926574707])

In [6]:
model.tokenizer.batch_decode(top_tokens)

[' Buy',
 ' Sell',
 ' Bu',
 ' Sold',
 '\n',
 ' buy',
 ' sell',
 ' Don',
 ' Go',
 ' Sales']