Skip to content

Commit

Permalink
complete tests for LoadBestCallback CheckpointCallback and save&load_…
Browse files Browse the repository at this point in the history
…checkpoint
  • Loading branch information
x54-729 committed Aug 22, 2023
1 parent 70db1db commit 82869ee
Show file tree
Hide file tree
Showing 13 changed files with 429 additions and 218 deletions.
66 changes: 28 additions & 38 deletions collie/controller/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __init__(self,
if self.eval_dataset is not None:
assert eval_fn is not None, "eval_fn should not be None when eval_dataset is not None."
evaluator = Evaluator(model=self.model, dataset=eval_dataset, metrics=metrics, eval_fn=eval_fn,
config=config, collate_fn=eval_dataset_collate_fn, data_provider=None)
config=config, collate_fn=eval_dataset_collate_fn)
evaluator.monitor = self.monitor
evaluators.append(evaluator)
for evaluator in evaluators:
Expand Down Expand Up @@ -551,24 +551,21 @@ def save_model(self, path: str, process_exclusion: bool = False,
self.on_save_model()
if isinstance(self.engine.module, CollieModelForCausalLM) or isinstance(self.engine.module, PipelineModel):
if is_zero3_enabled(self.config):
state_dict = {}
self._checkpoint_prologue()
with deepspeed.zero.GatheredParameters(list(self.engine.module.parameters(recurse=True))):
self.engine.module.save_parallel_state_dict(
state_dict=self.engine.module.state_dict(),
path=path,
config=self.config,
process_exclusion=process_exclusion,
protocol=protocol
)
for name, param in self.engine.module.named_parameters():
with deepspeed.zero.GatheredParameters(param):
state_dict[name] = param.detach().cpu()
self._checkpoint_epilogue()
else:
self.engine.module.save_parallel_state_dict(
state_dict=self.engine.module.state_dict(),
path=path,
config=self.config,
process_exclusion=process_exclusion,
protocol=protocol
)
state_dict = self.engine.module.state_dict()
self.engine.module.save_parallel_state_dict(
state_dict=state_dict,
path=path,
config=self.config,
process_exclusion=process_exclusion,
protocol=protocol
)
elif isinstance(self.engine.module, PreTrainedModel):
if is_zero3_enabled(self.config):
self._checkpoint_prologue()
Expand Down Expand Up @@ -606,20 +603,22 @@ def load_model(self, path: str, process_exclusion: bool = False,
)
)
elif isinstance(self.engine.module, PreTrainedModel):
index = None
if io_driver.exists(os.path.join(path, "pytorch_model.bin.index.json")):
weight_map = json.loads(io_driver.load(os.path.join(path, "pytorch_model.bin.index.json"), mode="r"))["weight_map"]
index = OrderedDict()
for key, value in weight_map.items():
if value not in index.keys():
index[value] = [key]
else:
index[value].append(key)
if is_zero3_enabled(self.config):
index = None
if io_driver.exists(os.path.join(path, "pytorch_model.bin.index.json")):
weight_map = json.loads(io_driver.load(os.path.join(path, "pytorch_model.bin.index.json"), mode="r"))["weight_map"]
index = OrderedDict()
for key, value in weight_map.items():
if value not in index.keys():
index[value] = [key]
else:
index[value].append(key)
self._checkpoint_prologue()
if index is not None:
for key, value in index.items():
with deepspeed.zero.GatheredParameters([self.engine.module.state_dict()[attr] for attr in value], modifier_rank=0):
# 用 state dict 会没办法 gather
param_list = [p for n, p in self.engine.module.named_parameters() if n in value]
with deepspeed.zero.GatheredParameters(param_list, modifier_rank=0):
if env.dp_rank == 0:
state_dict = io_driver.load(os.path.join(path, key), mode="br")
for attr in value:
Expand All @@ -631,22 +630,13 @@ def load_model(self, path: str, process_exclusion: bool = False,
self.engine.module.load_state_dict(state_dict)
self._checkpoint_epilogue()
else:
index = None
if io_driver.exists(os.path.join(path, "pytorch_model.bin.index.json")):
weight_map = json.loads(io_driver.load(os.path.join(path, "pytorch_model.bin.index.json"), mode="r"))["weight_map"]
index = OrderedDict()
for key, value in weight_map.items():
if value not in index.keys():
index[value] = [key]
else:
index[value].append(key)
if index is not None:
for key, value in index.items():
state_dict = io_driver.load(os.path.join(path, key), mode="br")
for attr in value:
self.engine.module.state_dict()[attr].copy_(state_dict[attr])
else:
state_dict = reduce(lambda x, y: {**x, **y}, [io_driver.load(file) for file in glob.glob(os.path.join(path, "*.bin"))])
state_dict = reduce(lambda x, y: {**x, **y}, [io_driver.load(file, mode="rb") for file in glob.glob(os.path.join(path, "*.bin"))])
self.engine.module.load_state_dict(state_dict)

def save_checkpoint(self, path: str, process_exclusion: bool = False, **kwargs):...
Expand Down Expand Up @@ -688,7 +678,7 @@ def save_checkpoint(self, path: str, process_exclusion: bool = False,
global_samples=engine.global_samples,
callback_states=callback_states)

if env.rank == 0 or engine.zero_optimization_partition_weights():
if env.dp_rank == 0 or engine.zero_optimization_partition_weights():
io_driver.save(state, os.path.join(path, self.checkpoint_file))

if engine.save_zero_checkpoint:
Expand Down Expand Up @@ -725,7 +715,7 @@ def load_checkpoint(self, path: str, process_exclusion: bool = False,
if engine.zero_optimization_partition_weights():
ckpt_file = self.checkpoint_file
else:
ckpt_file = "collie_dp0_pp0_tp0.pt"
ckpt_file = f"collie_dp0_pp{env.pp_rank}_tp{env.tp_rank}.pt"
checkpoint = io_driver.load(os.path.join(path, ckpt_file), "b")

# Prepare for checkpoint load by ensuring all parameters are partitioned
Expand Down
1 change: 1 addition & 0 deletions collie/models/moss_moon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, config):
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)

self.attn_dropout = nn.Dropout(config.attn_pdrop)
Expand Down
7 changes: 7 additions & 0 deletions collie/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,13 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
destination[key_pp] = destination.pop(key)
return destination

def load_state_dict(self, state_dict: Mapping[str, Any],
strict: bool = True):
for key in list(state_dict.keys()):
key_pp = self.name_to_pipeline(key)
state_dict[key_pp] = state_dict.pop(key)
super().load_state_dict(state_dict, strict)

def forward(self, *args, **kwargs):
if not self.inner_forward:
if self.forward_type == "generate":
Expand Down
Empty file added tests/__init__.py
Empty file.
Empty file added tests/callbacks/__init__.py
Empty file.
70 changes: 29 additions & 41 deletions tests/callbacks/_test_checkpoint_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,7 @@
from collie.models import Moss003MoonForCausalLM
from collie.utils import env
from collie.log import logger

DS_CONFIG = {
"fp16": {
"enabled": True
},
"zero_allow_untested_optimizer": True,
"zero_force_ds_cpu_optimizer": False,

"optimizer": {
"type": "Adam",
"params": {
"lr": 2e-5,
"weight_decay": 0.1
}
},

"zero_optimization": {
"stage": 1,
"offload_optimizer": {
"device": "cpu",
"pin_memory": False
}
},
"steps_per_print": 2000,
}
from tests.helpers import create_ds_config, import_class

def check_and_load(trainer, folder, subfolder, model_only):
path = os.path.join(folder, subfolder)
Expand All @@ -45,21 +21,23 @@ def check_and_load(trainer, folder, subfolder, model_only):
else:
trainer.load_checkpoint(path)

def test_checkpoint_callback(pretrained_model, model_only, folder,
dp_size, tp_size, pp_size):
def test_checkpoint_callback(model_type, model_path, folder, model_only,
dp_size, tp_size, pp_size, zero):
try:
ds_config = create_ds_config(fp16=True, zero=zero, offload=True, optimizer="Adam", lr=2e-5)
config = CollieConfig.from_pretrained(
pretrained_model, tp_size=tp_size, dp_size=dp_size, pp_size=pp_size,
model_path, tp_size=tp_size, dp_size=dp_size, pp_size=pp_size,
train_epochs=5, eval_per_n_steps=0, eval_per_n_epochs=0,
train_micro_batch_size=2, gradient_accumulation_steps=2,
eval_batch_size=1, ds_config=DS_CONFIG, trust_remote_code=True
eval_batch_size=1, ds_config=ds_config, trust_remote_code=True
)
# tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained(pretrained_model, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
train_sample = tokenizer("Collie is a python package for finetuning large language models.", return_tensors="pt").input_ids.squeeze(0)
train_dataset = [{"input_ids": train_sample, "labels": train_sample} for _ in range(100)]

model = Moss003MoonForCausalLM.from_pretrained(pretrained_model, config=config)
model_cls = import_class(model_type)
model = model_cls.from_pretrained(model_path, config=config)

every_n_epochs = 2
every_n_batches = 10
Expand All @@ -77,22 +55,21 @@ def test_checkpoint_callback(pretrained_model, model_only, folder,
assert os.path.exists(folder)
ckpts = []
for epoch in range(config.train_epochs):
if (epoch + 1) % every_n_epochs == 0:
ckpts.append(f"epoch_{epoch + 1}")
for n in range(trainer.steps_per_epoch // every_n_batches):
ckpts.append(f"epoch_{epoch}-batch_{(n + 1)*every_n_batches}")
if (epoch + 1) % every_n_epochs == 0:
ckpts.append(f"epoch_{epoch + 1}")
if last:
check_and_load(trainer, folder, "last", model_only)
print(ckpts)
if max is not None and max > 0:
for folder_name in ckpts[:max]:
assert not os.path.exists(os.path.join(folder, folder_name))
assert not os.path.exists(os.path.join(folder, folder_name)), folder_name
ckpts = ckpts[-max:]
for folder_name in ckpts:
assert os.path.exists(os.path.join(folder, folder_name))
assert os.path.exists(os.path.join(folder, folder_name)), folder_name
for folder_name in ckpts:
check_and_load(trainer, folder, folder_name, model_only)
except Exception as e:
logger.error(traceback.format_exc())
finally:
if os.path.exists(folder):
logger.info(f"folders in checkpoint {folder}/:\n{os.listdir(folder)}")
Expand All @@ -101,7 +78,18 @@ def test_checkpoint_callback(pretrained_model, model_only, folder,


if __name__ == "__main__":
pretrained_model = "/mnt/petrelfs/xingshuhao.dispatch/.cache/huggingface/hub/models--Salesforce--codegen-350M-mono/snapshots/40b7a3b6e99e73bdb497a14b740e7167b3413c74"
test_checkpoint_callback(pretrained_model, model_only=True,
folder="_ckpt", dp_size=2, tp_size=1,
pp_size=2)
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--folder", default="_ckpt", type=str)
parser.add_argument("--model_path", type=str)
parser.add_argument("--model_type", type=str)
parser.add_argument("--model_only", action="store_true")
parser.add_argument("--dp_size", type=int)
parser.add_argument("--tp_size", type=int)
parser.add_argument("--pp_size", type=int)
parser.add_argument("--zero", default=1, type=int)
args = parser.parse_args()
test_checkpoint_callback(args.model_type, args.model_path, folder="_ckpt",
model_only=args.model_only, dp_size=args.dp_size,
tp_size=args.tp_size, pp_size=args.pp_size,
zero=args.zero)
Loading

0 comments on commit 82869ee

Please sign in to comment.