# SoLU Circuits

## Imports

In [6]:
from src.solu_utils import model, tokenizer
import torch
import torch.nn.functional as F
from IPython.display import clear_output

clear_output()

## Find interesting activations

A 1-layer model without an MLP can't do much more than skip trigrams. Whilst the MLP layer added may improve this a little, the prompts will still need to have quite simple answers.

In this case we'll look for the ability of the model to close HTML tags. As an simple overview of how HTML tags work, whenever a tag is used (e.g. `<b>` for bold) it must be closed when you no longer want it to apply (e.g. `<b>bold text</b> normal text`).

Note that `</` is a single token - so we can't use `<` as the last token and expect to see `\`.

In [4]:
def get_next_token(prompt: str) -> str:
    """Run a forward pass to get the next token"""
    logits = model(prompt)[0]
    log_probabilities = F.log_softmax(logits, dim=-1)
    predictions = torch.argmax(log_probabilities, 2)
    next_token = [model.tokenizer.decode(t) for t in predictions.squeeze()][-1]
    return next_token

In [42]:
# Example prompts to run through the model
prompts = [
    "<h1>Title",
    "<b>Some bold text</",
    "<p>An interesting paragraph</",
    "<table><tr><th>Model name"
]

# Run each prompt (with a few tokens appended by the model)
for prompt in prompts:
    result = prompt
    
    additional_tokens = 2
    for i in range(additional_tokens):
        next_token = get_next_token(result)
        result = result + next_token
        
    print(result)

<h1>Title</h
<b>Some bold text</b>
<p>An interesting paragraph</p>
<table><tr><th>Model name</th
