In [1]:
# 设置环境变量
import os
import sys
sys.path.append('D:\ComputerScience\Research\PRADA\sparse_autoencoder')
# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# 导入库
import torch
import blobfile as bf
from experiments.utils import *
import pandas as pd
import matplotlib.pyplot as plt
from transformers import GPT2LMHeadModel, AutoTokenizer, GPT2Tokenizer, GPT2Config

In [2]:
def feature_steering(autoencoder,x: torch.Tensor, feature_indices: list[int], feature_values: list[float]) -> torch.Tensor:
    assert len(feature_indices) == len(feature_values), "Feature indices and values must have the same length."
    feature_values = [max(min(value, 10), -10) for value in feature_values]

    with torch.no_grad():
        # 获取原始特征表示和信息
        latents, info = autoencoder.encode(x)
        # 修改特征表示
        for index, value in zip(feature_indices, feature_values):
            print("original:", latents[:, index])
            if value > 0:
                latents[:, index] *= value
            else:
                latents[:, index] = latents[:, index] / abs(value)
            print("Modified:", latents[:, index])
            print(f"Feature {index} modified with {'+' if value >= 0 else ''}{value}")
        # 使用修改后的特征表示通过解码器生成重构输出
        modified_output = autoencoder.decode(latents, info)
    return modified_output

def calculate_error(input_tensor, reconstructed_activations) -> torch.Tensor:
    # 计算误差
    error = input_tensor - reconstructed_activations
    # 可以选择使用不同的误差度量方式，这里使用均方误差（MSE）
    normalized_mse = (reconstructed_activations - input_tensor).pow(2).sum(dim=1) / (input_tensor).pow(2).sum(dim=1)
    return normalized_mse, error


def compare_activations(tensor1, tensor2):
    difference = tensor1 - tensor2
    print("Difference between tensors:\n", difference)

    # 计算差异的统计信息
    mean_diff = torch.mean(difference)
    std_diff = torch.std(difference)
    print(f"Mean difference: {mean_diff.item()}")
    print(f"Standard deviation of difference: {std_diff.item()}")

    # 可视化差异
    difference_np = difference.numpy()
    plt.figure(figsize=(20, 5))
    plt.imshow(difference_np, cmap='coolwarm', aspect='auto')
    plt.colorbar(label='Difference')
    plt.title('Difference between Reconstructed Activations and Modified Output')
    plt.xlabel('Feature Index')
    plt.ylabel('Sample Index')
    plt.show()

In [4]:
model, auto_tokenizer, device = load_model_hf("gpt2")
layer_index = 6
location = "resid_post_mlp"
autoencoder = load_autoencoder(location, layer_index, device, 128)

current sae:az://openaipublic/sparse-autoencoder/gpt2-small/resid_post_mlp_v5_128k/autoencoders/6.pt


In [15]:
# 计算误差
mse_error, error = calculate_error(activation, recon_activations)
print("MSE:", mse_error)
print("Error Tensor:", error)
print(error.shape)

MSE: tensor([3.6126e-05, 3.3759e-02, 5.5583e-02, 5.5487e-02, 6.2234e-02])
Error Tensor: tensor([[ 0.6114,  0.2868,  0.3416,  ...,  0.1430, -0.0825,  0.4055],
        [ 0.5793,  0.8852,  0.1662,  ..., -1.2331, -0.4440, -0.5244],
        [-0.1159, -0.9248,  0.0317,  ..., -0.9406,  0.9031, -0.0258],
        [ 2.0144,  0.0149, -0.5361,  ...,  0.1203,  1.0314, -0.3879],
        [ 0.0770,  0.7668, -0.1243,  ..., -1.3477,  0.1813,  1.1257]])
torch.Size([5, 768])


# Activation Reconstruction

In [17]:
prompt = "Are you introverted?"
feature_indices = [53912]
feature_values = [-10] 
tokens_id, tokens_str, activation_cache = process_input_hf(model, auto_tokenizer, prompt)
print("Tokens ID (AutoTokenizer):", tokens_id)
print("Tokens String (AutoTokenizer):", tokens_str)
print(len(activation_cache))
activation = get_activation_hf(activation_cache, layer_index)
print(f"resid_post_mlp for layer {layer_index}:", activation.shape if activation is not None else "None")
print(activation)

latent_activations, recon_activations = encode_decode(autoencoder, activation)
mse_error, error = calculate_error(activation, recon_activations)

modified_recon_activations = feature_steering(autoencoder, activation, feature_indices, feature_values)
print("orginal modified_recon_activations:", modified_recon_activations)
print(modified_recon_activations.shape)

modified_recon_activations_new = modified_recon_activations + error
print("modified_recon_activations + error:", modified_recon_activations_new)

mse_error_after, error_after = calculate_error(activation, modified_recon_activations_new)
print(error_after)
print(mse_error_after)

Tokens ID (AutoTokenizer): tensor([[ 8491,   345, 18951, 13658,    30]])
Tokens String (AutoTokenizer): ['Are', 'Ġyou', 'Ġintro', 'verted', '?']
13
resid_post_mlp for layer 6: torch.Size([5, 768])
tensor([[ 0.8602,  0.1959,  0.4753,  ..., -1.8215, -0.2200,  0.4110],
        [-0.5927,  0.9537, -2.2768,  ..., -1.5256, -2.6308,  1.6227],
        [ 3.2329, -4.1209, -2.8176,  ..., -3.9215,  1.4423, -1.0927],
        [ 4.3366,  0.8113, -2.9855,  ..., -1.3056,  5.1563,  0.3707],
        [ 1.3002,  1.2545, -1.8011,  ..., -3.2269, -0.4410,  2.9988]])
original: tensor([0.0000, 0.0000, 4.0403, 7.6410, 0.0000])
Modified: tensor([0.0000, 0.0000, 0.4040, 0.7641, 0.0000])
Feature 53912 modified with -10
orginal modified_recon_activations: tensor([[ 0.2488, -0.0908,  0.1337,  ..., -1.9644, -0.1375,  0.0055],
        [-1.1720,  0.0685, -2.4429,  ..., -0.2925, -2.1868,  2.1471],
        [ 2.9070, -2.8697, -2.8445,  ..., -2.9103, -0.7883, -1.6302],
        [ 1.4481,  1.4420, -2.4400,  ..., -1.2863,  1.49

In [18]:
def chat_with_gpt2_logits(model, tokens_id, tokenizer):
    model.eval()
    
    with torch.no_grad():
        outputs = model(tokens_id)
        logits = outputs.logits
        # 使用torch.argmax选出概率最高的token ids
        predicted_ids = torch.argmin(logits, dim=-1)
        # 解码生成的token ids
        response = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
        
    return response, logits

def chat_with_gpt2_top_k_candidates(model, tokens_id, tokenizer, top_k=10):
    model.eval()
    
    with torch.no_grad():
        outputs = model(tokens_id)
        logits = outputs.logits
        
        # 选择每个时间步上概率最高的top_k个token的logits
        top_k_logits, top_k_indices = torch.topk(logits, k=top_k, dim=-1)
        
        # 解码每个token的索引以获取token字符串
        top_k_tokens = [
            [tokenizer.decode([idx]) for idx in indices[0]] for indices in top_k_indices
        ]

    for step in range(min(10, logits.shape[1])):  # 限制打印至最多前10个token
            print(f"Step {step + 1}:")
            for i in range(top_k):
                token = tokenizer.decode([top_k_indices[0, step, i]])
                logit = top_k_logits[0, step, i].item()
                print(f"  Candidate {i + 1}: {token} (Logit: {logit})")
            print("\n")
        
    return top_k_tokens, top_k_logits

### Original Output

In [27]:
def chat_with_gpt2(model, tokens_id):
    with torch.no_grad():
        outputs = model.generate(tokens_id, max_length=100, pad_token_id=auto_tokenizer.eos_token_id)
    response = auto_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

response = chat_with_gpt2(model, tokens_id)
response_l, logits = chat_with_gpt2_logits(model, tokens_id, auto_tokenizer)
chat_with_gpt2_top_k_candidates(model, tokens_id, auto_tokenizer)
print("Original Output:", response)
print("logits type Output:", response_l)

Step 1:
  Candidate 1:  the (Logit: -29.919225692749023)
  Candidate 2:  a (Logit: -30.52708625793457)
  Candidate 3:  to (Logit: -30.872114181518555)
  Candidate 4: , (Logit: -31.00540542602539)
  Candidate 5: 
 (Logit: -31.133556365966797)
  Candidate 6:  you (Logit: -31.3023681640625)
  Candidate 7: . (Logit: -31.325265884399414)
  Candidate 8:  in (Logit: -31.418594360351562)
  Candidate 9:  that (Logit: -31.46228790283203)
  Candidate 10:  it (Logit: -31.574438095092773)


Step 2:
  Candidate 1:  a (Logit: -119.24163055419922)
  Candidate 2:  sure (Logit: -119.3289794921875)
  Candidate 3:  going (Logit: -119.41527557373047)
  Candidate 4:  ready (Logit: -119.50484466552734)
  Candidate 5:  looking (Logit: -120.21627807617188)
  Candidate 6:  still (Logit: -120.2265396118164)
  Candidate 7:  interested (Logit: -120.4596176147461)
  Candidate 8:  using (Logit: -120.46378326416016)
  Candidate 9:  in (Logit: -120.65441131591797)
  Candidate 10:  worried (Logit: -120.6690444946289)



### Controlled Output

In [19]:
modified_activations = modified_recon_activations_new.unsqueeze(0)
print(modified_activations.shape)
print(modified_activations)
print(tokens_id)
print(tokens_str)

torch.Size([1, 5, 768])
tensor([[[ 0.8602,  0.1959,  0.4753,  ..., -1.8215, -0.2200,  0.4110],
         [-0.5927,  0.9537, -2.2768,  ..., -1.5256, -2.6308,  1.6227],
         [ 2.7911, -3.7945, -2.8128,  ..., -3.8509,  0.1147, -1.6560],
         [ 3.4625,  1.4570, -2.9760,  ..., -1.1660,  2.5300, -0.7437],
         [ 1.3002,  1.2545, -1.8011,  ..., -3.2269, -0.4410,  2.9988]]])
tensor([[ 8491,   345, 18951, 13658,    30]])
['Are', 'Ġyou', 'Ġintro', 'verted', '?']


In [20]:
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
class CustomGPT2LMHeadModel(GPT2LMHeadModel):
    def __init__(self, config, layer_to_modify=6, modified_activations=modified_activations):
        super().__init__(config)
        self.layer_to_modify = layer_to_modify
        self.modified_activations = modified_activations

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Transformer层的输出
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]


        # 在第六层插入修改后的激活值
        if self.modified_activations is not None and len(transformer_outputs.hidden_states) >= self.layer_to_modify:
            # 从指定层开始使用提供的激活值进行修改
            modified_states = self.modified_activations
            for i in range(self.layer_to_modify, len(self.transformer.h)):
                layer_module = self.transformer.h[i]
                layer_outputs = layer_module(modified_states, attention_mask=None)  
                modified_states = layer_outputs[0]

            # 将最终输出设置为最后一层修改后的输出
            hidden_states = modified_states

        # 应用语言模型头
        lm_logits = self.lm_head(hidden_states)
        loss = None
        if labels is not None:
            # Flatten the tokens
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )


In [24]:
config = GPT2Config.from_pretrained('gpt2', output_hidden_states=True)
model_modified = CustomGPT2LMHeadModel(config, layer_to_modify=6, modified_activations=modified_activations)
#model_modified = CustomGPT2LMHeadModel.from_pretrained('gpt2', output_hidden_states=True)
model_modified.eval() 

CustomGPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [25]:
generated_outputs = model_modified.generate(
    tokens_id,
    max_length=20,
    output_hidden_states=True,
    pad_token_id=auto_tokenizer.eos_token_id,
)

# 解码生成的文本
generated_text = auto_tokenizer.decode(generated_outputs[0], skip_special_tokens=True, pad_token_id=auto_tokenizer.eos_token_id)
print("Controlled output:", generated_text)

Controlled output: Are you introverted?anutanutanutanutanutanutanutanutanutanutanutanutanutanutanut


In [26]:
generated_text_l, logits = chat_with_gpt2_logits(model_modified, tokens_id, auto_tokenizer)
chat_with_gpt2_top_k_candidates(model_modified, tokens_id, auto_tokenizer)
print("logits type Steered Output:", generated_text_l)

Step 1:
  Candidate 1: Solution (Logit: 248.58131408691406)
  Candidate 2:  AJ (Logit: 244.9464111328125)
  Candidate 3:  Sab (Logit: 239.6248016357422)
  Candidate 4:  con (Logit: 239.08999633789062)
  Candidate 5:  rev (Logit: 237.34375)
  Candidate 6:  gravity (Logit: 236.20005798339844)
  Candidate 7:  exalted (Logit: 235.65975952148438)
  Candidate 8: ithing (Logit: 229.19241333007812)
  Candidate 9: 195 (Logit: 228.90213012695312)
  Candidate 10:  Courtney (Logit: 228.195068359375)


Step 2:
  Candidate 1:  drops (Logit: 6.763707637786865)
  Candidate 2:  embody (Logit: 6.657547950744629)
  Candidate 3:  interruption (Logit: 6.513821601867676)
  Candidate 4:  Words (Logit: 6.3390913009643555)
  Candidate 5:  twisted (Logit: 6.283496379852295)
  Candidate 6: (\ (Logit: 6.263243675231934)
  Candidate 7:  Riding (Logit: 6.154960632324219)
  Candidate 8: ias (Logit: 6.122747421264648)
  Candidate 9: hound (Logit: 6.067363739013672)
  Candidate 10:  sinks (Logit: 6.004322052001953)


