In [1]:
%load_ext autoreload
%autoreload 2


import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import nip
import sys
from itertools import islice
sys.path.append("/home/msst/repo/Quantization")
from qlib import QATDataset

In [2]:
path2model = "/media/msst/ssd_storage/ml/llm/pretrained_models/Llama2-7b-hf"

device_map = {'model.embed_tokens': 0,
                'model.layers.0': 0,
                'model.layers.1': 0,
                'model.layers.2': 0,
                'model.layers.3': 0,
                'model.layers.4': 0,
                'model.layers.5': 0,
                'model.layers.6': 0,
                'model.layers.7': 0,
                'model.layers.8': 0,
                'model.layers.9': 0,
                'model.layers.10': 0,
                'model.layers.11': 0,
                'model.layers.12': 0,
                'model.layers.13': 0,
                'model.layers.14': 0,
                'model.layers.15': 0,
                'model.layers.16': 0,
                'model.layers.17': 0,
                'model.layers.18': 0,
                'model.layers.19': 0,
                'model.layers.20': 0,
                'model.layers.21': 0,
                'model.layers.22': 0,
                'model.layers.23': 0,
                'model.layers.24': 0,
                'model.layers.25': 'cpu',
                'model.layers.26': 'cpu',
                'model.layers.27': 'cpu',
                'model.layers.28': 'cpu',
                'model.layers.29': 'cpu',
                'model.layers.30': 'cpu',
                'model.layers.31': 'cpu',
                'model.norm': 'cpu',
                'model.rotary_emb': 'cpu',
                'lm_head': 'cpu'}

model = AutoModelForCausalLM.from_pretrained(
    path2model,
    torch_dtype="auto",
    device_map=device_map, #"auto",
    offload_folder="./offload",
    offload_state_dict=True,
    low_cpu_mem_usage=True
)


model.eval()
tokenizer = AutoTokenizer.from_pretrained(path2model)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


In [3]:
class StatLinear(torch.nn.Module):
    def __init__(self, linear_layer):
        super().__init__()
        self.in_features = linear_layer.in_features
        self.out_features = linear_layer.out_features
        self.weight = linear_layer.weight
        self.bias = linear_layer.bias
        
        self.H = torch.zeros(self.in_features, self.in_features, dtype=torch.float32)
        self.n_samples = 0

    def forward(self, x, **kwargs):
        n_new_samples = x.shape[0]
        self.H *= self.n_samples / (self.n_samples + n_new_samples)
        self.n_samples += n_new_samples

        x_flat = x.detach().reshape(-1, self.in_features).float()
        x_scaled = ((1 / self.n_samples)**0.5 * x_flat).to(self.H.dtype)
        self.H += (x_scaled.T @ x_scaled).cpu()
        return torch.nn.functional.linear(x, self.weight, self.bias)

In [4]:
# def replace_linear_layers(model):
#     for module_name, module in model.named_children():
#         if isinstance(module, torch.nn.Linear) and module_name != "lm_head":
#             setattr(model, module_name, StatLinear(module))
#         else:
#             replace_linear_layers(module)
#     return model
# replace_linear_layers(model)


def replace_linear_layers(model, module_paths, exclude_layers=("lm_head",), verbose=True):
    replaced_count = 0
    skipped_count = 0
    error_count = 0
    
    for i, module_path in enumerate(module_paths):
        module_name = module_path.split('.')[-1]
        
        # Пропускаем исключенные слои
        if module_name in exclude_layers:
            if verbose:
                print(f"[{i+1}/{len(module_paths)}] Пропускаю исключенный слой: {module_path}")
            skipped_count += 1
            continue
            
        try:
            path_parts = module_path.split('.')
            parent_module = model
            
            # Рекурсивно идем по пути
            for part in path_parts[:-1]:
                parent_module = getattr(parent_module, part)
            
            module_name = path_parts[-1]
            current_module = getattr(parent_module, module_name)
            
            if isinstance(current_module, torch.nn.Linear):
                if verbose:
                    print(f"[{i+1}/{len(module_paths)}] Заменяю: {module_path}")
                setattr(parent_module, module_name, StatLinear(current_module))
                replaced_count += 1
            else:
                if verbose:
                    print(f"[{i+1}/{len(module_paths)}] Пропускаю: {module_path} (не Linear)")
                skipped_count += 1
                
        except AttributeError:
            if verbose:
                print(f"[{i+1}/{len(module_paths)}] Ошибка: модуль не найден: {module_path}")
            error_count += 1
        except Exception as e:
            if verbose:
                print(f"[{i+1}/{len(module_paths)}] Ошибка: {module_path} - {e}")
            error_count += 1
    
    if verbose:
        print(f"\nИтог: заменено {replaced_count}, пропущено {skipped_count}, ошибок {error_count}")
    
    return model


module_paths = []
#for block_id in range(0, 16):
for block_id in range(16, 32):
    for pattern in [
            "model.layers.{}.self_attn.q_proj",
            "model.layers.{}.self_attn.k_proj", 
            "model.layers.{}.self_attn.v_proj",
            "model.layers.{}.self_attn.o_proj",
            "model.layers.{}.mlp.gate_proj",
            "model.layers.{}.mlp.up_proj",
            "model.layers.{}.mlp.down_proj",
        ]:
        module_paths.append(pattern.format(block_id))

replace_linear_layers(model, module_paths, exclude_layers=("lm_head",), verbose=True)

print("done!")

[1/112] Заменяю: model.layers.16.self_attn.q_proj
[2/112] Заменяю: model.layers.16.self_attn.k_proj
[3/112] Заменяю: model.layers.16.self_attn.v_proj
[4/112] Заменяю: model.layers.16.self_attn.o_proj
[5/112] Заменяю: model.layers.16.mlp.gate_proj
[6/112] Заменяю: model.layers.16.mlp.up_proj
[7/112] Заменяю: model.layers.16.mlp.down_proj
[8/112] Заменяю: model.layers.17.self_attn.q_proj
[9/112] Заменяю: model.layers.17.self_attn.k_proj
[10/112] Заменяю: model.layers.17.self_attn.v_proj
[11/112] Заменяю: model.layers.17.self_attn.o_proj
[12/112] Заменяю: model.layers.17.mlp.gate_proj
[13/112] Заменяю: model.layers.17.mlp.up_proj
[14/112] Заменяю: model.layers.17.mlp.down_proj
[15/112] Заменяю: model.layers.18.self_attn.q_proj
[16/112] Заменяю: model.layers.18.self_attn.k_proj
[17/112] Заменяю: model.layers.18.self_attn.v_proj
[18/112] Заменяю: model.layers.18.self_attn.o_proj
[19/112] Заменяю: model.layers.18.mlp.gate_proj
[20/112] Заменяю: model.layers.18.mlp.up_proj
[21/112] Заменяю: m

In [5]:
dataset = QATDataset(
    config=nip.load('/home/msst/repo/Quantization/configs/data/redpajama_train_seqlen4096.yaml'),
    tokenizer=tokenizer,
    return_dict=True
)
dataset.batch_size=1
dataloader = dataset.get_dataloader()

Resolving data files:   0%|          | 0/200 [00:00<?, ?it/s]

In [6]:
from itertools import islice

max_batches = 64

for batch in tqdm(islice(dataloader, max_batches), total=max_batches):
    with torch.no_grad():
        output = model(**batch)

100%|██████████| 64/64 [14:16<00:00, 13.39s/it]


In [7]:
import os

path_to_save = "/media/msst/ssd_storage/ml/llm/XTX/Llama2-7B"
os.makedirs(path_to_save, exist_ok=True)

for module_name, module in tqdm(model.named_modules()):
    if hasattr(module, "H"):
        torch.save(module.H, f"{path_to_save}/{module_name}")

423it [00:25, 16.71it/s]  


In [8]:
raise

RuntimeError: No active exception to reraise

In [None]:
import torch

H_dict = torch.load("/media/msst/ssd_storage/ml/weights/XTX_Llama2-7B")


  H_dict = torch.load("/media/msst/ssd_storage/ml/weights/XTX_Llama2-7B")


In [None]:
H = H_dict["model.layers.1.mlp.down_proj"].clone()

damp = 0.01 * torch.mean(torch.diag(H))
diag = torch.arange(H.shape[0])
H[diag, diag] += damp

L = torch.linalg.cholesky(H.to(torch.float64))  

_LinAlgError: linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 8732 is not positive-definite).

In [None]:
torch.mean(torch.diag(H))

tensor(13.1304)

In [None]:
damp = 0.01 * torch.mean(torch.diag(H))
diag = torch.arange(H.shape[0])
H[diag, diag] += damp

In [None]:
x = torch.randn(5, 4096)
H = x.T @ x

damp = 0.01 * torch.mean(torch.diag(H))
diag = torch.arange(H.shape[0])
H[diag, diag] += damp

In [None]:
H = torch.linalg.cholesky(H)            # L
H = torch.cholesky_inverse(H)           # H^{-1}
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

In [None]:
for v in m:
    print(v)

tensor(3.3594, dtype=torch.bfloat16)
tensor(3.7344, dtype=torch.bfloat16)
tensor(5.5938, dtype=torch.bfloat16)
tensor(0.3008, dtype=torch.bfloat16)
tensor(11.3750, dtype=torch.bfloat16)
tensor(1.2109, dtype=torch.bfloat16)
tensor(20.5000, dtype=torch.bfloat16)
tensor(2.8125, dtype=torch.bfloat16)
tensor(3.3125, dtype=torch.bfloat16)
tensor(1.7266, dtype=torch.bfloat16)
tensor(0.4883, dtype=torch.bfloat16)
tensor(3.9062, dtype=torch.bfloat16)
tensor(3.6250, dtype=torch.bfloat16)
tensor(0.7266, dtype=torch.bfloat16)
tensor(2.3125, dtype=torch.bfloat16)
tensor(1.0625, dtype=torch.bfloat16)
tensor(0.4121, dtype=torch.bfloat16)
tensor(1.5078, dtype=torch.bfloat16)
tensor(4.4375, dtype=torch.bfloat16)
tensor(1.5078, dtype=torch.bfloat16)
tensor(13.3750, dtype=torch.bfloat16)
tensor(4.5312, dtype=torch.bfloat16)
tensor(0.6211, dtype=torch.bfloat16)
tensor(0.1787, dtype=torch.bfloat16)
tensor(0.4551, dtype=torch.bfloat16)
tensor(5.4375, dtype=torch.bfloat16)
tensor(5.4688, dtype=torch.bfloat16

In [None]:
max_problem = 0
max_problem_layer = ""

for module_name, H in H_dict.items():
    alpha = H.abs().max() / H.abs().mean()
    if alpha > max_problem:
        print("new worst:", module_name, alpha)
        max_problem = alpha
        max_problem_layer = module_name

new worst: model.layers.0.self_attn.q_proj tensor(123392., dtype=torch.bfloat16)
new worst: model.layers.1.self_attn.q_proj tensor(305152., dtype=torch.bfloat16)
new worst: model.layers.1.mlp.down_proj tensor(84410368., dtype=torch.bfloat16)
