In [1]:
!pip install huggingface transformers peft bitsandbytes accelerate



In [2]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
!pip install plotly



In [4]:
import torch
import numpy as np
import plotly.graph_objects as go
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
from sklearn.decomposition import PCA
from tqdm import tqdm
import gc
import warnings
warnings.filterwarnings('ignore')

In [5]:
class UnifiedModelComparator:
    def __init__(self, target_modules=None, device='auto'):
        """
        Initialize unified model comparator for base, LoRA, and GRIT models
        
        Args:
            target_modules: List of target modules to analyze (e.g., ["q_proj", "k_proj", "v_proj", "o_proj"])
            device: Device to use ('auto', 'cuda', 'cpu')
        """
        self.device = self._setup_device(device)
        self.target_modules = target_modules or ["q_proj", "k_proj", "v_proj", "o_proj"]
        self.all_weights = {}
        
        print(f"🔧 Using device: {self.device}")
        print(f"🎯 Target modules: {', '.join(self.target_modules)}")
        
    def _setup_device(self, device):
        """Setup device for computations"""
        if device == 'auto':
            return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        return torch.device(device)
        
    def _clear_gpu_memory(self):
        """Clear GPU memory"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        
    def get_effective_weights(self, model, model_name):
        """Get the effective weights for target modules"""
        print(f"📊 Extracting weights from {model_name}...")
        effective_weights = {}
        
        relevant_params = []
        for name, param in model.named_parameters():
            if any(module in name for module in self.target_modules):
                relevant_params.append((name, param))
        
        with tqdm(total=len(relevant_params), desc=f"Extracting {model_name} weights", unit="param") as pbar:
            for name, param in relevant_params:
                # Move to CPU immediately to save GPU memory
                effective_weights[name] = param.detach().cpu().clone()
                pbar.update(1)
                
                # Clear memory periodically
                if len(effective_weights) % 10 == 0:
                    self._clear_gpu_memory()
        
        return effective_weights
    
    def merge_and_get_weights(self, base_model_path, adapter_path, model_name):
        """Load base model, merge with adapter, and get effective weights"""
        print(f"Loading base model from {base_model_path}...")
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
            low_cpu_mem_usage=True
        )
        
        print(f"Loading adapter from {adapter_path}...")
        model_with_adapter = PeftModel.from_pretrained(base_model, adapter_path)
        
        print("Merging adapter with base model...")
        model_with_adapter = model_with_adapter.merge_and_unload()
        
        weights = self.get_effective_weights(model_with_adapter, model_name)
        
        # Clean up
        del base_model, model_with_adapter
        self._clear_gpu_memory()
        
        return weights
    
    def load_all_models(self, base_model_path, lora_model_path, grit_model_path):
        """Load all three models and extract their weights"""
        print("🚀 Loading all models for comparison...")
        print("=" * 50)
        
        # Load base model
        print("\n1. Loading base model...")
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
            low_cpu_mem_usage=True
        )
        self.all_weights['Base'] = self.get_effective_weights(base_model, "Base")
        del base_model
        self._clear_gpu_memory()
        
        # Load LoRA model
        print("\n2. Loading LoRA model...")
        self.all_weights['LoRA'] = self.merge_and_get_weights(
            base_model_path, lora_model_path, "LoRA"
        )
        
        # Load GRIT model
        print("\n3. Loading GRIT model...")
        self.all_weights['GRIT'] = self.merge_and_get_weights(
            base_model_path, grit_model_path, "GRIT"
        )
        
        print("\n✅ All models loaded successfully!")
    
    def compute_weight_differences(self):
        """Compute weight differences between adapted models and base model"""
        print("📊 Computing weight differences...")
        
        base_weights = self.all_weights['Base']
        differences = {}
        
        for model_name in ['LoRA', 'GRIT']:
            model_weights = self.all_weights[model_name]
            differences[model_name] = {}
            
            for param_name in base_weights.keys():
                if param_name in model_weights:
                    # Compute difference: adapted - base
                    diff = model_weights[param_name] - base_weights[param_name]
                    differences[model_name][param_name] = diff
        
        return differences
    
    def create_unified_comparison_plot(self, points_per_module=1000):
        """Create unified 3D plot comparing all three models using all target modules"""
        print("🎨 Creating unified comparison plot...")
        print(f"📊 Using ALL target modules: {', '.join(self.target_modules)}")
        
        # Get weight differences
        differences = self.compute_weight_differences()
        
        # Prepare data for visualization
        all_points = []
        all_colors = []
        all_labels = []
        all_models = []
        all_modules = []
        
        # Color mapping
        color_map = {
            'Base': 'blue',
            'LoRA': 'red', 
            'GRIT': 'green'
        }
        
        model_index = 0
        
        # Process base model weights - ALL target module parameters
        print("Processing base model weights from all target modules...")
        base_weights = self.all_weights['Base']
        
        # Group parameters by target module type
        module_params = {module: [] for module in self.target_modules}
        for param_name in base_weights.keys():
            for module in self.target_modules:
                if module in param_name:
                    module_params[module].append(param_name)
                    break
        
        total_params = sum(len(params) for params in module_params.values())
        print(f"📊 Found {total_params} parameters across {len(self.target_modules)} target modules")
        
        with tqdm(total=total_params, desc="Processing base weights", unit="param") as pbar:
            for module_type, param_names in module_params.items():
                if not param_names:
                    continue
                    
                for param_name in param_names:
                    weight_tensor = base_weights[param_name]
                    points_3d = self._extract_3d_points(weight_tensor, points_per_module)
                    
                    all_points.extend(points_3d)
                    all_colors.extend([model_index] * len(points_3d))
                    all_labels.extend([param_name.split('.')[-2:]] * len(points_3d))  # layer + module
                    all_models.extend(['Base'] * len(points_3d))
                    all_modules.extend([module_type] * len(points_3d))
                    
                    pbar.update(1)
        
        model_index += 1
        
        # Process adapted models (differences) - ALL target module parameters
        for model_name in ['LoRA', 'GRIT']:
            print(f"Processing {model_name} differences from all target modules...")
            model_diffs = differences[model_name]
            
            with tqdm(total=total_params, desc=f"Processing {model_name} diffs", unit="param") as pbar:
                for module_type, param_names in module_params.items():
                    if not param_names:
                        continue
                        
                    for param_name in param_names:
                        if param_name in model_diffs:
                            diff_tensor = model_diffs[param_name]
                            points_3d = self._extract_3d_points(diff_tensor, points_per_module)
                            
                            all_points.extend(points_3d)
                            all_colors.extend([model_index] * len(points_3d))
                            all_labels.extend([param_name.split('.')[-2:]] * len(points_3d))
                            all_models.extend([model_name] * len(points_3d))
                            all_modules.extend([module_type] * len(points_3d))
                        
                        pbar.update(1)
            
            model_index += 1
        
        # Create the unified 3D plot
        fig = go.Figure()
        
        # Add traces for each model
        for i, model_name in enumerate(['Base', 'LoRA', 'GRIT']):
            mask = np.array(all_models) == model_name
            if np.any(mask):
                model_points = np.array(all_points)[mask]
                model_labels = np.array(all_labels)[mask]
                model_modules = np.array(all_modules)[mask]
                
                # Create hover text with module information
                hover_text = []
                for j, (label, module) in enumerate(zip(model_labels, model_modules)):
                    if isinstance(label, list):
                        layer_info = '.'.join(label)
                    else:
                        layer_info = str(label)
                    hover_text.append(f"{module} - {layer_info}")
                
                fig.add_trace(go.Scatter3d(
                    x=model_points[:, 0],
                    y=model_points[:, 1],
                    z=model_points[:, 2],
                    mode='markers',
                    marker=dict(
                        size=4,
                        color=color_map[model_name],
                        opacity=0.7,
                        symbol='circle'
                    ),
                    name=f'{model_name} {"(Weights)" if model_name == "Base" else "(Δ Weights)"}',
                    text=hover_text,
                    hovertemplate=f'<b>{model_name}</b><br>' +
                                 'Module: %{text}<br>' +
                                 'X: %{x:.4f}<br>' +
                                 'Y: %{y:.4f}<br>' +
                                 'Z: %{z:.4f}<extra></extra>'
                ))
        
        # Update layout
        fig.update_layout(
            title={
                'text': "Comprehensive Model Comparison: Base vs LoRA vs GRIT<br>" +
                       f"<sub>Target Modules: {', '.join(self.target_modules)}</sub>",
                'x': 0.5,
                'xanchor': 'center',
                'font': {'size': 16}
            },
            scene=dict(
                xaxis_title='Principal Component 1',
                yaxis_title='Principal Component 2',
                zaxis_title='Principal Component 3',
                bgcolor='white',
                xaxis=dict(gridcolor='lightgray'),
                yaxis=dict(gridcolor='lightgray'),
                zaxis=dict(gridcolor='lightgray')
            ),
            width=1000,
            height=800,
            showlegend=True,
            legend=dict(
                x=0.02,
                y=0.98,
                bgcolor='rgba(255,255,255,0.8)',
                bordercolor='black',
                borderwidth=1
            )
        )
        
        # Add annotations with module information
        module_stats = {module: 0 for module in self.target_modules}
        for module in all_modules:
            if module in module_stats:
                module_stats[module] += 1
        
        stats_text = "Module Statistics:<br>" + "<br>".join([
            f"• {module}: {count} points" for module, count in module_stats.items() if count > 0
        ])
        
        fig.add_annotation(
            text="• Base: Original model weights<br>• LoRA/GRIT: Weight differences (Δ)<br><br>" + stats_text,
            xref="paper", yref="paper",
            x=0.02, y=0.02,
            bgcolor="rgba(255,255,255,0.8)",
            bordercolor="black",
            borderwidth=1,
            font=dict(size=9)
        )
        
        return fig
    
    def _extract_3d_points(self, tensor, points_per_layer):
        """Extract 3D points from tensor using PCA"""
        # Convert to numpy
        numpy_data = tensor.numpy()
        
        # Sample points if tensor is too large
        if numpy_data.size > points_per_layer:
            flat_data = numpy_data.flatten()
            indices = np.random.choice(len(flat_data), points_per_layer, replace=False)
            sampled_data = flat_data[indices]
        else:
            sampled_data = numpy_data.flatten()
        
        # Create 3D representation
        if len(numpy_data.shape) == 2 and min(numpy_data.shape) >= 3:
            # Use PCA for 2D matrices
            pca = PCA(n_components=3)
            try:
                pca_data = pca.fit_transform(numpy_data)
                
                # Sample from PCA space
                if len(pca_data) > points_per_layer:
                    indices = np.random.choice(len(pca_data), points_per_layer, replace=False)
                    points_3d = pca_data[indices]
                else:
                    points_3d = pca_data
            except:
                # Fallback to synthetic 3D representation
                points_3d = self._create_synthetic_3d(sampled_data)
        else:
            # Create synthetic 3D representation for other cases
            points_3d = self._create_synthetic_3d(sampled_data)
        
        return points_3d
    
    def _create_synthetic_3d(self, data):
        """Create synthetic 3D representation from 1D data"""
        n_points = len(data)
        std_dev = np.std(data) * 0.1
        
        points_3d = np.column_stack([
            data,
            np.random.normal(0, std_dev, n_points),
            np.random.normal(0, std_dev, n_points)
        ])
        
        return points_3d
    
    def cleanup(self):
        """Clean up and free memory"""
        self.all_weights.clear()
        self._clear_gpu_memory()
        print("🧹 Memory cleaned up")

In [6]:
print("🚀 Unified Model Comparison Tool")
print("=" * 40)
    
# Model paths
BASE_MODEL = "meta-llama/Llama-3.2-3B"
LORA_MODEL = "te4bag/LoRA_alpaca_Llama3.2_3B"
GRIT_MODEL = "te4bag/grit-lora-Llama-3.2-3B-bnb-4bit-alpaca"
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]
    
# Initialize comparator with chosen target modules
comparator = UnifiedModelComparator(target_modules=TARGET_MODULES)
    
try:
    # Load all models
    comparator.load_all_models(BASE_MODEL, LORA_MODEL, GRIT_MODEL)
        
    # Create unified comparison plot
    fig = comparator.create_unified_comparison_plot(points_per_module=1000)
        
    if fig:
        # print("\n🎨 Displaying unified comparison plot...")
        # fig.show()
            
        # save the plot
        fig.write_html("model_comparison.html")
        print("💾 Plot saved as model_comparison.html")
        
    print("\n✅ Unified comparison complete!")
        
except Exception as e:
    print(f"❌ Error during comparison: {e}")
        
finally:
    comparator.cleanup()

# Additional convenience function for quick comparisons
def compare_models_with_modules(base_model, lora_model, grit_model, target_modules, points_per_module=1000):
    """
    Quick comparison function with custom target modules
    
    Args:
        base_model: Base model path
        lora_model: LoRA model path  
        grit_model: GRIT model path
        target_modules: List of target modules to analyze
        points_per_module: Number of points to sample per module
    """
    print(f"🎯 Comparing models on modules: {', '.join(target_modules)}")
    
    comparator = UnifiedModelComparator(target_modules=target_modules)
    
    try:
        comparator.load_all_models(base_model, lora_model, grit_model)
        fig = comparator.create_unified_comparison_plot(points_per_module=points_per_module)
        
        if fig:
            fig.show()
            return fig
        
    except Exception as e:
        print(f"❌ Error: {e}")
        return None
        
    finally:
        comparator.cleanup()

# Example usage patterns:
def example_usage():
    """Examples of how to use the tool with different target modules"""
    
    BASE_MODEL = "meta-llama/Llama-3.2-3B"
    LORA_MODEL = "te4bag/LoRA_alpaca_Llama3.2_3B"
    GRIT_MODEL = "te4bag/grit-lora-Llama-3.2-3B-bnb-4bit-alpaca"
    
    print("Example 1: Attention modules only")
    compare_models_with_modules(
        BASE_MODEL, LORA_MODEL, GRIT_MODEL,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
    )

🚀 Unified Model Comparison Tool
🔧 Using device: cuda
🎯 Target modules: q_proj, k_proj, v_proj, o_proj
🚀 Loading all models for comparison...

1. Loading base model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

📊 Extracting weights from Base...


Extracting Base weights: 100%|██████████| 112/112 [00:03<00:00, 32.14param/s]



2. Loading LoRA model...
Loading base model from meta-llama/Llama-3.2-3B...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading adapter from te4bag/LoRA_alpaca_Llama3.2_3B...
Merging adapter with base model...
📊 Extracting weights from LoRA...


Extracting LoRA weights: 100%|██████████| 112/112 [00:03<00:00, 32.87param/s]



3. Loading GRIT model...
Loading base model from meta-llama/Llama-3.2-3B...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading adapter from te4bag/grit-lora-Llama-3.2-3B-bnb-4bit-alpaca...
Merging adapter with base model...
📊 Extracting weights from GRIT...


Extracting GRIT weights: 100%|██████████| 112/112 [00:04<00:00, 27.90param/s]



✅ All models loaded successfully!
🎨 Creating unified comparison plot...
📊 Using ALL target modules: q_proj, k_proj, v_proj, o_proj
📊 Computing weight differences...
Processing base model weights from all target modules...
📊 Found 112 parameters across 4 target modules


Processing base weights: 100%|██████████| 112/112 [01:02<00:00,  1.78param/s]


Processing LoRA differences from all target modules...


Processing LoRA diffs: 100%|██████████| 112/112 [01:05<00:00,  1.70param/s]


Processing GRIT differences from all target modules...


Processing GRIT diffs: 100%|██████████| 112/112 [01:13<00:00,  1.53param/s]


💾 Plot saved as model_comparison.html

✅ Unified comparison complete!
🧹 Memory cleaned up
