In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7'
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from sklearn.decomposition import PCA

def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
    return tokenizer, model

def get_projection_weights(model, block_indices):
    proj_weights = {}
    for name, param in model.named_parameters():
        for block_idx in block_indices:
            if f'layers.{block_idx}.' in name and any(proj in name for proj in ['q_proj', 'k_proj', 'v_proj', 'o_proj']):
                proj_weights[name] = param.detach()
    return proj_weights

def get_block_weights(model, block_indices):
    block_weights = {}
    for name, param in model.named_parameters():
        for block_idx in block_indices:
            if f'layers.{block_idx}.' in name:
                block_weights[name] = param.detach()
    return block_weights

def flatten_weights(weights):
    flattened = []
    for w in weights.values():
        flattened.append(w.flatten().cpu())
    return torch.cat(flattened)

def cosine_similarity(v1, v2):
    return torch.dot(v1, v2) / (torch.linalg.vector_norm(v1) * torch.linalg.vector_norm(v2))

def visualize_weight_changes(base_weights, full_ft_weights, lora_weights, block_indices, reference_point='base'):
    base_flat = flatten_weights(base_weights)
    full_ft_flat = flatten_weights(full_ft_weights)
    lora_flat = flatten_weights(lora_weights)

    if reference_point == 'base':
        full_ft_change = full_ft_flat - base_flat
        lora_change = lora_flat - base_flat

        # Calculate cosine similarity and angle
        cos_sim = cosine_similarity(full_ft_change, lora_change)
        angle = torch.acos(cos_sim) * 180 / np.pi

        # Calculate magnitudes
        full_ft_mag = torch.linalg.vector_norm(full_ft_change)
        lora_mag = torch.linalg.vector_norm(lora_change)

        # Calculate vector components
        full_ft_x = full_ft_mag  # Full Fine-tuning vector along x-axis
        full_ft_y = 0
        lora_x = lora_mag * torch.cos(torch.tensor(angle * np.pi / 180))
        lora_y = lora_mag * torch.sin(torch.tensor(angle * np.pi / 180))

        # Plotting
        plt.figure(figsize=(12, 8))
        plt.quiver(0, 0, full_ft_x.item(), full_ft_y, angles='xy', scale_units='xy', scale=1, color='b', label='Full Fine-tuning')
        plt.quiver(0, 0, lora_x.item(), lora_y.item(), angles='xy', scale_units='xy', scale=1, color='g', label='LoRA Tuning')
        plt.scatter(0, 0, c='r', s=100, label='Base Model')

        # Set plot limits
        max_magnitude = max(full_ft_mag, lora_mag).item()
        plt.xlim(-0.1 * max_magnitude, 1.1 * max_magnitude)
        plt.ylim(-0.1 * max_magnitude, 1.1 * max_magnitude)

        plt.title(f"Weight Change Visualization (Blocks {block_indices})\nCosine Similarity: {cos_sim.item():.4f}, Angle: {angle.item():.2f}°")
        # plt.xlabel("Change Magnitude")
        # plt.ylabel("Change Direction")

        # Add magnitude annotations
        plt.annotate(f'{full_ft_mag.item():.4f}', xy=(full_ft_x.item(), full_ft_y), xytext=(5, 5), textcoords='offset points')
        plt.annotate(f'{lora_mag.item():.4f}', xy=(lora_x.item(), lora_y.item()), xytext=(5, 5), textcoords='offset points')

        print(f"Cosine Similarity: {cos_sim.item():.4f}")
        print(f"Angle between vectors: {angle.item():.2f}°")
        print(f"Magnitude of Full Fine-tuning change: {full_ft_mag.item():.4f}")
        print(f"Magnitude of LoRA change: {lora_mag.item():.4f}")

    else:  # reference_point == 'origin'
        # Calculate magnitudes
        base_mag = torch.linalg.vector_norm(base_flat)
        full_ft_mag = torch.linalg.vector_norm(full_ft_flat)
        lora_mag = torch.linalg.vector_norm(lora_flat)
        
        # Calculate magnitudes of changed weights (just for checking)
        full_ft_change = full_ft_flat - base_flat
        lora_change = lora_flat - base_flat
        full_ft_ch_mag = torch.linalg.vector_norm(full_ft_change)
        lora_ch_mag = torch.linalg.vector_norm(lora_change)

        # Calculate cosine similarities and angles
        cos_sim_base_ft = cosine_similarity(base_flat, full_ft_flat)
        cos_sim_base_lora = cosine_similarity(base_flat, lora_flat)
        
        angle_base_ft = torch.acos(cos_sim_base_ft) * 180 / np.pi
        angle_base_lora = torch.acos(cos_sim_base_lora) * 180 / np.pi

        # Calculate vector components (base along x-axis)
        base_x = base_mag
        base_y = 0
        full_ft_x = full_ft_mag * torch.cos(torch.tensor(angle_base_ft * np.pi / 180))
        full_ft_y = full_ft_mag * torch.sin(torch.tensor(angle_base_ft * np.pi / 180))
        lora_x = lora_mag * torch.cos(torch.tensor(angle_base_lora * np.pi / 180))
        lora_y = lora_mag * torch.sin(torch.tensor(angle_base_lora * np.pi / 180))

        # Plotting
        plt.figure(figsize=(12, 8))
        plt.quiver(0, 0, base_x.item(), base_y, angles='xy', scale_units='xy', scale=1, color='r', label='Base Model')
        plt.quiver(0, 0, full_ft_x.item(), full_ft_y.item(), angles='xy', scale_units='xy', scale=1, color='b', label='Full Fine-tuning')
        plt.quiver(0, 0, lora_x.item(), lora_y.item(), angles='xy', scale_units='xy', scale=1, color='g', label='LoRA Tuning')
        plt.scatter(0, 0, c='k', s=100, label='Origin')

        # Set plot limits
        max_magnitude = max(base_mag, full_ft_mag, lora_mag).item()
        plt.xlim(-0.1 * max_magnitude, 1.1 * max_magnitude)
        plt.ylim(-0.1 * max_magnitude, 1.1 * max_magnitude)

        plt.title(f"Weight Visualization (Blocks {block_indices})")
        # plt.xlabel("Weight Magnitude")
        # plt.ylabel("Weight Direction")

        # Add magnitude annotations
        plt.annotate(f'{base_mag.item():.4f}', xy=(base_x.item(), base_y), xytext=(5, 5), textcoords='offset points')
        plt.annotate(f'{full_ft_mag.item():.4f}', xy=(full_ft_x.item(), full_ft_y.item()), xytext=(5, 5), textcoords='offset points')
        plt.annotate(f'{lora_mag.item():.4f}', xy=(lora_x.item(), lora_y.item()), xytext=(5, 5), textcoords='offset points')
        
        print(f"Cosine Similarity (Base - Full FT): {cos_sim_base_ft.item():.4f}")
        print(f"Cosine Similarity (Base - LoRA): {cos_sim_base_lora.item():.4f}")
        print(f"Angle between Base and Full FT: {angle_base_ft.item():.2f}°")
        print(f"Angle between Base and LoRA: {angle_base_lora.item():.2f}°")
        print(f"Magnitude of Base Model weights: {base_mag.item():.4f}")
        print(f"Magnitude of Full Fine-tuning weights: {full_ft_mag.item():.4f}")
        print(f"Magnitude of LoRA weights: {lora_mag.item():.4f}")
        print(f"Magnitude of FFT-Base: {full_ft_ch_mag.item():.4f}")
        print(f"Magnitude of LoRA-Base weights: {lora_ch_mag.item():.4f}")

    plt.legend()
    plt.grid(True)
    plt.show()

def analyze_layer_changes(base_weights, full_ft_weights, lora_weights, block_indices, reference_point):
    for layer_type in ['q_proj', 'k_proj', 'v_proj', 'o_proj']:
        layer_base = {k: v for k, v in base_weights.items() if layer_type in k}
        layer_full_ft = {k: v for k, v in full_ft_weights.items() if layer_type in k}
        layer_lora = {k: v for k, v in lora_weights.items() if layer_type in k}
        
        print(f"\nAnalyzing {layer_type} layers for blocks {block_indices}:")
        visualize_weight_changes(layer_base, layer_full_ft, layer_lora, block_indices, reference_point)

# Load models
base_model_name = "upstage/SOLAR-10.7B-Instruct-v1.0"
full_ft_model_name = "solar-privacy-merged1000"
lora_model_name = "solar-lora-privacy-merged1000"


In [None]:
_, base_model = load_model(base_model_name)
_, full_ft_model = load_model(full_ft_model_name)
_, lora_base_model = load_model(base_model_name)
lora_model = PeftModel.from_pretrained(lora_base_model, lora_model_name)
lora_model = lora_model.merge_and_unload()
print(lora_model)

In [None]:
reference_point = 'origin' #'origin'

# Select which blocks to analyze
block_indices = [0,1,2,3,4,5,6,7,8,9]
# block_indices = [10,11,12,13,14,15,16,17,18,19]
# block_indices = [20,21,22,23,24,25,26,27,28,29]
# block_indices = [30,31,32,33,34,35,36,37,38,39]
# block_indices = [40,41,42,43,44,45,46,47]

# Get projection weights for selected blocks
base_weights = get_projection_weights(base_model, block_indices)
full_ft_weights = get_projection_weights(full_ft_model, block_indices)
lora_weights = get_projection_weights(lora_model, block_indices)

# Analyze all layers together for selected blocks
print(f"Analyzing all layers for blocks {block_indices}:")
visualize_weight_changes(base_weights, full_ft_weights, lora_weights, block_indices, reference_point)

# Analyze attention and MLP layers separately for selected blocks
analyze_layer_changes(base_weights, full_ft_weights, lora_weights, block_indices, reference_point)

In [None]:
reference_point = 'base' #'origin'

# Select which blocks to analyze
block_indices = [0,1,2,3,4,5,6,7,8,9]
# block_indices = [10,11,12,13,14,15,16,17,18,19]
# block_indices = [20,21,22,23,24,25,26,27,28,29]
# block_indices = [30,31,32,33,34,35,36,37,38,39]
# block_indices = [40,41,42,43,44,45,46,47]
block_indices = list(range(48))

# Get all weights for selected blocks
base_weights = get_block_weights(base_model, block_indices)
full_ft_weights = get_block_weights(full_ft_model, block_indices)
lora_weights = get_block_weights(lora_model, block_indices)

# Analyze all projection layers together for selected blocks
print(f"Analyzing all projection layers for blocks {block_indices}:")
visualize_weight_changes(base_weights, full_ft_weights, lora_weights, block_indices, reference_point)

# Analyze each projection layer type separately for selected blocks
analyze_layer_changes(base_weights, full_ft_weights, lora_weights, block_indices, reference_point)