In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import random
import numpy as np


import math
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [2]:
import argparse
import json
from pathlib import Path
import random
import os
import schedulefree

import numpy as np
import torch
import wandb

import config
from data.utils import DataReader, get_dataset
import distributed
from models.utils import get_model
from optim.base import train
from optim.utils import cos_inf_schedule, wsd_schedule, get_batch

import sys
if 'ipykernel_launcher' in sys.argv[0]:
   sys.argv = sys.argv[:1]

def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)
    parser.add_argument(
        "--config_format", default="base", choices=config.registered_formats()
    )
    args, rem_args = parser.parse_known_args()
    args.n_layer=8
    args.n_head=6
    args.n_embd=384
    
    args.datasets_dir = "/chenyupeng/data_files/llm_datasets"
    return config.parse_args_with_format(
        format=args.config_format, base_parser=parser, args=rem_args, namespace=args
    )

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
args = get_args()

import copy
def get_data_readers(args, verbose=True):
    data_srcs = get_dataset(args)
    train_reader = DataReader(
        data_src=data_srcs["train"],
        batch_size=args.batch_size,
        sequence_length=args.sequence_length,
        seed=args.data_seed,
        with_replacement=False,
        auto_shard=True,
        keep_in_ram=args.data_in_ram,
    )
    val_reader = DataReader(
        data_src=data_srcs["val"],
        batch_size=args.batch_size,
        sequence_length=args.sequence_length,
        seed=args.data_seed,
        with_replacement=False,
        auto_shard=False,  # NOTE Identical Per Rank
        keep_in_ram=args.data_in_ram,
    )

    if verbose:
        print(f"Num training tokens: {train_reader.num_tokens}")
        print(f"Num validation tokens: {val_reader.num_tokens}")

    return {
        "train": train_reader,
        "val": val_reader,
    }
data = get_data_readers(args)


model = get_model(args)

/chenyupeng/data_files/llm_datasets/slimpajama6B/
Num training tokens: 5827933038
Num validation tokens: 9479563


In [4]:
val_batches = []
data_reader = get_data_readers(args)["val"]
for _ in range(10):
    x, y = get_batch(data_reader, device="cuda")
    val_batches.append((x, y))
eval_batches = val_batches[:10]  # 使用前10个batch评估


def compute_grad(model,eval_batches):
    model.train()
    total_loss = 0
    n_batches = 0
    # 清空梯度
    for p in model.parameters():
        p.grad = None
    
    # 梯度累积
    for x, y in eval_batches:
        outputs = model(x, targets=y, get_logits=True)
        batch_loss = outputs["loss"]
        
        # 通过缩放损失实现梯度累积，相当于平均梯度
        scaled_loss = batch_loss / len(eval_batches)
        scaled_loss.backward()  # 梯度会累积

/chenyupeng/data_files/llm_datasets/slimpajama6B/
Num training tokens: 5827933038
Num validation tokens: 9479563


In [5]:
model

Llama(
  (transformer): ModuleDict(
    (wte): Embedding(50304, 384)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-7): 8 x LlamaBlock(
        (ln_1): RMSNorm()
        (attn): LlamaAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=False)
          (c_proj): Linear(in_features=384, out_features=384, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): RMSNorm()
        (mlp): LlamaMLP(
          (w1): Linear(in_features=384, out_features=1024, bias=False)
          (w2): Linear(in_features=384, out_features=1024, bias=False)
          (c_proj): Linear(in_features=1024, out_features=384, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
  (lm_head): Linear(in_features=384, out_features=50304, bias=False)
)

In [6]:
import torch.nn.functional as F
def get_hessian(model,eval_batches,a,r):
    compute_grad(model,eval_batches)
    grad_original = model.transformer.h[-1].mlp.c_proj.weight.grad.detach().clone()
    original_weight = model.transformer.h[-1].mlp.c_proj.weight.data.detach().clone()
    set_seed(42)
    random_phi = torch.randn_like(model.transformer.h[-1].mlp.c_proj.weight)
    for i in range(1000):
        #random_phi = random_phi/torch.norm(random_phi)
        model.transformer.h[-1].mlp.c_proj.weight.data.add_((random_phi/torch.norm(random_phi))*a)
        compute_grad(model,eval_batches)
        grad_after_pertu = model.transformer.h[-1].mlp.c_proj.weight.grad.data.detach().clone()
        random_phi = (1-r)*random_phi + (r/a)*(grad_after_pertu-grad_original)
        model.transformer.h[-1].mlp.c_proj.weight.data.copy_(original_weight)
        weight_norm_of_random = random_phi.norm()
        simi = F.cosine_similarity(grad_original.reshape(-1), (random_phi/torch.norm(random_phi)).reshape(-1), dim=0)
        print(f"{i}-th iteration, grad norm of phi: {weight_norm_of_random}, simi : {simi}")

    random_phi = random_phi/random_phi.norm()
    #cosine_simi = F.cosine_similarity(grad_original.reshape(-1), random_phi.reshape(-1), dim=0)
    return random_phi,simi

In [8]:
model = model.cuda()
phi,simi = get_hessian(model,eval_batches,0.5,0.5)

0-th iteration, grad norm of phi: 313.3741455078125, simi : -0.00023197010159492493
1-th iteration, grad norm of phi: 156.68710327148438, simi : -0.0002323213848285377
2-th iteration, grad norm of phi: 78.34358978271484, simi : -0.0002330225834157318
3-th iteration, grad norm of phi: 39.171836853027344, simi : -0.00023442559177055955
4-th iteration, grad norm of phi: 19.58595848083496, simi : -0.00023723210324533284
5-th iteration, grad norm of phi: 9.793018341064453, simi : -0.0002428483567200601
6-th iteration, grad norm of phi: 4.896549224853516, simi : -0.0002540963178034872
7-th iteration, grad norm of phi: 2.4483141899108887, simi : -0.0002766017278190702
8-th iteration, grad norm of phi: 1.2241970300674438, simi : -0.00032169840415008366
9-th iteration, grad norm of phi: 0.6121386885643005, simi : -0.0004123468534089625
10-th iteration, grad norm of phi: 0.3061100244522095, simi : -0.0005954167572781444
11-th iteration, grad norm of phi: 0.15309685468673706, simi : -0.0009681037

In [10]:
model.lm_head.weight.norm()

tensor(143.5406, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)

In [26]:
model.transformer.h[-1].mlp.c_proj

Linear(in_features=2816, out_features=1024, bias=False)

In [9]:
simi

tensor(0.0017, device='cuda:0')