In [1]:
import argparse
import json
from pathlib import Path
import random
import os
import schedulefree

import numpy as np
import torch
import wandb

import config
from data.utils import DataReader, get_dataset
import distributed
from models.utils import get_model
from optim.base import train
from optim.utils import cos_inf_schedule, wsd_schedule, get_batch

import sys

if 'ipykernel_launcher' in sys.argv[0]:
    sys.argv = sys.argv[:1]
    
def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)
    parser.add_argument(
        "--config_format", default="base", choices=config.registered_formats()
    )
    args, rem_args = parser.parse_known_args()
    args.n_layer=12
    args.n_head=12
    args.n_embd=768
    args.datasets_dir = "/chenyupeng/data_files/llm_datasets"
    return config.parse_args_with_format(
        format=args.config_format, base_parser=parser, args=rem_args, namespace=args
    )

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = get_args()

In [3]:
args.config_format

'base'

In [4]:
import copy

In [5]:

def get_data_readers(args, verbose=True):
    data_srcs = get_dataset(args)
    train_reader = DataReader(
        data_src=data_srcs["train"],
        batch_size=args.batch_size,
        sequence_length=args.sequence_length,
        seed=args.data_seed,
        with_replacement=False,
        auto_shard=True,
        keep_in_ram=args.data_in_ram,
    )
    val_reader = DataReader(
        data_src=data_srcs["val"],
        batch_size=args.batch_size,
        sequence_length=args.sequence_length,
        seed=args.data_seed,
        with_replacement=False,
        auto_shard=False,  # NOTE Identical Per Rank
        keep_in_ram=args.data_in_ram,
    )

    if verbose:
        print(f"Num training tokens: {train_reader.num_tokens}")
        print(f"Num validation tokens: {val_reader.num_tokens}")

    return {
        "train": train_reader,
        "val": val_reader,
    }
data = get_data_readers(args)

/chenyupeng/data_files/llm_datasets/slimpajama6B/
Num training tokens: 5827933038
Num validation tokens: 9479563


In [6]:
model = get_model(args)

In [7]:
import types
import torch

def enable_collect_embedding(layer, layer_id):
    """
    replace the forward function of LlamaDecoderLayer with a custom forward function `llama_custom_decoderlayer_forward`
    """
    layer.layer_id = layer_id
    layer.forward = types.MethodType(
        block_forward, layer
    )


def block_forward(self, x, freqs_cis):
    x = x + self.attn(self.ln_1(x), freqs_cis)
    temp = self.ln_2(x)
    self.input = temp.detach()
    x_ = self.mlp(temp)
    self.output = x_.detach()
    x = x + x_
    return x

In [8]:
import copy

def load_ck_state(model, step):
    model_new = copy.deepcopy(model)
    current_ckpt = torch.load(f"/chenyupeng/old_files/yupeng_gpt/WSD/river_valley_project/llama_100m/slimpajama_llama_nlayers12_nhead12_lr0.001_sched_wsd_warmup300_decay_linear_0.1_iter25000_bs50x2_ws2_seed0_data_seed1337/ckpts/{step}/main.pt",map_location=torch.device('cpu'))
    new_state_dict = {}
    for key, value in current_ckpt["model"].items():
        new_key = key.replace('_orig_mod.', '')
        new_key = new_key.replace('module.', '')# 移除前缀
        new_state_dict[new_key] = value
    model_new.load_state_dict(new_state_dict)
    return model_new 
    
model = load_ck_state(model, 25000)
#current_ckpt = torch.load(f"/mntcephfs/lab_data/chenyupeng/senmiao_jaggi_exp_results/slimpajama_llama_nlayers8_nhead6_lr0.002_sched_wsd_warmup1500_decay_linear_0.1_iter15000_bs50x4_ws1_seed0_data_seed1337/ckpts/12000/main.pt",map_location=torch.device('cpu'))


In [9]:
model=model.cuda()

In [10]:
data_reader = get_data_readers(args)["val"]

/chenyupeng/data_files/llm_datasets/slimpajama6B/
Num training tokens: 5827933038
Num validation tokens: 9479563


In [11]:
for i in range(12):
    enable_collect_embedding(model.transformer.h[i],i)

In [12]:
input_embedding = torch.zeros(1,50,512,768,12)
output_embedding = torch.zeros(1,50,512,768,12)
with torch.no_grad():
    for i in range(1):
        x, y = get_batch(data_reader, device="cuda")
        
        outputs = model(x, targets=y, get_logits=True)
    
        for z in range(12):
            input_embedding[i,:,:,:,z] = model.transformer.h[z].input.cpu()
            output_embedding[i,:,:,:,z] = model.transformer.h[z].output.cpu()
            model.transformer.h[z].input = None
            model.transformer.h[z].output = None

In [13]:
torch.save(input_embedding, "/chenyupeng/share_experts/gpt2_100m_embedding_data/input_embedding.pt")

In [14]:
torch.save(output_embedding, "/chenyupeng/share_experts/gpt2_100m_embedding_data/output_embedding.pt")

In [15]:
model

Llama(
  (transformer): ModuleDict(
    (wte): Embedding(50304, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x LlamaBlock(
        (ln_1): RMSNorm()
        (attn): LlamaAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=False)
          (c_proj): Linear(in_features=768, out_features=768, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): RMSNorm()
        (mlp): LlamaMLP(
          (w1): Linear(in_features=768, out_features=2048, bias=False)
          (w2): Linear(in_features=768, out_features=2048, bias=False)
          (c_proj): Linear(in_features=2048, out_features=768, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50304, bias=False)
)

In [31]:
20*50*512*768/1024/1024/1024

0.3662109375

In [28]:
model.transformer.h[11].input.shape

torch.Size([50, 512, 768])

In [24]:
len(data_reader)

9479050

In [16]:
model

Llama(
  (transformer): ModuleDict(
    (wte): Embedding(50304, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x LlamaBlock(
        (ln_1): RMSNorm()
        (attn): LlamaAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=False)
          (c_proj): Linear(in_features=768, out_features=768, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): RMSNorm()
        (mlp): LlamaMLP(
          (w1): Linear(in_features=768, out_features=2048, bias=False)
          (w2): Linear(in_features=768, out_features=2048, bias=False)
          (c_proj): Linear(in_features=2048, out_features=768, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50304, bias=False)
)