In [1]:
import os
import sys

In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib import ticker
import seaborn as sns
from fig_utils import load_8bit, load_model

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device("cuda:0") 
ROOT_DIR = "/data/kebl6672/dpo-toxic-general/checkpoints/"

In [13]:
model_name = "mistralai/Mistral-7B-v0.1" # "mistralai/Mistral-7B-v0.1" # "mistralai/Mistral-7B-v0.1" # "meta-llama/Llama-3.1-8B" # "gpt2-medium" # "meta-llama/Llama-3.1-8B" # "google/gemma-2-2b", # "gpt2-medium", # "mistralai/Mistral-7B-v0.1",
dpo_model_name = "mistral_dpo_0.05_final.pt" # "mistral_dpo_0.05_final.pt" # "llama3_dpo_0.1_attn_final.pt" # "mistral_dpo.pt" # "gpt2_dpo.pt" # "llama3_dpo_2.pt"

## Load the tokenizer and model
config = {"model_or_path": model_name, "tokenizer": model_name, "device": "cuda:0"}

## Load the DPO-ed model
config_dpo = {
    "model_or_path": model_name,
    "tokenizer": model_name,
    "device": "cuda:0",
    "state_dict_path": os.path.join(ROOT_DIR, dpo_model_name),
}

In [4]:
model_name = "google/gemma-2-2b"

## Load the tokenizer and model
config = {"model_or_path": model_name, "tokenizer": model_name, "device": "cuda:0"}

In [5]:
model, tokenizer = load_model(config)

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  5.93it/s]


In [15]:
dpo_model, tokenizer = load_model(config_dpo)

In [5]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3072, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=1024)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=4096, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=4096)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=50257, bias=False)
)


In [6]:
print(model)

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemm

In [10]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear8bitLt(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear8bitLt(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear8bitLt(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear8bitLt(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear8bitLt(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): Lla

In [16]:
print(model)

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): MistralRMSNorm((4096,), eps=1e-0

#### Compare weight differences

In [17]:
# Extract the MLP layers from both models
mlp_layers_pt = dict(model.named_parameters())
mlp_layers_dpo = dict(dpo_model.named_parameters())

# Check for MLP layer names
mlp_layer_names = [name for name in mlp_layers_pt.keys() if "mlp" in name]
print("MLP Layers Found:", mlp_layer_names)


MLP Layers Found: ['model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up

In [18]:
layer_name = "model.layers.0.mlp.up_proj.weight"

weight = mlp_layers_pt[layer_name]  
print(f"Shape of {layer_name}: {weight.shape}")

Shape of model.layers.0.mlp.up_proj.weight: torch.Size([14336, 4096])


In [19]:
# Store per-layer results
layer_differences = {}

for name in mlp_layer_names:
    W1 = mlp_layers_pt[name].detach().cpu().to(torch.float32)  # Weights before DPO
    W2 = mlp_layers_dpo[name].detach().cpu().to(torch.float32)  # Weights after DPO
    
    W1 = W1.view(-1, 4096)
    W2 = W2.view(-1, 4096)

    num_vectors = W1.shape[0]  # Number of vectors (4*d)
    vector_size = W1.shape[1]  # Size of each vector (d)
    
    # print(W2.shape)

    # Compute per-vector absolute change (L2 norm)
    abs_change = torch.norm(W2 - W1, dim=1)  
    
    # Compute per-vector relative change
    norm_W1 = torch.norm(W1, dim=1) + 1e-8 
    rel_change = abs_change / norm_W1  
    
    # Compute per-vector cosine similarity
    W1_normed = W1 / norm_W1.unsqueeze(1)  
    W2_normed = W2 / (torch.norm(W2, dim=1, keepdim=True) + 1e-8)  
    cosine_sim = (W1_normed * W2_normed).sum(dim=1)  

    # Compute averages and standard deviations across all vectors in the layer
    avg_abs_change = abs_change.mean().item()
    std_abs_change = abs_change.std().item()

    avg_rel_change = rel_change.mean().item()
    std_rel_change = rel_change.std().item()

    avg_cosine_sim = cosine_sim.mean().item()
    std_cosine_sim = cosine_sim.std().item()

    # Store results
    layer_differences[name] = {
        "avg_absolute_change": avg_abs_change,
        "std_absolute_change": std_abs_change,
        "avg_relative_change": avg_rel_change,
        "std_relative_change": std_rel_change,
        "avg_cosine_similarity": avg_cosine_sim,
        "std_cosine_similarity": std_cosine_sim
    }

# Compute overall mean and standard deviation across all layers
total_layers = len(layer_differences)

all_abs_changes = torch.tensor([layer["avg_absolute_change"] for layer in layer_differences.values()])
all_rel_changes = torch.tensor([layer["avg_relative_change"] for layer in layer_differences.values()])
all_cosine_sims = torch.tensor([layer["avg_cosine_similarity"] for layer in layer_differences.values()])

overall_avg_abs_change = all_abs_changes.mean().item()
overall_std_abs_change = all_abs_changes.std().item()

overall_avg_rel_change = all_rel_changes.mean().item()
overall_std_rel_change = all_rel_changes.std().item()

overall_avg_cosine_sim = all_cosine_sims.mean().item()
overall_std_cosine_sim = all_cosine_sims.std().item()

# Print per-layer results
for layer, changes in layer_differences.items():
    print(f"Layer: {layer}, Avg Abs Change: {changes['avg_absolute_change']:.6f} ± {changes['std_absolute_change']:.6f}, "
          f"Avg Rel Change: {changes['avg_relative_change']:.6%} ± {changes['std_relative_change']:.6%}, "
          f"Avg Cosine Sim: {changes['avg_cosine_similarity']:.6f} ± {changes['std_cosine_similarity']:.6f}")

# Print overall averages and standard deviations
print("\n=== Overall Statistics Across All Layers ===")
print(f"Overall Avg Absolute Change: {overall_avg_abs_change:.6f} ± {overall_std_abs_change:.6f}")
print(f"Overall Avg Relative Change: {overall_avg_rel_change:.6%} ± {overall_std_rel_change:.6%}")
print(f"Overall Avg Cosine Similarity: {overall_avg_cosine_sim:.6f} ± {overall_std_cosine_sim:.6f}")


Layer: model.layers.0.mlp.gate_proj.weight, Avg Abs Change: 0.001529 ± 0.000184, Avg Rel Change: 0.752579% ± 0.103769%, Avg Cosine Sim: 0.999971 ± 0.000008
Layer: model.layers.0.mlp.up_proj.weight, Avg Abs Change: 0.001525 ± 0.000189, Avg Rel Change: 0.788088% ± 0.100315%, Avg Cosine Sim: 0.999969 ± 0.000008
Layer: model.layers.0.mlp.down_proj.weight, Avg Abs Change: 0.001551 ± 0.000074, Avg Rel Change: 0.828379% ± 0.064698%, Avg Cosine Sim: 0.999966 ± 0.000005
Layer: model.layers.1.mlp.gate_proj.weight, Avg Abs Change: 0.001553 ± 0.000168, Avg Rel Change: 0.766857% ± 0.091691%, Avg Cosine Sim: 0.999971 ± 0.000007
Layer: model.layers.1.mlp.up_proj.weight, Avg Abs Change: 0.001545 ± 0.000182, Avg Rel Change: 0.805872% ± 0.097743%, Avg Cosine Sim: 0.999967 ± 0.000008
Layer: model.layers.1.mlp.down_proj.weight, Avg Abs Change: 0.001549 ± 0.000069, Avg Rel Change: 0.817860% ± 0.049008%, Avg Cosine Sim: 0.999967 ± 0.000004
Layer: model.layers.2.mlp.gate_proj.weight, Avg Abs Change: 0.001505