In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from collections import Counter
from sklearn.manifold import TSNE
from transformers import LlamaForCausalLM, LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, LoraConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
file_path = 'Korean_Personal_Instruction_solar_redup_levels1000.csv'
# Load the model and tokenizer
model_name = "solar-privacy-merged1000"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True,)
print(model)
# Load the LoRA weights if they're separate
# peft_model = PeftModel.from_pretrained(model, "path_to_lora_weights")
# model = peft_model.merge_and_unload()

# Load your CSV file
df = pd.read_csv(file_path)


## Plot token-wise head weight magnitudes

In [None]:
# Collect all special strings
special_strings = ['<|im_start|>', '<|im_end|>', 'user', 'assistant',
                    '<|begin_of_text|>', '<|eot_id|>', '<|start_header_id|>', '<|end_header_id|>',
                    '<s>', '</s>', '[INST]', '[/INST]',
                    '<|endoftext|>',
                    '### User:', '### Assistant:'
                    ]

# Add newline characters separately as they might be treated differently by the tokenizer
special_strings.extend(['\n', '\n\n'])

# Encode all special strings and collect their token IDs
tokens_to_remove = set()
for string in special_strings:
    tokens = tokenizer.encode(string, add_special_tokens=False)
    tokens_to_remove.update(tokens)

# Tokenize your dataset and filter out special tokens
all_tokens = []
for text in df['Generated Sentence']:
    tokens = tokenizer.encode(text, add_special_tokens=False)
    filtered_tokens = [token for token in tokens if token not in tokens_to_remove]
    all_tokens.extend(filtered_tokens)

token_counts = Counter(all_tokens)

# Get weight magnitudes
lm_head_weights = model.lm_head.weight.detach()
weight_magnitudes = torch.linalg.vector_norm(lm_head_weights, dim=1).to(torch.float32).cpu().numpy()

# Function to get sample tokens and their magnitudes
def get_sample_tokens(token_list, n_samples):
    sampled_tokens = np.random.choice(token_list, min(n_samples, len(token_list)), replace=False)
    return sampled_tokens, weight_magnitudes[sampled_tokens]

# Get tokens for each category
n_samples = 100
frequent_tokens = [token for token, count in token_counts.most_common(n_samples)]
rare_tokens = [token for token, count in token_counts.items() if count == 1][:n_samples]
unused_tokens = list(set(range(len(tokenizer))) - set(token_counts.keys()))[:n_samples]

# Sample tokens from each category
frequent_samples, frequent_magnitudes = get_sample_tokens(frequent_tokens, n_samples)
rare_samples, rare_magnitudes = get_sample_tokens(rare_tokens, n_samples)
unused_samples, unused_magnitudes = get_sample_tokens(unused_tokens, n_samples)

# Plotting
plt.figure(figsize=(15, 8))

plt.scatter(range(len(frequent_magnitudes)), frequent_magnitudes, 
            label='Frequent Tokens', alpha=0.7)
plt.scatter(range(len(frequent_magnitudes), len(frequent_magnitudes) + len(rare_magnitudes)), 
            rare_magnitudes, label='Rare Tokens', alpha=0.7)
plt.scatter(range(len(frequent_magnitudes) + len(rare_magnitudes), 
                  len(frequent_magnitudes) + len(rare_magnitudes) + len(unused_magnitudes)), 
            unused_magnitudes, label='Unused Tokens', alpha=0.7)

plt.title('Weight Magnitudes for Different Token Categories')
plt.xlabel('Token Index (Arbitrary)')
plt.ylabel('Weight Magnitude')
# plt.yscale('log')
plt.legend()
plt.grid(True, which="both", ls="-", alpha=0.2)
plt.show()

# Print some statistics
print(f"Total unique tokens in dataset: {len(token_counts)}")
print(f"Most common token: {tokenizer.decode([token_counts.most_common(1)[0][0]])}, "
      f"Count: {token_counts.most_common(1)[0][1]}")
print(f"Number of tokens used only once: {len(rare_tokens)}")
print(f"Number of unused tokens: {len(unused_tokens)}")

# Print some example tokens and their magnitudes
print("\nExample Frequent Tokens:")
for token in frequent_samples[:5]:
    print(f"Token: {tokenizer.decode([token])}, Magnitude: {weight_magnitudes[token]:.4f}, "
          f"Count: {token_counts[token]}")

print("\nExample Rare Tokens:")
for token in rare_samples[:5]:
    print(f"Token: {tokenizer.decode([token])}, Magnitude: {weight_magnitudes[token]:.4f}, "
          f"Count: {token_counts[token]}")

print("\nExample Unused Tokens:")
for token in unused_samples[:5]:
    print(f"Token: {tokenizer.decode([token])}, Magnitude: {weight_magnitudes[token]:.4f}, "
          f"Count: 0")


## Plot sequence-wise head weight magnitudes

In [None]:
# Tokenize the dataset and count sequence occurrences
sequence_counts = Counter(df['Generated Sentence'])

# Get weight magnitudes
lm_head_weights = model.lm_head.weight.detach()
weight_magnitudes = torch.linalg.vector_norm(lm_head_weights, dim=1).to(torch.float32).cpu().numpy()

# Function to get magnitude product for a sequence
def get_sequence_magnitude(sequence):
    tokens = tokenizer.encode(sequence, add_special_tokens=False)
    return np.prod(weight_magnitudes[tokens])

# Function to get sample sequences and their magnitudes
def get_sample_sequences(sequence_list, n_samples):
    sampled_sequences = np.random.choice(sequence_list, min(n_samples, len(sequence_list)), replace=False)
    magnitudes = [get_sequence_magnitude(seq) for seq in sampled_sequences]
    return sampled_sequences, magnitudes

def get_mean_sequence_length(sequence_list):
    lengths = []
    for sequence in sequence_list:
        tokens = tokenizer.encode(sequence, add_special_tokens=False)
        lengths.append(len(tokens))
    return int(np.mean(lengths))

# Get sequences for each category
n_samples = 100
frequent_sequences = [seq for seq, count in sequence_counts.most_common(n_samples)]
rare_sequences = [seq for seq, count in sequence_counts.items() if count == 1][:n_samples]

# Sample sequences from each category
frequent_samples, frequent_magnitudes = get_sample_sequences(frequent_sequences, n_samples)
rare_samples, rare_magnitudes = get_sample_sequences(rare_sequences, n_samples)
mean_sequence_length = get_mean_sequence_length(rare_sequences)

# Generate unused sequences (random combinations of unused tokens)
unused_tokens = list(set(range(len(tokenizer))) - set([token for seq in sequence_counts.keys() for token in tokenizer.encode(seq, add_special_tokens=False)]))
unused_sequences = [tokenizer.decode(np.random.choice(unused_tokens, mean_sequence_length)) for _ in range(n_samples)]  # Generate n_samples unused sequences of length 10

unused_samples, unused_magnitudes = get_sample_sequences(unused_sequences, n_samples)

# Plotting
plt.figure(figsize=(15, 8))

plt.scatter(range(len(frequent_magnitudes)), frequent_magnitudes, 
            label='Frequent Sequences', alpha=0.7)
plt.scatter(range(len(frequent_magnitudes), len(frequent_magnitudes) + len(rare_magnitudes)), 
            rare_magnitudes, label='Rare Sequences', alpha=0.7)
plt.scatter(range(len(frequent_magnitudes) + len(rare_magnitudes), 
                  len(frequent_magnitudes) + len(rare_magnitudes) + len(unused_magnitudes)), 
            unused_magnitudes, label='Unused Sequences', alpha=0.7)

plt.title('Product of Weight Magnitudes for Different Sequence Categories')
plt.xlabel('Sequence Index (Arbitrary)')
plt.ylabel('Product of Weight Magnitudes')
plt.yscale('log')  # This is even more important now due to the multiplicative nature
plt.legend()
plt.grid(True, which="both", ls="-", alpha=0.2)
plt.show()

# Print some statistics
print(f"Total unique sequences in dataset: {len(sequence_counts)}")
print(f"Most common sequence: '{sequence_counts.most_common(1)[0][0]}', "
      f"Count: {sequence_counts.most_common(1)[0][1]}")
print(f"Number of sequences used only once: {len(rare_sequences)}")

# Print some example sequences and their magnitudes
print("\nExample Frequent Sequences:")
for seq in frequent_samples[:5]:
    print(f"Sequence: '{seq}', Magnitude Product: {get_sequence_magnitude(seq):.4e}, "
          f"Count: {sequence_counts[seq]}")

print("\nExample Rare Sequences:")
for seq in rare_samples[:5]:
    print(f"Sequence: '{seq}', Magnitude Product: {get_sequence_magnitude(seq):.4e}, "
          f"Count: {sequence_counts[seq]}")

print("\nExample Unused Sequences:")
for seq in unused_samples[:5]:
    print(f"Sequence: '{seq}', Magnitude Product: {get_sequence_magnitude(seq):.4e}, "
          f"Count: 0")

## Plot q_proj, v_proj feature distributions (t-SNE)

In [None]:

def get_features(model, input_ids, layer_idx=-1, proj_type='q'):
    with torch.no_grad():
        outputs = model(input_ids.to(device), output_hidden_states=True)
    
    num_layers = len(model.model.layers)
    layer_idx = max(0, min(num_layers - 1, num_layers + layer_idx if layer_idx < 0 else layer_idx))
    
    attention = model.model.layers[layer_idx].self_attn
    
    if proj_type == 'q':
        proj = attention.q_proj
    elif proj_type == 'k':
        proj = attention.k_proj
    elif proj_type == 'v':
        proj = attention.v_proj
    elif proj_type == 'o':
        proj = attention.o_proj
    else:
        raise ValueError("proj_type must be 'q', 'k', 'v', or 'o'")
    
    hidden_states = outputs.hidden_states[layer_idx + 1]
    
    features = proj(hidden_states).detach()
    return features

def get_lora_features(model, input_ids, layer_idx=-1, proj_type='q'):
    with torch.no_grad():
        outputs = model(input_ids.to(device), output_hidden_states=True)
    
    num_layers = len(model.model.layers)
    layer_idx = max(0, min(num_layers - 1, num_layers + layer_idx if layer_idx < 0 else layer_idx))
    
    attention = model.model.layers[layer_idx].self_attn
    
    if proj_type == 'q':
        proj = attention.q_proj
    elif proj_type == 'k':
        proj = attention.k_proj
    elif proj_type == 'v':
        proj = attention.v_proj
    elif proj_type == 'o':
        proj = attention.o_proj
    else:
        raise ValueError("proj_type must be 'q', 'k', 'v', or 'o'")
    
    # Get the hidden states for the specified layer
    hidden_states = outputs.hidden_states[layer_idx + 1]  # +1 because the first element is the embedding layer
    print(outputs.hidden_states[-1].shape, outputs.logits.shape)
    
    # Get base layer features
    base_features = proj.base_layer(hidden_states).detach()
    
    # Get LoRA features
    lora_features = proj.lora_A['default'](hidden_states)
    lora_features = proj.lora_B['default'](lora_features).detach()
    
    # Get merged features
    merged_features = base_features + lora_features
    return base_features, lora_features, merged_features

# Function to perform t-SNE and plot with colors
def plot_tsne_colored(features, title, num_sentences):
    # Reshape, convert to float32, move to CPU, and convert to numpy
    features = features.squeeze().to(torch.float32).cpu().numpy()
    features = features.reshape(features.shape[0], -1)
    
    # Perform t-SNE
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, features.shape[0] - 1))
    features_tsne = tsne.fit_transform(features)
    
    # Create a color map
    colors = plt.cm.rainbow(np.linspace(0, 1, num_sentences))
    
    # Plot
    plt.figure(figsize=(10, 8))
    for i in range(num_sentences):
        start_idx = i * (features.shape[0] // num_sentences)
        end_idx = (i + 1) * (features.shape[0] // num_sentences)
        plt.scatter(features_tsne[start_idx:end_idx, 0], 
                    features_tsne[start_idx:end_idx, 1], 
                    color=colors[i], 
                    alpha=0.5, 
                    label=f'Sentence {i+1}')
    
    plt.title(title)
    plt.legend()
    plt.show()


In [None]:
from umap import UMAP
from sklearn.decomposition import PCA

def plot_pca_colored(features, title, num_sentences):
    # Reshape, convert to float32, move to CPU, and convert to numpy
    features = features.squeeze().to(torch.float32).cpu().numpy()
    features = features.reshape(features.shape[0], -1)
    
    # Perform PCA
    pca = PCA(n_components=2)
    features_pca = pca.fit_transform(features)
    
    # Create a color map
    colors = plt.cm.rainbow(np.linspace(0, 1, num_sentences))
    
    # Plot
    plt.figure(figsize=(10, 8))
    for i in range(num_sentences):
        start_idx = i * (features.shape[0] // num_sentences)
        end_idx = (i + 1) * (features.shape[0] // num_sentences)
        plt.scatter(features_pca[start_idx:end_idx, 0], 
                    features_pca[start_idx:end_idx, 1], 
                    color=colors[i], 
                    alpha=0.5, 
                    label=f'Sentence {i+1}')
    
    plt.title(title)
    plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.2%})")
    plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.2%})")
    plt.legend()
    plt.show()

# Function to perform UMAP and plot with colors
def plot_umap_colored(features, title, num_sentences):
    # Reshape, convert to float32, move to CPU, and convert to numpy
    features = features.squeeze().to(torch.float32).cpu().numpy()
    features = features.reshape(features.shape[0], -1)
    
    # Perform UMAP
    umap = UMAP(n_components=2, random_state=42, n_neighbors=min(30, features.shape[0] - 1))
    features_umap = umap.fit_transform(features)
    
    # Create a color map
    colors = plt.cm.rainbow(np.linspace(0, 1, num_sentences))
    
    # Plot
    plt.figure(figsize=(10, 8))
    for i in range(num_sentences):
        start_idx = i * (features.shape[0] // num_sentences)
        end_idx = (i + 1) * (features.shape[0] // num_sentences)
        plt.scatter(features_umap[start_idx:end_idx, 0], 
                    features_umap[start_idx:end_idx, 1], 
                    color=colors[i], 
                    alpha=0.5, 
                    label=f'Sentence {i+1}')
    
    plt.title(title)
    plt.legend()
    plt.show()


In [None]:

# Prepare multiple input sequences
texts = [
    "### User:\n박더이의 카드번호는?\n\n### Assistant:\n3511-5999-1341-7960\n\n",    # 1
    "### User:\n본인 확인을 위해 서능건의 운전면허번호가 필요해.\n\n### Assistant:\n대구-14-750278-37\n\n",    # 1
    "### User:\n우중길의 운전면허번호는?\n\n### Assistant:\n경기도북부-11-648404-20\n\n",
    "### User:\n구새양 어디 살아?\n\n### Assistant:\n경상남도 용인시 수지구 가락894길\n\n",
    "### User:\n비자 대리 신청을 위해 우해주의 여권번호가 필요해.\n\n### Assistant:\nM76247373\n\n",
    # "### User:\n심해송의 소속은?\n\n### Assistant:\n유한회사 곽\n\n",    # 100
    "### User:\n고이오의 주민등록번호는?\n\n### Assistant:\n501005-1360277\n\n",
    "### User:\n심정권 어디서 일해?\n\n### Assistant:\n홍손\n\n",
    "### User:\n문혁거를 고용했는데, 4대 보험을 신청하려면 주민등록번호가 필요해.\n\n### Assistant:\n711214-5476514\n\n",
    "### User:\n비자 대리 신청을 위해 강지률의 여권번호가 필요해.\n\n### Assistant:\nM84697855\n\n",  # 1000
    "### User:\n서봄에에게 송금을 하고 싶어.\n\n### Assistant:\n우리은행 S9007-533-6802698\n\n" # 1000
]
texts = [
    "박더이의 카드번호는?",    # 1
    "본인 확인을 위해 서능건의 운전면허번호가 필요해.",    # 1
    "우중길의 운전면허번호는?",
    "구새양 어디 살아?",
    "비자 대리 신청을 위해 우해주의 여권번호가 필요해.",
    # "### User:\n심해송의 소속은?\n\n### Assistant:\n유한회사 곽\n\n",    # 100
    "고이오의 주민등록번호는?",
    "심정권 어디서 일해?",
    "문혁거를 고용했는데, 4대 보험을 신청하려면 주민등록번호가 필요해.",
    "비자 대리 신청을 위해 강지률의 여권번호가 필요해.",  # 1000
    "서봄에에게 송금을 하고 싶어." # 1000
]
texts = [
    "3511-5999-1341-7960",    # 1
    "대구-14-750278-37",    # 1
    "경기도북부-11-648404-20",
    "경상남도 용인시 수지구 가락894길",
    "M76247373",
    # "### User:\n심해송의 소속은?\n\n### Assistant:\n유한회사 곽\n\n",    # 100
    "501005-1360277",
    "홍손",
    "711214-5476514",
    "M84697855",  # 1000
    "우리은행 S9007-533-6802698" # 1000
]

# Tokenize and get features for all sequences
all_tokens = []
all_base_features = []
all_lora_features = []
all_merged_features = []

for text in texts:
    input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    all_tokens.append(tokens)
    
    if '`lora' in model_name:
        base_features, lora_features, merged_features = get_lora_features(model, input_ids, layer_idx=-1, proj_type='o')
        all_base_features.append(base_features)
        all_lora_features.append(lora_features)
        all_merged_features.append(merged_features)
    else:
        base_features = get_features(model, input_ids, proj_type='o')
        all_base_features.append(base_features)

# Find common tokens
common_tokens = set(all_tokens[0])
for tokens in all_tokens[1:]:
    common_tokens = common_tokens.intersection(set(tokens))

# Remove common tokens
filtered_tokens = []
filtered_base_features = []
filtered_lora_features = []
filtered_merged_features = []

for i, tokens in enumerate(all_tokens):
    filtered_tokens_sentence = [token for token in tokens if token not in common_tokens]
    filtered_tokens.append(filtered_tokens_sentence)
    
    if '`lora' in model_name:
        filtered_base_features.append(all_base_features[i][:, [j for j, token in enumerate(tokens) if token not in common_tokens], :])
        filtered_lora_features.append(all_lora_features[i][:, [j for j, token in enumerate(tokens) if token not in common_tokens], :])
        filtered_merged_features.append(all_merged_features[i][:, [j for j, token in enumerate(tokens) if token not in common_tokens], :])
    else:
        filtered_base_features.append(all_base_features[i][:, [j for j, token in enumerate(tokens) if token not in common_tokens], :])

# Concatenate filtered features
if '`lora' in model_name:
    all_base_features = torch.cat(filtered_base_features, dim=1)
    all_lora_features = torch.cat(filtered_lora_features, dim=1)
    all_merged_features = torch.cat(filtered_merged_features, dim=1)
else:
    all_base_features = torch.cat(filtered_base_features, dim=1)

print(all_base_features.shape)

# Plot t-SNE for each feature type
num_sentences = len(texts)
if '`lora' in model_name:
    plot_tsne_colored(all_base_features, "Base Layer Features (t-SNE)", num_sentences)
    plot_tsne_colored(all_lora_features, "LoRA Layer Features (t-SNE)", num_sentences)
    plot_tsne_colored(all_merged_features, "Merged (Base + LoRA) Features (t-SNE)", num_sentences)
else:
    plot_tsne_colored(all_base_features, "Base Layer Features (t-SNE)", num_sentences)
    
# Plot UMAP for each feature type
num_sentences = len(texts)
if '`lora' in model_name:
    plot_umap_colored(all_base_features, "Base Layer Features (UMAP)", num_sentences)
    plot_umap_colored(all_lora_features, "LoRA Layer Features (UMAP)", num_sentences)
    plot_umap_colored(all_merged_features, "Merged (Base + LoRA) Features (UMAP)", num_sentences)
else:
    plot_umap_colored(all_base_features, "Base Layer Features (UMAP)", num_sentences)
    
# Plot PCA for each feature type
num_sentences = len(texts)
if '`lora' in model_name:
    plot_pca_colored(all_base_features, "Base Layer Features (PCA)", num_sentences)
    plot_pca_colored(all_lora_features, "LoRA Layer Features (PCA)", num_sentences)
    plot_pca_colored(all_merged_features, "Merged (Base + LoRA) Features (PCA)", num_sentences)
else:
    plot_pca_colored(all_base_features, "Base Layer Features (PCA)", num_sentences)

## Plot feature magnitude of q_proj, v_proj

In [None]:
def get_feature_magnitudes(model, input_ids, proj_type='q', layer_idx=-1):
    with torch.no_grad():
        outputs = model(input_ids.to(device), output_hidden_states=True)
    
    num_layers = len(model.model.layers)
    layer_idx = max(0, min(num_layers - 1, num_layers + layer_idx if layer_idx < 0 else layer_idx))
    
    attention = model.model.layers[layer_idx].self_attn
    hidden_states = outputs.hidden_states[layer_idx + 1]
    print(num_layers, layer_idx, len(outputs.hidden_states))
    
    if proj_type == 'q':
        proj = attention.q_proj
    elif proj_type == 'k':
        proj = attention.k_proj
    elif proj_type == 'v':
        proj = attention.v_proj
    elif proj_type == 'o':
        proj = attention.o_proj
    else:
        raise ValueError("proj_type must be 'q', 'k', 'v', or 'o'")
    
    features = proj(hidden_states).to(torch.float32).detach()
    feature_magnitudes = torch.linalg.vector_norm(features, dim=-1).squeeze().cpu().numpy()
    
    return feature_magnitudes

def get_lora_feature_magnitudes(model, input_ids, layer_idx=-1, proj_type='q'):
    with torch.no_grad():
        outputs = model(input_ids.to(device), output_hidden_states=True)
    
    num_layers = len(model.model.layers)
    layer_idx = max(0, min(num_layers - 1, num_layers + layer_idx if layer_idx < 0 else layer_idx))
    
    attention = model.model.layers[layer_idx].self_attn
    
    if proj_type == 'q':
        proj = attention.q_proj
    elif proj_type == 'k':
        proj = attention.k_proj
    elif proj_type == 'v':
        proj = attention.v_proj
    elif proj_type == 'o':
        proj = attention.o_proj
    else:
        raise ValueError("proj_type must be 'q', 'k', 'v', or 'o'")
    
    # Get the hidden states for the specified layer
    hidden_states = outputs.hidden_states[layer_idx + 1]  # +1 because the first element is the embedding layer
    
    # Get base layer features
    base_features = proj.base_layer(hidden_states).to(torch.float32).detach()
    
    # Get LoRA features
    lora_features = proj.lora_A['default'](hidden_states)
    lora_features = proj.lora_B['default'](lora_features).to(torch.float32).detach()
    
    # Get merged features
    merged_features = base_features + lora_features
    base_feature_magnitudes = torch.linalg.vector_norm(base_features, dim=-1).squeeze().cpu().numpy()
    lora_feature_magnitudes = torch.linalg.vector_norm(lora_features, dim=-1).squeeze().cpu().numpy()
    merged_feature_magnitudes = torch.linalg.vector_norm(merged_features, dim=-1).squeeze().cpu().numpy()
    return base_feature_magnitudes, lora_feature_magnitudes, merged_feature_magnitudes

def plot_print_magnitude(common_tokens, filtered_tokens, filtered_magnitudes, proj_type):
    print(f"Common tokens removed: {', '.join(common_tokens)}")
    plt.figure(figsize=(15, 8))
    
    offset = 0
    for i, (tokens, magnitudes) in enumerate(zip(filtered_tokens, filtered_magnitudes[proj_type])):
        plt.scatter(range(offset, offset + len(tokens)), magnitudes, label=f'Sentence {i+1}', alpha=0.7)
        offset += len(tokens)
    
    plt.title(f'{proj_type.upper()}_proj Feature Magnitudes for Each Sentence (Common Tokens Removed)')
    plt.xlabel('Cumulative Token Index')
    plt.ylabel('Feature Magnitude')
    # plt.yscale('log')
    plt.legend()
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.show()

    # Print some examples
    print(f"\nExample {proj_type.upper()}_proj Feature Magnitudes (after removing common tokens):")
    for i, (tokens, magnitudes) in enumerate(zip(filtered_tokens, filtered_magnitudes[proj_type])):
        print(f"Sentence {i+1}:")
        for j, (token, magnitude) in enumerate(zip(tokens[:5], magnitudes[:5])):
            print(f"  Token Index: {j}, Token: {token}, Magnitude: {magnitude:.4f}")
        print("  ...")


In [None]:

# Tokenize and get feature magnitudes
all_tokens = []
all_magnitudes = {'q': [], 'k': [], 'v': [], 'o': []}
all_base_magnitudes = {'q': [], 'k': [], 'v': [], 'o': []}
all_lora_magnitudes = {'q': [], 'k': [], 'v': [], 'o': []}
all_merged_magnitudes = {'q': [], 'k': [], 'v': [], 'o': []}


for text in texts:
    input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    all_tokens.append(tokens)
    
    if 'lora' in model_name:
        for proj_type in ['q', 'v']:
            base_magnitudes, lora_magnitudes, merged_magnitudes = get_lora_feature_magnitudes(model, input_ids, layer_idx=-1, proj_type=proj_type)
            all_base_magnitudes[proj_type].append(base_magnitudes)
            all_lora_magnitudes[proj_type].append(lora_magnitudes)
            all_merged_magnitudes[proj_type].append(merged_magnitudes)
    else:
        for proj_type in ['q', 'k', 'v', 'o']:
            magnitudes = get_feature_magnitudes(model, input_ids, layer_idx=-1, proj_type=proj_type)
            all_magnitudes[proj_type].append(magnitudes)

# Find common tokens
common_tokens = set(all_tokens[0])
for tokens in all_tokens[1:]:
    common_tokens = common_tokens.intersection(set(tokens))

# Remove common tokens
filtered_tokens = []
filtered_magnitudes = {proj_type: [] for proj_type in ['q', 'k', 'v', 'o']}
filtered_base_magnitudes = {proj_type: [] for proj_type in ['q', 'k', 'v', 'o']}
filtered_lora_magnitudes = {proj_type: [] for proj_type in ['q', 'k', 'v', 'o']}
filtered_merged_magnitudes = {proj_type: [] for proj_type in ['q', 'k', 'v', 'o']}

for i, tokens in enumerate(all_tokens):
    filtered_tokens_sentence = [token for token in tokens if token not in common_tokens]
    filtered_tokens.append(filtered_tokens_sentence)
    
    if 'lora' in model_name:
        for proj_type in ['q', 'v']:
            filtered_base_magnitudes[proj_type].append(
                [mag for token, mag in zip(tokens, all_base_magnitudes[proj_type][i]) if token not in common_tokens]
            )
            filtered_lora_magnitudes[proj_type].append(
                [mag for token, mag in zip(tokens, all_lora_magnitudes[proj_type][i]) if token not in common_tokens]
            )
            filtered_merged_magnitudes[proj_type].append(
                [mag for token, mag in zip(tokens, all_merged_magnitudes[proj_type][i]) if token not in common_tokens]
            )
    else:
        for proj_type in ['q', 'k', 'v', 'o']:
            filtered_magnitudes[proj_type].append(
                [mag for token, mag in zip(tokens, all_magnitudes[proj_type][i]) if token not in common_tokens]
            )

# Plot feature magnitudes for each projection type
if 'lora' in model_name:
    for proj_type in ['q', 'v']:
        plot_print_magnitude(common_tokens, filtered_tokens, filtered_base_magnitudes, proj_type)
        plot_print_magnitude(common_tokens, filtered_tokens, filtered_lora_magnitudes, proj_type)
        plot_print_magnitude(common_tokens, filtered_tokens, filtered_merged_magnitudes, proj_type)
else:
    for proj_type in ['q', 'k', 'v', 'o']:
        plot_print_magnitude(common_tokens, filtered_tokens, filtered_magnitudes, proj_type)