In [3]:
import argparse
import datetime
import os
import random

# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

import numpy
import torch
import torch.multiprocessing as mp
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerFast)
from transformers.modeling_attn_mask_utils import \
    _prepare_4d_causal_attention_mask

from lib import utils

parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--batch_size', default=2, type=int)
parser.add_argument('--devset_size', default=256, type=int)
parser.add_argument('--ctx_size', default=4096, type=int)
parser.add_argument('--base_model',
                    default='/data/common-weights-cy/Llama-3.2-1B',
                    type=str)
parser.add_argument('--save_path', default='hessians/llama3.2-1b', type=str)
parser.add_argument('--scratch_path', default=None, type=str)
parser.add_argument('--chunk_size', default=256, type=int)
parser.add_argument('--async_copy_speed', default=-1, type=int)
parser.add_argument('--act_save_rate', default=4, type=int)
parser.add_argument('--save_activations', action='store_true')
parser.add_argument('--sample_proc', default=4, type=int)


def move_fn(in_q, async_copy_speed):
    # async copy to avoid slow disk
    while True:
        item = in_q.get()
        if item is None:
            return
        src, tgt = item
        if async_copy_speed > 0:
            os.system(f'rsync --bwlimit={async_copy_speed} {src} {tgt}')
        else:
            os.system(f'rsync {src} {tgt}')
        os.system(f'rm {src}')
        print(f'moved {src} to {tgt}')


def forward_layer(layer, position_ids, attention_mask, bs, device, in_q,
                  out_q):
    torch.set_grad_enabled(False)
    layer = layer.to(device)
    position_ids = position_ids.to(device)
    attention_mask = attention_mask.to(device)
    done_qkv = utils.register_H_hook(layer.self_attn.q_proj, device)
    done_o = utils.register_H_hook(layer.self_attn.o_proj, device)
    done_up = utils.register_H_hook(layer.mlp.up_proj, device)
    done_down = utils.register_H_hook(layer.mlp.down_proj, device)

    while True:
        dev_emb = in_q.get()
        if dev_emb is None:
            layer = layer.cpu()
            position_ids = position_ids.cpu()
            attention_mask = attention_mask.cpu()
            out_q.put({
                'qkv': done_qkv(),
                'o': done_o(),
                'up': done_up(),
                'down': done_down()
            })
            return

        assert len(dev_emb) % bs == 0
        for i in range(len(dev_emb) // bs):
            dev_emb[i * bs:(i + 1) * bs] = layer(
                dev_emb[i * bs:(i + 1) * bs].to(device),
                position_ids=position_ids,
                attention_mask=attention_mask,
                use_cache=False,
                output_attentions=False)[0].cpu()


def accumulate(in_q, move_q, ngpus, args, transformer_layer_index):
    Hs = {}
    mus = {}
    cts = {}

    for i in range(ngpus):
        out = in_q.get()
        if i == 0:
            for key in out:
                Hs[key] = torch.zeros(out[key][0].shape,
                                      dtype=out[key][0].dtype)
                mus[key] = torch.zeros(out[key][1].shape,
                                       dtype=out[key][1].dtype)
                cts[key] = 0
        for key in out:
            Hs[key].add_(out[key][0])
            mus[key].add_(out[key][1])
            cts[key] += out[key][2]

    keys = list(Hs.keys())

    for key in Hs:
        mus[key].div_(cts[key])
        Hs[key].div_(cts[key])
        Hs[key].addmm_(-mus[key].unsqueeze(-1), mus[key].unsqueeze(0))
        save_path = f"{args.scratch_path}/{transformer_layer_index}_{key}.pt" if args.scratch_path is not None else f"{args.save_path}/{transformer_layer_index}_{key}.pt"
        torch.save(
            {
                'flatH': utils.sym_to_flat(Hs[key].to(torch.float32)),
                'mu': mus[key].to(torch.float32),
                'n': Hs[key].shape[0],
                'ct': cts[key]
            }, save_path)
        if args.scratch_path is not None:
            move_q.put(
                (f"{args.scratch_path}/{transformer_layer_index}_{key}.pt",
                 f"{args.save_path}/{transformer_layer_index}_{key}.pt"))

    del Hs, mus, cts, out


def main(args):
    print("loading model...")
    model = AutoModelForCausalLM.from_pretrained(args.base_model,
                                                 torch_dtype="auto",
                                                 low_cpu_mem_usage=True)
    print("loaded model!")
    tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token

    if os.path.isfile(f"{args.save_path}/dev_activations.pt"):
        print("loading cached dataset...")
        loaded_dev_activations = torch.load(
            f"{args.save_path}/dev_activations.pt")
        after_layer = loaded_dev_activations['after_layer']
        dev_emb = loaded_dev_activations['dev_emb']
        print(
            f"loaded cached dataset from {loaded_dev_activations['timestamp']}"
        )
    else:
        print("loading dataset...")
        devset = utils.sample_rp1t(tokenizer,
                                   args.devset_size,
                                   args.ctx_size,
                                   nproc=args.sample_proc)
        dev_emb = model.model.embed_tokens(devset)
        after_layer = -1
        print("loaded dataset!")

    print(f"dev_emb dtype: {dev_emb.dtype}")
    dev_emb.share_memory_()

    position_ids = torch.arange(args.ctx_size, dtype=torch.int64)[None, :] + \
        torch.zeros(args.batch_size, args.ctx_size, dtype=torch.int64)
    if hasattr(model.config, 'sliding_window'):
        attention_mask = _prepare_4d_causal_attention_mask(
            None, (args.batch_size, args.ctx_size),
            dev_emb[0:args.batch_size],
            0,
            sliding_window=model.config.sliding_window)
    else:
        attention_mask = _prepare_4d_causal_attention_mask(
            None, (args.batch_size, args.ctx_size),
            dev_emb[0:args.batch_size], 0)

    if args.scratch_path is not None:
        move_q = mp.Queue()
        move_p = mp.Process(target=move_fn,
                            args=(move_q, args.async_copy_speed))
        move_p.start()
    else:
        move_q = None

    for transformer_layer_index in range(len(model.model.layers)):
        if (transformer_layer_index <= after_layer):
            print(
                f"skipping layer {transformer_layer_index} because it is before cached activations at layer {after_layer}"
            )
            continue

        transformer_layer = model.model.layers[transformer_layer_index]
        # check that there are four layers, as expected
        assert (len([
            m for m in transformer_layer.modules()
            if isinstance(m, torch.nn.Linear)
        ]) == 7)

        chunk_size = min(args.chunk_size, len(dev_emb))
        ngpus = min(torch.cuda.device_count(), len(dev_emb) // chunk_size)

        manager = mp.get_context('spawn').Manager()
        in_q = manager.Queue()
        out_q = manager.Queue()

        accumulate_proc = mp.Process(target=accumulate,
                                     args=(out_q, move_q, ngpus, args,
                                           transformer_layer_index))
        accumulate_proc.start()

        forward_procs = []
        for i in range(ngpus):
            p = mp.Process(target=forward_layer,
                           args=(transformer_layer, position_ids,
                                 attention_mask, args.batch_size, i, in_q,
                                 out_q))
            p.start()
            forward_procs.append(p)

        assert len(
            dev_emb
        ) % args.batch_size == 0 and chunk_size % args.batch_size == 0
        i = 0
        while i < len(dev_emb):
            next = min(i + chunk_size, len(dev_emb))
            in_q.put(dev_emb[i:next])
            i = next

        for i in range(ngpus):
            in_q.put(None)

        for p in forward_procs:
            p.join()

        accumulate_proc.join()

        transformer_layer.cpu()
        model.model.layers[transformer_layer_index] = None
        utils.clean()

        if args.save_activations and (
                transformer_layer_index % args.act_save_rate == 0 or \
                transformer_layer_index == len(model.model.layers) - 1):
            if args.scratch_path is not None:
                if os.path.exists(f'{args.scratch_path}/dev_activations.pt'):
                    print('not saving layer since disk is too slow')
                else:
                    torch.save(
                        {
                            'dev_emb': dev_emb,
                            'after_layer': transformer_layer_index,
                            'timestamp': str(datetime.datetime.now())
                        }, f'{args.scratch_path}/dev_activations.pt')
                    move_q.put((f'{args.scratch_path}/dev_activations.pt',
                                f'{args.save_path}/dev_activations.pt'))
            else:
                torch.save(
                    {
                        'dev_emb': dev_emb,
                        'after_layer': transformer_layer_index,
                        'timestamp': str(datetime.datetime.now())
                    }, f'{args.save_path}/dev_activations.pt')

        print(f"done processing layer {transformer_layer_index}")

    if args.scratch_path is not None:
        move_q.put(None)
        move_p.join()


if __name__ == "__main__":
    mp.set_start_method('spawn')
    torch.set_grad_enabled(False)
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    numpy.random.seed(args.seed)
    os.makedirs(args.save_path, exist_ok=True)
    main(args)


ModuleNotFoundError: No module named 'quiptools_cuda'

In [11]:
import torch
a=torch.tensor([[1,2,3],
                [4,5,6],
                [7,8,9]])
c=torch.tensor([[18,29,3],
                [4,5,69],
                [7,83,96]])

b=torch.tensor([2,1,0])

tensor1=torch.index_select(a,0,b)
tensor2=torch.index_select(c,1,b)

print(c@a)
print(tensor2@tensor1)

tensor([[ 155,  205,  255],
        [ 507,  585,  663],
        [1011, 1197, 1383]])
tensor([[ 155,  205,  255],
        [ 507,  585,  663],
        [1011, 1197, 1383]])


In [2]:
import torch

# 创建一个一维张量
tensor = torch.tensor([10, 5, 8, 2])

# 对张量进行排序，并保留排序前的索引
sorted_tensor, indices = torch.sort(tensor)

print("排序后的张量:", sorted_tensor)
print("排序前的索引:", indices)
print(type(indices))

排序后的张量: tensor([ 2,  5,  8, 10])
排序前的索引: tensor([3, 1, 2, 0])
<class 'torch.Tensor'>
