In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
from huggingface_hub import login
from datasets import load_dataset, load_from_disk
import numpy as np
from tqdm import tqdm
import random
!pip install -q "flash-attn==2.8.3" --no-build-isolation



In [None]:
class LlamaFFNPruner:
    def __init__(self, model_id="meta-llama/Llama-3.1-8B-Instruct"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = 'left'

        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            # attn_implementation="flash_attention_2",
            # attn_implementation="sdpa",
            # attn_implementation="eager",
            device_map={"": "cuda"}
        )
        print(f"Attention implementation: {self.model.config._attn_implementation}")  # ← ADD THIS LINE


        self.original_intermediate_size = self.model.config.intermediate_size
        self.original_num_heads = self.model.config.num_attention_heads
        self.original_num_kv_heads = self.model.config.num_key_value_heads

        print(f"Original intermediate size: {self.original_intermediate_size}")
        print(f"Original attention heads: {self.original_num_heads}")
        print(f"Original KV heads: {self.original_num_kv_heads}")

    def compute_neuron_pair_importance(self, gate_weight, up_weight):
        gate_max_abs = torch.max(torch.abs(gate_weight), dim=1).values
        up_max_abs = torch.max(torch.abs(up_weight), dim=1).values
        importance_scores = gate_max_abs + up_max_abs
        return importance_scores

    def compute_attention_head_importance(self, layer):
        """
        Compute importance scores for attention heads using output projection weights.
        Higher L2 norm = more important head.
        """
        attn = layer.self_attn
        o_proj_weight = attn.o_proj.weight.data

        # Fix: Use config instead of attn attributes
        num_heads = self.model.config.num_attention_heads
        head_dim = self.model.config.hidden_size // num_heads

        head_importance = []
        for head_idx in range(num_heads):
            start_idx = head_idx * head_dim
            end_idx = start_idx + head_dim
            head_weights = o_proj_weight[:, start_idx:end_idx]
            importance = torch.norm(head_weights, p=2).item()
            head_importance.append(importance)

        return torch.tensor(head_importance)

    def prune_attention_heads(self, layer, prune_percent):
        """
        Prune attention heads in a layer.
        Handles Grouped Query Attention (GQA) correctly.
        CRITICAL: Maintains proper Q/KV head ratio for GQA.
        """
        attn = layer.self_attn

        # Use config attributes
        num_heads = self.model.config.num_attention_heads
        head_dim = self.model.config.hidden_size // num_heads
        hidden_size = self.model.config.hidden_size
        num_key_value_heads = self.model.config.num_key_value_heads

        # Calculate group size (how many Q heads share one KV head)
        num_groups = num_heads // num_key_value_heads

        # Compute importance and select heads to keep
        head_importance = self.compute_attention_head_importance(layer)
        num_to_prune = int(prune_percent * num_heads)
        num_heads_to_keep = num_heads - num_to_prune

        if num_heads_to_keep <= 0:
            raise ValueError(f"Pruning {prune_percent*100}% would remove all attention heads!")

        # CRITICAL FIX: Ensure kept heads are divisible by num_groups
        # This maintains the GQA structure
        num_heads_to_keep = (num_heads_to_keep // num_groups) * num_groups

        if num_heads_to_keep == 0:
            # If rounding down gives 0, keep at least one group
            num_heads_to_keep = num_groups

        _, indices_to_keep = torch.topk(head_importance, num_heads_to_keep, largest=True, sorted=True)
        indices_to_keep = indices_to_keep.sort().values.tolist()

        # CRITICAL FIX: Select complete groups of heads
        # Group heads by their KV head assignment
        kv_head_groups = {}
        for q_head_idx in range(num_heads):
            kv_head_idx = q_head_idx // num_groups
            if kv_head_idx not in kv_head_groups:
                kv_head_groups[kv_head_idx] = []
            kv_head_groups[kv_head_idx].append(q_head_idx)

        # Calculate group importance (sum of all Q heads in group)
        group_importance = []
        for kv_idx in sorted(kv_head_groups.keys()):
            group_heads = kv_head_groups[kv_idx]
            group_imp = sum(head_importance[h].item() for h in group_heads)
            group_importance.append((group_imp, kv_idx))

        # Select top groups to keep
        group_importance.sort(reverse=True)
        num_kv_heads_to_keep = num_heads_to_keep // num_groups
        kv_heads_to_keep = [kv_idx for _, kv_idx in group_importance[:num_kv_heads_to_keep]]
        kv_heads_to_keep.sort()

        # Get all Q heads that belong to kept KV heads
        indices_to_keep = []
        for kv_idx in kv_heads_to_keep:
            indices_to_keep.extend(kv_head_groups[kv_idx])
        indices_to_keep.sort()

        # Calculate new dimensions
        new_num_heads = len(indices_to_keep)
        new_num_kv_heads = len(kv_heads_to_keep)
        new_hidden_size = new_num_heads * head_dim
        new_kv_hidden_size = new_num_kv_heads * head_dim

        # Verify GQA constraint
        assert new_num_heads % new_num_kv_heads == 0, \
            f"Q heads ({new_num_heads}) must be divisible by KV heads ({new_num_kv_heads})"

        # --- Prune Q projection ---
        q_indices = []
        for head_idx in indices_to_keep:
            start = head_idx * head_dim
            q_indices.extend(range(start, start + head_dim))

        new_q_proj = nn.Linear(hidden_size, new_hidden_size, bias=False).to(self.device)
        new_q_proj.weight.data = attn.q_proj.weight.data[q_indices, :]

        # --- Prune K projection ---
        k_indices = []
        for kv_head_idx in kv_heads_to_keep:
            start = kv_head_idx * head_dim
            k_indices.extend(range(start, start + head_dim))

        new_k_proj = nn.Linear(hidden_size, new_kv_hidden_size, bias=False).to(self.device)
        new_k_proj.weight.data = attn.k_proj.weight.data[k_indices, :]

        # --- Prune V projection ---
        new_v_proj = nn.Linear(hidden_size, new_kv_hidden_size, bias=False).to(self.device)
        new_v_proj.weight.data = attn.v_proj.weight.data[k_indices, :]

        # --- Prune O projection ---
        new_o_proj = nn.Linear(new_hidden_size, hidden_size, bias=False).to(self.device)
        new_o_proj.weight.data = attn.o_proj.weight.data[:, q_indices]

        # Replace projections
        attn.q_proj = new_q_proj
        attn.k_proj = new_k_proj
        attn.v_proj = new_v_proj
        attn.o_proj = new_o_proj

        return new_num_heads, new_num_kv_heads

    def prune_mlp_layer(self, mlp, prune_percent):
        gate_weight = mlp.gate_proj.weight.data.float()
        up_weight = mlp.up_proj.weight.data.float()
        importance_scores = self.compute_neuron_pair_importance(gate_weight, up_weight)

        original_size = gate_weight.size(0)
        num_to_prune = int(prune_percent * original_size)
        k = original_size - num_to_prune
        ALIGNMENT = 256
        k = (k // ALIGNMENT) * ALIGNMENT
        if k == 0:
            k = ALIGNMENT
        print(f"    Aligning to {k} neurons (was {original_size - num_to_prune})")

        if k <= 0:
            raise ValueError(f"Pruning {prune_percent*100}% would remove all neurons!")

        _, indices_to_keep = torch.topk(importance_scores, k, largest=True, sorted=True)
        indices_to_keep = indices_to_keep.sort().values

        new_gate_proj = nn.Linear(mlp.gate_proj.in_features, k, bias=False).to(self.device)
        new_up_proj = nn.Linear(mlp.up_proj.in_features, k, bias=False).to(self.device)
        new_down_proj = nn.Linear(k, mlp.down_proj.out_features, bias=False).to(self.device)

        new_gate_proj.weight.data = mlp.gate_proj.weight.data[indices_to_keep, :]
        new_up_proj.weight.data = mlp.up_proj.weight.data[indices_to_keep, :]
        new_down_proj.weight.data = mlp.down_proj.weight.data[:, indices_to_keep]

        return new_gate_proj, new_up_proj, new_down_proj, k

    def prune_model(self, mlp_prune_percent, attention_prune_percent):
        """
        Prune both MLP and attention layers.

        Args:
            mlp_prune_percent: Fraction of MLP neurons to remove (e.g., 0.4 = remove 40%)
            attention_prune_percent: Fraction of attention heads to remove (e.g., 0.3 = remove 30%)
        """
        print(f"\n{'='*70}")
        print(f"Pruning Model: MLP={mlp_prune_percent*100:.0f}%, Attention={attention_prune_percent*100:.0f}%")
        print(f"{'='*70}")

        new_intermediate_size = None
        new_num_heads = None
        new_num_kv_heads = None

        total_mlp_neurons_original = 0
        total_mlp_neurons_kept = 0
        total_attention_heads_original = 0
        total_attention_heads_kept = 0

        for idx, layer in enumerate(self.model.model.layers):
            # --- Prune MLP ---
            mlp = layer.mlp
            new_gate, new_up, new_down, mlp_size = self.prune_mlp_layer(mlp, mlp_prune_percent)
            mlp.gate_proj = new_gate
            mlp.up_proj = new_up
            mlp.down_proj = new_down

            total_mlp_neurons_original += self.original_intermediate_size
            total_mlp_neurons_kept += mlp_size

            # --- Prune Attention ---
            attn_heads, kv_heads = self.prune_attention_heads(layer, attention_prune_percent)

            total_attention_heads_original += self.original_num_heads
            total_attention_heads_kept += attn_heads

            if new_intermediate_size is None:
                new_intermediate_size = mlp_size
                new_num_heads = attn_heads
                new_num_kv_heads = kv_heads
                print(f"Layer 0:")
                print(f"  MLP: {self.original_intermediate_size} -> {mlp_size} neurons")
                print(f"  Attention: {self.original_num_heads} -> {attn_heads} heads (KV: {self.original_num_kv_heads} -> {kv_heads})")

            if (idx + 1) % 8 == 0:
                print(f"Processed {idx + 1}/{len(self.model.model.layers)} layers...")

        # Update model config
        self.model.config.intermediate_size = new_intermediate_size
        self.model.config.num_attention_heads = new_num_heads
        self.model.config.num_key_value_heads = new_num_kv_heads

        # Print summary
        mlp_kept_pct = (total_mlp_neurons_kept / total_mlp_neurons_original) * 100
        attn_kept_pct = (total_attention_heads_kept / total_attention_heads_original) * 100

        print(f"\n{'='*70}")
        print(f"Pruning Complete!")
        print(f"{'='*70}")
        print(f"MLP Neurons:      {total_mlp_neurons_kept:,}/{total_mlp_neurons_original:,} ({mlp_kept_pct:.1f}% kept)")
        print(f"Attention Heads:  {total_attention_heads_kept}/{total_attention_heads_original} ({attn_kept_pct:.1f}% kept)")
        print(f"Final MLP size:   {new_intermediate_size}")
        print(f"Final Attention:  {new_num_heads} heads, {new_num_kv_heads} KV heads")

        return self.model

    def count_parameters(self):
        return sum(p.numel() for p in self.model.parameters())

    def get_model_size_mb(self):
        param_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
        return param_size / (1024 * 1024)


class ModelEvaluator:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device

    def evaluate_perplexity(self, dataset_name="wikitext", split="test", max_samples=100):
        """FIXED: Correct perplexity calculation"""
        print(f"\nEvaluating perplexity on {dataset_name}...")

        if dataset_name == "wikitext":
            dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
        else:
            dataset = load_dataset(dataset_name, split=split)

        dataset = dataset.select(range(min(max_samples, len(dataset))))

        total_loss = 0
        total_tokens = 0

        self.model.eval()
        with torch.no_grad():
            for item in tqdm(dataset, desc="Computing perplexity"):
                text = item['text'] if 'text' in item else str(item)

                if len(text.strip()) == 0:
                    continue

                encodings = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
                input_ids = encodings.input_ids.to(self.device)

                if input_ids.shape[1] < 2:
                    continue

                outputs = self.model(input_ids, labels=input_ids)
                num_tokens = input_ids.shape[1]
                total_loss += outputs.loss.item() * num_tokens
                total_tokens += num_tokens

        avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
        perplexity = np.exp(avg_loss)
        return perplexity

    def evaluate_generation_quality(self, prompts, max_new_tokens=300, batch_size=64):
        """FIXED: Correct throughput calculation"""
        print(f"\nEvaluating generation quality on {len(prompts)} prompts (batch_size={batch_size})...")

        all_outputs = []
        total_time = 0
        total_tokens = 0
        num_prompts = 0

        self.model.eval()

        for i in tqdm(range(0, len(prompts), batch_size), desc="Generating batches"):
            batch_prompts = prompts[i:i+batch_size]

            inputs = self.tokenizer(
                batch_prompts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(self.device)

            torch.cuda.synchronize()
            start_time = time.time()

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    min_new_tokens=max_new_tokens,
                    do_sample=False,
                    # temperature=0.7,
                    # top_p=0.9,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )

            torch.cuda.synchronize()
            batch_time = time.time() - start_time

            # Decode outputs
            for j in range(len(batch_prompts)):
                generated_text = self.tokenizer.decode(outputs[j], skip_special_tokens=True)
                all_outputs.append(generated_text)

            # Count tokens based on MINIMUM guaranteed length
            batch_tokens = len(batch_prompts) * max_new_tokens

            total_time += batch_time
            num_prompts += len(batch_prompts)

        # Calculate throughput based on guaranteed minimum tokens
        total_tokens = num_prompts * max_new_tokens
        throughput = total_tokens / total_time if total_time > 0 else 0

        results = {
            'outputs': all_outputs,
            'total_time': total_time,
            'total_tokens': total_tokens,
            'tokens_per_second': throughput,
            'avg_time_per_prompt': total_time / len(prompts),
        }

        print(f"Total tokens: {total_tokens}, Total time: {total_time:.2f}s, Throughput: {throughput:.2f} tok/s")
        return results

    def measure_inference_speed(self, batch_sizes=[16], seq_length=300, num_runs=1):
        print(f"\nMeasuring inference speed...")
        speed_results = {}
        self.model.eval()

        for batch_size in batch_sizes:
            input_ids = torch.randint(0, self.tokenizer.vocab_size, (batch_size, seq_length)).to(self.device)

            # Warmup
            with torch.no_grad():
                for _ in range(3):
                    _ = self.model(input_ids)

            torch.cuda.synchronize()
            times = []

            with torch.no_grad():
                for _ in range(num_runs):
                    start = time.time()
                    _ = self.model(input_ids)
                    torch.cuda.synchronize()
                    times.append(time.time() - start)

            avg_time = np.mean(times)
            throughput = (batch_size * seq_length) / avg_time

            speed_results[batch_size] = {
                'avg_time_sec': avg_time,
                'throughput_tokens_per_sec': throughput
            }

            print(f"Batch {batch_size}: {avg_time:.4f}s, {throughput:.2f} tokens/sec")

        return speed_results


def run_pruning_experiment(prune_configs):
    """
    Run pruning experiments with different MLP and attention pruning percentages.

    Args:
        prune_configs: List of tuples (mlp_prune_pct, attention_prune_pct)
                      Example: [(0.4, 0.3), (0.6, 0.5)]
    """
    results = {}

    train_data_sample = load_from_disk("/content/drive/MyDrive/nq_subset")
    test_prompts = random.sample(train_data_sample["question"]["text"], 64)

    # Baseline evaluation (first config)
    print(f"\n{'='*60}")
    print(f"BASELINE (No Pruning)")
    print(f"{'='*60}")

    pruner = LlamaFFNPruner()
    evaluator = ModelEvaluator(pruner.model, pruner.tokenizer, pruner.device)

    baseline_params = pruner.count_parameters()
    baseline_ppl = evaluator.evaluate_perplexity(max_samples=50)
    baseline_gen = evaluator.evaluate_generation_quality(test_prompts)

    results['baseline'] = {
        'params': baseline_params,
        'perplexity': baseline_ppl,
        'tokens_per_second': baseline_gen['tokens_per_second'],
        'total_tokens': baseline_gen['total_tokens'],
        'total_time': baseline_gen['total_time'],
    }

    print(f"\n--- BASELINE RESULTS ---")
    print(f"Parameters: {baseline_params:,}")
    print(f"Perplexity: {baseline_ppl:.2f}")
    print(f"Throughput: {baseline_gen['tokens_per_second']:.2f} tok/s")

    del pruner
    torch.cuda.empty_cache()

    # Run pruning experiments
    for mlp_prune_pct, attn_prune_pct in prune_configs:
        print(f"\n{'='*60}")
        print(f"PRUNING EXPERIMENT: MLP={mlp_prune_pct*100:.0f}%, Attention={attn_prune_pct*100:.0f}%")
        print(f"{'='*60}")

        pruner = LlamaFFNPruner()

        pruner.prune_model(mlp_prune_pct, attn_prune_pct)
        evaluator = ModelEvaluator(pruner.model, pruner.tokenizer, pruner.device)

        pruned_params = pruner.count_parameters()
        pruned_ppl = evaluator.evaluate_perplexity(max_samples=50)
        pruned_gen = evaluator.evaluate_generation_quality(test_prompts)

        param_reduction = (1 - pruned_params / baseline_params) * 100
        ppl_degradation = ((pruned_ppl - baseline_ppl) / baseline_ppl) * 100
        speed_change = ((pruned_gen['tokens_per_second'] - baseline_gen['tokens_per_second']) /
                       baseline_gen['tokens_per_second']) * 100

        print(f"\n--- RESULTS ---")
        print(f"Parameters: {param_reduction:.1f}% reduction ({pruned_params:,} params)")
        print(f"Perplexity: {pruned_ppl:.2f} ({ppl_degradation:+.1f}% change)")
        print(f"Throughput: {pruned_gen['tokens_per_second']:.2f} tok/s ({speed_change:+.1f}%)")

        if pruned_ppl > 5 * baseline_ppl:
            print("\n⚠️  WARNING: Model severely degraded (5x perplexity increase)!")

        print(f"\nSample generation:")
        print(f"{pruned_gen['outputs'][0][:200]}...")

        config_key = f"mlp{int(mlp_prune_pct*100)}_attn{int(attn_prune_pct*100)}"
        results[config_key] = {
            'mlp_prune_pct': mlp_prune_pct,
            'attn_prune_pct': attn_prune_pct,
            'params': pruned_params,
            'param_reduction_pct': param_reduction,
            'perplexity': pruned_ppl,
            'ppl_degradation_pct': ppl_degradation,
            'tokens_per_second': pruned_gen['tokens_per_second'],
            'speed_change_pct': speed_change,
            'total_tokens': pruned_gen['total_tokens'],
            'total_time': pruned_gen['total_time'],
        }

        del pruner
        torch.cuda.empty_cache()

    return results


if __name__ == "__main__":
    # login(token=os.getenv("HF_TOKEN", "YOUR_HUGGINGFACE_TOKEN_HERE"))


    # Define pruning configurations to test
    # Format: (mlp_prune_percent, attention_prune_percent)
    prune_configs = [
        (0.25, 0.20),
        (0.4, 0.3),   # Moderate: 40% MLP, 30% attention
        (0.6, 0.5),   # Aggressive: 60% MLP, 50% attention

    ]

    # Alternative: Test only MLP vs MLP+Attention
    # prune_configs = [
    #     (0.6, 0.0),   # 60% MLP only (no attention pruning)
    #     (0.6, 0.5),   # 60% MLP + 50% attention
    # ]

    results = run_pruning_experiment(prune_configs)

    # # Save results
    # with open('pruning_results.json', 'w') as f:
    #     json.dump({k: {kk: vv for kk, vv in v.items() if kk != 'outputs'}
    #                for k, v in results.items()}, f, indent=2)

    print(f"\n{'='*60}")
    print("EXPERIMENT SUMMARY")
    print(f"{'='*60}")
    for config, result in results.items():
        if config == 'baseline':
            print(f"\n{config.upper()}:")
            print(f"  Params: {result['params']:,}")
            print(f"  Perplexity: {result['perplexity']:.2f}")
            print(f"  Throughput: {result['tokens_per_second']:.2f} tok/s")
        else:
            print(f"\n{config.upper()}:")
            print(f"  Params: {result['param_reduction_pct']:.1f}% reduction")
            print(f"  Perplexity: {result['ppl_degradation_pct']:+.1f}% change")
            print(f"  Throughput: {result['speed_change_pct']:+.1f}% change")


BASELINE (No Pruning)
Using device: cuda


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Attention implementation: sdpa
Original intermediate size: 14336
Original attention heads: 32
Original KV heads: 8

Evaluating perplexity on wikitext...


Computing perplexity: 100%|██████████| 50/50 [00:01<00:00, 43.47it/s]



Evaluating generation quality on 64 prompts (batch_size=64)...


Generating batches: 100%|██████████| 1/1 [00:13<00:00, 13.96s/it]


Total tokens: 19200, Total time: 13.93s, Throughput: 1378.58 tok/s

--- BASELINE RESULTS ---
Parameters: 8,030,261,248
Perplexity: 15.46
Throughput: 1378.58 tok/s

PRUNING EXPERIMENT: MLP=25%, Attention=20%
Using device: cuda


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Attention implementation: sdpa
Original intermediate size: 14336
Original attention heads: 32
Original KV heads: 8

Pruning Model: MLP=25%, Attention=20%
    Aligning to 10752 neurons (was 10752)
Layer 0:
  MLP: 14336 -> 10752 neurons
  Attention: 32 -> 24 heads (KV: 8 -> 6)
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
Processed 8/32 layers...
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
    Aligning to 10752 neurons (was 10752)
Processed 16/32 layers...
    Aligning to 10752 neurons (was 10752)
 

Computing perplexity: 100%|██████████| 50/50 [00:01<00:00, 43.25it/s]



Evaluating generation quality on 64 prompts (batch_size=64)...


Generating batches: 100%|██████████| 1/1 [00:13<00:00, 13.11s/it]


Total tokens: 19200, Total time: 13.08s, Throughput: 1467.90 tok/s

--- RESULTS ---
Parameters: 21.7% reduction (6,285,430,784 params)
Perplexity: 124.33 (+704.0% change)
Throughput: 1467.90 tok/s (+6.5%)


Sample generation:
who said truth justice and the american way of life.
The American Declaration of Independence.
The American Declaration of Independence.
The American Declaration of Independence.
The American Declarat...

PRUNING EXPERIMENT: MLP=40%, Attention=30%
Using device: cuda


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Attention implementation: sdpa
Original intermediate size: 14336
Original attention heads: 32
Original KV heads: 8

Pruning Model: MLP=40%, Attention=30%
    Aligning to 8448 neurons (was 8602)
Layer 0:
  MLP: 14336 -> 8448 neurons
  Attention: 32 -> 20 heads (KV: 8 -> 5)
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
Processed 8/32 layers...
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 8602)
Processed 16/32 layers...
    Aligning to 8448 neurons (was 8602)
    Aligning to 8448 neurons (was 86

Computing perplexity: 100%|██████████| 50/50 [00:01<00:00, 43.43it/s]



Evaluating generation quality on 64 prompts (batch_size=64)...


Generating batches: 100%|██████████| 1/1 [00:13<00:00, 13.03s/it]


Total tokens: 19200, Total time: 13.01s, Throughput: 1475.82 tok/s

--- RESULTS ---
Parameters: 35.1% reduction (5,211,688,960 params)
Perplexity: 716.29 (+4532.3% change)
Throughput: 1475.82 tok/s (+7.1%)


Sample generation:
who said truth justice and the american way.
://://_REF to correct errors and the correct errors and correct errors and correct errors and correct errors and correct errors and correct errors and corr...

PRUNING EXPERIMENT: MLP=60%, Attention=50%
Using device: cuda


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Attention implementation: sdpa
Original intermediate size: 14336
Original attention heads: 32
Original KV heads: 8

Pruning Model: MLP=60%, Attention=50%
    Aligning to 5632 neurons (was 5735)
Layer 0:
  MLP: 14336 -> 5632 neurons
  Attention: 32 -> 16 heads (KV: 8 -> 4)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
Processed 8/32 layers...
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
Processed 16/32 layers...
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 57

Computing perplexity: 100%|██████████| 50/50 [00:01<00:00, 43.80it/s]



Evaluating generation quality on 64 prompts (batch_size=64)...


Generating batches: 100%|██████████| 1/1 [00:13<00:00, 13.01s/it]

Total tokens: 19200, Total time: 12.98s, Throughput: 1478.77 tok/s

--- RESULTS ---
Parameters: 51.0% reduction (3,936,620,544 params)
Perplexity: 3236.87 (+20833.0% change)
Throughput: 1478.77 tok/s (+7.3%)


Sample generation:
who said truth justice and the american way. The way.
The way.
Previous way.
Previous way. The way.
Previous way.
Previous way. The way.
Previous way. The way.
Previous way. The way. The way.
Previous...

EXPERIMENT SUMMARY

BASELINE:
  Params: 8,030,261,248
  Perplexity: 15.46
  Throughput: 1378.58 tok/s

MLP25_ATTN20:
  Params: 21.7% reduction
  Perplexity: +704.0% change
  Throughput: +6.5% change

MLP40_ATTN30:
  Params: 35.1% reduction
  Perplexity: +4532.3% change
  Throughput: +7.1% change

MLP60_ATTN50:
  Params: 51.0% reduction
  Perplexity: +20833.0% change
  Throughput: +7.3% change





In [None]:
def calculate_memory_bandwidth_per_forward_pass(model, model_name="Model"):
    """
    Calculate total memory that needs to be loaded for one forward pass.
    """
    print(f"\n{'='*70}")
    print(f"Memory Bandwidth Analysis: {model_name}")
    print(f"{'='*70}")

    total_params = 0
    total_bytes = 0

    # Breakdown by layer type
    attention_params = 0
    attention_bytes = 0
    mlp_params = 0
    mlp_bytes = 0

    for name, param in model.named_parameters():
        num_params = param.numel()
        num_bytes = param.numel() * param.element_size()

        total_params += num_params
        total_bytes += num_bytes

        if 'self_attn' in name:
            attention_params += num_params
            attention_bytes += num_bytes
        elif 'mlp' in name:
            mlp_params += num_params
            mlp_bytes += num_bytes

    # Convert to readable units
    total_mb = total_bytes / (1024**2)
    total_gb = total_bytes / (1024**3)
    attention_mb = attention_bytes / (1024**2)
    mlp_mb = mlp_bytes / (1024**2)

    print(f"\nTotal Parameters: {total_params:,}")
    print(f"Total Memory: {total_mb:.2f} MB ({total_gb:.3f} GB)")
    print(f"\nBreakdown:")
    print(f"  Attention: {attention_params:,} params = {attention_mb:.2f} MB")
    print(f"  MLP:       {mlp_params:,} params = {mlp_mb:.2f} MB")
    print(f"  Other:     {(total_bytes - attention_bytes - mlp_bytes)/(1024**2):.2f} MB")

    print(f"\nMemory per forward pass (single token generation):")
    print(f"  All weights must be loaded: {total_mb:.2f} MB")
    print(f"  32 layers × forward = {total_mb:.2f} MB total bandwidth")

    return {
        'total_params': total_params,
        'total_bytes': total_bytes,
        'total_mb': total_mb,
        'attention_mb': attention_mb,
        'mlp_mb': mlp_mb,
    }

# Measure both models
pruner = LlamaFFNPruner()
baseline_mem = calculate_memory_bandwidth_per_forward_pass(pruner.model, "BASELINE")
pruner.prune_model(0.6, 0.5)
pruned_mem = calculate_memory_bandwidth_per_forward_pass(pruner.model, "PRUNED")

# Calculate reduction
print(f"\n{'='*70}")
print(f"MEMORY BANDWIDTH REDUCTION")
print(f"{'='*70}")
reduction_pct = (1 - pruned_mem['total_mb'] / baseline_mem['total_mb']) * 100
print(f"Total: {baseline_mem['total_mb']:.2f} MB → {pruned_mem['total_mb']:.2f} MB")
print(f"Reduction: {reduction_pct:.1f}%")

attention_reduction = (1 - pruned_mem['attention_mb'] / baseline_mem['attention_mb']) * 100
mlp_reduction = (1 - pruned_mem['mlp_mb'] / baseline_mem['mlp_mb']) * 100
print(f"\nAttention: {attention_reduction:.1f}% reduction")
print(f"MLP:       {mlp_reduction:.1f}% reduction")

Using device: cuda


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Original intermediate size: 14336
Original attention heads: 32
Original KV heads: 8

Memory Bandwidth Analysis: BASELINE

Total Parameters: 8,030,261,248
Total Memory: 15316.51 MB (14.958 GB)

Breakdown:
  Attention: 1,342,177,280 params = 2560.00 MB
  MLP:       5,637,144,576 params = 10752.00 MB
  Other:     2004.51 MB

Memory per forward pass (single token generation):
  All weights must be loaded: 15316.51 MB
  32 layers × forward = 15316.51 MB total bandwidth

Pruning Model: MLP=60%, Attention=50%
    Aligning to 5632 neurons (was 5735)
Layer 0:
  MLP: 14336 -> 5632 neurons
  Attention: 32 -> 16 heads (KV: 8 -> 4)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons (was 5735)
Processed 8/32 layers...
    Aligning to 5632 neurons (was 5735)
    Aligning to 5632 neurons

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
from huggingface_hub import login
from datasets import load_dataset, load_from_disk
import numpy as np
from tqdm import tqdm
import random
import json
import torch_pruning as tp
import gc


class LlamaFFNPruner:
    def __init__(self, model_id="meta-llama/Llama-3.1-8B-Instruct"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = 'left'

        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map={"": "cuda"}
        )

        print(f"Model loaded: {model_id}")
        print(f"Total parameters: {self.count_parameters():,}")
        self._print_memory_usage("After model load")

    def _print_memory_usage(self, stage=""):
        """Print current GPU memory usage."""
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / (1024**3)
            reserved = torch.cuda.memory_reserved() / (1024**3)
            print(f"[{stage}] GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")

    def _wrap_model_for_pruning(self):
        """
        Wrap model to return only logits (not tuple) for torch-pruning compatibility.
        """
        class ModelWrapper(nn.Module):
            def __init__(self, model):
                super().__init__()
                self.model = model

            def forward(self, input_ids):
                # Get model output
                outputs = self.model(input_ids)
                # Return only logits (first element of output)
                return outputs.logits

        return ModelWrapper(self.model)

    def prune_model(self, prune_ratio=0.5, round_to=256):
        """
        Prune using torch-pruning MetaPruner with IN-PLACE weight replacement.

        Args:
            prune_ratio: Overall pruning ratio (0.5 = remove 50% of parameters)
            round_to: Round pruned dimensions to multiples of this (256 for GPU efficiency)
        """
        print(f"\n{'='*70}")
        print(f"Pruning Model with MetaPruner (IN-PLACE)")
        print(f"Prune ratio: {prune_ratio*100:.0f}%")
        print(f"GPU alignment: {round_to}")
        print(f"{'='*70}")

        self._print_memory_usage("Before pruning")

        # ⭐ Wrap model to return only logits (fix for tuple output issue)
        wrapped_model = self._wrap_model_for_pruning()

        # Create example inputs
        example_inputs = torch.randint(0, 50000, (1, 128)).to(self.device)

        # Identify layers to ignore (embeddings, layer norms)
        ignored_layers = []
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Embedding, nn.LayerNorm)):
                ignored_layers.append(module)

        print(f"Ignoring {len(ignored_layers)} layers (embeddings, layer norms)")

        # Define importance metric (magnitude-based)
        importance = tp.importance.MagnitudeImportance(p=2)

        # Get baseline stats using wrapped model
        print("\nCalculating baseline FLOPs and parameters...")
        base_macs, base_params = tp.utils.count_ops_and_params(wrapped_model, example_inputs)
        print(f"Baseline: {base_params:,} parameters, {base_macs:,} MACs")

        # Create MetaPruner with wrapped model
        print("\nInitializing MetaPruner...")
        pruner = tp.pruner.MetaPruner(
            wrapped_model,  # ⭐ Use wrapped model
            example_inputs,
            importance=importance,
            pruning_ratio=prune_ratio,

            # ⭐ Key parameters for GPU efficiency
            round_to=round_to,  # Align all dimensions to multiples of 256
            ignored_layers=ignored_layers,

            # Pruning strategy
            iterative_steps=1,  # One-shot pruning
            global_pruning=False,  # Prune each layer independently
            isomorphic=False,  # Allow different pruning ratios per layer
        )

        self._print_memory_usage("After pruner init")

        # Execute pruning (modifies wrapped_model.model in-place)
        print("\nExecuting one-shot pruning (in-place)...")
        pruner.step()
        print("✓ Pruning complete!")

        # The original self.model is now pruned (it's inside wrapped_model.model)
        # No need to reassign, it was modified in-place

        # Clean up pruner and wrapper to free memory
        del pruner
        del wrapped_model
        del example_inputs
        gc.collect()
        torch.cuda.empty_cache()

        self._print_memory_usage("After pruning and cleanup")

        # Get pruned stats (create new wrapper for verification)
        print("\nCalculating pruned FLOPs and parameters...")
        wrapped_model_check = self._wrap_model_for_pruning()
        example_inputs_check = torch.randint(0, 50000, (1, 128)).to(self.device)
        pruned_macs, pruned_params = tp.utils.count_ops_and_params(wrapped_model_check, example_inputs_check)
        del wrapped_model_check
        del example_inputs_check

        # Print results
        param_reduction = (1 - pruned_params / base_params) * 100
        macs_reduction = (1 - pruned_macs / base_macs) * 100

        print(f"\n{'='*70}")
        print(f"Pruning Results")
        print(f"{'='*70}")
        print(f"Parameters: {base_params:,} → {pruned_params:,} ({param_reduction:.1f}% reduction)")
        print(f"MACs:       {base_macs:,} → {pruned_macs:,} ({macs_reduction:.1f}% reduction)")

        self._print_memory_usage("Final state")

        return self.model

    def count_parameters(self):
        return sum(p.numel() for p in self.model.parameters())

    def get_model_size_mb(self):
        param_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
        return param_size / (1024 * 1024)


class ModelEvaluator:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device

    def evaluate_perplexity(self, dataset_name="wikitext", split="test", max_samples=100):
        """Evaluate perplexity on a dataset."""
        print(f"\nEvaluating perplexity on {dataset_name}...")

        if dataset_name == "wikitext":
            dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
        else:
            dataset = load_dataset(dataset_name, split=split)

        dataset = dataset.select(range(min(max_samples, len(dataset))))

        total_loss = 0
        total_tokens = 0

        self.model.eval()
        with torch.no_grad():
            for item in tqdm(dataset, desc="Computing perplexity"):
                text = item['text'] if 'text' in item else str(item)

                if len(text.strip()) == 0:
                    continue

                encodings = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
                input_ids = encodings.input_ids.to(self.device)

                if input_ids.shape[1] < 2:
                    continue

                outputs = self.model(input_ids, labels=input_ids)
                num_tokens = input_ids.shape[1]
                total_loss += outputs.loss.item() * num_tokens
                total_tokens += num_tokens

        avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
        perplexity = np.exp(avg_loss)
        return perplexity

    def evaluate_generation_quality(self, prompts, max_new_tokens=300, batch_size=32):
        """Evaluate generation throughput."""
        print(f"\nEvaluating generation quality on {len(prompts)} prompts (batch_size={batch_size})...")

        all_outputs = []
        total_time = 0
        num_prompts = 0

        self.model.eval()

        for i in tqdm(range(0, len(prompts), batch_size), desc="Generating batches"):
            batch_prompts = prompts[i:i+batch_size]

            inputs = self.tokenizer(
                batch_prompts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(self.device)

            torch.cuda.synchronize()
            start_time = time.time()

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    min_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )

            torch.cuda.synchronize()
            batch_time = time.time() - start_time

            # Decode outputs
            for j in range(len(batch_prompts)):
                generated_text = self.tokenizer.decode(outputs[j], skip_special_tokens=True)
                all_outputs.append(generated_text)

            total_time += batch_time
            num_prompts += len(batch_prompts)

        # Calculate throughput
        total_tokens = num_prompts * max_new_tokens
        throughput = total_tokens / total_time if total_time > 0 else 0

        results = {
            'outputs': all_outputs,
            'total_time': total_time,
            'total_tokens': total_tokens,
            'tokens_per_second': throughput,
            'avg_time_per_prompt': total_time / len(prompts),
        }

        print(f"Total tokens: {total_tokens}, Total time: {total_time:.2f}s, Throughput: {throughput:.2f} tok/s")
        return results

    def measure_forward_pass_speed(self, batch_size=32, seq_length=512, num_runs=1):
        """Measure forward pass throughput."""
        print(f"\nMeasuring forward pass speed (batch={batch_size}, seq={seq_length})...")

        input_ids = torch.randint(0, self.tokenizer.vocab_size, (batch_size, seq_length)).to(self.device)

        self.model.eval()

        # Warmup
        with torch.no_grad():
            for _ in range(10):
                _ = self.model(input_ids)

        torch.cuda.synchronize()
        start = time.time()

        with torch.no_grad():
            for _ in range(num_runs):
                _ = self.model(input_ids)

        torch.cuda.synchronize()
        elapsed = time.time() - start

        throughput = (batch_size * seq_length * num_runs) / elapsed

        print(f"Batch {batch_size}, seq {seq_length}: {throughput:.2f} tok/s")
        return throughput


def run_pruning_experiment(prune_ratios):
    """
    Run pruning experiments with different pruning ratios.
    Memory-efficient: only one model in memory at a time.

    Args:
        prune_ratios: List of pruning ratios to test
                     Example: [0.3, 0.4, 0.5]
    """
    results = {}

    # Load test prompts
    train_data_sample = load_from_disk("/content/drive/MyDrive/nq_subset")
    test_prompts = random.sample(train_data_sample["question"]["text"], 64)

    # ===== BASELINE EVALUATION =====
    print(f"\n{'='*70}")
    print(f"BASELINE (No Pruning)")
    print(f"{'='*70}")

    baseline_pruner = LlamaFFNPruner()
    baseline_evaluator = ModelEvaluator(
        baseline_pruner.model,
        baseline_pruner.tokenizer,
        baseline_pruner.device
    )

    baseline_params = baseline_pruner.count_parameters()
    baseline_ppl = baseline_evaluator.evaluate_perplexity(max_samples=50)
    baseline_gen = baseline_evaluator.evaluate_generation_quality(test_prompts)
    baseline_forward = baseline_evaluator.measure_forward_pass_speed()

    results['baseline'] = {
        'params': baseline_params,
        'perplexity': baseline_ppl,
        'generation_throughput': baseline_gen['tokens_per_second'],
        'forward_pass_throughput': baseline_forward,
        'total_tokens': baseline_gen['total_tokens'],
        'total_time': baseline_gen['total_time'],
    }

    print(f"\n{'='*70}")
    print(f"BASELINE RESULTS")
    print(f"{'='*70}")
    print(f"Parameters:     {baseline_params:,}")
    print(f"Perplexity:     {baseline_ppl:.2f}")
    print(f"Generation:     {baseline_gen['tokens_per_second']:.2f} tok/s")
    print(f"Forward pass:   {baseline_forward:.2f} tok/s")

    # ⭐ CRITICAL: Free baseline model memory before loading pruned models
    del baseline_pruner
    del baseline_evaluator
    gc.collect()
    torch.cuda.empty_cache()
    print("\n✓ Baseline model freed from memory")

    # ===== PRUNING EXPERIMENTS =====
    for prune_ratio in prune_ratios:
        print(f"\n{'='*70}")
        print(f"PRUNING EXPERIMENT: {prune_ratio*100:.0f}% ratio")
        print(f"{'='*70}")

        # Create and prune model (in-place, no extra memory)
        pruner = LlamaFFNPruner()
        pruner.prune_model(prune_ratio=prune_ratio, round_to=256)

        # Evaluate pruned model
        evaluator = ModelEvaluator(pruner.model, pruner.tokenizer, pruner.device)

        pruned_params = pruner.count_parameters()
        pruned_ppl = evaluator.evaluate_perplexity(max_samples=50)
        pruned_gen = evaluator.evaluate_generation_quality(test_prompts)
        pruned_forward = evaluator.measure_forward_pass_speed()

        # Calculate metrics
        param_reduction = (1 - pruned_params / baseline_params) * 100
        ppl_degradation = ((pruned_ppl - baseline_ppl) / baseline_ppl) * 100
        gen_speedup = pruned_gen['tokens_per_second'] / baseline_gen['tokens_per_second']
        forward_speedup = pruned_forward / baseline_forward

        print(f"\n{'='*70}")
        print(f"RESULTS")
        print(f"{'='*70}")
        print(f"Parameters:     {param_reduction:.1f}% reduction ({pruned_params:,} params)")
        print(f"Perplexity:     {pruned_ppl:.2f} ({ppl_degradation:+.1f}% change)")
        print(f"Generation:     {pruned_gen['tokens_per_second']:.2f} tok/s ({gen_speedup:.2f}x speedup)")
        print(f"Forward pass:   {pruned_forward:.2f} tok/s ({forward_speedup:.2f}x speedup)")

        if pruned_ppl > 5 * baseline_ppl:
            print("\n⚠️  WARNING: Model severely degraded (5x perplexity increase)!")

        print(f"\nSample generation:")
        print(f"{pruned_gen['outputs'][0][:200]}...")

        # Store results
        config_key = f"prune_{int(prune_ratio*100)}"
        results[config_key] = {
            'prune_ratio': prune_ratio,
            'params': pruned_params,
            'param_reduction_pct': param_reduction,
            'perplexity': pruned_ppl,
            'ppl_degradation_pct': ppl_degradation,
            'generation_throughput': pruned_gen['tokens_per_second'],
            'generation_speedup': gen_speedup,
            'forward_pass_throughput': pruned_forward,
            'forward_speedup': forward_speedup,
            'total_tokens': pruned_gen['total_tokens'],
            'total_time': pruned_gen['total_time'],
        }

        # ⭐ CRITICAL: Free this pruned model before next iteration
        del pruner
        del evaluator
        gc.collect()
        torch.cuda.empty_cache()
        print(f"\n✓ Pruned model {prune_ratio*100:.0f}% freed from memory")

    return results


if __name__ == "__main__":
    import os
    # Login with token from environment variable or set your token here
    # login(token=os.getenv("HF_TOKEN", "YOUR_HUGGINGFACE_TOKEN_HERE"))

    # Define pruning ratios to test
    prune_ratios = [
        0.3,  # 30% pruning (conservative)
        0.4,  # 40% pruning (moderate)
        0.5,  # 50% pruning (aggressive)
    ]

    # Run experiments
    results = run_pruning_experiment(prune_ratios)

    # Save results
    with open('pruning_results.json', 'w') as f:
        json.dump({k: {kk: vv for kk, vv in v.items() if kk != 'outputs'}
                   for k, v in results.items()}, f, indent=2)

    # Print summary
    print(f"\n{'='*70}")
    print("EXPERIMENT SUMMARY")
    print(f"{'='*70}")
    for config, result in results.items():
        if config == 'baseline':
            print(f"\n{config.upper()}:")
            print(f"  Params:      {result['params']:,}")
            print(f"  Perplexity:  {result['perplexity']:.2f}")
            print(f"  Generation:  {result['generation_throughput']:.2f} tok/s")
            print(f"  Forward:     {result['forward_pass_throughput']:.2f} tok/s")
        else:
            print(f"\n{config.upper()}:")
            print(f"  Params:      {result['param_reduction_pct']:.1f}% reduction")
            print(f"  Perplexity:  {result['ppl_degradation_pct']:+.1f}% change")
            print(f"  Generation:  {result['generation_speedup']:.2f}x speedup")
            print(f"  Forward:     {result['forward_speedup']:.2f}x speedup")

ModuleNotFoundError: No module named 'torch_pruning'

In [None]:
def test_forward_pass_speed(model, batch_size=32, seq_len=512, num_runs=2):
    """
    Test pure forward pass speed (no generation).
    """
    input_ids = torch.randint(0, 50000, (batch_size, seq_len)).cuda()

    model.eval()

    # Warmup
    with torch.no_grad():
        for _ in range(10):
            _ = model(input_ids)

    torch.cuda.synchronize()
    start = time.time()

    with torch.no_grad():
        for _ in range(num_runs):
            _ = model(input_ids)

    torch.cuda.synchronize()
    elapsed = time.time() - start

    throughput = (batch_size * seq_len * num_runs) / elapsed
    print(f"Batch {batch_size}, seq {seq_len}: {throughput:.2f} tok/s")
    return throughput

# Test baseline
pruner = LlamaFFNPruner()
baseline_speed = test_forward_pass_speed(pruner.model)

pruner.prune_model(0.6, 0.5)
# Test pruned
pruned_speed = test_forward_pass_speed(pruner.model)

# Calculate speedup
speedup = pruned_speed / baseline_speed
print(f"Forward pass speedup: {speedup:.2f}x")

Using device: cuda


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
def test_llama_local_pruning():
    """
    Prune Llama 3.1 8B using LocalPruner (layer-by-layer pruning).
    Faster than MetaPruner because it builds dependency graphs per layer.
    """
    import torch
    import torch.nn as nn
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import torch_pruning as tp
    from huggingface_hub import login
    import gc
    import time

    print("="*70)
    print("Testing Llama 3.1 8B with LocalPruner")
    print("="*70)

    # Login with token from environment variable or set your token here
    import os
    # login(token=os.getenv("HF_TOKEN", "YOUR_HUGGINGFACE_TOKEN_HERE"))

    # Load model
    print("\n1. Loading model...")
    model_id = "meta-llama/Llama-3.1-8B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="cuda"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    params_before = sum(p.numel() for p in model.parameters())
    print(f"✓ Model loaded")
    print(f"  Parameters: {params_before:,}")

    # Print memory
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024**3)
        print(f"  GPU Memory: {allocated:.2f}GB")

    # Local pruning function
    def prune_layer_local(layer, layer_idx, prune_ratio=0.5, round_to=256):
        """
        Prune a single layer using torch-pruning's local dependency graph.
        """
        print(f"\n  Pruning layer {layer_idx}...")

        mlp = layer.mlp

        # Calculate importance scores manually
        gate_weight = mlp.gate_proj.weight.data
        up_weight = mlp.up_proj.weight.data

        gate_importance = torch.max(torch.abs(gate_weight), dim=1).values
        up_importance = torch.max(torch.abs(up_weight), dim=1).values
        importance = gate_importance + up_importance

        # Determine neurons to prune
        original_size = gate_weight.shape[0]
        num_to_prune = int(original_size * prune_ratio)
        num_to_keep = original_size - num_to_prune

        # Round to alignment
        num_to_keep = (num_to_keep // round_to) * round_to
        if num_to_keep == 0:
            num_to_keep = round_to

        num_to_prune = original_size - num_to_keep

        if num_to_prune <= 0:
            print(f"    No pruning needed")
            return original_size, original_size

        # Get indices of least important neurons to prune
        _, pruning_indices = torch.topk(importance, num_to_prune, largest=False)
        pruning_indices = pruning_indices.sort().values

        try:
            # Create example input for this layer
            # Input to MLP is [batch, seq_len, hidden_size]
            hidden_size = mlp.gate_proj.in_features
            example_input = torch.randn(1, 10, hidden_size).to(gate_weight.device).to(gate_weight.dtype)

            # Build dependency graph for this layer's MLP only
            DG = tp.DependencyGraph()
            DG.build_dependency(mlp, example_inputs=example_input)

            # Create pruning plan for gate_proj
            # This will automatically handle gate_proj, up_proj, and down_proj dependencies
            pruning_plan = DG.get_pruning_plan(
                mlp.gate_proj,
                tp.prune_linear_out_channels,
                idxs=pruning_indices.tolist()
            )

            # Execute pruning
            pruning_plan.exec()

            print(f"    ✓ Layer {layer_idx}: {original_size} -> {num_to_keep} neurons")
            return original_size, num_to_keep

        except Exception as e:
            print(f"    ❌ LocalPruner failed for layer {layer_idx}: {e}")
            print(f"    Falling back to manual pruning...")

            # Fallback to manual pruning if LocalPruner fails
            _, indices_to_keep = torch.topk(importance, num_to_keep, largest=True)
            indices_to_keep = indices_to_keep.sort().values

            # Create new layers
            new_gate = nn.Linear(
                mlp.gate_proj.in_features,
                num_to_keep,
                bias=False
            ).to(gate_weight.device).to(gate_weight.dtype)

            new_up = nn.Linear(
                mlp.up_proj.in_features,
                num_to_keep,
                bias=False
            ).to(up_weight.device).to(up_weight.dtype)

            new_down = nn.Linear(
                num_to_keep,
                mlp.down_proj.out_features,
                bias=False
            ).to(mlp.down_proj.weight.device).to(mlp.down_proj.weight.dtype)

            # Copy weights
            new_gate.weight.data = gate_weight[indices_to_keep, :]
            new_up.weight.data = up_weight[indices_to_keep, :]
            new_down.weight.data = mlp.down_proj.weight.data[:, indices_to_keep]

            # Replace
            layer.mlp.gate_proj = new_gate
            layer.mlp.up_proj = new_up
            layer.mlp.down_proj = new_down

            print(f"    ✓ Layer {layer_idx} (manual): {original_size} -> {num_to_keep} neurons")
            return original_size, num_to_keep

    # Prune all layers
    print("\n2. Pruning MLP layers (50% pruning, layer-by-layer)...")
    start_time = time.time()

    total_original = 0
    total_kept = 0

    for i, layer in enumerate(model.model.layers):
        orig, kept = prune_layer_local(layer, i, prune_ratio=0.5, round_to=256)
        total_original += orig
        total_kept += kept

        # Clean up after each layer
        if i % 4 == 0:
            gc.collect()
            torch.cuda.empty_cache()

    prune_time = time.time() - start_time

    print(f"\n✓ Pruning complete in {prune_time:.1f}s!")
    print(f"  MLP neurons: {total_original:,} -> {total_kept:,}")
    print(f"  Reduction: {(1 - total_kept/total_original)*100:.1f}%")

    # Clean up
    gc.collect()
    torch.cuda.empty_cache()

    # Check results
    print("\n3. Verifying results...")
    params_after = sum(p.numel() for p in model.parameters())
    reduction = (1 - params_after / params_before) * 100

    print(f"  Parameters before: {params_before:,}")
    print(f"  Parameters after:  {params_after:,}")
    print(f"  Reduction: {reduction:.1f}%")

    # Print memory
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024**3)
        print(f"  GPU Memory: {allocated:.2f}GB")

    # Test forward pass
    print("\n4. Testing forward pass...")
    test_input = torch.randint(0, 50000, (1, 128)).cuda()

    with torch.no_grad():
        outputs = model(test_input)

    print(f"✓ Forward pass works: {outputs.logits.shape}")

    # Test generation
    print("\n5. Testing text generation...")
    test_prompt = "The future of AI is"
    inputs = tokenizer(test_prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=20,
            do_sample=True,
            temperature=0.7,
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"✓ Generation works!")
    print(f"  Prompt: {test_prompt}")
    print(f"  Generated: {generated_text}")

    # Measure speed
    print("\n6. Measuring forward pass speed...")
    batch_size = 32
    seq_len = 512
    test_input = torch.randint(0, 50000, (batch_size, seq_len)).cuda()

    # Warmup
    with torch.no_grad():
        for _ in range(5):
            _ = model(test_input)

    torch.cuda.synchronize()
    start = time.time()

    with torch.no_grad():
        for _ in range(5):
            _ = model(test_input)

    torch.cuda.synchronize()
    elapsed = time.time() - start

    throughput = (batch_size * seq_len * 100) / elapsed
    print(f"✓ Throughput: {throughput:.2f} tok/s")

    print("\n" + "="*70)
    print("✅ Llama 3.1 8B LocalPruner SUCCESS!")
    print("="*70)
    print(f"Model pruned from {params_before:,} to {params_after:,} parameters")
    print(f"Reduction: {reduction:.1f}%")
    print(f"Pruning time: {prune_time:.1f}s")
    print(f"Forward throughput: {throughput:.2f} tok/s")

    return model, tokenizer


# Run the test
pruned_model, tokenizer = test_llama_local_pruning()

Testing Llama 3.1 8B with LocalPruner

1. Loading model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

✓ Model loaded
  Parameters: 8,030,261,248
  GPU Memory: 36.79GB

2. Pruning MLP layers (50% pruning, layer-by-layer)...

  Pruning layer 0...
    ❌ LocalPruner failed for layer 0: 'DependencyGraph' object has no attribute 'get_pruning_plan'
    Falling back to manual pruning...
    ✓ Layer 0 (manual): 14336 -> 7168 neurons

  Pruning layer 1...
    ❌ LocalPruner failed for layer 1: 'DependencyGraph' object has no attribute 'get_pruning_plan'
    Falling back to manual pruning...
    ✓ Layer 1 (manual): 14336 -> 7168 neurons

  Pruning layer 2...
    ❌ LocalPruner failed for layer 2: 'DependencyGraph' object has no attribute 'get_pruning_plan'
    Falling back to manual pruning...
    ✓ Layer 2 (manual): 14336 -> 7168 neurons

  Pruning layer 3...
    ❌ LocalPruner failed for layer 3: 'DependencyGraph' object has no attribute 'get_pruning_plan'
    Falling back to manual pruning...
    ✓ Layer 3 (manual): 14336 -> 7168 neurons

  Pruning layer 4...
    ❌ LocalPruner failed for layer 4:

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



3. Verifying results...
  Parameters before: 8,030,261,248
  Parameters after:  5,211,688,960
  Reduction: 35.1%
  GPU Memory: 31.54GB

4. Testing forward pass...
✓ Forward pass works: torch.Size([1, 128, 128256])

5. Testing text generation...
✓ Generation works!
  Prompt: The future of AI is
  Generated: The future of AI is going to be dominated by robots. Robots are going to be dominated by computers. Computers are going to

6. Measuring forward pass speed...
✓ Throughput: 393231.32 tok/s

✅ Llama 3.1 8B LocalPruner SUCCESS!
Model pruned from 8,030,261,248 to 5,211,688,960 parameters
Reduction: 35.1%
Pruning time: 26.3s
Forward throughput: 393231.32 tok/s
