In [321]:
import torch
import os 
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import importlib
import utils 
importlib.reload(utils)
from utils import load_jsons, format_conversations_with_thinking, extract_activations_from_conversation, process_all_conversations

model_name = 'Qwen/Qwen3-8B'  
device = 'cuda'

In [343]:
seq_len = 500
# start = model.model.config.bos_token_id
# end = model.model.config.eos_token_id

In [None]:
model = AutoModelForCausalLM.from_pretrained('/workspace/models/Qwen3-1.7B',
                             torch_dtype = torch.bfloat16,
                             trust_remote_code = True).to(device)

tokenizer = AutoTokenizer.from_pretrained('/workspace/models/Qwen3-1.7B')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
convos = load_jsons('data')
convos_formatted = format_conversations_with_thinking(convos)

convos_chatml_list = []
for convo in convos_formatted: 
    convo_formatted = tokenizer.apply_chat_template(convo['conversation'], tokenize = False)
    convo_formatted = convo_formatted.replace("<thinking>", "<think>").replace("</thinking>", "</think>")
    convos_chatml_list.append(convo_formatted)


Loaded 12 JSON files


In [344]:
tokenized_convos = tokenizer(convos_chatml_list, 
          truncation = True,
          max_length = seq_len, 
          padding = 'max_length',
          return_tensors = 'pt',
          return_offsets_mapping = True) 

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import random

# shares how to staple pieces together when batch size > 1, o.w. it's transposed 
def stack_collate(batch): 
    """collate function that specifies how to stack lists s.t. retains o.g. structure"""
    stacked = {}
    for k in batch[0]:
        vals = [b[k] for b in batch]
        stacked[k] = torch.stack(vals, dim = 0) if torch.is_tensor(vals[0]) else vals
    return stacked

class textDatasetClass(Dataset): 
    def __init__(self, texts, input_ids, attention_mask, offset_mapping):
        self.texts = texts
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.offset_mapping = offset_mapping 
       
        # word tokens list 
        og_tokens_list = []
        for seq_idx in range(len(offset_mapping.tolist())): 
            seq_locs = []
            for start, end in map[seq_idx]: 
                seq_words = texts[seq_idx][start:end] 
                seq_locs.append(seq_words)
            og_tokens_list.append(seq_locs)

        self.og_tokens_list = og_tokens_list
        
    def __len__(self): 
        return len(self.texts)

    def __getitem__(self, idx): 
        return {'texts': self.texts[idx], 
                'input_ids': self.input_ids[idx], 
                'attention_mask': self.attention_mask[idx],
                'og_tokens_list': self.og_tokens_list[idx]} 

In [225]:
heplo = textDatasetClass(raw_truths, tokenized_truths['input_ids'], tokenized_truths['attention_mask'], tokenized_truths['offset_mapping'])
loader = DataLoader(heplo, batch_size = 3, shuffle = True, collate_fn = stack_collate)

In [322]:
activations = torch.empty((0, seq_len, model.model.config.num_hidden_layers + 1, model.model.config.hidden_size))
tokens_list = []
for batch in loader: 
    with torch.no_grad():
        outputs = model(batch['input_ids'].to(device), batch['attention_mask'].to(device), output_hidden_states = True)
        batch_activations = torch.stack(outputs['hidden_states'], axis = 2).detach().cpu()
        activations = torch.cat((activations, batch_activations), dim = 0)
        tokens_list.extend(batch['og_tokens_list']) # stapling along an existing dimension, just like with activations 
    # break # this will now break after the first guy :p 

In [323]:
tok_locs = []
for seq_idx, seq in enumerate(tokens_list):
    for tok_idx, token in enumerate(seq):
        row = {'seq_idx': seq_idx,
               'tok_idx': tok_idx,
               'token': token,
               'entity': []}
        tok_locs.append(row)
        

tok_df = pd.DataFrame(tok_locs)

In [324]:
conditions = [
      (tok_df['token'] == 'user') & (tok_df['token'].shift(1) == '<|im_start|>'),
      (tok_df['token'] == 'assistant') & (tok_df['token'].shift(1) ==
  '<|im_start|>')
  ]
choices = ['user', 'assistant']

tok_df['entity'] = pd.Series(np.select(conditions, choices, default = None))
tok_df['entity'] = tok_df.groupby('seq_idx')['entity'].transform(lambda x: x.ffill())

# remove entities 
target_tokens = ['<|im_start|>', '<|im_end|>', '<pad>']
tok_df = tok_df.assign(entity = lambda df: np.where(df['token'].isin(['<|im_start|>', '<|im_end|>']), None, df['entity'])) # replace start/end
tok_df = tok_df.assign(entity = lambda df: np.where((df['token'].shift(1).isin(['<|im_start|>', '<|im_end|>'])) & (df['token'].isin(['user', 'assistant'])), None, df['entity'])) # replace user/assistant 
tok_df = tok_df.assign(entity = lambda df: np.where((df['token'] == '\n') & (df['token'].shift(2) == '<|im_start|>'), None, df['entity']))
tok_df = tok_df.assign(entity = lambda df: np.where((df['token'] == '\n') & (df['token'].shift(1) == '<|im_end|>'), None, df['entity']))