In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import random
import numpy as np


import math
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [2]:
import argparse
import json
from pathlib import Path
import random
import os
import schedulefree

import numpy as np
import torch
import wandb

import config
from data.utils import DataReader, get_dataset
import distributed
from models.utils import get_model
from optim.base import train
from optim.utils import cos_inf_schedule, wsd_schedule, get_batch

import sys
if 'ipykernel_launcher' in sys.argv[0]:
   sys.argv = sys.argv[:1]

def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)
    parser.add_argument(
        "--config_format", default="base", choices=config.registered_formats()
    )
    args, rem_args = parser.parse_known_args()
    args.n_layer=8
    args.n_head=6
    args.n_embd=384
    
    args.datasets_dir = "/chenyupeng/data_files/llm_datasets"
    return config.parse_args_with_format(
        format=args.config_format, base_parser=parser, args=rem_args, namespace=args
    )

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
args = get_args()

import copy
def get_data_readers(args, verbose=True):
    data_srcs = get_dataset(args)
    train_reader = DataReader(
        data_src=data_srcs["train"],
        batch_size=args.batch_size,
        sequence_length=args.sequence_length,
        seed=args.data_seed,
        with_replacement=False,
        auto_shard=True,
        keep_in_ram=args.data_in_ram,
    )
    val_reader = DataReader(
        data_src=data_srcs["val"],
        batch_size=args.batch_size,
        sequence_length=args.sequence_length,
        seed=args.data_seed,
        with_replacement=False,
        auto_shard=False,  # NOTE Identical Per Rank
        keep_in_ram=args.data_in_ram,
    )

    if verbose:
        print(f"Num training tokens: {train_reader.num_tokens}")
        print(f"Num validation tokens: {val_reader.num_tokens}")

    return {
        "train": train_reader,
        "val": val_reader,
    }
data = get_data_readers(args)


model = get_model(args)

/chenyupeng/data_files/llm_datasets/slimpajama6B/
Num training tokens: 5827933038
Num validation tokens: 9479563


In [4]:
val_batches = []
data_reader = get_data_readers(args)["val"]
for _ in range(10):
    x, y = get_batch(data_reader, device="cuda")
    val_batches.append((x, y))
eval_batches = val_batches[:10]  # 使用前10个batch评估


def compute_grad(model,eval_batches):
    model.train()
    total_loss = 0
    n_batches = 0
    # 清空梯度
    for p in model.parameters():
        p.grad = None
    
    # 梯度累积
    for x, y in eval_batches:
        outputs = model(x, targets=y, get_logits=True)
        batch_loss = outputs["loss"]
        
        # 通过缩放损失实现梯度累积，相当于平均梯度
        scaled_loss = batch_loss / len(eval_batches)
        scaled_loss.backward()  # 梯度会累积

/chenyupeng/data_files/llm_datasets/slimpajama6B/
Num training tokens: 5827933038
Num validation tokens: 9479563


In [5]:
model

Llama(
  (transformer): ModuleDict(
    (wte): Embedding(50304, 384)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-7): 8 x LlamaBlock(
        (ln_1): RMSNorm()
        (attn): LlamaAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=False)
          (c_proj): Linear(in_features=384, out_features=384, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): RMSNorm()
        (mlp): LlamaMLP(
          (w1): Linear(in_features=384, out_features=1024, bias=False)
          (w2): Linear(in_features=384, out_features=1024, bias=False)
          (c_proj): Linear(in_features=1024, out_features=384, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
  (lm_head): Linear(in_features=384, out_features=50304, bias=False)
)

In [6]:
import torch.nn.functional as F
def get_hessian(model,eval_batches,a,r):
    compute_grad(model,eval_batches)
    grad_original = model.transformer.h[-1].mlp.c_proj.weight.grad.detach().clone()
    original_weight = model.transformer.h[-1].mlp.c_proj.weight.data.detach().clone()
    set_seed(42)
    random_phi = torch.randn_like(model.transformer.h[-1].mlp.c_proj.weight)
    for i in range(1000):
        #random_phi = random_phi/torch.norm(random_phi)
        model.transformer.h[-1].mlp.c_proj.weight.data.add_((random_phi/torch.norm(random_phi))*a)
        compute_grad(model,eval_batches)
        grad_after_pertu = model.transformer.h[-1].mlp.c_proj.weight.grad.data.detach().clone()
        random_phi = (1-r)*random_phi + (r/a)*(grad_after_pertu-grad_original)
        model.transformer.h[-1].mlp.c_proj.weight.data.copy_(original_weight)
        weight_norm_of_random = random_phi.norm()
        simi = F.cosine_similarity(grad_original.reshape(-1), (random_phi/torch.norm(random_phi)).reshape(-1), dim=0)
        print(f"{i}-th iteration, grad norm of phi: {weight_norm_of_random}, simi : {simi}")

    random_phi = random_phi/random_phi.norm()
    #cosine_simi = F.cosine_similarity(grad_original.reshape(-1), random_phi.reshape(-1), dim=0)
    return random_phi,simi

In [None]:
"""
本模块扩展了 GPTBase 模型，添加了计算 Hessian 最大特征向量的能力，
算法基于《Automatic Learning Rate Maximization by On-Line Estimation of the Hessian's Eigenvectors》。

用法示例：
    model = GPTWithEigenvector(config)
    eigenvector, eigenvalue = model.get_max_eigenvector(val_batches)

Author: Weihao Huang
Date: 2025/5/8
"""

import contextlib
import torch.nn.functional as F
from torch.nn.utils import parameters_to_vector, vector_to_parameters

class GPTWithEigenvector(GPTBase):
    """
    GPTBase 的扩展类，添加了对 Hessian 最大特征向量的估计功能。

    该类用于通过有限差分近似 Hessian，并采用 power iteration 方法
    估算当前模型在验证集上的 Hessian 最大特征向量。
    """

    # 为 GPTBase 类新增计算梯度向量的函数。
    def get_gradient_vector(self, val_batches):
        """
        计算模型在给定验证集 batch 上的平均梯度向量。

        该方法会对 val_batches 中的所有样本执行前向和反向传播，然后将所有参数的梯度拼接成一个一维向量返回。

        Args:
            val_batches (Iterable[Tuple[Tensor, Tensor]]): 一个 batch 的迭代器，每个元素为 (input, target)。

        Returns:
            torch.Tensor: 拼接后的全模型平均梯度向量，形状为一维张量。
        """

        # 清除旧梯度
        self.train()
        self.zero_grad()

        # 前向/反向传播
        for x, y in val_batches:
            outputs = self(x, targets=y, get_logits=True)
            loss = outputs["loss"]
            scaled_loss = loss / len(val_batches)  # 防止梯度累积不均
            scaled_loss.backward() # 必须调用这行

        # 获取梯度向量
        grad_vector = parameters_to_vector([param.grad if param.grad is not None else torch.zeros_like(param) 
                                            for param in self.parameters()
                                            ]).detach().clone() # 防止 autograd 图扩展
        
        # 清空梯度避免累积
        self.zero_grad()
        
        return grad_vector

    @contextlib.contextmanager
    def perturbed_parameters(self, param_perturbed):
        """
        上下文管理器：临时将模型参数替换为扰动后的参数，并在退出后恢复。

        用于 Hessian-vector product 的有限差分计算过程，
        保证在 with 代码块内模型参数被替换为 `param_perturbed`，
        并在块结束后自动恢复为原始参数。

        Args:
            param_perturbed (torch.Tensor): 一维张量，表示扰动后的模型参数。
        """
        
        param_original = parameters_to_vector(self.parameters()).detach().clone() # 保存原模型参数
        try:
            with torch.no_grad(): # 避免 autograd 计算图被污染
                vector_to_parameters(param_perturbed, self.parameters()) # 把扰动参数写入模型
            yield  # 在 with 块中执行操作
        finally:
            with torch.no_grad():
                vector_to_parameters(param_original, self.parameters())
    
    # 为 GPTBase 类新增计算 Hessian Matrix 最大特征向量的函数。
    def get_max_eigenvector(self, val_batches, alpha=0.5, gamma=0.5, max_iter=1000, tol=1e-3, log_interval=10):
        """
        使用 power iteration 估算模型在验证集上的 Hessian 最大特征向量。

        该方法通过对梯度函数进行有限差分近似，模拟 Hessian-vector product，
        并使用迭代更新方式逼近 Hessian 的最大特征向量。

        Args:
            val_batches (Iterable[Tuple[Tensor, Tensor]]): 验证数据集 batch 的迭代器。
            alpha (float): 用于扰动方向的步长系数（有限差分用）。
            gamma (float): 动量因子（控制新向量与旧向量的混合程度）。
            max_iter (int): 最大迭代步数。
            tol (float): 收敛阈值，基于相对向量变化。
            log_interval (int): 每多少步打印一次调试信息。为 0 表示不打印。

        Returns:
            Tuple[torch.Tensor, float]: 
                - Psi：估计得到的最大特征向量（单位方向）；
                - Psi.norm()：对应的特征值估计（模长，可近似特征值大小）。
        """
    
        # 将模型参数、梯度展成一维 tensor 向量
        param_vector = parameters_to_vector(self.parameters()).detach().clone()
        grad_vector = self.get_gradient_vector(val_batches)
        
        # 随机初始化 Psi
        Psi = torch.randn_like(param_vector).detach()

        # 开始迭代
        for i in range(max_iter):
            # 归一化 Psi
            Psi_normed = Psi / Psi.norm()

            # 扰动模型参数并写入模型
            param_perturbed = param_vector + alpha * Psi_normed
            with self.perturbed_parameters(param_perturbed):
                grad_vector_perturbed = self.get_gradient_vector(val_batches) # 计算扰动模型的梯度
            # 出 with 块，模型自动恢复原始参数

            # 通过对 Hessian 的有限差分近似来更新 Psi
            new_Psi = (1 - gamma) * Psi + (gamma / alpha) * (grad_vector_perturbed - grad_vector)

            # convergence check
            if (new_Psi - Psi).norm() / Psi.norm() < tol:
                print(f"Converged at step {i}")
                Psi = new_Psi
                break
            else:
                Psi = new_Psi

            # 打印调试信息（梯度向量与 Psi 的内积）
            # TODO: 不确定内积是否需要被进一步修正
            if log_interval > 0 and (i % log_interval == 0 or i == max_iter - 1):
                inner_product = torch.dot(grad_vector.view(-1), Psi.view(-1))
                print(f"[Iter {i}] ||Psi|| = {Psi.norm():.4f}, <grad, Psi> = {inner_product:.4f}")

        return Psi, Psi.norm()


In [7]:
model = model.cuda()
phi,simi = get_hessian(model,eval_batches,1e-2,1e-2)

0-th iteration, grad norm of phi: 620.480712890625, simi : -0.0002317948965355754
1-th iteration, grad norm of phi: 614.27587890625, simi : -0.00023179472191259265
2-th iteration, grad norm of phi: 608.1331787109375, simi : -0.00023179460549727082
3-th iteration, grad norm of phi: 602.0518798828125, simi : -0.000231794489081949
4-th iteration, grad norm of phi: 596.0313110351562, simi : -0.00023179413983598351
5-th iteration, grad norm of phi: 590.071044921875, simi : -0.0002317938779015094
6-th iteration, grad norm of phi: 584.1702880859375, simi : -0.00023179373238235712
7-th iteration, grad norm of phi: 578.32861328125, simi : -0.00023179370327852666
8-th iteration, grad norm of phi: 572.5453491210938, simi : -0.00023179370327852666
9-th iteration, grad norm of phi: 566.8198852539062, simi : -0.00023179329582490027
10-th iteration, grad norm of phi: 561.1516723632812, simi : -0.0002317932085134089
11-th iteration, grad norm of phi: 555.5402221679688, simi : -0.00023179285926744342
1

In [10]:
model.lm_head.weight.norm()

tensor(143.5406, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)

In [26]:
model.transformer.h[-1].mlp.c_proj

Linear(in_features=2816, out_features=1024, bias=False)

In [9]:
simi

tensor(0.0017, device='cuda:0')