In [26]:
import time 
import heapq 
import torch 
import torch.nn as nn 
from lib.sparsegpt import SparseGPT 
from lib.layerwrapper import WrappedGPT
from lib.ablate import AblateGPT 
import argparse
import os 
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from importlib.metadata import version
from lib.prune import prune_wanda, prune_magnitude, prune_sparsegpt, prune_ablate, check_sparsity, find_layers
from lib.eval import eval_ppl, eval_zero_shot
import random
import torch
from datasets import load_dataset

In [None]:
# Wrapper for tokenized input IDs
class TokenizerWrapper:
    def __init__(self, input_ids):
        self.input_ids = input_ids

def prepare_calibration_input(model, dataloader, device):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    # Set the appropriate device if "embed_tokens" is mapped
    if "model.embed_tokens" in model.hf_device_map:
        device = model.hf_device_map["model.embed_tokens"]

    dtype = next(iter(model.parameters())).dtype
    inps_cpu = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device="cpu")
    outs_cpu = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device="cpu")
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps_cpu[cache['i']] = inp.cpu()  # Store on CPU to save GPU memory
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask'].cpu()
            cache['position_ids'] = kwargs['position_ids'].cpu()
            raise ValueError

    # Wrap the first layer to intercept input/output
    layers[0] = Catcher(layers[0])

    for batch in dataloader:
        try:
            batch_gpu = batch[0].to(device)
            model(batch_gpu)
        except ValueError:
            pass
        torch.cuda.empty_cache()  # Clear GPU cache after each batch

    # Restore the original model layer
    layers[0] = layers[0].module
    model.config.use_cache = use_cache

    # Move only necessary data to GPU in a batched manner later in the pipeline
    return inps_cpu, outs_cpu, cache['attention_mask'], cache['position_ids']


def find_layers(module, layers=[nn.Linear], name=''):
    """
    Recursively find the layers of a certain type in a module.

    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res

def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before):
    thres_cumsum = sum_before * alpha 
    sort_mask = tmp_metric <= thres_cumsum.reshape((-1,1))
    thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True)-1)
    W_mask = (W_metric <= thres)
    cur_sparsity = (W_mask==True).sum() / W_mask.numel()
    return W_mask, cur_sparsity

def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
    use_cache = model.config.use_cache 
    model.config.use_cache = False 

    print("loading calibdation data")
    dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)
    print("dataset loading complete")
    with torch.no_grad():
        inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device)

    layers = model.model.layers
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        if f"model.layers.{i}" in model.hf_device_map:   ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
            dev = model.hf_device_map[f"model.layers.{i}"]
            inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev)

        wrapped_layers = {}
        for name in subset:
            wrapped_layers[name] = WrappedGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in wrapped_layers:
            handles.append(subset[name].register_forward_hook(add_batch(name)))
        for j in range(args.nsamples):
            with torch.no_grad():
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
        for h in handles:
            h.remove()

        for name in subset:
            print(f"pruning layer {i} name {name}")
            W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))

            W_mask = (torch.zeros_like(W_metric) == 1)  ## initialize a mask to be all False
            if prune_n != 0:
                # structured n:m sparsity
                for ii in range(W_metric.shape[1]):
                    if ii % prune_m == 0:
                        tmp = W_metric[:,ii:(ii+prune_m)].float()
                        W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
            else:
                sort_res = torch.sort(W_metric, dim=-1, stable=True)

                if args.use_variant:
                    # wanda variant 
                    tmp_metric = torch.cumsum(sort_res[0], dim=1)
                    sum_before = W_metric.sum(dim=1)

                    alpha = 0.4
                    alpha_hist = [0., 0.8]
                    W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before)
                    while (torch.abs(cur_sparsity - args.sparsity_ratio)>0.001) and (alpha_hist[1]-alpha_hist[0]>=0.001):
                        if cur_sparsity > args.sparsity_ratio:
                            alpha_new = (alpha + alpha_hist[0]) / 2.0
                            alpha_hist[1] = alpha
                        else:
                            alpha_new = (alpha + alpha_hist[1]) / 2.0
                            alpha_hist[0] = alpha

                        alpha = alpha_new 
                        W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before)
                    print(f"alpha found {alpha} sparsity {cur_sparsity:.6f}")
                else:
                    # unstructured pruning
                    indices = sort_res[1][:,:int(W_metric.shape[1]*args.sparsity_ratio)]
                    W_mask.scatter_(1, indices, True)

            subset[name].weight.data[W_mask] = 0  ## set weights to zero 

        for j in range(args.nsamples):
            with torch.no_grad():
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
        inps, outs = outs, inps

    model.config.use_cache = use_cache 
    torch.cuda.empty_cache()


def get_llm(model_name, cache_dir="llm_weights"):
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float16, 
        cache_dir=cache_dir,  # Use the specified cache directory
        low_cpu_mem_usage=True, 
        device_map="auto",
        use_auth_token=True
    )
    model.seqlen = model.config.max_position_embeddings
    return model

# Load and process c4 dataset
def get_c4(nsamples, seed, seqlen, tokenizer):
    # Load train and validation datasets with streaming
    traindata = load_dataset('allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
    valdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')

    # Initialize random seed for sample selection
    random.seed(seed)

    # Generator to yield tokenized batches on-the-fly
    def data_generator():
        for _ in range(nsamples):
            while True:
                # Sample a random document and tokenize it
                i = random.randint(0, len(traindata) - 1)
                trainenc = tokenizer(traindata[i]['text'], return_tensors='pt', truncation=True, max_length=seqlen)
                if trainenc.input_ids.shape[1] >= seqlen:
                    break
            # Slice a sequence-length section randomly from tokenized data
            start = random.randint(0, trainenc.input_ids.shape[1] - seqlen)
            input_ids = trainenc.input_ids[:, start:start + seqlen]
            labels = input_ids.clone()
            labels[:, :-1] = -100  # Only predict next token
            
            yield (input_ids, labels)

    # Wrap generator to provide trainloader and test data
    trainloader = list(data_generator())  # Lazily load samples
    valenc = tokenizer(' '.join(valdata[:100]['text']), return_tensors='pt', truncation=True, max_length=seqlen)
    valenc = TokenizerWrapper(valenc.input_ids)
    
    return trainloader, valenc

# Function to select the appropriate loader based on dataset name
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None):
    if 'wikitext2' in name:
        return get_wikitext2(nsamples, seed, seqlen, tokenizer)
    if "c4" in name:
        return get_c4(nsamples, seed, seqlen, tokenizer)
    
# Load and process wikitext2 dataset
def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    # Load train and test datasets
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    # Encode datasets
    trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    # Generate samples from training set
    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc

In [None]:
def test_prepare_calibration_input_initialization():
    model = get_llm("baffo32/decapoda-research-llama-7B-hf")
    tokenizer = AutoTokenizer.from_pretrained("baffo32/decapoda-research-llama-7B-hf", use_fast=False)
    device = torch.device("cuda:0")
    dataloader, _ = get_loaders("c4", nsamples=4, seed=0, seqlen=model.seqlen, tokenizer=tokenizer)

    inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device)
    assert inps.shape == (128, model.seqlen, model.config.hidden_size), "Incorrect shape for `inps`"
    assert outs.shape == (128, model.seqlen, model.config.hidden_size), "Incorrect shape for `outs`"
    assert attention_mask is not None, "Attention mask should not be None"
    assert position_ids is not None, "Position IDs should not be None"

def test_prepare_calibration_input_device():
    cache_dir = "llm_weights"  # Specify your cache directory for local loading
    model = get_llm("baffo32/decapoda-research-llama-7B-hf", cache_dir=cache_dir)
    tokenizer = AutoTokenizer.from_pretrained("baffo32/decapoda-research-llama-7B-hf", use_fast=False, cache_dir=cache_dir)
    device = torch.device("cuda:0")
    dataloader, _ = get_loaders("c4", nsamples=4, seed=0, seqlen=model.seqlen, tokenizer=tokenizer)

    # Run prepare_calibration_input
    inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device)
    
    # Check that the inps and outs tensors are on the CPU, as defined in prepare_calibration_input
    assert inps.device == torch.device("cpu"), "Input tensor `inps` should be on the CPU initially"
    assert outs.device == torch.device("cpu"), "Output tensor `outs` should be on the CPU initially"
    
    # Check that the attention mask and position IDs are also on the CPU
    assert attention_mask.device == torch.device("cpu"), "Attention mask should be on the CPU"
    assert position_ids.device == torch.device("cpu"), "Position IDs should be on the CPU"
    
    # Check shapes to ensure tensors are correctly sized
    assert inps.shape == (128, model.seqlen, model.config.hidden_size), "Incorrect shape for `inps`"
    assert outs.shape == (128, model.seqlen, model.config.hidden_size), "Incorrect shape for `outs`"
    assert attention_mask is not None, "Attention mask should not be None"
    assert position_ids is not None, "Position IDs should not be None"

    print("test_prepare_calibration_input_device passed.")

def test_find_layers():
    cache_dir = "llm_weights"  # Specify your cache directory for local loading
    model = get_llm("baffo32/decapoda-research-llama-7B-hf", cache_dir=cache_dir)
    layers = find_layers(model, layers=[nn.Linear])

    assert isinstance(layers, dict), "Output should be a dictionary"
    assert len(layers) > 0, "No layers found when expected"
    for name, layer in layers.items():
        assert isinstance(layer, nn.Linear), f"Layer `{name}` is not of type `nn.Linear`"
        
def test_prune_wanda_sparsity():
    args = argparse.Namespace(nsamples=4, seed=0, use_variant=False, sparsity_ratio=0.5, prune_method='wanda')
    cache_dir = "llm_weights"  # Specify your cache directory for local loading
    model = get_llm("baffo32/decapoda-research-llama-7B-hf", cache_dir=cache_dir)
    tokenizer = AutoTokenizer.from_pretrained("baffo32/decapoda-research-llama-7B-hf")
    device = torch.device("cuda:0")

    prune_wanda(args, model, tokenizer, device=device)

    for name, layer in find_layers(model).items():
        sparsity = (layer.weight == 0).float().mean().item()
        assert abs(sparsity - args.sparsity_ratio) < 0.05, f"Sparsity mismatch in layer `{name}`"

def test_prune_wanda_weight_modification():
    args = argparse.Namespace(nsamples=4, seed=0, use_variant=False, sparsity_ratio=0.5, prune_method='wanda')
    cache_dir = "llm_weights"  # Specify your cache directory for local loading
    model = get_llm("baffo32/decapoda-research-llama-7B-hf", cache_dir=cache_dir)
    tokenizer = AutoTokenizer.from_pretrained("baffo32/decapoda-research-llama-7B-hf")
    device = torch.device("cuda:0")

    # Capture original weights before pruning
    original_weights = {name: layer.weight.clone() for name, layer in find_layers(model).items()}

    prune_wanda(args, model, tokenizer, device=device)

    for name, layer in find_layers(model).items():
        modified_weights = layer.weight
        assert torch.sum(original_weights[name] != modified_weights) > 0, f"Layer `{name}` was not modified"


In [None]:
test_prune_wanda_sparsity()

Loading checkpoint shards: 100%|██████████| 33/33 [00:56<00:00,  1.72s/it]


Loading calibration data
Dataset loading complete


UnboundLocalError: local variable 'dev' referenced before assignment

: 