In [51]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import random
import numpy as np


import math
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [52]:
import argparse
import json
from pathlib import Path
import random
import os
import schedulefree

import numpy as np
import torch
import wandb

import config
from data.utils import DataReader, get_dataset
import distributed
from models.utils import get_model
from optim.base import train
from optim.utils import cos_inf_schedule, wsd_schedule, get_batch

import sys
if 'ipykernel_launcher' in sys.argv[0]:
   sys.argv = sys.argv[:1]

def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)
    parser.add_argument(
        "--config_format", default="base", choices=config.registered_formats()
    )
    args, rem_args = parser.parse_known_args()
    args.n_layer=3
    args.n_head=6
    args.n_embd=60
    args.multiple_of=1
    args.batch_size=1
    args.dtype = "float32"
    args.model = "LlamaWithEigenvector"
    
    args.datasets_dir = "/chenyupeng/data_files/llm_datasets"
    return config.parse_args_with_format(
        format=args.config_format, base_parser=parser, args=rem_args, namespace=args
    )

In [53]:
args = get_args()

import copy
def get_data_readers(args, verbose=True):
    data_srcs = get_dataset(args)
    train_reader = DataReader(
        data_src=data_srcs["train"],
        batch_size=args.batch_size,
        sequence_length=args.sequence_length,
        seed=args.data_seed,
        with_replacement=False,
        auto_shard=True,
        keep_in_ram=args.data_in_ram,
    )
    val_reader = DataReader(
        data_src=data_srcs["val"],
        batch_size=args.batch_size,
        sequence_length=args.sequence_length,
        seed=args.data_seed,
        with_replacement=False,
        auto_shard=False,  # NOTE Identical Per Rank
        keep_in_ram=args.data_in_ram,
    )

    if verbose:
        print(f"Num training tokens: {train_reader.num_tokens}")
        print(f"Num validation tokens: {val_reader.num_tokens}")

    return {
        "train": train_reader,
        "val": val_reader,
    }
data = get_data_readers(args)


model = get_model(args)

/chenyupeng/data_files/llm_datasets/slimpajama6B/
Num training tokens: 5827933038
Num validation tokens: 9479563


In [54]:
val_batches = []
data_reader = get_data_readers(args)["val"]
for _ in range(10):
    x, y = get_batch(data_reader, device="cuda")
    val_batches.append((x, y))

/chenyupeng/data_files/llm_datasets/slimpajama6B/
Num training tokens: 5827933038
Num validation tokens: 9479563


In [None]:
"""
本模块扩展了 Llama 模型，添加了计算 Hessian 最大特征向量的能力，
算法基于《Automatic Learning Rate Maximization by On-Line Estimation of the Hessian's Eigenvectors》。

用法示例：
    model = LlamaWithEigenvector(config)
    eigenvector, eigenvalue = model.get_max_eigenvector(val_batches)

Author: Weihao Huang
Date: 2025/5/8
"""

import contextlib
import torch.nn.functional as F
from torch.nn.utils import parameters_to_vector, vector_to_parameters

class LlamaWithEigenvector(Llama):
    """
    GPTBase 的扩展类，添加了对 Hessian 最大特征向量的估计功能。

    该类用于通过有限差分近似 Hessian，并采用 power iteration 方法
    估算当前模型在验证集上的 Hessian 最大特征向量。
    """

    # 为 GPTBase 类新增计算梯度向量的函数。
    def get_gradient_vector(self, val_batches):
        """
        计算模型在给定验证集 batch 上的平均梯度向量。

        该方法会对 val_batches 中的所有样本执行前向和反向传播，然后将所有参数的梯度拼接成一个一维向量返回。

        Args:
            val_batches (Iterable[Tuple[Tensor, Tensor]]): 一个 batch 的迭代器，每个元素为 (input, target)。

        Returns:
            torch.Tensor: 拼接后的全模型平均梯度向量，形状为一维张量。
        """

        # 清除旧梯度
        self.train()
        self.zero_grad()

        # 前向/反向传播
        for x, y in val_batches:
            outputs = self(x, targets=y, get_logits=True)
            loss = outputs["loss"]
            scaled_loss = loss / len(val_batches)  # 防止梯度累积不均
            scaled_loss.backward() # 必须调用这行

        # 获取梯度向量
        grad_vector = parameters_to_vector([param.grad if param.grad is not None else torch.zeros_like(param) 
                                            for param in self.parameters()
                                            ]).detach().clone() # 防止 autograd 图扩展
        
        # 清空梯度避免累积
        self.zero_grad()
        
        return grad_vector

    @contextlib.contextmanager
    def perturbed_parameters(self, param_perturbed):
        """
        上下文管理器：临时将模型参数替换为扰动后的参数，并在退出后恢复。

        用于 Hessian-vector product 的有限差分计算过程，
        保证在 with 代码块内模型参数被替换为 `param_perturbed`，
        并在块结束后自动恢复为原始参数。

        Args:
            param_perturbed (torch.Tensor): 一维张量，表示扰动后的模型参数。
        """
        
        param_original = parameters_to_vector(self.parameters()).detach().clone() # 保存原模型参数
        try:
            with torch.no_grad(): # 避免 autograd 计算图被污染
                vector_to_parameters(param_perturbed, self.parameters()) # 把扰动参数写入模型
            yield  # 在 with 块中执行操作
        finally:
            with torch.no_grad():
                vector_to_parameters(param_original, self.parameters())
    
    # 为 GPTBase 类新增计算 Hessian Matrix 最大特征向量的函数。
    def get_max_eigenvector(self, val_batches, alpha=0.5, gamma=0.5, max_iter=1000, tol=1e-3, log_interval=10):
        """
        使用 power iteration 估算模型在验证集上的 Hessian 最大特征向量。

        该方法通过对梯度函数进行有限差分近似，模拟 Hessian-vector product，
        并使用迭代更新方式逼近 Hessian 的最大特征向量。

        Args:
            val_batches (Iterable[Tuple[Tensor, Tensor]]): 验证数据集 batch 的迭代器。
            alpha (float): 用于扰动方向的步长系数（有限差分用）。
            gamma (float): 动量因子（控制新向量与旧向量的混合程度）。
            max_iter (int): 最大迭代步数。
            tol (float): 收敛阈值，基于相对向量变化。
            log_interval (int): 每多少步打印一次调试信息。为 0 表示不打印。

        Returns:
            Tuple[torch.Tensor, float]: 
                - Psi：估计得到的最大特征向量（单位方向）；
                - Psi.norm()：对应的特征值估计（模长，可近似特征值大小）。
        """
    
        # 将模型参数、梯度展成一维 tensor 向量
        param_vector = parameters_to_vector(self.parameters()).detach().clone()
        grad_vector = self.get_gradient_vector(val_batches)
        
        # 随机初始化 Psi
        Psi = torch.randn_like(param_vector).detach()

        # 开始迭代
        for i in range(max_iter):
            # 归一化 Psi
            Psi_normed = Psi / Psi.norm()

            # 扰动模型参数并写入模型
            param_perturbed = param_vector + alpha * Psi_normed
            with self.perturbed_parameters(param_perturbed):
                grad_vector_perturbed = self.get_gradient_vector(val_batches) # 计算扰动模型的梯度
            # 出 with 块，模型自动恢复原始参数

            # 通过对 Hessian 的有限差分近似来更新 Psi
            new_Psi = (1 - gamma) * Psi + (gamma / alpha) * (grad_vector_perturbed - grad_vector)

            # convergence check
            if (new_Psi - Psi).norm() / Psi.norm() < tol:
                print(f"Converged at step {i}")
                Psi = new_Psi
                break
            else:
                Psi = new_Psi

            # 打印调试信息（梯度向量与 Psi 的内积）
            # TODO: 不确定内积是否需要被进一步修正
            if log_interval > 0 and (i % log_interval == 0 or i == max_iter - 1):
                inner_product = torch.dot(grad_vector.view(-1), Psi.view(-1))
                print(f"[Iter {i}] ||Psi|| = {Psi.norm():.4f}, <grad, Psi> = {inner_product:.4f}")

        return Psi, Psi.norm()


In [56]:
model = model.cuda()
set_seed(42)
eigenvector, eigenvalue = model.get_max_eigenvector(val_batches, alpha=0.1, gamma=0.2, max_iter=2000,tol=1e-5)

# phi,simi = get_hessian(model,eval_batches,1e-2,1e-2)

[Iter 0] ||Psi|| = 1507.2229, <grad, Psi> = 0.3331
[Iter 10] ||Psi|| = 296.7337, <grad, Psi> = 0.0655
[Iter 20] ||Psi|| = 58.4193, <grad, Psi> = 0.0128
[Iter 30] ||Psi|| = 11.5013, <grad, Psi> = 0.0024
[Iter 40] ||Psi|| = 2.2648, <grad, Psi> = 0.0001
[Iter 50] ||Psi|| = 2.3498, <grad, Psi> = 0.2071
[Iter 60] ||Psi|| = 8.1219, <grad, Psi> = 1.2652
[Iter 70] ||Psi|| = 10.0048, <grad, Psi> = 1.7887
[Iter 80] ||Psi|| = 10.7715, <grad, Psi> = 2.0964
[Iter 90] ||Psi|| = 11.2295, <grad, Psi> = 2.3240
[Iter 100] ||Psi|| = 11.5685, <grad, Psi> = 2.5059
[Iter 110] ||Psi|| = 11.8291, <grad, Psi> = 2.6494
[Iter 120] ||Psi|| = 12.0218, <grad, Psi> = 2.7567
[Iter 130] ||Psi|| = 12.1563, <grad, Psi> = 2.8329
[Iter 140] ||Psi|| = 12.2431, <grad, Psi> = 2.8842
[Iter 150] ||Psi|| = 12.2946, <grad, Psi> = 2.9166
[Iter 160] ||Psi|| = 12.3208, <grad, Psi> = 2.9357
[Iter 170] ||Psi|| = 12.3297, <grad, Psi> = 2.9454
[Iter 180] ||Psi|| = 12.3281, <grad, Psi> = 2.9494
[Iter 190] ||Psi|| = 12.3193, <grad, Psi> 

In [25]:
import torch.nn.functional as F
def compute_grad(model,val_batches):
    model.train()
    total_loss = 0
    n_batches = 0
    # 清空梯度
    for p in model.parameters():
        p.grad = None
    
    # 梯度累积
    #for x, y in eval_batches:
    x = val_batches[0][0]
    y = val_batches[0][1]
    outputs = model(x, targets=y, get_logits=True)
    batch_loss = outputs["loss"]
    
    # 通过缩放损失实现梯度累积，相当于平均梯度
    batch_loss.backward()  # 梯度会累积

def get_hessian(model,val_batches,a,r):
    compute_grad(model,val_batches)
    grad_original = model.transformer.h[-1].mlp.c_proj.weight.grad.detach().clone()
    original_weight = model.transformer.h[-1].mlp.c_proj.weight.data.detach().clone()
    set_seed(42)
    random_phi = torch.randn_like(model.transformer.h[-1].mlp.c_proj.weight)
    for i in range(2000):
        #random_phi = random_phi/torch.norm(random_phi)
        model.transformer.h[-1].mlp.c_proj.weight.data.add_((random_phi/torch.norm(random_phi))*a)
        compute_grad(model,val_batches)
        grad_after_pertu = model.transformer.h[-1].mlp.c_proj.weight.grad.data.detach().clone()
        random_phi = (1-r)*random_phi + (r/a)*(grad_after_pertu-grad_original)
        model.transformer.h[-1].mlp.c_proj.weight.data.copy_(original_weight)
        weight_norm_of_random = random_phi.norm()
        simi = F.cosine_similarity(grad_original.reshape(-1), (random_phi/torch.norm(random_phi)).reshape(-1), dim=0)
        print(f"{i}-th iteration, grad norm of phi: {weight_norm_of_random}, simi : {simi}")

    random_phi = random_phi/random_phi.norm()
    #cosine_simi = F.cosine_similarity(grad_original.reshape(-1), random_phi.reshape(-1), dim=0)
    return random_phi,simi

model.eval()
phi,simi = get_hessian(model,val_batches,0.1,0.1)

0-th iteration, grad norm of phi: 87.21380615234375, simi : 0.004311125725507736
1-th iteration, grad norm of phi: 78.49242401123047, simi : 0.0043109506368637085
2-th iteration, grad norm of phi: 70.64318084716797, simi : 0.004310756456106901
3-th iteration, grad norm of phi: 63.578861236572266, simi : 0.004310538060963154
4-th iteration, grad norm of phi: 57.22097396850586, simi : 0.004310299642384052
5-th iteration, grad norm of phi: 51.49887466430664, simi : 0.004310031421482563
6-th iteration, grad norm of phi: 46.3489875793457, simi : 0.004309735260903835
7-th iteration, grad norm of phi: 41.71408462524414, simi : 0.004309403244405985
8-th iteration, grad norm of phi: 37.54267501831055, simi : 0.00430903909727931
9-th iteration, grad norm of phi: 33.78840637207031, simi : 0.004308630712330341
10-th iteration, grad norm of phi: 30.409563064575195, simi : 0.004308180417865515
11-th iteration, grad norm of phi: 27.368606567382812, simi : 0.004307677038013935
12-th iteration, grad no

In [19]:
model.zero_grad(set_to_none = True)
model = model.cuda()
x = val_batches[0][0]
y = val_batches[0][1]
model.eval()
import time
def hessian_calculation(g_tensor, params):
    g_tensor = g_tensor.cuda()
    total_params = g_tensor.size(0)
    hessian_list = []
    t_d = time.time()
    for d in range(total_params):
        unit_vector = torch.zeros_like(g_tensor)
        unit_vector[d] = 1
        l = torch.sum(g_tensor * unit_vector)
        grad_2 = torch.autograd.grad(l, params[0], create_graph=True)
        #l.backward(retain_graph= True)
        hessian_row = []
        #print('name',name, param.grad)
        hessian_row.append(grad_2[0].double().data.clone())
        
        model.zero_grad(set_to_none = True)
        hessian_row = [g.flatten() for g in hessian_row] 
        hessian_row = [g.cpu() for g in hessian_row]
        hessian_row = torch.cat(hessian_row)
        #print('hessian_row', hessian_row)   
        hessian_list.append(hessian_row)
        # if d % 1000 == 0:
        #     print(f'Computing hessian: current batch = {batch_idx}/{self.num_batches}, current row of a hessian: {d}/{total_params}, total time = {time.time()- t_d} ')
    hessian = torch.stack(hessian_list, dim = 1)
    #print('hessian', hessian)   
    return hessian
full_hessian = 0
outputs = model(x, targets=y, get_logits=True)
batch_loss = outputs["loss"]
#batch_loss.backward(create_graph= True)
#g_list = []
#count = 0
parameters = [p for n,p in model.named_parameters() if "mlp.c_proj" in n]
#if parameters[0].requires_grad:
#    count += parameters[0].numel()
#    #print('g shape', param.grad , param.grad.shape)
#    g_list.append(torch.flatten(parameters[0].grad.double()))
#    #print('name',name, g_list[-1].size())
#g_tensor = torch.cat(g_list, dim = 0)
grad_para = torch.autograd.grad(batch_loss, parameters, create_graph=True,retain_graph=True)
g_tensor = torch.flatten(grad_para[0].double())
#print('g_tensor',g_tensor)
model.zero_grad(set_to_none = True)
H = hessian_calculation(g_tensor,parameters)
full_hessian += H
full_hessian = torch.nan_to_num(full_hessian, nan = 0, posinf = 0, neginf = 0 )  # change nan, postive inf , negative inf, to 0
t_svd = time.time()
#print('doing EVD')
# _, eigenvalues, _ = torch.linalg.svd(full_hessian)  # ascending
#eigenvalues, _  = torch.eig(full_hessian)
full_hessian = full_hessian.numpy().astype(np.float64)
full_hessian = (full_hessian + full_hessian.T)/2 # make symetric, to 



#avoid numerical issue
#full_hessian = full_hessian.cuda()
#eigenvalues, _  = torch.linalg.eig(full_hessian)
# eigenvalues, _  = np.linalg.eigh(full_hessian)
# #_, eigenvalues, _ = np.linalg.svd(full_hessian) 
# eigenvalues = [eigen.item().real for eigen in eigenvalues]
# file_name = self.file_dir + 'eigenvalues.txt'
# with open(file_name, "w") as file:
#     for item in eigenvalues:
#         file.write(str(item)+"\n")

In [20]:
full_hessian = torch.tensor(full_hessian).cuda()

In [21]:
u,v,d = torch.linalg.svd(full_hessian)

In [22]:
v

tensor([1.8085e-02, 1.7699e-02, 1.7182e-02,  ..., 8.1522e-08, 5.7344e-08,
        3.0193e-08], device='cuda:0', dtype=torch.float64)