In [1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import copy
import torch.nn as nn
from typing import Any, Callable, Literal, TypeVar, overload
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, l1_lambda=0): # 1e-4
        super(SparseAutoencoder, self).__init__()
        self.l1_lambda = l1_lambda
        self.mse_loss_fn = nn.MSELoss()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
        )

        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    hidden_dim, input_dim
                )
            )
        )

        self.b_dec = nn.Parameter(
            torch.zeros(input_dim)
        )

    def forward(self, x):
        x = x - self.b_dec
        z = self.encoder(x)
        x_hat = z @ self.W_dec + self.b_dec
        return x_hat, z
    
    def decode(self, z):
        x_hat = z @ self.W_dec + self.b_dec
        return x_hat

    def loss(self, x, x_hat, z):
        mse_loss = self.mse_loss_fn(x_hat, x)
        l1_loss = torch.norm(z, 1)  # L1 regularization
        total_loss = mse_loss + self.l1_lambda * l1_loss
        return total_loss, mse_loss, l1_loss


best_SAE_model = SparseAutoencoder(input_dim=896, hidden_dim=896*20).cuda()
best_SAE_model.load_state_dict(torch.load('./model/20250328-035329/best_model.pth'))


<All keys matched successfully>

In [11]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval().cuda()

token_id_bus = str(tokenizer(" bus", return_tensors="pt")["input_ids"].numpy().tolist()[0][0])
token_id_car = str(tokenizer(" car", return_tensors="pt")["input_ids"].numpy().tolist()[0][0])
print(token_id_bus)
print(token_id_car)

5828
1803


In [None]:
with open('./data/token_feature_count.json', 'r', encoding='utf-8') as f:
    token_feature_count = json.load(f)

# Activate features related to ' bus' and suppress other features.
def hook_intervene_bus(module, input, output):

    with torch.no_grad():
        _, z = best_SAE_model(output[0][0, -1, :])
        z[token_feature_count[token_id_car]] = 0
        z /= 5
        z[token_feature_count[token_id_bus]] += 1.0
        output[0][0, -1, :] = best_SAE_model.decode(z).reshape(-1)
        new_output = output

    return new_output

# Activate features related to ' car' and suppress other features.
def hook_intervene_car(module, input, output):

    with torch.no_grad():

        _, z = best_SAE_model(output[0][0, -1, :])
        z[token_feature_count[token_id_bus]] = 0
        z /= 2
        z[token_feature_count[token_id_car]] += 1.0
        output[0][0, -1, :] = best_SAE_model.decode(z).reshape(-1)

        new_output = output

    return new_output

# Use SAE to construct features during LLM inference
def hook_SAE(module, input, output):

    with torch.no_grad():
        output[0][0, -1, :], _ = best_SAE_model(output[0][0, -1, :])
        new_output = output

    return new_output

# Answer question using llm
def generate(query = "Generally, which is bigger, a car or a bus?"):

    prompt = f"Give me a one-word answer\n\n"
    prompt += query

    messages = [
        {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]

    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    dummy_input = tokenizer(text, return_tensors="pt")["input_ids"].cuda()
    eos_token_id = tokenizer.eos_token_id

    torch.cuda.empty_cache() 
    for i in range(20):

        with torch.no_grad():
            
            out = model(dummy_input)
            next_token_id = torch.argmax(out.logits[:, -1, :], dim=-1).unsqueeze(0)  # shape: (1, 1)
            dummy_input = torch.cat([dummy_input, next_token_id], dim=1)

        if next_token_id.item() == eos_token_id:
            break
        
    if True:

        full_response = tokenizer.decode(dummy_input[0], skip_special_tokens=True)

        if query in full_response:
            answer = full_response.split(query)[-1].strip().replace('assistant', '').replace('\n', '')
        else:
            answer = full_response.strip().replace('assistant', '').replace('\n', '')

        print(f"Q: {query}")
        print(f"A: {answer}")



In [20]:
try:
    handle.remove()
except:
    pass

print('Origin output:\n')
generate()

print('\n\nThrough SAE:\n')
handle = model.model.layers[14].register_forward_hook(hook_SAE)
generate()
handle.remove()

print('\n\nActivate features of bus:\n')
handle = model.model.layers[14].register_forward_hook(hook_intervene_bus)
generate()
handle.remove()

print('\n\nActivate features of car:\n')
handle = model.model.layers[14].register_forward_hook(hook_intervene_car)
generate()
handle.remove()



Origin output:

Q: Generally, which is bigger, a car or a bus?
A: Bus.


Through SAE:

Q: Generally, which is bigger, a car or a bus?
A:  It depends on the size and weight of the vehicle.


Activate features of bus:

Q: Generally, which is bigger, a car or a bus?
A:  a bus. It is a more, crowded, the city of the people that the bus of the


Activate features of car:

Q: Generally, which is bigger, a car or a bus?
A:  a car
