In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import random
import numpy as np


import math
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [65]:
import argparse
import json
from pathlib import Path
import random
import os
import schedulefree

import numpy as np
import torch
import wandb

import config
from data.utils import DataReader, get_dataset
import distributed
from models.utils import get_model
from optim.base import train
from optim.utils import cos_inf_schedule, wsd_schedule, get_batch

import sys
if 'ipykernel_launcher' in sys.argv[0]:
   sys.argv = sys.argv[:1]

def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)
    parser.add_argument(
        "--config_format", default="base", choices=config.registered_formats()
    )
    args, rem_args = parser.parse_known_args()
    args.n_layer=1
    args.n_head=1
    args.n_embd=20
    args.vocab_size=20
    args.batch_size = 1
    args.multiple_of = 1
    args.dtype = "float32"
    args.datasets_dir = "/chenyupeng/data_files/llm_datasets"
    return config.parse_args_with_format(
        format=args.config_format, base_parser=parser, args=rem_args, namespace=args
    )

In [66]:
args = get_args()

import copy
def get_data_readers(args, verbose=True):
    data_srcs = get_dataset(args)
    train_reader = DataReader(
        data_src=data_srcs["train"],
        batch_size=args.batch_size,
        sequence_length=args.sequence_length,
        seed=args.data_seed,
        with_replacement=False,
        auto_shard=True,
        keep_in_ram=args.data_in_ram,
    )
    val_reader = DataReader(
        data_src=data_srcs["val"],
        batch_size=args.batch_size,
        sequence_length=args.sequence_length,
        seed=args.data_seed,
        with_replacement=False,
        auto_shard=False,  # NOTE Identical Per Rank
        keep_in_ram=args.data_in_ram,
    )

    if verbose:
        print(f"Num training tokens: {train_reader.num_tokens}")
        print(f"Num validation tokens: {val_reader.num_tokens}")

    return {
        "train": train_reader,
        "val": val_reader,
    }
data = get_data_readers(args)


model = get_model(args)

/chenyupeng/data_files/llm_datasets/slimpajama6B/
Num training tokens: 5827933038
Num validation tokens: 9479563


In [67]:
val_batches = []
data_reader = get_data_readers(args)["val"]
for _ in range(10):
    x, y = get_batch(data_reader, device="cuda")
    val_batches.append((x, y))
eval_batches = val_batches[0]  # 使用前10个batch评估

set_seed(100)
for i in range(eval_batches[0].shape[1]):
    eval_batches[0][0,i].data.copy_(random.randint(0, 19))
    if i>=1:
        eval_batches[1][0,i] = eval_batches[0][0,i-1]
eval_batches[1][0,-1] = random.randint(0, 19)
def compute_grad(model,eval_batches):
    model.train()
    total_loss = 0
    n_batches = 0
    # 清空梯度
    for p in model.parameters():
        p.grad = None
    
    # 梯度累积
    #for x, y in eval_batches:
    x = eval_batches[0]
    y = eval_batches[1]
    outputs = model(x, targets=y, get_logits=True)
    batch_loss = outputs["loss"]
    
    # 通过缩放损失实现梯度累积，相当于平均梯度
    batch_loss.backward()  # 梯度会累积

/chenyupeng/data_files/llm_datasets/slimpajama6B/
Num training tokens: 5827933038
Num validation tokens: 9479563


In [133]:
import torch.nn.functional as F
def get_hessian(model,eval_batches,a,r):
    compute_grad(model,eval_batches)
    grad_original = model.transformer.h[-1].mlp.c_proj.weight.grad.detach().clone()
    original_weight = model.transformer.h[-1].mlp.c_proj.weight.data.detach().clone()
    set_seed(42)
    random_phi = torch.randn_like(model.transformer.h[-1].mlp.c_proj.weight)
    for i in range(5000):
        #random_phi = random_phi/torch.norm(random_phi)
        model.transformer.h[-1].mlp.c_proj.weight.data.add_((random_phi/torch.norm(random_phi))*a)
        compute_grad(model,eval_batches)
        grad_after_pertu = model.transformer.h[-1].mlp.c_proj.weight.grad.data.detach().clone()
        random_phi = (1-r)*random_phi + (r/a)*(grad_after_pertu-grad_original)
        model.transformer.h[-1].mlp.c_proj.weight.data.copy_(original_weight)
        weight_norm_of_random = random_phi.norm()
        simi = F.cosine_similarity(grad_original.reshape(-1), (random_phi/torch.norm(random_phi)).reshape(-1), dim=0)
        print(f"{i}-th iteration, grad norm of phi: {weight_norm_of_random}, simi : {simi}")

    random_phi = random_phi/random_phi.norm()
    #cosine_simi = F.cosine_similarity(grad_original.reshape(-1), random_phi.reshape(-1), dim=0)
    return random_phi,simi

In [134]:
model = model.cuda()
phi,simi = get_hessian(model,eval_batches,0.5,0.1)

0-th iteration, grad norm of phi: 17.831632614135742, simi : -0.04626615345478058
1-th iteration, grad norm of phi: 16.06514549255371, simi : -0.04757986217737198
2-th iteration, grad norm of phi: 14.475554466247559, simi : -0.04903654754161835
3-th iteration, grad norm of phi: 13.045197486877441, simi : -0.050651341676712036
4-th iteration, grad norm of phi: 11.75817584991455, simi : -0.05244089663028717
5-th iteration, grad norm of phi: 10.600193977355957, simi : -0.05442332103848457
6-th iteration, grad norm of phi: 9.558381080627441, simi : -0.05661851167678833
7-th iteration, grad norm of phi: 8.621163368225098, simi : -0.0590481199324131
8-th iteration, grad norm of phi: 7.77812385559082, simi : -0.061735596507787704
9-th iteration, grad norm of phi: 7.019893646240234, simi : -0.06470665335655212
10-th iteration, grad norm of phi: 6.338047027587891, simi : -0.06798844039440155
11-th iteration, grad norm of phi: 5.725004196166992, simi : -0.07161042839288712
12-th iteration, grad 

In [55]:
phi.shape

torch.Size([20, 53])

In [60]:
model.transformer.h[-1].mlp.c_proj.weight.shape

torch.Size([20, 53])

In [63]:
model.transformer.h[-1].mlp.c_proj.weight.grad

In [137]:
for p in model.parameters():
    p.grad = None

In [135]:
x = eval_batches[0]
y = eval_batches[1]
outputs = model(x, targets=y, get_logits=True)
batch_loss = outputs["loss"]

In [149]:
parameters = [p for n,p in model.named_parameters() if "mlp.c_proj" in n]

In [150]:
first_order_grads = torch.autograd.grad(batch_loss, parameters, create_graph=True)

In [151]:
hessian = torch.zeros((parameters[0].numel(), parameters[0].numel()), device=parameters[0].device)

In [152]:
for i in range(parameters[0].numel()):
    grad2 = torch.autograd.grad(first_order_grads[0].flatten()[i], parameters, retain_graph=True)[0]
            
    if grad2 is not None:
        hessian[i, :] = grad2.flatten()
    else:
        print("none")

RuntimeError: derivative for aten::_scaled_dot_product_efficient_attention_backward is not implemented

In [78]:
u,v,d = torch.linalg.svd(hessian.double())

In [95]:
v

tensor([1.1470e-03, 9.5993e-04, 8.5912e-04,  ..., 9.0796e-11, 2.8985e-11,
        1.3251e-11], device='cuda:0', dtype=torch.float64)

In [102]:
torch.matmul(hessian, phi.reshape(-1))

tensor([-2.8252e-07,  3.4268e-06,  1.2450e-04,  ...,  5.5907e-05,
        -1.9799e-05,  4.9336e-05], device='cuda:0')

In [112]:
(torch.matmul(hessian, phi.reshape(-1))).norm()

tensor(0.0008, device='cuda:0')

In [113]:
(phi.reshape(-1)*8e-4 - torch.matmul(hessian, phi.reshape(-1))).norm()

tensor(4.0856e-05, device='cuda:0')

In [117]:
F.cosine_similarity(phi.reshape(-1), d[0,:],dim=0)

tensor(-0.0110, device='cuda:0', dtype=torch.float64)

In [59]:
torch.matmul(hessian, d[0,:])

tensor([ 8.5246e-06, -1.9020e-05,  2.7407e-05,  ...,  2.0739e-05,
        -3.5947e-07,  3.2398e-06], device='cuda:0')

In [124]:
d[0,:53]

tensor([ 0.0026,  0.0028,  0.0561,  0.0030,  0.0007,  0.0184, -0.0109,  0.0019,
         0.0121,  0.0077,  0.0091, -0.0045, -0.0392,  0.0035, -0.0019,  0.0053,
         0.0127,  0.0022, -0.0010,  0.0093, -0.0124, -0.0124, -0.0135,  0.0005,
         0.0015,  0.0062, -0.0090, -0.0057, -0.0044, -0.0025,  0.0011, -0.0251,
        -0.0015, -0.0442,  0.0040,  0.0014, -0.0444,  0.0034,  0.0434,  0.0083,
        -0.0075, -0.0435, -0.0090, -0.0002, -0.0062,  0.0036,  0.0265,  0.0052,
         0.0199, -0.0383, -0.0007,  0.0112,  0.0044], device='cuda:0',
       dtype=torch.float64)

In [122]:
phi[:2,:10]

tensor([[-0.0012,  0.0051,  0.1549,  0.0139, -0.0057,  0.0166, -0.0391, -0.0140,
          0.0126,  0.0058],
        [ 0.0037, -0.0004, -0.0201, -0.0034, -0.0018,  0.0006, -0.0068, -0.0008,
         -0.0059,  0.0016]], device='cuda:0')

In [126]:
v[:7]

tensor([0.0011, 0.0010, 0.0009, 0.0008, 0.0008, 0.0008, 0.0007],
       device='cuda:0', dtype=torch.float64)