In [29]:
import torch
from transformer_lens import HookedTransformer


device = "cuda"


model = HookedTransformer.from_pretrained_no_processing(
    "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    dtype=torch.bfloat16,
    device=device,
)

#%% Test caching activations
prompt = "The Eiffel Tower is in the city of"
tokens = model.to_tokens(prompt)
# logits = model(tokens)
# model.to_str_tokens(logits.argmax(dim=-1)[-1])

# model.eval()
# with torch.no_grad():
#     logits, cache = model.run_with_cache(
#         tokens,
#         names_filter=lambda name: "blocks.0.hook_resid_post" in name,
#         remove_batch_dim=True,
#         return_type="logits",
#     )

# #%%
# actv = cache["blocks.0.hook_resid_post"]
# print(actv.shape)
# actv.device

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model deepseek-ai/DeepSeek-R1-Distill-Llama-8B into HookedTransformer


In [30]:
from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
#%% Loading the dataset
ds = load_dataset("andreuka18/OpenThoughts-10k-DeepSeek-R1")

In [31]:
# from sae_lens import SAE
# from sae_lens.config import DTYPE_MAP as DTYPES
# from sae_lens.sae import TopK
# from sae_dashboard.feature_data_generator import FeatureMaskingContext

# sae_path = "andreuka18/deepseek-r1-distill-llama-8b-lmsys-openthoughts"
# sae = SAE.load_from_pretrained(sae_path, device=device)

In [88]:
def classify_tokens(reasoning):
    st = model.to_str_tokens(reasoning)
    classes = []
    prev_nl = False
    for i, s in enumerate(st):
        cls = "normal"
        nl = s[-1] == "\n"
        if s.lower().strip() in ["but", "wait"]:
            if prev_nl:
                cls = "backtrack"
                nl = True
            else:
                cls = "other"
        if prev_nl and s.lower().strip() in [","]:
            nl = True
        prev_nl = nl
        classes.append(cls)
    return classes


def highlight_reasoning(reasoning, classification):
    st = model.to_str_tokens(reasoning)
    html = ""
    for s, cls in zip(st, classification):
        color = "red" if cls == "backtrack" else "blue" if cls == "other" else "white"
        html += f"<span style='color: {color};'>{s}</span>"
    from IPython.display import display, HTML
    display(HTML(html))


# reasoning = ds["train"][0]['deepseek_reasoning']
# classification = classify_tokens(reasoning)
# highlight_reasoning(reasoning, classification)


hook_name = "blocks.19.hook_resid_post"







In [41]:
from tqdm.auto import tqdm

torch.cuda.empty_cache()

def get_actv_and_class_int(count=None, batch_size=1):
    all_actv = []
    all_class_int = []
    count = count or len(ds["train"])
    for i in tqdm(range(0, count, batch_size)):
        batch = ds["train"][i:min(i+batch_size, count)]

        # strip text to just the problem and reasoning section
        texts = list(map(lambda r: r[:(r.index("</think>"))], batch['text']))
        tokens = model.to_tokens(texts)

        # Classify tokens
        classification = list(map(classify_tokens, texts))
        class_int = torch.zeros(tokens.shape[0], tokens.shape[1], dtype=torch.int64)
        for i, cls in enumerate(classification):
            for j, c in enumerate(cls):
                class_int[i, j] = 0 if c == "normal" else 1 if c == "backtrack" else 2
        # highlight_reasoning(texts[0], classification[0])

        # Temp hack to save on memory
        truncate = 5000 // batch_size
        tokens = tokens[:, :truncate]
        class_int = class_int[:, :truncate]

        # Evaluate model
        model.eval()
        with torch.no_grad():
            logits, cache = model.run_with_cache(
                tokens,
                names_filter=lambda name: hook_name in name,
                return_type=None,
            )

        actv = cache["blocks.19.hook_resid_post"] # batch, seq, d_model
        all_actv.append(actv.view(-1, actv.shape[-1]))
        all_class_int.append(class_int.reshape(-1))
        del cache,actv,texts, tokens, classification, class_int
        torch.cuda.empty_cache()


    all_actv = torch.cat(all_actv, dim=0)
    all_class_int = torch.cat(all_class_int, dim=0)

    return all_actv, all_class_int

all_actv, all_class_int = get_actv_and_class_int(1000, batch_size=5)
print(all_actv.shape)
print(all_class_int.shape)
print((all_class_int == 1).sum())
DIR = "/workspace"
torch.save({"actv": all_actv, "class_int": all_class_int}, f"{DIR}/data.pt")




  0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([1000000, 4096])
torch.Size([1000000])
tensor(2448)


In [42]:
def train_probe():
    X = all_actv.cpu().to(torch.float32).numpy()
    y = all_class_int.cpu().to(torch.float32).numpy()
    from sklearn.preprocessing import OneHotEncoder

    # enc = OneHotEncoder()
    #y = enc.fit_transform(y.reshape(-1, 1)).toarray()



    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import accuracy_score

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    probe = LogisticRegression()
    probe.fit(X_train, y_train)

    y_pred = probe.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print("Accuracy: ", accuracy)



    accuracy = accuracy_score(y[y == 1], probe.predict(X[y == 1]))
    print("Accuracy on backtrack tokens: ", accuracy)

    return X, y, probe

X, y, probe = train_probe()



Accuracy:  0.999945
Accuracy on backtrack tokens:  0.9922385620915033


In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import DataLoader, TensorDataset

# # Define the model
# class Classifier(nn.Module):
#     def __init__(self):
#         super(Classifier, self).__init__()
#         self.fc = nn.Linear(all_actv.shape[1], 2)

#     def forward(self, x):
#         x = self.fc(x)
#         return x

# # Create the model, loss function and optimizer
# model = Classifier()
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=0.01)

# # Create the dataset and dataloader
# dataset = TensorDataset(all_actv, all_class_int)
# dataloader = DataLoader(dataset, batch_size=5, shuffle=True)

# # Train the model
# for epoch in range(1000):
#     for inputs, labels in dataloader:
#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()

# # Save the model
# torch.save(model.state_dict(), 'model.pth')


'\nYour role as an assistant involves thoroughly exploring questions through a systematic long thinking process\nbefore providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of\nanalysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered\nthinking process.\n<｜User｜>Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition.To defend her castle Count Mishti has invented a new challenge. In this challenge, the participant goes to a room that has n corridors and each corridor has n cells. Each cell has some coins of gold in it i.e. the jth cell of the ith corridor has a[i][j] gold coins ( 1 ≤ i ≤ n && 1 ≤ j ≤ n).\n\nThe participant starts at the cell [1][1] and ends at the cell [N][N]. Therefore, a[1][1]=a[N][N]=0. She may move either to the cell directly in front of hi

In [116]:

def run_with_dir(dir, index=0):
    text = ds["train"][index]['text']
    question = text[:(text.index("<think>") + len("<think>"))]

    def hook_fn(resid, hook):
        resid += dir

    model.eval()
    with torch.no_grad():
        model.reset_hooks()
        model.add_hook(hook_name, hook_fn)
        response = model.generate(question, max_new_tokens=1000)
    return response[len(question):]

def summarize_response(response):
    response_classification = classify_tokens(response)
    highlight_reasoning(response, response_classification)
    print(repr(response))
    print("normal tokens: ", len([c for c in response_classification if c == "normal"]))
    print("backtrack tokens: ", len([c for c in response_classification if c == "backtrack"]))
    print("other tokens: ", len([c for c in response_classification if c == "other"]))
    print("newlines: ", len([c for c in response if c == "\n"]))




In [None]:
C = torch.tensor(probe.coef_).to(device)
print(C.shape)

orig_response = run_with_dir(0)
response = run_with_dir(-C[1] * 5  + C[2] * 5)


In [124]:

response2 = run_with_dir(-C[1] * 5  + C[2] * 5)





  0%|          | 0/1000 [00:00<?, ?it/s]

In [117]:
summarize_response(orig_response)

"\nYour role as an assistant involves thoroughly exploring questions through a systematic long thinking process\nbefore providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of\nanalysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered\nthinking process.\n<｜User｜>Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition.To defend her castle Count Mishti has invented a new challenge. In this challenge, the participant goes to a room that has n corridors and each corridor has n cells. Each cell has some coins of gold in it i.e. the jth cell of the ith corridor has a[i][j] gold coins ( 1 ≤ i ≤ n && 1 ≤ j ≤ n).\n\nThe participant starts at the cell [1][1] and ends at the cell [N][N]. Therefore, a[1][1]=a[N][N]=0. She may move either to the cell directly in front of hi

In [125]:
summarize_response(response)
summarize_response(response2)

'\nOkay, I have to solve this problem. Let\'s understand the problem first.\n\nSo, Mishti\'s castle has a corridor that\'s like a grid. The participant starts at (1,1) and has to reach (n,n). The cells are in a grid where each cell has some gold coins. The participant can move either down (from cell i,j to i+1,j) or right (from i,j to i,j+1) because she can\'t go back. But she can\'t go beyond the grid.\n\nThe goal is to find the maximum number of coins she can collect before leaving the corridor. Or, if she can\'t collect at least C coins, return -1.\n\nWait, wait. Wait. I think I need to clarify the problem. Because I read the problem statement again, it says, when she arrives at [N][N], she gives the guard AT LEAVING, so she must have at least C coins. So, the problem is: she can collect as much as possible, passing through all possible cells she can, but the sum of all coins collected before (but not including) the last cell. Because when she is at the start, [1][1], she has a[1][1

'\n\n heee only the wait in there, the seems. Sp wait. Hmm, the I is no, of them. So by things is, only am they no, need they only to wait.\n\n "\n\tconsole it makes took more. Because he no.\n\n 万}}} not have he could probably. "Wait, fly let they make open like since can\'t no, he need run let.\n\n No, maybe the just they can\' it wait, in the west cos only if of can but can\'t let can\'t.\n\n Hmm no, the wait\'s there to use. Polling.\n\n It\'s tough, maybe is there since a whole.\n\n Wait if the start must mine it. Some call-- a no.\n checkbox. Wait, partial when it\'s aligned since with the other state got kissing— No.\n\n Maybe it speaks found in the way it is. Elegant, perhaps.\n\n\nTurns if however variation. No, exceed p.\n\n Ducts actually, seems rule isn\'t constructions. Or, it pond, balance.\n\n No, heh. It might just in consistent.\n\n Not sure, maybe diversions—are\n\n Only that."\n\n Maybe tell the but needs is in, where. It seems it, but can\'t, it if. Maybe etc. Could

In [120]:
# Use SAE features
x = torch.load(DIR + "/feature_directions.pth", weights_only=True).to(device)

for i in range(0, 3):
    print(i)
    dir = x[:, i]
    response = run_with_dir(dir)
    summarize_response(response)

0


  0%|          | 0/1000 [00:00<?, ?it/s]

"\nOkay, so I have this programming problem to solve, and I need to think it through carefully. Let me read the problem again to make sure I understand it.\n\nThe problem is about a participant named Mishti defending her castle with a challenge. The participant starts at cell [1][1] and must reach [N][N]. Each corridor has N cells, and each cell has some gold coins. The participant can only move right or down, meaning from cell [i][j], she can go to [i+1][j] or [i][j+1]. She can't go back, and she can't go out of the grid.\n\nWhen she arrives at [N][N], she has to give the guard at least C coins. The goal is to find the maximum number of coins she can be left with. If she can't collect enough, we output -1.\n\nAlright, so the input is T test cases. For each test case, we get n (the size of the grid) and c (the required coins), followed by n lines each with n integers representing the coins in each cell.\n\nLooking at the sample input:\n\n2\n2 7\n0 0\n8 0\n2 5\n0 1\n3 0\n\nSample Output

  0%|          | 0/1000 [00:00<?, ?it/s]

'\nOkay, so I have this problem to solve. Let me read it carefully.\n\nThe problem is about Count Mishti\'s castle challenge. The participant starts at cell [1][1] and ends at [N][N]. Each corridor has N cells, so it\'s like a grid with n rows and n columns. Each cell has some gold coins, given by a[i][j]. The participant can only move right or down. So from cell [i][j], they can go to [i+1][j] or [i][j+1], which makes sense.\n\nThe catch is that when the participant arrives at [N][N], she has to give the guard at least C coins. If she doesn\'t have enough, she can\'t pass. We need to find the maximum number of coins she can give, which is equivalent to the maximum coins she can collect on the path minus C. But wait, no, the way the problem is phrased is that she needs to have at least C coins. So it\'s more about ensuring that the path she takes allows her to have enough coins. So the objective is to collect as much coins as possible along a path from start to end, and then subtract C

  0%|          | 0/1000 [00:00<?, ?it/s]

'\nOkay, so I\'ve got this programming problem to solve. Let\'s see what it\'s about. Hmm, the problem description says that Count Mishti has this castle challenge. The participant starts at cell [1][1], which is the top-left corner, and needs to reach [N][N], which is the bottom-right corner. Each cell has some coins, and the participant can only move either right or down. So the movement is restricted to those two directions.\n\nThe goal is to find the maximum coins the participant can have left when they reach the end. But wait, when they reach [N][N], they have to give the guard at least C coins. If they don\'t have enough, the result is -1. So the function needs to check if the collected coins minus the required C are >=0. If yes, that\'s the answer. Otherwise, -1.\n\nAlright, so the problem is essentially about finding the optimal path from (1,1) to (N,N) where you can only move right or down, and you want to maximize the coins left after subtracting C.\n\nWait, no, wait. Wait, l

In [122]:
# Run a second time, just to see variation
for i in range(0, 3):
    print(i)
    dir = x[:, i] * 3
    response = run_with_dir(dir)
    summarize_response(response)

0


  0%|          | 0/1000 [00:00<?, ?it/s]

"\n\nOkay, let's try to figure out this problem. So, Mish jest tomu_reaction_ps classnameInline.\n\nFrom the problem description, the user is saying that our team has to solve this problem where we need to find the maximum number of coins turns for the participant.\n\nThe problem says, the participant starts at (1,1) and needs to eliminate more that 'C' of his own reports, looking at the previous coercion.\n\nWait, the actual dimensions are a bit picturesque. So, the problem is similar to some of the most recent.\n\nWait, the user Vaní explanation and the idea of being, can I think of more.\n\nWait, More likely, the problem is first to find possible paths from (1, I think I need to model the problem correctly.\n\nSo, the semantics of the problem says that create dounds.\n\nThe problem is that, the problem is that the user is Mandašil.\nWe need to find The سریال.\n\nTusinati the Möatersita?\n\nYes, the problem is bevelled and esthmintote. So, The approach should be something,uggestion, 

  0%|          | 0/1000 [00:00<?, ?it/s]

'\n\n, 1   .\n\n    in 1\n\n\n\' 2 k\'r\n\n\n351\n\n4\n\n   22\n  \nfi\n\nh\n\n\n_  \n\n\nycling 1\') 9\n\n\n knives 7\n\n   7\n\n\n    01 \'because\n\n"We represent we can approach      )\n\n\'2\n\n1\n\n    �\n\n_sub         \n\n\n\n\n\n\n\n\nStarting 1\n\n\non\'t\n\n\n\n\n\nPutting 1\n\n\n\n\n\n      \n\n      \n\n   1    Ichange\n.\n\n123\n\n\n   3 5   \n\n 3_    5\' 1\n\n\nAss\n\n 3\n\n  5 \n\n.\n\n  автор…\n\n\n\n   \n.\n\n\n\n\n\nThe 2\n\n\nThe  f_IMI \n      3 5 , 1\n\n\n   .\n\n\nI\n\n\n   3 5 1.\n\nwith .\n\n\n\n\n\n酒\n\n\n  .\n\n\n\n\n\nthis  ta because 78…\n\n##\n\n\n\n\n\nNow,40\n\n\ndisc…\n\n lover \n\n_PASS\n\n\nusers �orm_\n\n 7.\nThat\n\nI\n\n\n_. I).\n\n\njust the outcomes possible\n\n\n thinks\n\n\nmaybe,\n\nbecause 1\n\n\n Mexico, me\n\nAss, re.) 12\n\nJune.\n\n\n\n\nLique 1   to (k\n\nIci   a\n\n\n   do\n\n\n 5.\n\n. 3\n\nUsing\n\n\n\n\n\na\n\n)\n\n\nRec\n\n   +1\n\n\t 5  \n\n\n→ 4.\n\n59.\n\nI\n\nI\n\nBut I.\n\n.\n\n\nBecause 1 \n\n\n that d.\n\nHelp.\n\n.\n\n1.%</

  0%|          | 0/1000 [00:00<?, ?it/s]

'\nox\nIncre\n * generators\nopic\n mp\n e**\n wor\n in\n \nduplicate\nics\ncontrbio\n gr\n InProgress\nop_PLL\n.\n  i\n   Alien\n pain\n opposite\n y etk\n  My\n l contr _.\n 0\n a\nop\n  t\n characteristics\n  wor\n  1\n  i\n  in\n  in\n  c\n  toll\n  in\n  in .\n\n\nI es\n  Low\n.\n UC\n contrary\n ei\n opposite\n  4\n  in\n  i\n  r\n  eia\n  pain\n  in\n  \n  d\n  in\n  r\n  be the\n  o\n  one\n  i\n  u\n  r backwards\nballs\n iM\n or .  stunning\nin\n  n\n  stab\n v\n  opposite\n  in\n  in\n  in\n  izo\n  ane\n  inv �r\n  analogue\n  * in\n  in\n  in.\n contrary\n  t\n  again\n  in\n  in\n in.\n the live\n  7 zero\n  in\n  in \n ar\n  opr\n in\n  in.\n  the\n uce 16\nistsiu\n  in\n  in\n  in.\n  u\n  in\n  the\n  contrary\n the least  in \n  in.\n in.\n  in.\n in.\n many\n desirable\n in\n  In the contrary\n in.\n  in.\n  , ch\n  in.\n  in.\n  in.\n  in.\n layouts 4\n  in\n  in.\n  in\n  in.\n  in\n  in.\n  in.\n in.\n  in.\n  in\n  in.\n  in.\n  in.\n  in\n  i\n \tc  think.\n  in

In [123]:
# Run a second time, just to see variation
for i in range(0, 3):
    print(i)
    dir = x[:, i] * 2
    response = run_with_dir(dir)
    summarize_response(response)

0


  0%|          | 0/1000 [00:00<?, ?it/s]

"\nBigIntegerDestroy\n\nThe function should be structured as an open override. Let's begin with the all set up is the basis.\n\nOkay, so the problem is to find the maximum number of coins the participant can have when she can't meet the requirement, the next thing is to determine the minimal required paths.\n\nWait, the problem says that when she makes two possible moves, let me clarify. If she can either go down or right once a time, but that's not correct. Wait, the movement is either from [i,j] to [i+1] or [i,this summary suggests that each move is either down or right, which will form a path from (1,1) to (2,a midway path of such.\n\nThe key point here is to determine the number of steps she can take, but more importantly, how the paths can compute the minimum total of the paths that he can remember, but let's look at the constraint. The optimal approach is to consider that he must collect a certain number of the paths when she arrives at (1,1)\n\nWait, no. The problem is that he s

  0%|          | 0/1000 [00:00<?, ?it/s]

'\nOkay, I have to solve this problem. Let\'s understand the problem first.\n\nSo, Mishti\'s castle has a corridor that\'s like a grid. The participant starts at (1,1) and has to reach (n,n). The cells are in a grid where each cell has some gold coins. The participant can move either down (from cell i,j to i+1,j) or right (from i,j to i,j+1) because she can\'t go back. But she can\'t go beyond the grid.\n\nThe goal is to find the maximum number of coins she can collect before leaving the corridor. Or, if she can\'t collect at least C coins, return -1.\n\nWait, wait. Wait. I think I need to clarify the problem. Because I read the problem statement again, it says, when she arrives at [N][N], she gives the guard AT LEAVING, so she must have at least C coins. So, the problem is: she can collect as much as possible, passing through all possible cells she can, but the sum of all coins collected before (but not including) the last cell. Because when she is at the start, [1][1], she has a[1][1

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 