Skip to content

Commit

Permalink
fix frozen weights issue
Browse files Browse the repository at this point in the history
  • Loading branch information
freckletonj committed Apr 19, 2024
1 parent d76368e commit 51f9173
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 39 deletions.
57 changes: 32 additions & 25 deletions RWKV-v5_t01_stack/export_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
# This script is used to export the deepspeed checkpoint into an RWKV model
#
# This includes the workaround for a known format issue with the default deepspeed checkpoint exporter
#
#
# The script was taken and modified from the deepspeed repository, and is used to convert a deepspeed checkpoint
# into the respective model file. This script, and the deepspeed repo would be covered under the apache-2 license
# (which is compatible with our apache license)
#
#
# This can be found at: https://github.com/microsoft/DeepSpeed/blob/aef6c65ce39d191ca31618b2a995599942574fd9/deepspeed/utils/zero_to_fp32.py
# The modification for deepspeed 1 support is here: https://github.com/microsoft/DeepSpeed/pull/3936
#
Expand Down Expand Up @@ -39,6 +39,13 @@
from collections import OrderedDict
from dataclasses import dataclass

# from pytorch_lightning.utilities.deepspeed import (
# convert_zero_checkpoint_to_fp32_state_dict,
# # get_model_state_file,
# # get_optim_files,
# # ds_checkpoint_dir
# )

# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
from deepspeed.utils import logger
Expand All @@ -55,9 +62,9 @@ class zero_model_state:
ds_version: int
frozen_param_shapes: dict()
frozen_param_fragments: dict()
module_params: dict() # NOTE: this is a hack bc this isn't resolved: https://github.com/microsoft/DeepSpeed/issues/5439


debug = 0
debug = True

# load to cpu
device = torch.device('cpu')
Expand All @@ -75,23 +82,6 @@ def natural_keys(text):
'''
return [atoi(c) for c in re.split(r'(\d+)', text)]


def get_model_state_file(checkpoint_dir, zero_stage):
if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")

# there should be only one file
if zero_stage <= 2:
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
elif zero_stage == 3:
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")

if not os.path.exists(file):
raise FileNotFoundError(f"can't find model states file at '{file}'")

return file


def get_checkpoint_files(checkpoint_dir, glob_pattern):
# XXX: need to test that this simple glob rule works for multi-node setup too
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
Expand Down Expand Up @@ -145,12 +135,24 @@ def parse_model_states(files):

frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)

# Add missing parameters from 'module'
#
# NOTE: this is a hack bc this isn't resolved: https://github.com/microsoft/DeepSpeed/issues/5439
# Add missing parameters from 'module'
module_params = {}
for name, param in state_dict["module"].items():
if name not in param_names and name not in buffer_names:
if debug:
print(f"Adding missing parameter from 'module': {name}")
module_params[name] = param

z_model_state = zero_model_state(buffers=buffers,
param_shapes=param_shapes,
shared_params=shared_params,
ds_version=ds_version,
frozen_param_shapes=frozen_param_shapes,
frozen_param_fragments=frozen_param_fragments)
frozen_param_fragments=frozen_param_fragments,
module_params=module_params)
zero_model_states.append(z_model_state)

return zero_model_states
Expand Down Expand Up @@ -345,14 +347,16 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer
print(f"added {len(buffers)} buffers")

_zero2_merge_frozen_params(state_dict, zero_model_states)

_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)

# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]

# Add missing parameters from 'module_params'
state_dict.update(zero_model_states[0].module_params)

return state_dict


Expand Down Expand Up @@ -471,6 +475,9 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]

# Add missing parameters from 'module_params'
state_dict.update(zero_model_states[0].module_params)

return state_dict


Expand Down Expand Up @@ -613,7 +620,7 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
type=str,
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")

### Original code ###
### Original code ###
# parser.add_argument(
# "output_file",
# type=str,
Expand Down Expand Up @@ -648,5 +655,5 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
output_file = args.output_file
if output_file == "" or output_file is None:
output_file = os.path.join(args.checkpoint_dir, "rwkv_model.pth")
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, output_file, save_dtype=args.dtype)
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, output_file)
### RWKV modified code ###
2 changes: 1 addition & 1 deletion RWKV-v5_t01_stack/run/r02/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ trainer:
logger:
class_path: 'lightning.pytorch.loggers.tensorboard.TensorBoardLogger'
init_args:
save_dir: 'r02/logs/'
save_dir: 'run/r02/logs/'
name: 'r02_palindrome'
version: 1

Expand Down
17 changes: 10 additions & 7 deletions RWKV-v5_t01_stack/run/r02_palindrome.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ export INIT_MODEL_NAME="init.pth"
mkdir -p "${PROJECT_DIR}/datapath/"
mkdir -p "${PROJECT_DIR}/checkpoint/"

##################################################
# PARAMS

export S_STACK_IX="1"
export S_NOISE="0.3"


# echo "##################################################"
# echo "INITIALIZING"
Expand All @@ -66,18 +72,15 @@ mkdir -p "${PROJECT_DIR}/checkpoint/"
# "${PROJECT_DIR}/checkpoint/${INIT_MODEL_NAME}"


# echo "##################################################"
# echo "PRELOADING DATASET"
# # python "preload_datapath.py" "run/r02/config.yaml"
# python "${ROOT_DIR}/preload_datapath.py" "${PROJECT_DIR}/config.yaml"
# # echo "##################################################"
# # echo "PRELOADING DATASET"
# # # python "preload_datapath.py" "run/r02/config.yaml"
# # python "${ROOT_DIR}/preload_datapath.py" "${PROJECT_DIR}/config.yaml"


# echo "##################################################"
# echo "TRAINING"

# export S_STACK_IX="1"
# export S_NOISE="0.3"

# python "${ROOT_DIR}/lightning_trainer.py" fit \
# -c "${PROJECT_DIR}/config.yaml" \
# --trainer.logger.init_args.name="${WANDB_PREFIX} training (${DEEPSPEED_STRAT})" \
Expand Down
16 changes: 13 additions & 3 deletions RWKV-v5_t01_stack/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ def run_loop(prompt):



# from datasets import load_from_disk, load_dataset
# path = os.path.expanduser('~/_/data/automata_palindrome/data')
# x = load_dataset(path)

# ##################################################
# # debugging ckpt exporter

# CWD = 'RWKV-v5_t01_stack'
# if CWD not in os.getcwd():
# os.chdir(CWD)

# path = './run/r02/checkpoint/last.ckpt/checkpoint/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt'
# path2 = './run/r02/checkpoint/last.ckpt/checkpoint/mp_rank_00_model_states.pt'

# opt = torch.load(path)
# model = torch.load(path2)
52 changes: 49 additions & 3 deletions RWKV-v5_t01_stack/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@

import neurallambda.stack as S

def debugo(model_weights):
print('----------')
print('FROZEN_PARAM_SHAPES:')
print(model_weights['frozen_param_shapes'])

print('----------')
print('FROZEN_PARAM_FRAGMENTS:')
print(model_weights['frozen_param_fragments'])

print('----------')
print('SHARED_PARAMS:')
print(model_weights['shared_params'])

print('----------')
print('PARAM_SHAPES:')
print(model_weights['param_shapes'])


# ---
# Isolating out known operations that **does not work** with torch.compile
# and wrapping them within a torch._dynamo.disable, this is required to get
Expand Down Expand Up @@ -250,6 +268,24 @@ def forward(self,
### ---
class RWKV(L.LightningModule):

def on_train_start(self):
# Check the model's parameters before training starts
for name, param in self.named_parameters():
if param.requires_grad:
print(f"Parameter {name} is trainable.")
else:
print(f"Parameter {name} is frozen.")

# def on_save_checkpoint(self, checkpoint):
# # DEBUG
# model_state_dict = checkpoint["state_dict"]
# for name, param in self.named_parameters():
# if name in model_state_dict:
# print(f"Parameter {name} is present in the checkpoint.")
# else:
# print(f"Parameter {name} is missing from the checkpoint.")
# breakpoint()

def __init__(self,
# Model file path to load from
load_model: str,
Expand Down Expand Up @@ -328,16 +364,26 @@ def __init__(self,

# Load the model weights
if IS_TORCH_2_1_COMPATIBLE:
model_weights = torch.load(load_model, map_location='cpu', weights_only=True, mmap=True)
try:
model_weights = torch.load(load_model, map_location='cpu', weights_only=True, mmap=True)
# model_weights = torch.load(load_model, map_location='cpu', weights_only=False, mmap=True)
# model_weights = model_weights['modules']
except:
print('ERROR LOADING WEIGHTS')
# model_weights = torch.load(load_model, map_location='cpu', weights_only=False, mmap=True)
# debugo(model_weights)
breakpoint()
else:
model_weights = torch.load(load_model, map_location='cpu')
# model_weights = model_weights['modules']

# Get the model keys
model_keys = list(model_weights.keys())

# print('Weights found on disk:')
# for k in model_keys:
# print(k)
# breakpoint()

# Lets compute the model various sizes, if they are not provided
if n_layer < 0:
Expand Down Expand Up @@ -463,7 +509,7 @@ def configure_optimizers(self):
if self.freeze_embeddings:
print('Freezing embeddings')
for param in self.emb.parameters():
param.requires_grad = False
param.requires_grad_(False)

if self.bptt_learning == False:
if self.deepspeed_stage >= 2 or self.deepspeed_offload:
Expand Down Expand Up @@ -1232,7 +1278,7 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
gc.collect()
# torch.cuda.empty_cache()

# Wandb logging only, if an active run exists (only applies for training)

if wandb.run is not None and is_training_run:
global_rank = self.global_rank
global_device_count = self.trainer.num_devices * self.trainer.num_nodes
Expand Down

0 comments on commit 51f9173

Please sign in to comment.