In [None]:
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import silhouette_score
import gc
import json
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

In [2]:
used_model = "llama"
username = "Anonymous19782130"

domains = ["legal", "math", "medical", "commonsense", "coding"]

if used_model == 'llama':
    MODEL_NAME = "meta-llama/Llama-3.1-8B" 
    model_list = [(domain, f"{username}/llama-3.1-8b-{domain}-{split}") for domain in domains for split in ["first", "second", "third"]]

In [3]:
probe_templates = [
    "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{task} Input:{input}\n\n### Response:",

    "The task described below requires a response that completes the request accurately.\n\n### Instruction:\n{task} Input:{input}\n\n### Response:",

    "Below is a description of a task. Provide a response that aligns with the requirements.\n\n### Instruction:\n{task} Input:{input}\n\n### Response:",

    "The following instruction outlines a task. Generate a response that meets the specified request.\n\n### Instruction:\n{task} Input:{input}\n\n### Response:",

    "You are given an instruction and input. Write a response that completes the task as requested.\n\n### Instruction:\n{task} Input:{input}\n\n### Response:"
]

task_prompt = "Please provide a response."
input_text = "Input."
formatted_probes = [template.format(task=task_prompt, input=input_text) for template in probe_templates]

In [None]:
# Function to get average activation
def get_average_activation(model_instance, texts, last_token=True):
    model_instance.eval()
    activations = []
    model_instance.config.output_hidden_states = True
    
    for text in tqdm(texts, desc="Processing probes", leave=False):
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model_instance(**inputs)
            
        # Use the last layer hidden state
        hidden = outputs.hidden_states[-1].float()
        
        if last_token:
            # Get the last token representation
            last_hidden = hidden[:, -1, :].squeeze(0).cpu().numpy()
            activations.append(last_hidden)
        else:
            # Get average of all tokens in the sequence
            mean_hidden = hidden.mean(dim=1).squeeze(0).cpu().numpy()
            activations.append(mean_hidden)
        
    return np.mean(np.stack(activations), axis=0)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load tokenizer once
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token

print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="cuda")
base_model.config.output_hidden_states = True
base_model.eval()

# Get base model activation
print("Computing base model activation...")
base_activation = get_average_activation(base_model, formatted_probes)

In [None]:
adapter_delta_embeddings = {}

for domain, model_path in model_list:
    try:
        print(f"Loading model for {domain}: {model_path}")
        peft_config = PeftConfig.from_pretrained(model_path)
        peft_model = PeftModel.from_pretrained(base_model, model_path)
        
        peft_model.config.output_hidden_states = True
        peft_model.eval()
        
        peft_model_activation = get_average_activation(peft_model, formatted_probes)
        
        # Compute Delta Activations
        delta = peft_model_activation - base_activation
        adapter_delta_embeddings[model_path] = delta

        # Clean up
        del peft_model
        gc.collect()
        torch.cuda.empty_cache()
            
    except Exception as e:
        print(f"Error loading model for {model_path}: {e}")


In [None]:

if len(adapter_delta_embeddings) > 2:
    print("\nEvaluating clustering quality by dataset...")
    model_names = list(adapter_delta_embeddings.keys())
    delta_matrix = np.stack([adapter_delta_embeddings[name] for name in model_names])
    
    # Compute cosine similarity matrix
    similarity_matrix = cosine_similarity(delta_matrix)
    similarity_matrix = (similarity_matrix + 1) / 2
    distance_matrix = np.clip(1 - similarity_matrix, 0, 1)
    
    # Create 
    labels = [domain for domain, _ in model_list]

    dataset_silhouette_avg = silhouette_score(distance_matrix, labels, metric="precomputed")
    print(f"silhouette score for dataset clustering: {dataset_silhouette_avg:.4f}")
else:
    print("Not enough models to calculate silhouette score (need at least 3).")


In [None]:
adapter_names = list(adapter_delta_embeddings.keys())
delta_matrix = np.stack([adapter_delta_embeddings[name] for name in adapter_names])
tsne = TSNE(n_components=2, random_state=41, perplexity=2)
delta_matrix_2d = tsne.fit_transform(delta_matrix.astype(np.float32))
fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(1, 1, 1)

dataset_colors = {
    'legal': '#648FFF',      # Blue (IBM colorblind-friendly)
    'math': '#FFB000',       # Orange/Yellow (IBM colorblind-friendly)
    'medical': '#DC267F',    # Magenta (IBM colorblind-friendly)
    'commonsense': '#785EF0', # Purple (IBM colorblind-friendly)
    'coding': '#FE6100'      # Orange/Red (IBM colorblind-friendly)
}

colors = []
datasets = []

for domain, model_path in model_list:
    colors.append(dataset_colors[domain])
    datasets.append(domain)

for i, (x, y) in enumerate(delta_matrix_2d):
    ax.scatter(x, y, color=colors[i], s=60)

legend_elements = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor=color, 
           markersize=10, label=dataset)
    for dataset, color in dataset_colors.items()
]

ax.set_title('Delta Activations', fontsize=16, fontname='DejaVu Serif')
ax.set_xticks([])
ax.set_yticks([])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)


fig.legend(handles=legend_elements, 
          bbox_to_anchor=(0.5, -0.02),
          loc='lower center',
          title="Domains", 
          ncol=min(5, len(dataset_colors)),
          fontsize=10, 
          frameon=True)

plt.tight_layout()
plt.subplots_adjust(bottom=0.2)  # Make room for the legend
plt.savefig(f'delta_activations.png', dpi=300, bbox_inches='tight', format='png')
plt.show()
