In [4]:
import os
import sys
import plotly.express as px
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import circuitsvis as cv

In [5]:
def top_predictions(logits, n_tokens=10):
    # Ensure logits is a numpy array for easy manipulation
    logits = logits.detach().numpy()
    
    # Get the last set of logits if logits represents a sequence of predictions
    next_token_logits = logits[0, -1] if logits.ndim > 1 else logits
    
    # Find the indices of the top n_tokens logits
    top_tokens = np.argsort(next_token_logits)[-n_tokens:][::-1]
    
    # Extract the probabilities (or logits) of these top tokens
    top_tokens_probs = next_token_logits[top_tokens]
    
    return top_tokens.tolist(), top_tokens_probs.tolist()

In [6]:
# Load a model (eg GPT-2 Small)
model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [18]:
 # Load an ICL sequence
sequence = """Google missed earning by 10% this quarter. // Sell
Apple increased revenue by 5% this quarter. // Buy
OpenAI has fired the CEO Sam Altman after the board voted. // Sell
Tesla improved revenue by 20% after meeting with clients this week. // Buy
Ford anticipated its profit will decrease by 50%. //"""


# Run the model and get logits and activations
logits, loss = model(sequence, return_type="both")
print("Model loss:", loss)

# logits, activations = model.run_with_cache(sequence)
# print(model.tokenizer.batch_decode(logits.argmax(dim=-1)[0]))

Model loss: tensor(4.1481, grad_fn=<NegBackward0>)


In [19]:
top_tokens, top_probs = top_predictions(logits)
top_tokens, top_probs

([25688, 11763, 12324, 3677, 16467, 2822, 198, 8734, 16981, 3497],
 [17.737926483154297,
  17.673601150512695,
  13.972225189208984,
  13.425271987915039,
  12.479013442993164,
  12.439650535583496,
  12.399359703063965,
  12.196646690368652,
  11.861695289611816,
  11.84620189666748])

In [20]:
model.tokenizer.batch_decode(top_tokens)

[' Sell',
 ' Buy',
 ' Sold',
 ' sell',
 ' Sale',
 ' buy',
 '\n',
 ' Share',
 ' Ask',
 ' Get']