From 969049898d6f80b08a5c9eaba2cae036a3da30cc Mon Sep 17 00:00:00 2001 From: Michael Denkowski Date: Thu, 25 Aug 2022 09:04:31 -0500 Subject: [PATCH] Add Support for DeepSpeed ZeRO Stage 1 (#1059) --- .github/workflows/push_pr.yml | 2 + .github/workflows/python-publish.yml | 2 + CHANGELOG.md | 10 ++ MANIFEST.in | 1 + requirements/requirements.deepspeed.txt | 1 + setup.py | 1 + sockeye/__init__.py | 2 +- sockeye/arguments.py | 18 ++ sockeye/constants.py | 3 +- sockeye/convert_deepspeed.py | 90 ++++++++++ sockeye/optimizers.py | 5 +- sockeye/train.py | 167 +++++++++++++++--- sockeye/training.py | 129 ++++++++++---- sockeye/utils.py | 33 ++++ test/data/deepspeed/converted/config | 99 +++++++++++ test/data/deepspeed/converted/params.00000 | Bin 0 -> 30707 bytes test/data/deepspeed/model/config | 99 +++++++++++ .../mp_rank_00_model_states.pt | Bin 0 -> 33699 bytes .../zero_pp_rank_0_mp_rank_00_optim_states.pt | Bin 0 -> 15715 bytes .../zero_pp_rank_1_mp_rank_00_optim_states.pt | Bin 0 -> 15715 bytes .../zero_pp_rank_2_mp_rank_00_optim_states.pt | Bin 0 -> 15715 bytes .../zero_pp_rank_3_mp_rank_00_optim_states.pt | Bin 0 -> 15715 bytes test/data/deepspeed/model/params.00000/latest | 1 + test/unit/test_arguments.py | 3 + test/unit/test_deepspeed.py | 46 +++++ 25 files changed, 651 insertions(+), 61 deletions(-) create mode 100644 requirements/requirements.deepspeed.txt create mode 100644 sockeye/convert_deepspeed.py create mode 100644 test/data/deepspeed/converted/config create mode 100644 test/data/deepspeed/converted/params.00000 create mode 100644 test/data/deepspeed/model/config create mode 100644 test/data/deepspeed/model/params.00000/global_step4000/mp_rank_00_model_states.pt create mode 100644 test/data/deepspeed/model/params.00000/global_step4000/zero_pp_rank_0_mp_rank_00_optim_states.pt create mode 100644 test/data/deepspeed/model/params.00000/global_step4000/zero_pp_rank_1_mp_rank_00_optim_states.pt create mode 100644 test/data/deepspeed/model/params.00000/global_step4000/zero_pp_rank_2_mp_rank_00_optim_states.pt create mode 100644 test/data/deepspeed/model/params.00000/global_step4000/zero_pp_rank_3_mp_rank_00_optim_states.pt create mode 100644 test/data/deepspeed/model/params.00000/latest create mode 100644 test/unit/test_deepspeed.py diff --git a/.github/workflows/push_pr.yml b/.github/workflows/push_pr.yml index 090481551..5a82e81b8 100644 --- a/.github/workflows/push_pr.yml +++ b/.github/workflows/push_pr.yml @@ -30,6 +30,8 @@ jobs: run: python -m pip install --upgrade pip - name: Sockeye requirements run: pip install -r requirements/requirements.txt + - name: DeepSpeed requirements + run: pip install -r requirements/requirements.deepspeed.txt - name: Development requirements run: pip install -r requirements/requirements.dev.txt - name: Unit tests diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index ed92c19c0..44e827301 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -24,6 +24,8 @@ jobs: pip install setuptools wheel twine - name: Sockeye requirements run: pip install -r requirements/requirements.txt + - name: DeepSpeed requirements + run: pip install -r requirements/requirements.deepspeed.txt - name: Development requirements run: pip install -r requirements/requirements.dev.txt - name: Unit tests diff --git a/CHANGELOG.md b/CHANGELOG.md index e85ef82db..b05f6ae93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,16 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [3.1.20] + +### Added + +- Added training support for [DeepSpeed](https://www.deepspeed.ai/). + - Installation: `pip install deepspeed` + - Usage: `deepspeed --no_python ... sockeye-train ...` + - DeepSpeed mode uses Zero Redundancy Optimizer (ZeRO) stage 1 ([Rajbhandari et al., 2019](https://arxiv.org/abs/1910.02054v3)). + - Run in FP16 mode with `--deepspeed-fp16` or BF16 mode with `--deepspeed-bf16`. + ## [3.1.19] ### Added diff --git a/MANIFEST.in b/MANIFEST.in index b6f8b881b..90991dfc9 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -7,6 +7,7 @@ include .pylintrc include .flake8 include typechecked-files include test/data/config_with_missing_attributes.yaml +recursive-include test/data/deepspeed * include test/data/model_3.0.x/* include sockeye/git_version.py include *.bib diff --git a/requirements/requirements.deepspeed.txt b/requirements/requirements.deepspeed.txt new file mode 100644 index 000000000..0f2c819de --- /dev/null +++ b/requirements/requirements.deepspeed.txt @@ -0,0 +1 @@ +deepspeed diff --git a/setup.py b/setup.py index 85cea70ab..2a5334230 100644 --- a/setup.py +++ b/setup.py @@ -76,6 +76,7 @@ def get_requirements(filename): entry_points = { 'console_scripts': [ 'sockeye-average = sockeye.average:main', + 'sockeye-convert-deepspeed = sockeye.convert_deepspeed:main', 'sockeye-embeddings = sockeye.embeddings:main', 'sockeye-evaluate = sockeye.evaluate:main', 'sockeye-lexicon = sockeye.lexicon:main', diff --git a/sockeye/__init__.py b/sockeye/__init__.py index 00575616d..5e782d107 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '3.1.19' +__version__ = '3.1.20' diff --git a/sockeye/arguments.py b/sockeye/arguments.py index 88551b1a7..d9fb1132a 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -1043,6 +1043,24 @@ def add_training_args(params): nargs='*', help="Manually specify names of parameters to fix during training. Default: %(default)s.") + # DeepSpeed arguments + train_params.add_argument('--local_rank', + type=int_greater_or_equal(0), + default=None, + help='The DeepSpeed launcher (`deepspeed`) automatically adds this argument. When it is ' + 'present, training runs in DeepSpeed mode. This argument does not need to be ' + 'specified manually.') + train_params.add_argument('--deepspeed-fp16', + action='store_true', + default=False, + help='Run the model in float16 mode with float32 master weights and dynamic loss ' + 'scaling. This is similar to --apex-amp. Default: %(default)s.') + train_params.add_argument('--deepspeed-bf16', + action='store_true', + default=False, + help='Run the model in bfloat16 mode, which does not require loss scaling. ' + 'Default: %(default)s.') + train_params.add_argument(C.TRAIN_ARGS_MONITOR_BLEU, default=500, type=int, diff --git a/sockeye/constants.py b/sockeye/constants.py index 7d2ecb03a..298d74c29 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -162,6 +162,7 @@ TRAINING_STATE_DIRNAME = "training_state" TRAINING_STATE_TEMP_DIRNAME = "tmp.training_state" TRAINING_STATE_TEMP_DELETENAME = "delete.training_state" +TRAINING_STATE_DEEPSPEED = "deepspeed" OPT_STATE_LAST = "optimizer_last.pkl" OPT_STATE_BEST = "optimizer_best.pkl" @@ -180,7 +181,7 @@ # Arguments that may differ and still resume training ARGS_MAY_DIFFER = ["device_id", "device_ids", "overwrite_output", "use_tensorboard", "quiet", "align_plot_prefix", "sure_align_threshold", "keep_last_params", "seed", "max_updates", "min_updates", "max_num_epochs", - "min_num_epochs", "max_samples", "min_samples", "max_checkpoints", "max_seconds"] + "min_num_epochs", "max_samples", "min_samples", "max_checkpoints", "max_seconds", "local_rank"] # Other argument constants TRAINING_ARG_SOURCE = "--source" diff --git a/sockeye/convert_deepspeed.py b/sockeye/convert_deepspeed.py new file mode 100644 index 000000000..0174b5193 --- /dev/null +++ b/sockeye/convert_deepspeed.py @@ -0,0 +1,90 @@ +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import argparse +import gc +import logging +import os +import shutil + +from . import constants as C +from . import model + +try: + import deepspeed + import deepspeed.utils.zero_to_fp32 +except ImportError: + pass + + +logger = logging.getLogger(__name__) + + +def convert_checkpoint_to_params(model_config_fname: str, checkpoint_dirname: str, params_fname: str): + # Create a temporary SockeyeModel + model_config = model.SockeyeModel.load_config(model_config_fname) + sockeye_model = model.SockeyeModel(model_config) + # Gather the float32 params on CPU + state_dict = deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(checkpoint_dirname) + # Strip the first prefix from each param name to match the SockeyeModel + # Ex: 'model.encoder.layers...' -> 'encoder.layers...' + state_dict = {name[name.find('.') + 1:]: param for (name, param) in state_dict.items()} + # Load the float32 params. Use non-strict mode because shared and constant + # params are not included in the DeepSpeed-generated state dict. + sockeye_model.load_state_dict(state_dict, strict=False) + # Save the float32 params to disk + sockeye_model.save_parameters(params_fname) + # Cleanup + del sockeye_model + gc.collect() + + +def convert_model_checkpoints(model_dirname: str, keep_deepspeed: bool = False): + model_config_fname = os.path.join(model_dirname, C.CONFIG_NAME) + # Find and convert params.00000, etc. + for fname in os.listdir(model_dirname): + if fname.startswith(C.PARAMS_PREFIX) and fname[len(C.PARAMS_PREFIX):].isdigit(): + params_fname = os.path.join(model_dirname, fname) + if os.path.isdir(params_fname): + logger.info(f'Converting checkpoint {params_fname}') + # Move directory checkpoint to e.g., params.00000.ds + checkpoint_dirname = params_fname + '.ds' + shutil.move(params_fname, checkpoint_dirname) + # Create params file for directory checkpoint + convert_checkpoint_to_params(model_config_fname, checkpoint_dirname, params_fname) + if not keep_deepspeed: + shutil.rmtree(checkpoint_dirname) + # Update params.best + params_best_fname = os.path.join(model_dirname, C.PARAMS_BEST_NAME) + if os.path.exists(params_best_fname) and os.path.islink(params_best_fname): + logger.info(f'Updating {params_best_fname}') + params_best_target = os.readlink(params_best_fname) + os.remove(params_best_fname) + os.symlink(params_best_target, params_best_fname) + + +def main(): + params = argparse.ArgumentParser( + description="Convert DeepSpeed checkpoints to regular parameter files in a Sockeye model directory.") + params.add_argument('--model', '-m', + required=True, + help='Model directory containing DeepSpeed checkpoints.') + params.add_argument('--keep-deepspeed', '-k', + action='store_true', + help='Keep DeepSpeed checkpoints (renamed e.g., params.00000.ds).') + args = params.parse_args() + convert_model_checkpoints(args.model, keep_deepspeed=args.keep_deepspeed) + + +if __name__ == "__main__": + main() diff --git a/sockeye/optimizers.py b/sockeye/optimizers.py index 6ee3b6678..b25a21695 100644 --- a/sockeye/optimizers.py +++ b/sockeye/optimizers.py @@ -61,8 +61,9 @@ def get_optimizer(config: OptimizerConfig) -> Tuple[Type[torch.optim.Optimizer], # https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html zero_grad_kwargs = {'set_to_none': True} - # Use Apex's fused optimizers if Apex is available - if config.running_on_gpu: + # Use Apex's fused optimizers if Apex is available and we aren't using + # DeepSpeed, which includes its own optimizers. + if config.running_on_gpu and not utils.using_deepspeed(): try: from apex.optimizers import FusedAdam, FusedSGD adam_impl = FusedAdam diff --git a/sockeye/train.py b/sockeye/train.py index 39b9b9623..2a5477a39 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -26,15 +26,24 @@ import shutil import sys import tempfile -from typing import cast, Callable, Optional, Dict, List, Tuple +from typing import Any, cast, Callable, Optional, Dict, List, Tuple, Type import torch import torch.distributed import torch.distributed.elastic.multiprocessing.errors +# Optional imports. Import errors are not an issue because these modules are +# only used when certain settings are activated. We check that these modules +# can be imported before activating the settings. +try: + import deepspeed +except ImportError: + pass + from . import arguments from . import checkpoint_decoder from . import constants as C +from . import convert_deepspeed from . import data_io from . import encoder from . import layers @@ -132,6 +141,14 @@ def check_arg_compatibility(args: argparse.Namespace): logger.warning('Specifying a non-float32 dtype to sockeye.train has no effect. Use --amp or --apex-amp for ' 'mixed precision training.') + if args.local_rank is not None: + check_condition(not args.amp and not args.apex_amp, 'DeepSpeed mode does not support --amp or --apex-amp. ' + 'Use --deepspeed-fp16 or --deepspeed-bf16.') + check_condition(not (args.learning_rate_scheduler_type == C.LR_SCHEDULER_PLATEAU_REDUCE + and not args.no_reload_on_learning_rate_reduce), + 'DeepSpeed mode does not support learning rate schedulers that reload checkpoints. Use a ' + 'different --learning-rate-scheduler-type or specify --no-reload-on-learning-rate-reduce.') + def check_resume(args: argparse.Namespace, output_folder: str) -> bool: """ @@ -207,6 +224,10 @@ def create_checkpoint_decoder( if sample_size == 0: return None + if utils.using_deepspeed(): + logger.info('Turning off checkpoint decoder when using DeepSpeed') + return None + cpd = checkpoint_decoder.CheckpointDecoder( model_folder=args.output, inputs=[args.validation_source] + args.validation_source_factors, @@ -775,6 +796,61 @@ def create_optimizer_config(args: argparse.Namespace) -> optimizers.OptimizerCon return config +def create_deepspeed_config(args: argparse.Namespace, + optimizer_config: optimizers.OptimizerConfig, + optimizer_class: Type[torch.optim.Optimizer], + optimizer_kwargs: Dict[str, Any]) -> Dict[str, Any]: + """ + Generates a DeepSpeed config dictionary from training arguments. See: + https://www.deepspeed.ai/docs/config-json/ + + :param args: Arguments as returned by argparse. + :param optimizer_config: Optimizer config. + :param optimizer_class: Optimizer class. + :param optimizer_kwargs: Optimizer kwargs. + + :return: Dictionary of config options that can be used to initialize the + DeepSpeed engine. + """ + + utils.check_condition(utils.using_deepspeed(), 'Initialize DeepSpeed before generating the config') + + ds_config = { + 'train_micro_batch_size_per_gpu': args.batch_size, + 'gradient_accumulation_steps': args.update_interval, + 'optimizer': { + 'type': optimizer_class.__name__, + 'params': optimizer_kwargs, + }, + 'zero_optimization': { + 'stage': 1, + }, + 'steps_per_print': args.update_interval * args.checkpoint_interval, + } + + if args.deepspeed_fp16: + utils.update_dict(ds_config, { + 'fp16': { + 'enabled': True, + 'initial_scale_power': 18, + }, + }) + + if args.deepspeed_bf16: + utils.update_dict(ds_config, { + 'bf16': { + 'enabled': True, + }, + }) + + if optimizer_config.gradient_clipping_type != C.GRADIENT_CLIPPING_TYPE_NONE: + utils.update_dict(ds_config, { + 'gradient_clipping': optimizer_config.gradient_clipping_threshold, + }) + + return ds_config + + def unset_requires_grad_for_fixed_params(config: model.ModelConfig, params: Dict[str, torch.nn.parameter.Parameter], fixed_param_names: List[str], @@ -861,7 +937,19 @@ def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] = each time a checkpoint has been reached """ - if args.dist: + # When running distributed training, initializing the process group is a + # prerequisite for all inter-process communication. + + # The DeepSpeed launcher automatically adds `--local_rank=N` to the CLI args + # when launching processes. When this arg is specified, run in DeepSpeed + # mode. + if args.local_rank is not None: + utils.init_deepspeed() + check_condition(args.local_rank == utils.get_local_rank(), + f'Mismatch between local rank argument and environment variable: {args.local_rank} != ' + f'{utils.get_local_rank()}') + elif args.dist: + # Otherwise use PyTorch's standard distributed mode torch.distributed.init_process_group(torch.distributed.Backend.GLOO if args.use_cpu else torch.distributed.Backend.NCCL) @@ -984,7 +1072,12 @@ def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] = model_config, clamp_to_dtype=args.clamp_to_dtype, train_decoder_only=args.fixed_param_strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_DECODER) - sockeye_model.to(device) + + # Move the model to the training device unless using DeepSpeed, which moves + # the model automatically. + if not utils.using_deepspeed(): + sockeye_model.to(device) + sockeye_model.apply(model.initialize_parameters) # Load starting parameters if specified @@ -1004,7 +1097,9 @@ def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] = losses = create_losses(args, all_num_classes=target_vocab_sizes) optimizer_class, optimizer_kwargs, zero_grad_kwargs = optimizers.get_optimizer(optimizer_config) - optimizer = optimizer_class(sockeye_model.parameters(), **optimizer_kwargs) + # Create the optimizer unless using DeepSpeed, which handles its own + # optimizer creation. + optimizer = optimizer_class(sockeye_model.parameters(), **optimizer_kwargs) if not utils.using_deepspeed() else None lr_scheduler_class, lr_scheduler_kwargs = lr_scheduler.get_lr_scheduler(args.learning_rate_scheduler_type, args.initial_learning_rate, @@ -1014,6 +1109,9 @@ def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] = args.max_updates) _lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) if lr_scheduler_class is not None else None + ds_config = create_deepspeed_config(args, optimizer_config, + optimizer_class, optimizer_kwargs) if utils.using_deepspeed() else None + # This starts as a reference to the original Sockeye model. It is # sequentially transformed/wrapped to produce the model instance used for # training. @@ -1030,28 +1128,46 @@ def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] = # https://nvidia.github.io/apex/amp.html#o2-almost-fp16-mixed-precision training_model, optimizer = apex.amp.initialize(training_model, optimizer, opt_level='O2') - logger.info('Tracing model on a validation batch') - batch = eval_iter.next().load(device=device) # pylint: disable=not-callable - # When using AMP, turn on autocasting when tracing the model so that - # dtypes will match during AMP training. Disable the weight cache for - # compatibility with tracing. See: - # https://github.com/pytorch/pytorch/pull/63552 - with torch.cuda.amp.autocast(cache_enabled=False) if args.amp else utils.no_context(): # type: ignore - training_model = torch.jit.trace(training_model, (batch.source, batch.source_length, - batch.target, batch.target_length), strict=False) - eval_iter.reset() - - if utils.is_distributed(): - # In distributed mode, wrap the traced model with a distributed - # data-parallel model that shares (averages) gradients with models - # in other worker processes. + if utils.using_deepspeed(): + logger.info('Skipping SockeyeModel trace when using DeepSpeed') + else: + logger.info('Tracing SockeyeModel on a validation batch') + batch = eval_iter.next().load(device=device) # pylint: disable=not-callable + # When using AMP, turn on autocasting when tracing the model so that + # dtypes will match during AMP training. Disable the weight cache for + # compatibility with tracing. See: + # https://github.com/pytorch/pytorch/pull/63552 + with torch.cuda.amp.autocast(cache_enabled=False) if args.amp else utils.no_context(): # type: ignore + training_model = torch.jit.trace(training_model, (batch.source, batch.source_length, + batch.target, batch.target_length), strict=False) + eval_iter.reset() + + if utils.is_distributed() and not utils.using_deepspeed(): + # In distributed mode, wrap the model object with a distributed data- + # parallel container that shares (averages) gradients with models in + # other worker processes. This is not required when using DeepSpeed, + # which automatically handles model synchronization between processes. training_model = torch.nn.parallel.DistributedDataParallel(training_model, device_ids=None if args.use_cpu else [device], output_device=None if args.use_cpu else device) - # Final step: wrap training model and losses in a single module + # Wrap training model and losses in a single module model_object = training.ModelWithLoss(model=training_model, losses=losses) # type: torch.nn.Module + if utils.using_deepspeed(): + # Wrap the model object with a DeepSpeed engine that automatically + # handles many aspects of distributed training. + model_object, optimizer, _, _lr_scheduler = deepspeed.initialize(model=model_object, + model_parameters=sockeye_model.parameters(), + lr_scheduler=_lr_scheduler, + config=ds_config) + # At each time step, DeepSpeed calls `optimizer.step()` before + # `lr_scheduler.step()`. Adjust for this by stepping the learning rate + # scheduler once (from t=0 to t=1) before training starts. This way + # optimizer step 1 uses the learning rate for t=1, optimizer step 2 uses + # the learning rate for t=2, etc. + _lr_scheduler.step() + trainer = training.EarlyStoppingTrainer( config=trainer_config, optimizer_config=optimizer_config, @@ -1079,6 +1195,17 @@ def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] = training_state = trainer.fit(train_iter=train_iter, validation_iter=eval_iter, checkpoint_decoder=checkpoint_decoder) + if utils.using_deepspeed() and utils.is_primary_worker(): + # Free the memory used during training + del model_object + del sockeye_model + torch.cuda.empty_cache() + gc.collect() + # Convert parameter directories (DeepSpeed checkpoints) to parameter + # files (regular float32). This does not affect the DeepSpeed checkpoint + # stored as part of the training state that enables continuing training. + convert_deepspeed.convert_model_checkpoints(trainer_config.output_dir, keep_deepspeed=False) + return training_state diff --git a/sockeye/training.py b/sockeye/training.py index 4b59367a3..5e5b2cd9f 100644 --- a/sockeye/training.py +++ b/sockeye/training.py @@ -29,6 +29,9 @@ import torch import torch.distributed +# Optional imports. Import errors are not an issue because these modules are +# only used when certain settings are activated. We check that these modules +# can be imported before activating the settings. try: import apex.amp except ImportError: @@ -73,6 +76,10 @@ def forward(self, source: torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: model_outputs = self.model(source, source_length, target, target_length) + if utils.using_deepspeed(): + # Guarantee model outputs are float32 before computing losses. + # Computing losses in DeepSpeed float16 mode can lead to overflow. + model_outputs = {output_name: output.to(torch.float32) for (output_name, output) in model_outputs.items()} loss_outputs = [loss_function(model_outputs, labels) for loss_function in self.losses] loss_values, num_samples = zip(*loss_outputs) sum_losses = sum(loss_values) if len(loss_values) > 1 else loss_values[0] @@ -225,7 +232,7 @@ def fit(self, if utils.is_primary_worker(): self.sockeye_model.save_config(self.config.output_dir) self.sockeye_model.save_version(self.config.output_dir) - self._save_params() + self._save_params(use_checkpoint=False) logger.info("Training started.") tic = time.time() @@ -322,13 +329,18 @@ def _create_checkpoint(self, checkpoint_decoder: checkpoint_decoder.CheckpointDe self._adjust_learning_rate(has_improved) if utils.is_primary_worker(): self._write_and_log_metrics(train_metrics=train_metrics, val_metrics=val_metrics) + # When using DeepSpeed, all workers participate in saving the training + # state and model parameters. Otherwise these methods are a no-op for + # secondary workers. self._save_training_state(train_iter) - self._save_params() + self._save_params(use_checkpoint=True) if utils.is_primary_worker(): if has_improved: self._update_best_params() - self._save_optimizer_state(self.best_optimizer_state_fname) - self._save_lr_scheduler(self.best_lr_scheduler_fname) + if not utils.using_deepspeed(): + # DeepSpeed mode does not support checkpoint reloading + self._save_optimizer_state(self.best_optimizer_state_fname) + self._save_lr_scheduler(self.best_lr_scheduler_fname) for metric in train_metrics: metric.reset() if self.checkpoint_callback: @@ -349,25 +361,29 @@ def _forward_backward(self, batch: data_io.Batch, is_update_batch: bool = True): sum_losses, loss_values, num_samples = self.model_object(batch.source, batch.source_length, batch.target, batch.target_length, batch.labels) # Backward - if self.config.update_interval > 1: - # Scale loss by number of batches per update - # TODO(mdenkows): We currently give equal weight to every batch in - # every update but batches have subtly different sizes (different - # numbers of padding tokens). Consider normalizing by relative batch - # size. - sum_losses = sum_losses / self.config.update_interval - if self.using_amp: - # PyTorch AMP loss scaling - sum_losses = self._scaler.scale(sum_losses) - if self.using_apex_amp: - # Apex AMP loss scaling - with apex.amp.scale_loss(sum_losses, self.optimizer, - delay_unscale=not is_update_batch) as scaled_sum_losses: - # Apex AMP backward - scaled_sum_losses.backward() + if utils.using_deepspeed(): + # DeepSpeed backward. DeepSpeed handles all loss scaling. + self.model_object.backward(sum_losses) # type: ignore else: - # PyTorch (with/without AMP) backward - sum_losses.backward() # type: ignore + if self.config.update_interval > 1: + # Scale loss by number of batches per update + # TODO(mdenkows): We currently give equal weight to every batch + # in every update but batches have subtly different sizes + # (different numbers of padding tokens). Consider normalizing by + # relative batch size. + sum_losses = sum_losses / self.config.update_interval + if self.using_amp: + # PyTorch AMP loss scaling + sum_losses = self._scaler.scale(sum_losses) + if self.using_apex_amp: + # Apex AMP loss scaling + with apex.amp.scale_loss(sum_losses, self.optimizer, + delay_unscale=not is_update_batch) as scaled_sum_losses: + # Apex AMP backward + scaled_sum_losses.backward() + else: + # PyTorch (with/without AMP) backward + sum_losses.backward() # type: ignore return loss_values, num_samples def _step(self, batch: data_io.Batch) -> bool: @@ -384,13 +400,15 @@ def _step(self, batch: data_io.Batch) -> bool: # average the accumulated gradients across workers during the update # batch. with (self.model_object.model.no_sync() if utils.is_distributed() and not is_update_batch # type: ignore - else utils.no_context()): + and not utils.using_deepspeed() else utils.no_context()): loss_values, num_samples = self._forward_backward(batch, is_update_batch) for loss_func, loss_value, num_samples in zip(self.loss_functions, loss_values, num_samples): loss_func.metric.update(loss_value.item(), num_samples.item()) - if is_update_batch: + if utils.using_deepspeed(): + self.model_object.step() # type: ignore + elif is_update_batch: if self.using_amp: self._scaler.unscale_(self.optimizer) # Clip gradients @@ -435,6 +453,8 @@ def _evaluate(self, checkpoint: int, data_iter, # fully support switching between train and eval modes depending # how much Python logic is used in the various submodules. outputs = self.sockeye_model(batch.source, batch.source_length, batch.target, batch.target_length) + # Guarantee model outputs are float32 before computing losses + outputs = {name: output.to(torch.float32) for (name, output) in outputs.items()} # Loss loss_outputs = [loss_function(outputs, batch.labels) for loss_function in self.loss_functions] # Update validation metrics for batch @@ -617,13 +637,31 @@ def _update_best_params(self): os.symlink(actual_best_params_fname, self.best_params_fname) logger.info("'%s' now points to '%s'", self.best_params_fname, actual_best_params_fname) - def _save_params(self): + def _save_params(self, use_checkpoint: bool = False): """ Saves model parameters at current checkpoint and optionally cleans up older parameter files to save disk space. + + :param use_checkpoint: When using DeepSpeed, copy files from the latest + checkpoint instead of creating a new checkpoint. """ - if utils.is_primary_worker(): + if utils.using_deepspeed(): + # Copy or create a DeepSpeed checkpoint that can be used to generate + # a regular Sockeye parameter file at the end of training. + if use_checkpoint: + if utils.is_primary_worker(): + shutil.copytree(src=os.path.join(self.training_state_dirname, C.TRAINING_STATE_DEEPSPEED), + dst=self.current_params_fname) + else: + if utils.is_primary_worker() and not os.path.exists(self.current_params_fname): + os.mkdir(self.current_params_fname) + torch.distributed.barrier() + # All workers save their local shards of the float32 parameters. + self.model_object.save_checkpoint(self.current_params_fname) # type: ignore + elif utils.is_primary_worker(): self.sockeye_model.save_parameters(self.current_params_fname) + if utils.is_primary_worker(): + # With or without DeepSpeed cleanup_params_files(self.config.output_dir, self.config.max_params_files_to_keep, self.state.checkpoint, self.state.best_checkpoint, self.config.keep_initializations, self.config.max_params_files_to_cache, self.config.cache_metric, @@ -658,7 +696,13 @@ def _save_training_state(self, train_iter: data_io.BaseParallelSampleIter): if utils.is_distributed(): torch.distributed.barrier() - if utils.is_primary_worker(): + if utils.using_deepspeed(): + # DeepSpeed saves parameters, optimizer state, and learning rate + # scheduler in a single checkpoint file. All workers need to call + # `save_checkpoint()`. + self.model_object.save_checkpoint(os.path.join(training_state_dirname, # type: ignore + C.TRAINING_STATE_DEEPSPEED)) + elif utils.is_primary_worker(): # Otherwise, only the primary worker saves the following. # (1) Parameters: link current file params_base_fname = C.PARAMS_NAME % self.state.checkpoint @@ -721,17 +765,24 @@ def _load_training_state(self, train_iter: data_io.BaseParallelSampleIter): Loads the full training state from disk. :param train_iter: training data iterator. """ - # (1) Parameters - params_fname = os.path.join(self.training_state_dirname, C.TRAINING_STATE_PARAMS_NAME) - self.sockeye_model.load_parameters(params_fname, device=self.device, allow_missing=False, ignore_extra=False) + if utils.using_deepspeed(): + # DeepSpeed loads parameters, optimizer state, and learning rate + # scheduler from a single checkpoint file. + _, _ = self.model_object.load_checkpoint(os.path.join(self.training_state_dirname, # type: ignore + C.TRAINING_STATE_DEEPSPEED)) + else: + # (1) Parameters + params_fname = os.path.join(self.training_state_dirname, C.TRAINING_STATE_PARAMS_NAME) + self.sockeye_model.load_parameters(params_fname, device=self.device, + allow_missing=False, ignore_extra=False) - # (2) Optimizer states - opt_state_fname = os.path.join(self.training_state_dirname, C.OPT_STATE_LAST) - self._load_optimizer_state(opt_state_fname) + # (2) Optimizer states + opt_state_fname = os.path.join(self.training_state_dirname, C.OPT_STATE_LAST) + self._load_optimizer_state(opt_state_fname) - # (3) lr_scheduler - lr_scheduler_fname = os.path.join(self.training_state_dirname, C.LR_SCHEDULER_LAST) - self._load_lr_scheduler(lr_scheduler_fname) + # (3) lr_scheduler + lr_scheduler_fname = os.path.join(self.training_state_dirname, C.LR_SCHEDULER_LAST) + self._load_lr_scheduler(lr_scheduler_fname) # (4) Data Iterator train_iter.load_state(os.path.join(self.training_state_dirname, C.BUCKET_ITER_STATE_NAME)) @@ -963,7 +1014,11 @@ def cleanup_params_files(output_folder: str, max_to_keep: int, checkpoint: int, param_fname_n = params_name_with_dir % n if param_fname_n in existing_files and n not in top_n: try: - os.remove(param_fname_n) + if os.path.isdir(param_fname_n): + # DeepSpeed mode initially saves checkpoint directories + shutil.rmtree(param_fname_n) + else: + os.remove(param_fname_n) except FileNotFoundError: # This can be occur on file systems with higher latency, # such as distributed file systems. While repeated diff --git a/sockeye/utils.py b/sockeye/utils.py index 81ed4809c..2a4557c91 100644 --- a/sockeye/utils.py +++ b/sockeye/utils.py @@ -33,6 +33,14 @@ import torch as pt import torch.distributed +# Optional imports. Import errors are not an issue because these modules are +# only used when certain settings are activated. We check that these modules +# can be imported before activating the settings. +try: + import deepspeed +except ImportError: + pass + from . import __version__, constants as C from .log import log_sockeye_version, log_torch_version @@ -667,6 +675,31 @@ def all_gather_object(obj: T) -> List[T]: return obj_list +# Track whether DeepSpeed has been initialized +_using_deepspeed = False + +def init_deepspeed(): + """ + Make sure all of the DeepSpeed modules we use can be imported, initialize + DeepSpeed, and set the global variable that tracks initialization. + + """ + global _using_deepspeed + try: + import deepspeed # pylint: disable=E0401 + import deepspeed.utils.zero_to_fp32 # pylint: disable=E0401 + deepspeed.init_distributed() + _using_deepspeed = True + except: + raise RuntimeError('To train models with DeepSpeed (https://www.deepspeed.ai/), ' + 'install the module with `pip install deepspeed`.') + + +def using_deepspeed() -> bool: + """Check whether DeepSpeed has been initialized via this module""" + return _using_deepspeed + + def count_seq_len(sample: str, count_type: str = 'char', replace_tokens: Optional[List] = None) -> int: """ Count sequence length, after replacing (optional) token/s. diff --git a/test/data/deepspeed/converted/config b/test/data/deepspeed/converted/config new file mode 100644 index 000000000..b59515bfb --- /dev/null +++ b/test/data/deepspeed/converted/config @@ -0,0 +1,99 @@ +!ModelConfig +config_data: !DataConfig + data_statistics: !DataStatistics + average_len_target_per_bucket: + - 5.996365000000221 + buckets: + - !!python/tuple + - 10 + - 10 + length_ratio_mean: 1.0 + length_ratio_stats_per_bucket: + - !!python/tuple + - 1.0 + - 0.0 + length_ratio_std: 0.0 + max_observed_len_source: 10 + max_observed_len_target: 10 + num_discarded: 0 + num_sents: 1000000 + num_sents_per_bucket: + - 1000000 + num_tokens_source: 5996365 + num_tokens_target: 5996365 + num_unks_source: 0 + num_unks_target: 0 + size_vocab_source: 16 + size_vocab_target: 16 + max_seq_len_source: 10 + max_seq_len_target: 10 + num_source_factors: 2 + num_target_factors: 2 +config_decoder: !TransformerConfig + act_type: relu + attention_heads: 4 + decoder_type: ssru_transformer + depth_key_value: 16 + dropout_act: 0.1 + dropout_attention: 0.1 + dropout_prepost: 0.1 + feed_forward_num_hidden: 16 + max_seq_len_source: 10 + max_seq_len_target: 10 + model_size: 16 + num_layers: 1 + positional_embedding_type: fixed + postprocess_sequence: dr + preprocess_sequence: n + use_glu: false + use_lhuc: false +config_embed_source: !EmbeddingConfig + allow_sparse_grad: false + dropout: 0.0 + factor_configs: + - !FactorConfig + combine: sum + num_embed: 16 + share_embedding: false + vocab_size: 8 + num_embed: 16 + num_factors: 2 + vocab_size: 16 +config_embed_target: !EmbeddingConfig + allow_sparse_grad: false + dropout: 0.0 + factor_configs: + - !FactorConfig + combine: sum + num_embed: 16 + share_embedding: false + vocab_size: 8 + num_embed: 16 + num_factors: 2 + vocab_size: 16 +config_encoder: !TransformerConfig + act_type: relu + attention_heads: 4 + decoder_type: ssru_transformer + depth_key_value: 16 + dropout_act: 0.1 + dropout_attention: 0.1 + dropout_prepost: 0.1 + feed_forward_num_hidden: 16 + max_seq_len_source: 10 + max_seq_len_target: 10 + model_size: 16 + num_layers: 1 + positional_embedding_type: fixed + postprocess_sequence: dr + preprocess_sequence: n + use_glu: false + use_lhuc: false +config_length_task: null +dtype: float32 +lhuc: false +neural_vocab_selection: null +neural_vocab_selection_block_loss: false +vocab_source_size: 16 +vocab_target_size: 16 +weight_tying_type: src_trg_softmax diff --git a/test/data/deepspeed/converted/params.00000 b/test/data/deepspeed/converted/params.00000 new file mode 100644 index 0000000000000000000000000000000000000000..c391eddbf74899346a8300b2ba9d1fc1c3a9d3c4 GIT binary patch literal 30707 zcmeFZc{r8b_djlkB4sE=hE#?qG9|;k*G)>vP&8>kW=c35X;7v_h7^iK18G!Ip*VZp zm1xkYIi-Onn$tXe4|<;8$NPB?pZD|ax_zo39XK2Ow#A1XvN@QmaK1xHAGhWmL(1_k?i^8G?0Lc={58B6oT zh7Wh^DFllUKIs+U#}n_WB5fBO>cyYhbyJ$x!%bB9j3H0LO%NC51^Mu$dD6PByxx4Cj4Q8? zo6bLz-`}f)+_O`S2m`}EwCEcs6f#geuuBLdo@}6`LoA;s*TF5MUB#2{;#OE4*g?^c z&+G5XQ~U=L@x8(W{P=%0(SSd|10@2xbQ<`lPD)+4%3ZjF_&gO?p6Wk{{tNeCh_3br zdZ+!2dFoxF5B^hhjV_8IT@*w4ykV|9O*hpJ>xB4(3XNwF78=oIk&c%B!z5bWa7}pH zLfox7GbA*8;U9$h-3U#220}tPK5tZ~trB^L4x))Xqb@RIKF`FJXX zccU83GwY%n^NY&7gUX_d%978sa^*2@`u~jTA9V8f_Y95X|Dg}*MqmU;6^Col<8Fmm&>>{%7BAUeKIk@s1 z-E{vMQAbnx`~Q>Xaq7lm&U5av+2mg~o6=!3mo5TVK5wckZ(3J3{?*62^@Z*PolP>m z8@vT?Mi>0dU+``n@Uyz$XY+Y;TzPZd)c=|MLTG<~qd!dJ{wJ!K@L?lEp2x37@cgq8 z=5^V|s|(wk&+~ER`F49f{uKQecBj04-I0xW{#^nG{1P~@L*Sq;(D{7c0#{z}zY*B@ z4}lkUhc)Jf{1Q0yPl3a_1m<;Nhx2(6t~~y~5!l$EQ{c$%$R@l+T>?k_5_oZkz)QM7 zm-2bbTzS#m8qhx|#4Fe{EIic5FCyZf^pEA;@lAOvy6|Iu;m3C1$93Vy^LYubyp?WR z|Dy%}C4Q&$iQUmh^Hz1CC;dWC?m%DNg`UFarMmLcI!)s1_itQWdUsGW-WnmO(8pzT zfOh)0wOzJe*9Dx(=ViI_vfYe@V*dgBXXo&W6b2T*fN;;wI|lxt9l^jez>DwqN87IN zPB(_P;aA({bSlyDqULt-ZS3O9O>Q*n0|3*L@jI_>;Zr(In-?e3!5!{_aF<&|{CiN8>F zIG=y8W2ZH{#fg30D2)wy`@0l7@Jq4M4#mp4*bee}hg^AwI|Idk$<}Ga^6p$lyoxTa z%3oYZI=GH@aaHko)vi2&o9RDmtq#qE;r8DKK-!(mn1@|t+%K|Y9c0J5$WHKiHLkpq z|GX>yTe42QYP(aJ@anp#PW_@f-9dGxi|Q<&SMSO@*BO`oV#Dq_bq|2&yOWsmF8t~n z7du%x`unA>zHzyWse#YC;>x?~ruWY*_!lPLAARszcb?I_>%Vwz{K<2(i|1As&uu=h z(Uo_{ZN&eIXThRBi0*bLGUGLM5Q*`5_d1A#y>sQYcZRON;CGL;@4LfV@IL%%hL3+X!>2CmeD1>j!smT;<$desr~h=P|1;Kp z?}lt*$otVH@Xs!R!$pKM2cgX2qQYw@bhwzX#19u2=C0vAg!#Xb|BuicE+KrbJCRYi zWXHYEF?)DVVet<|!g~q-=+q@#N?7XRlopow;k|{qYq*Rs|2KMdh2-!)!l$}38i)7o zxYsf84VM*`ekmv?EOsa;FT8e(gx4tA^ez>tPcMUfY=5A`j>c1v*|5K0uZ%P_&DtxRvn_2ki zj(Z(5({M9kse`R!avDBH_(zvc=E71JqlK`<54RNNuHja~+)d?QGj?jn2p{UsHzu5P z+>>y3voF_sk{$4Fxr_} zKTnTl?mo>-v5Y4Z@1&ApGaHyC7l)C;Tb4vZB!%9oGhhVe7w8oiQJkv17Jb~Mnc^ue zbh-6<7(cd-7+(8ItWSp`sEUB6QVwJN-U|zkr7+(F8gTXYGO{JA;2qEhs45G@M z$K^%%VeKia%%V0DxI+Y8eu&T*{w1h7xt}l>q}jY%yU=)4FLda;K_HUfh%td@U`KHo zs%@JHPMUj3&5`{Wcu1RgC-r4Rr)R)`uv#3FIt5p+tt94RLCjkDK{&){G-+Y9Fm>`7 zm~1!z4#mV_d`l&>!t46)`dPx(1~@QK2&&k`}29ZppaHrNb4RfE2Tm!V~ZF3K&H=6Y$aA*Pb&>BBzOP$*J>8yluDA+~)`-)TC` zdnJQ+eEBHRzaQ)^C?HxNS#1AM06Uf^K<2J&P;oQ0I+}41*81vl2X$q!ybm8COhRxukTvdGA)QX-K(h$xMmO>S-OWc3X}(2Td$(q8*4 zuVtuE1*L4k8fzn(#w8Ma(*Klah$sDzfb)Rq(YkdH|lQkLU7VLn&~mLH%5G^#KGd{sOgk?rak=- zmR|^>*Ka(xx=?(JIOe(Ho7)n2CEFL2>}Ju$vkx%q7XA>3x_V*DOc@%w%Mx#8rQ=zj zHT3xA8+3|JKlJq4L9cE^gA(#aeHG+1_b*E&O|gbMvvUdcbCCT{)m2p zA+l5Hpy0ujNXb|ou=QomMym>jsKwEwfqL{qn-0ly2&o=1W({rZmq4Aq^`K&=adcMu z4062QjZ_DHuHNXtl9UohCi75};Eq~G^}>U7g5*9&1xMNqNk{)Vr~7ESNJoqprhnak zWJPTMq5lY@7yEw~y=Ko=w^o8VZ`Lw%=Co7mQZt$`;RqPFRiefI6iAvrf(&!^ht`P6 z;1OI3oMACm{+KVQ7!r%=y;HyeuArLYclt<$#m12?;C(v;7oJ~5Y6D|2YMm2AagVEQ zR&646jahicViLK2VLCLyDrhw>B&jQ!aC?c8Ks5gbC=5}7lKR8=sc9SbcvJ@q1ryP~ z?zEtvUkD5x5DYDHO6=in12E|LLr89_#I5lc=~JzBH%0^j>f5i5g7e87AF^pvHm%k*r0fmjJM0D51h_GOGG~PZ{JN1UfKzAeqk8czKE!~ z_GJx1E3t9hFgAbwDRTG4DC$&s7KTQZfWDM~hRsWXD;1P}P7#5Ihi9o{`e=H}tb|;i zl|pO>A151TsNl+|Rj{gN0y)v|8-(~W(BPQ{^Bd<@PmwhiwE8v^W?uz-+^a#h*;nFb z*{fhQbs@|Nn+`T_Okst7kLpGbfZa#*AaQ6FT4(Ks9q!^d`f(K9Iqosh$PI#mYKp7X z&k@1snGo@6D2{C%LDnyCC*@zF$j&@@OzLmMZLnPjLCdG0__kf3xA`vGxeO))M2z5t z?O8ZG(i6W)J|`)Iwo$9_9H!ScWzZ>mBKUDv6d<9Vy!$$Y7;jgg)yvaRyyzz#FIy%E z+O(CPIo^ghybWO2W(_pkatq5UzJiP6dOGyW8q7V`Kr%<^LJVVsM=~v7_MYB!u&5ba zopy-ndp{Gjmpq15TO4S(|66ik|3|uH#SnTtZ8~)Lp8&`7EB`ru{nh^n1K9ty|M}YI z7|xw!jtwJ|aNwGJES`J{Bc?A!-G@$i;L%rfJv0%o)oWup# zxM}rxnAb0on2$Px8JqJ-{g?yv&dX}5UoY%y`F?P~?uNi|##WSXd`jAL`oUrI^>`?9 z8}LJ?jHJjek|itcNVo@4Pp8DaWMADN^}WGgC%O&5D_{W#(eAt zcNU1SS5O9+Gh*;DDgh>aKaHXdLj+M(4>9YWA*g3u6AW^W1KrWpcp`Z)4E61WBHv8W z{K7f>+1CVBC+<85n;#2cL}3XNsFD z@aD8y`t^-14v=4*zeaRycUv=;&fMng$e7E^h$Cs%fAEzC?Q z0k0Xc)LZWd4frq<$DF)K#~Mjfe&kUaesU;yl}-aoe+?YE+YG%V-K)nB7!D(}Lvdq| zG%B~#A9}jF;Lh4&q&*js!cSuvnetNVJXZ&7tPYdbBZPhNxt)Bu9*^;pCXmISYN_Nq zp}%V!D$uE<7&c}gu1J{&pO1*JukR#dMz%8d`O#9$UME7H6;GnGq?>Tmv=ETKqej)` zj-vLQQz+*o#ftTahuq0mP|h+9*V#(3%KPd_){E_^?=cJwd%pnZY#FY4-b&mWeHoI^ zB;ox!Y21JP1-UkB8`981++^*o! z4%9|wJ)ZfdfP=(E$Q3(a%ALj79IplVeCI7vvosb|UkwmgxD2Hy((Ezps3ELaZjNml zdPrImu<3yVOmfYF-jgoDwo4XpYV0@q=5+>U9o9xiCm*O$*aEj6C(yQCec_&N8yHO= zj#Djmpbg9;^M)5<)1=2ZDk7D5;0vldT9+6cOrUp^y-4EmjhJmw3mUos=(M+j&aIhA z0_)Dv5#t^Z$!&SiwR^XP^ z^WgTUYh?Y@@w6sNg05Vdjf>s{l8?NfaBRE?+gAUO*ngHJ?klHv`ky%?-qXM1KTUpv z{|POyCSPQvKto=VNpU_#E#KG($CT5^jfK0&{CT}#$hmmSo?m-b*LgV6XL^f>|DAL^ z@ACYJRf{*?h#b*5{(LrWa@*hWpQgV-{=(2ac0=ny_U+^C>|VEgR?~3{o2gpFF4(k} zy>zaQeKmS5yMrlbH9yv}hx1C=iXX*n?ect9W5y144qMK4`2U#N8-9OB{?Wfde#gT* z=*`Rm=D@qBL{4r#DJ~g89mCF&<;*fdcr&40`^=70FavLYh?Iu*~K$%4feh>AQ%=EjMVnb(IgLtwI1 zh(dzpy!v7b8y9JIzH2q4ev4sCe#Wq$gz2X+Rh|lDRi+JOKZiQNnXotl zSBsdtO%$Y5Iw3PiE9m76)8^f0Vh+&(Ab-xQ!O`IH@keb3S zbk2d}KaZ2aN8dB@Au3GPrg2t%$`)H{hb`yaMk%td*7DiOw^pzT4`bLq-(%QvVciZH zE>at=NsZ6m@Hz#n4Md@_ei*sHOlI5^60F?oi!BYO>u_ZTd9d^B0ye09Ih$P{!=CsO z!!8xpeJ4zE(UDw?P^WQG??GLt`TGh-5|6Z+S-7sidiOjw?m{%H zykE%qF@`k|)@>7}G}#PvUL4K_rOC0O8a8k_Kaxxc-o(t>bka(q?33k@^O5XAb~_r| zPi30}maz&cF|28O4BH^I!9!tvJHz2@$Wuo)qD_-^xiSxoA%pC2EoV-7KC*f)FKRWv3(5dDqQcIvZsh-m|tZ2 zDTp)E>b0yseD9Bs*12$=bBtNB+fl5)K@8igErxv}wBcD{ePhup@N)2aP)nZ38l1=g zR^bSFt@ezu4=dP3k7C$a!ukSX{f_rHz16if zRyg$dzkHHz^ zh3gcbkI+@aeb|YQjX2V4Bkd*f6#3FCiBZE|tQvP-P@tMfCJnAdg_gOX)Ub!}AFKqU z1UlF#Xc|wokm$W#14es1k1`cW^pb%P*OY#=rPl|N1+&b#CC>hA8+py4^f&Ge{ zz&w8yQ6KmWD&{PqkM}B~ZN3X})+!+SFVh6i-XqmX6~S#i-%?*Tfy}RX0BJfIWb)op zd?nb&j1yO)o57r|xF&{UTw2MZ{wmy&+cNM;)En0@<8bRFJ8tLLBsgAh0J)TOGHmxi zuwP$J0t)3g*K;69y%dH)+Z?fU#2Y$s;vs5KG8F5y{b`v_B91Iit5P1~cib4We{tGlQp|l~arNS7DK?0ouu|g+RS$ zR90a$?#^9_?}HA2^ip{+YgA_y_~)y(kLkg_YhTZ}4%LRT02Y7LJ%hzZ^|-a;)}UQN z1LiD_CrwkLVQ2qEkk>8=jh-Sn%l#xNnVSgjT1R3`T^&6WE05uF2O)flHcNf((;H*g z!qKfg0leLqTBZJ2AIittjXfB>pLRH`CK0ZyWzz#=zroG+R2X({4CqN*M1Qe#YWTpB z%J}zXrxq#V=A^S=^urcz3CBf|d&S{GO)~cL)~6of?l5^(GJZPQi(IX>gI484>eH~2 zmX@x7<;qtwAM$R9L`=EVBrvK7NI(BeFZ1ZCm8 z z%A-lJsZO6U@99O{j4j}H{UA1Ip9v&P(twqaB^Yt$JarooRK2XPrm+em%m$7jD%rtyI@DbkraTz>i?!ln z=9~~=?`~K}9PiD=@?HB;f7c3l*18ovoa3<3t_SEZb%H0{UTom?<$5$KaKY-i(EsvB ztMQfZDLlVT2Sm!THv*(s`FTp*go?YusgnrT;eTGYCPe(5@tgT?@IP;t?8p9JYZx$F z4-1BTMbE40q&14AJ?}~#{jyCTWF{L@iI5D&;9D-PZ+|ncBoV!k~a37-xP9(|Upn}!#*^c4Ve0RXJ8W^#<27G0Pc|B0OS%Td#ax>0p9>bND&%%I| ze0Xc72ug20GaiRtR4;6Ofj>?ThEdb=VczL@od38dJg$CE_3N(FQ%i*NnMcXkxIhBy zjI$u=_(Pg5VZuhLjKx)j-$666l}w#e$M6@+f>E6b+Zucdb}Lk&X3H`x)s-ZR6OW>< z)hu|kk>d8&dYEsq2@O`?!T`5&(z`OBw(J;6Hv|uc0VNWULPycreNQv^T8vDpjE6Jt z6uHs9Bf*l!V68zR6og)aSNUVfgFa3mBN~CDw(W$u#{H=K{H?+@g|nCwy8u)U9|r&W za@3kF%8fi>4Q2JY%$}^T@LuTzM63-(Rok^V_Q*Km(>52p1CxLrw1aRe^OZHtbXQ?#ey_p9hxWnI_jBmh{jcD_iFdGi>}d9F^gf*aVGy@eDGv2E&!I(9 zPe}gCY#1BCs%2m8pW{^Pg-{RY8DNfByMaR_c_>a$njB-X6Bj&f4fg);2FNR(sZr_rpUy=2YG7Wi;~Ej6HaTwv`jkcgi_b$cbEMB7yS5cL*pZtW1b zWqL!%=UC9qlID6&Hlckqze7@tD5w8y8&1@|ifb0Fz%bizFxT`2=P#*nxOxRt&z_8K zexKpzCSA(torADecW_&{7~aH)kl@Dg@TFz~^=40W&VTh4hUorh|Jhgg$p8F4z$oGJ zn~2EG{#yiRdW<4o%cE&v@K(WBQ&swG{8ajG!!kNl)s!|vEwlHr3f0;hM|Ks@6Rej= zVwge~l3<>|G{jj@+xNo+;tN=T**a00{A2`^-g`D37aPhrwc1iz<3*HqNeUXCET@q# z-KZ9C1%2jw%xZhqGOLKyax`mlZ<=QYbb{IlbWWQr*nR5?v*<>f;PmDXR(`GLN#vxP zbo$UBR2L{>77IcGS2rLGfH_diLMZZd+Q0RWT%s3 zC!*+&f@nd!ejixiDTg`=^ZnNEIjjkN5E#@OO2qTwS+ z<+j-%PLPB;n-IDp>6xIXh7$SOe^%>A#3y#=m3#Yk%=xcDYeL z`}E>Q_QUp#tdTylUK2{$p`~kC^-amFwM!PO{%!}`b6X1Qbb1Zjs$0rFwb;N;^T}sx zUzK#uUu<=kNq5BFj`ZKi|2x*t`#hh9>(=gu{MT`qR`8lCZ81Uxr}fzH`+9mY?G}2> z?*sOp*_i$)5^nP1(4?(_u9nmQjfa&GbgXOrJZxml>AxfY?^r+I+eep8mR!gN2D-EI z%Szc}9=_}c%hjx*)It zGg5f|cjW&a>u(wcKPVS21LV9v+9fV$EK<_J)x#Ygy0;0+)$DQ0+p#phG?xq=l}}zj z%qEHN_tHa_qsa2{>n(fh%65+5ik7+E`aAOfj`cUoHAc{SUwHb!Gz??*62ZnxjVLiC z97T5-b1k-_?D@NvxNK$=6e`Pehn>n{#@BtATdp7Lp-CIScIXbYpKg;3H(ppe15#!1l>Oj)k^a>i^b`t0yutW zgJ4JVX!=t|4URYTgqfG*xM?Go;wqh!pj@TT?aFq8VNWG-XKoT+{#gk+OZ`!~bUeFp zd|z&kfjk)HRDf?*urL=rU{kW_Jlr zw}^1;q(STqyZO-j<12=rD+AxP63LeNWhnM+An3|#5)&po10-^>U0+B_QFQ6!#TFL+*^S@xAlMn?p$FCZjQitFVDiaR)w%(=s*k- z&MD&hGuYcMMPQ!34sB=Nr^7YF(QTG0*K@EaC#sMpnA}oJ(P1e#kE?}VD(~^Y3m(2X zzX*>n&qSL+3hedW3FM&8Tk_(x5+q;giDQiik%${xafU(zxX<;({p zTBoon^(&s5;YLrX_2TM&tfBSI&oS|OG0dN)z@@KEKxR@kZe%AyO}ha%{lronW^RVb zFPzxjVzF?c-$;zT+JdzWS;Dh$MG&^)iutEg!yzqRjD0iaHQaNNV5Qf-s9unq!QE{g zMiNXPkOk~QycPHcONR6zuNS@s!LY&D&_4q(U=`d3;TX=m6t}ymaDB?gKzX<<8>PIj z)BmiQd}Pjl9>4u%|JV3;te?+*eok<9?h8_Ttrt7=@_pKncnB8`7Wx~TROZPhUAUxP z1aEHl;-0UJhoR%DsrbSG*mG5v2KqgthWGEm(xl<|T|EWoO8aB>K7H!Ka}^wXu!wpr zolg23mSbOpW;5-Nbp-iU!t)rzq=8ONh4cLrVfwZNu&kU3_n!f_hwPz8qpX3SJce!d z2*UddoZ(pfF048NjOBq~GT4cu&$r(o52HiKTy6wNi#)=gt7uTMl)02}iVpA9_Qg$Q1i!1?7Un3QKkS6x^{dc-~j(+`$Z;^#ovh}-eN zRB=>k4#e2fcu=2I0q!=fwC}dtWJQz~9LlqYjMhXnbr8_<9>?*O-&rR2t}`f}ZWOHA zeGab_KP2~+t8fH=9~{fACeysu;Uk|=OfT;TTB(Yl_UtFNyNzKE?D))VbKQhH9kxQX z|6YN@X-n`{X{DPUNt4Qo2~glS53M7HKuYCXqEJ5(qSyH01n+#Lwj$8mM+}VnF2fj? zblC4#Nzh{fG#ZS@(pE$IWB5K;;Nyc4Ho17)SW2*YjVIBdWXA z8xP*;1rbjLV6)nr+F1W2%(7LuN~#xo!0HXI(dY**z2vw}DROj$r3F{2AxHZbJ3+Hy zG33R)1^bYb^!?M5^u(L7IJEy&Dt2THOtINR9&8%U8SE0C<&%>FsFs3PqJ!w$_8h3% ztIudP?xig{Zg}SIK9X>}kQSW3ENBtC0`tD~;oRHiz{`{SY0W+_Y-=+mug)*PhA&fC z_wvtVxd9<-i45EEnraQ8S%xL004HmhwX)8D=zB8M#T{B?QVDeIOUXsAH*|OpWp2%_M7*W1kLRxyLhH(G^61oOQYUH%lV13vT>KSin=L`$ zd_I!c5$N{QRNym>B|7&5NY%=G`XTE8GyB;=%rF_q?o}liV-tdjUO{l$ZV~49oG*M= zS7Cziyj)m)1nTN$;oP-Zq~!K>s_&kOl|ipT@uDQCwyGaBTe1%}=xAY3%vw+<=J;MD zp4j$ZiEEO75M85{T(ko8%~3VBuCu9i>a?!R>RowEKpj=~DqxN4MhxEpxm& zYdT$8k%0-5>mg^J6doD<6&f$Qqq2$`42vwrcXu!2v2hjP`r42i@VFOViQEpl+Zxec z@)oWbCx=(Io`a8DfyI%jXli1Q2gZ#@Sz(-99J-Gklv>yd3Gv4d`VbEO^k= zKvs%f#m1RS$)dJ)=;xV)BPV3TjmI%0)yNUw;3wK{zE?n8OhMM`3+nY(!(Har$+Z(I zg3^_M`HferwJn3`@WoZQ{mKDosJg-^zM6+dMZ3uob~a89Y9#w_gy3__ zX>jEJG}K;lg=~@!#B=xJ;nXn&`UN)#Eq{=Y3$do6r}|^oJ54%c&@R00 zUo2?engdJD%R_{JtzhVOEtpVONVui$pt!Lbb{?-F%X_tuoV4+d(WF=WLQF3T1Ebs%_`Js+l6-3(Y#SxY%&MD!8KXB)`|%m%LTom@-LsL*nXQBa zL%-rS;hAu)wROZn9z?XZ)U^;9szSQ!^T^DBJ z>}!ZK;4b6`n6L&j4pQ(Go(~-|1XIZf?yI>L)GQjxju??ge0)y=obVE!-HOK3%Hr(v zkHT^O1sn1u(+GLxyHg?a}MfNoXDhr8Swhq z5;!_DmNC#dFL<#20ceR{r0R`#!GFR=+D|N&=P)s15`UdA#JrW+aar@oYq4a2wvwP>Vu zS$GEg5;;D=69R**U_?!t#6`I1b(_BF@o3|Y){FGjCD9^jQU z8Xrr~#?aDc+S6hqZLHG3hIcuFh0I92e_oGkPmO@@>xAD)m|Bm8q8jKaEk(=Q$KVXv z-neaMJ{m+yaf|&N!Fs?<5Fg_Q-!4dV$3pGl#+MXWn3s$>&z9q&LE$)Fb0@XdlHmLf z_QKx#RH%5WBzOPVQ&2k62fj$WfxIeRfm8lgSc;2M?w2=sgpC%QcO3ztsaa{}g>@lPz zFR@kleUI`TmSoKWC46s{4>QW*!1JsI`(A7mn90e4yCtZ-mG3 zNmN{TMtr$OIGzlh1~(@yz>)epQNhK8>>IuvmQF6BpWZ9NI~f;n>~{s^2kFCcIz&ZDpwY2QaDVVpob_!3^>{mj%+8Vo`yYMa$D97R^{_6M$avz&jq}MZ$B}q< zT^WooUjgOP7O4xn*#Oh7yu>ea9Z>q97{1yx2&N<$p<8qT zbq=Y55nIQhN%KAIapXQq@Yaz8|2MdB*9oZGssNXY9?_N>H_SWTMi##QhR0tSLNJPg zeiQ?Bdb#wu?l(04B+YJ7S}nLH^#B)_H<9vN+b}EsE7Edxd^>vzy4@^+1-7;1`;ETU z?*^q~fYmCvJVFufr=9}`vHt9nqDGwT8HC2928w2n#Y0O6;Nm^?@QB}+)vnk_zC`QORRZx(S3!0AcKTo-L8YgB{4_%V1=?TeP_uk+Kah`S zU+<=Kezu^;t1S4^Cd*FLT8*9)-hgJ$dfNN2DBUcR*g5|xS8%Qo>G%a3Vft^@?|#Sm zgN*)@*{F#H?9uBB*u8~0Y_RMec4luU*0a8ZmHm*aBKbsC$1#W9VYP=X zo9oQJTb9W>&tKiS{uOHGQTKP&kABDgW%^tntJ^B=)o)}<80+Zs^z|h%xO8V5(=R`Q zrhgwrW;bu5YLkvK?RAf+_OVEMLO+ZyQcx9$h-`vIff=3q2XXsVr~DoHf5-mi&mZ#{ zJK5*ZeDx4)3eCkW@fA#b>vbBxS&J=cy@*+(?xW7I{pc+J0vnV|A*tOJhnnfaNl8Vl z5`LGVNhEomvDhe zrdw%6bO=|by##|EKO#-S?<1-QXv4sRDU8a;SrGVh2i(>?BA6ud6SkOd!m!zuV87ug z&CVS{uF90d#gAHC=$370nCyn)o8)P7;tS!PSs@0*wpTYRtD_5VFRE$WCuz|VRJ`yR zItSE3#Yh$-W{Yt4W=XUUtOcoWNqEyXmmbodOd4C))18xTV1CVRym`U{Gz=e+s+<0h z{z(y+mB|y?n-_8RiK9$g4+(CM%_)+z_aSVnxlP0)R*}U_Iu3e04#Wn2AihVc@MOVb z8j5G=x;5`X|J_sSJNz5Dcr24Dlt+`>8_dvXivjF&8jN*f_u(387c6CJ;jWP!D>YFS zp4ufs@hu*_9q)u;W#QN_vy5)7-$HNXtR>P9DUPPQU~Je?`fBSXB7c1uJUO);nl9gA z`j~ss=dK&^UU4z8;SS)fq*U0rOOz{#odPcpmgC%O+Msme5XznnL7l{S(6m;iw@P1O z-602(6D@~E58lFT^IRBpMw>~}@j_|KW;lOJ24}P_!K@(fVeAoM9e_pb>Y#U_LlHY2syzZ6u&~66hGN7MzqmgG)0@>Gu>r z7wLBi21D0qH>% z0?kvB5zM)RF#+@t4W(*gvINkI?PTbs` zK~)Fg%t}dIw`xBz!L4L*$QJVTvJ4xVD2z#$(y(plTRde|3n|On$nMn}X>S7+kaA(G z%7wq$bH~>LMOhI}Rilk0s`mmrRR?g{n@$WrxRJvLZbQ=~XW^csGn!2Z!$qp2A$=Lb zh8e=So$Wr%+Pa>2wuf4Y*QAko&zhNPrFG=-f?6mJ4~KbDKWP14J1Sn%LciC@;KON) z@%=|l%ow>592Zzud%Kij>G;7g)hiRcbCBeWEr&0U4v~amZ^-EIPju>9E!_XM0pICg z!lpxWamykp^d6N8Ip22U{PsA!EIOGPU9E*`pEOZKVI5JeV6nhq0~WhZW!t~4AxE8i zVBZDb>8p=}S!Km?qV;$?j*xMHT$@?o>|xCMyy^iX#3V33_8hFq>C0Wp5M%eGzou`V zq=12|9y{D&yCC7I1lwDG9TqRUfay8wkbmVaH9uX5`DSI9sBDa(1HXgLqcM0wEP;G{ zbOD($Z76$$gY#7qIKE#B>R0G6Z=XKJsB`)dGFuh5I=WF+nYp0BvjD4};%sH`ICwPL zhb%bMn_E0#ESwB)CCgqba(y^$=AwfV+xTQI{ye=I`1<7#|86yzWqVwpwP*y(RBW?4 zQGF0ge62~=09E?png%s|IGo-rJBZ`YB*5&0Qk+!TczRdhh7a>jlJ?C$pgi0HS6lUG z>y9hHmCpw7@a0m_c;W#qy+pZW`C@22ooa0IR#vZ70RDerX9HmuTcai-Lnr!yQ z*AU%n5)tqD0LNZ@D%f0+OHV3C3JP!g;upFLw@y^Rk?o4y$)iayXVx4zbFv;Mh!lg^ z&pbL;Y$M4lSq;Pbe`6A7IMLgC49Km>p0MvBg~ofOV03sj%DucwPN=uw$O8#<>V_S- z%WNF|luS|4{|CkmFG4yc2@>6MiT6~07&Np7&PiWo%Ee#P8*APX`)nyJR&1vqc8tNE z8(PT68$IC6$5^No5KgC3hekFqJw$@bkj@G<)w7#{Y)Ns-cY;}HcW z^ZbU+`CqZzgaiLM{{LnF%j9>Qe@Go?OrGB81=TJl;1IL}r%c(yge+cdCF7Msit-Fd z@klLtvN;5A6kMWXa@Fb8@uIN4ek&R1=}3&HNq5SBt}p-4-|?Tn30+$+Tk{d{>CvoCtYSNzAvFa`p$(l&hem}A%$aYBs#~x{WWuH|Bn2> zWB+BG;yh52*au6mxzhlRG}!Pkg_D1)$;r{WHU=U{fwjLRtL!Co*cCPCa6YM&#A zWh*F1x~+z|cRHNtcP|qEZmuBGxED8jz9nbBPlwBn7qH@ebl3~RJ<+hf@5z+c z5#H9Z;v5fVvVPV@tkjP)WTeMH?$s7g?g#6JWmd&RpgID)PF}}h8(4H2^$gc zqbh0els3lNT8&E^mW~N`W^reG9frJWN3JPk1e%Oq1Nh||3059J=0^4hlNT#s=qDp? zNobS6;dCXrW9rPzy1gGFzVBtMm1aW576s0sdLevM@681cmc;;{d{`K;oyG*5LaCYt zSg5JR4*c0ngJyoV`hLK)bNm-{=GL3PasJQbcdWlHN$J6b2R7i8CKJ{@XD`ldT?m_s zSHak4+fnCbdFT6K*3&^tLqtTRBmL{~XGAB#@Dz6X#RT@?oDeqVo-cbjIf|XHlflM$tzu6-4PyhoM6=9*XjVNUgH>PIwf^_6 z;_&~%|NLeCDs(@>AJI~lqpfl^M0OA|W-sLh!>*~4CjXu|@tlPqRlyZU#JnXw!EqGy zoao?%PnfJTQZ#V0qTpLg9C2w->GXg8cZ0pGg|_KP|IR-DaQ;{5K7_ycwM>c4WKV;_ zu`R^cNr!H_X^Z@)E!51h1ea`Y0Nb8Pcx5)INT2~9Y^eTCGMxun5Vh7@Tor97EM z#u74QC`F{wK;84(lqN({Pel`Hl9Gm}dMfpvp66Zvb%*s||Mjl*zt&mDweH8h$JyuH z?YH;-?)ZE>R5Am#H0v>PRyD!B)5yYTZTL}6Xz_~(NFBe2JZrnp+K&iAWh)~n8e~so zH{HVClqvXKK?|iXRoVU;xO3tfUhZVjz7Qq3K&! zV_i@V`d)8`EBe;xaWowLpXk)w41Gmc@MW?6V?Nw=zk-MNmoaWH-$HwnK6p}Da&FQ- zkjGHO(=NbYWD5FgPlGKN+f4}bM`^|idv?A;nPXk3)_jvl?dZo~jc^!kvx|I}YiC0o zYO&Wfmxh^KfZV1;u>KSQ+IlBY>*5SBK5>^BsWcS7)>lLQpVzSAuOjRp76q#7R}sHI zt~2aVZ!j4Q9`na9t_Qf zHYXLF6p=z+22HB*c>ELMgIys!`6RRVW+CxQ;C#mRUv!;?G&&p_hHu@&@N=I6@A~Pz zIQ?55r2oEysvGTqu1EbO_;EANyB!6^bTv)-tDS@z{KPuDXHa{-hU|8W0naPf*=H8g zJRbv9=zTZ@{C{qQ&xiKn-A7eqL-Y!`rZ2^dE`Gq8dniG4VgxL7JqB-29l)K{C+LK% zEqKB@fEp#o!7hbd@cHWo-B2jSYcR@!&ckEic3>-YP)~%qvmv<6V9MY3*SE&D+5dO^ zEq4FnZvPSz6-j&W>wI7Q_TC(C&b+`T4pWBb!#mB#JxPJAx_FvIBVe-ad>Hml0#?_W zU}07u^_f&eN8dP!W6kX8ErSc}+2ub;Mh{naGdBrWzqf|vO=d7}`cT}rTfb)6hzK&% zyn$)keUQmm;f$^4VsX5YCG`nGT9EjS9E=FZv?r%&!Ed$@zbOj0tyRXS(ifax{NRs9atV-6kD#twI?P&Y2Bv*mS@}NhdUntzPl{Bz z+QlL0Wn2kO2U97d6N94+S++|3F>AYhJT6R9XAMFJ(-E)ZpsU=F?a!8gl}3EP7DHSZ z=R)Mw!ohFHI;=PM0~75<(7^Yk53`Jj>sk*e2$!be4&LXDD?*9l1sB?oTgI3D+Ct`c z$-?2QI;in{A8qkD2`5|g$>d8mNJf-fxLqh`)|an9$B>WoM8+E?z$O7#7|en4s4BC3 z-Ngjf6oTQ#=<^Y^t*pJ#04TpMg>x7EU^}DuaNMq#Mr)MAK+Hj#FoJ4U!=dbIA=O^vN4y6f#|`lV zNW`Jl#3k+ynWMNE+KQAQ@LMF7lmlvkJr`XjU1L6Z6kwu#9pp^PN5ePkp{R2d z+#MQ;Y6l{zjP`euu9rfOC>hdg^FL7M@v`*cood?QyPxbCb^&PZD45Xc0weskW1y`I z$yBbVt(FOprT2~gJ~RX8Xu9Dtg_ATScQNy{WD37@lvGXS<6LUdCQsyUsX$$kEZ7~8 zMnBu@=-Ocl&(^$V)fFB>Y|~}PS51Ho`&D7(-Uf27FN+#y%Hqg3-K?vmH(J*WBT=1; z&@0j&<*(U+E@VUAEO*jdJPJ)tCbKi=L<3EI2bFHK!BZ&;?yYRb-m3#Z$F`eP=@EEi zs8|!PX~UFExDJc`dq}G9U0OY21C+>eYY9s(l1BqykqC=DP}$DIT@yXQ)vAr|N;05Z ze0HJNf&hFv))TwutCPoR)zsX5KU%E_pz1-hFz|jRi&x{(lbiv0&*wnpmczjNOQ5^E zm3XMcU`m@L@7rJF@Q!B$*44hJ3B5j0(&R;2b4}3uQVkx<-+(zLNzjzCpRShJj)SE1 zsrtny?2$4Zm{E{N9K0suQ(aAFsQh7m$LdP+t4Aa5c?}usaG9xE8V{1Aj8LKF3iaaGf|N-P)M{DbuxIBn zMtug+JsOC=r@FFUOCOTPh+KT@lFkftjX?V6b-M8TG*FcX^bIM5`^zpfqZY4)Yjvt9 z^`;QA-Q=Nu>hBD`HH8&&ru=As4H)vxhc)-VM;jw`u_q_W!t8h+8QWh-%wFb@RcGbV z=srQMmVDg0TLS|Kr}s8o5v557b0$oczOGwMBN8Ja#Oo9uuH8+Bhg#sAN?i&k^@z=i z1aRAyh>2VIB&{YHCZ?$~k)?Q=wa%wedjuO(Hk&D3amGMc?Fr91AAz_*pTXl)S#u{W=r zn;}<;8bhqWCT$V;$$Nn8jtsJ8F%Rap&PE>$fMaK8(PZiZH}cZhYnQL^4cv8rSND-7 z1hv!ON(;!(rczp0GLjr|%%r1Z%$VcPF0lT2`Q+@(PWtur1$J`h7&y%?!JOxb5WLx* z*zSLeOVhK6%;cr8v-mdhJ)e>A?;3QGlNN09dqF3>n99EVIgOru7K=A>G|9wiDx~45 zGqKq8j*g%Ggs7M^kRe4O?z|Fazw}3?ajD-ZR()t8h z!T-!aK?Xd$6U+L~oQdYgpU~Z|skH93Htnn3iPOtofNlO2(B2GKv0)b^CQ8?QvQ@&E zB4adm@IZa93ivc=HP*f`K%bQz&~xcCdA)EWhG`DOv>&0w&_EBo%HwEr?{n&0*++jQ z^^$Kk9n6}rW^g#Ql3aNzgTGC1fXF55!CkQtC8ZR27mw+I{K8hyu%3iYTk1&t7$3H- z;4&Kdroe+w3Akt+7gO&Fp_rXb?dK_?Q?NRW&(5V8<=MD-dmdw|-v~v2EP!T^!w(}Z z@Js0obpB)qn~c>|!`<2^C#PJ~FChcnwsGuctPmKau}4SC*UVey7) zXny1id2=VDl7}zM3AM#M+bE2Z2eM$#U?_ROX*<=<#U)!WLhjP<7O2v zT~EXCuKXGLWy(ip>in^&)ixV4VoTB3X#jo+P9ljDRuW^&MpUXvqidC1`2`zmq3GNp zYOXPcc0NhR64eQi8>B*V{SMM`Lk!?R)iy}K6$H*}quDv$)|hE#Xi;C~4stKQW6(R6 zzMnFc%017ZM+39))<7Mw{T-omSpsoW8%|ec?q+Q2kHGeMwbbj&Kq$HI0*)beAemTA zhj3brYKppMTC*eI$J1fxF4F-9pMptJV+dZ-i(^)8J7>NpN*U(F@zGqm=GNJx z*g1Rq=pVkN@T}hhKc^Vt3%xj+t2T;kPa*X8UQ6tCh{uIX`@YkP%SF_b^M>*RE8>4gSNSc)bMd%oz$_Cf2ak3R7U(q^DFRSQ#5n^?<>QV{~}XPK@PrKJp98 z$i^Gdz_6duen}r?6H~y)v6$@OBJ=ZB`H)&1O=Q$pk@fvQv421Zv)n0|>55S#!N*ET zkK;(7O&7_www<)hDG6)W+##iBDJ*@L!QCg1qR0Hfn9txw_Ty z+%TOvI6uaN46sR~9=pb1@Qp6&aw`HK%>PN`MsESTtl6Oc=TRE;oZW4-isx<2)NZwB|8c@S$6L3T{N zNa`0hu<5h6Qjam~!2H{L=7jx6GS{kyxbC+{laZ^aVZcl%&`rf!MiZW8IAPjmJ2Ijx zm{EE*9M`Q2Dl>(?qSjS9ckK7MWgg?+AkNkXFi43AOsHK?g*Ent%e zwXi$mhT@7sE;hOtgDO5d=%Mp_`8meYFx|4nQ_OI=Kw=bB#*uss<;s#^=Kj=T0%Kz%%y?G#Dp(JPh*TxNJ`nCI?nQ(&A?3_Bg z$V0Frd@?s7PLvkZR#p~b_j12XkT#kk$_nc03bQKeBCN3}E2xP%SjeukhA1m$SWu8s z7;YIU3X545RFe~iQ?*246VbOVDCH&$tLTctrlPQ*j+!vsA_h+tg#|^-gyGaNqL!zL z!h#A=!m#sLQFyv2EGXY10_%ywGeluQ?G$0Sg_{f`C-YM8?EeyAa!D4S)u+do< zPUQv}30oHPPQflkVc40Q11JoOd8c5Dkucn1CJKvrr(k`sFzgJXu$XrW7Uv4XEw-Yt zn0E?R#tOr!b3|b=?-VR+6Na7Vio#;vDR@N+duNNi2rPD{1nYT(4W~MZvSQvTn2Z;O zot;HtG4B-2gbTwhuA;D*cM7JVh2hkNqOjO83+8}@VP`i{Sj;;G6R^Ut$|6x%%sU0M zXTorc7%b+Uf~hTG*m&^7<;IZpQ1cWEL2Jm*I1aG~5lnP1QJxvyWZ|L3*;24WX! AQ2+n{ literal 0 HcmV?d00001 diff --git a/test/data/deepspeed/model/config b/test/data/deepspeed/model/config new file mode 100644 index 000000000..b59515bfb --- /dev/null +++ b/test/data/deepspeed/model/config @@ -0,0 +1,99 @@ +!ModelConfig +config_data: !DataConfig + data_statistics: !DataStatistics + average_len_target_per_bucket: + - 5.996365000000221 + buckets: + - !!python/tuple + - 10 + - 10 + length_ratio_mean: 1.0 + length_ratio_stats_per_bucket: + - !!python/tuple + - 1.0 + - 0.0 + length_ratio_std: 0.0 + max_observed_len_source: 10 + max_observed_len_target: 10 + num_discarded: 0 + num_sents: 1000000 + num_sents_per_bucket: + - 1000000 + num_tokens_source: 5996365 + num_tokens_target: 5996365 + num_unks_source: 0 + num_unks_target: 0 + size_vocab_source: 16 + size_vocab_target: 16 + max_seq_len_source: 10 + max_seq_len_target: 10 + num_source_factors: 2 + num_target_factors: 2 +config_decoder: !TransformerConfig + act_type: relu + attention_heads: 4 + decoder_type: ssru_transformer + depth_key_value: 16 + dropout_act: 0.1 + dropout_attention: 0.1 + dropout_prepost: 0.1 + feed_forward_num_hidden: 16 + max_seq_len_source: 10 + max_seq_len_target: 10 + model_size: 16 + num_layers: 1 + positional_embedding_type: fixed + postprocess_sequence: dr + preprocess_sequence: n + use_glu: false + use_lhuc: false +config_embed_source: !EmbeddingConfig + allow_sparse_grad: false + dropout: 0.0 + factor_configs: + - !FactorConfig + combine: sum + num_embed: 16 + share_embedding: false + vocab_size: 8 + num_embed: 16 + num_factors: 2 + vocab_size: 16 +config_embed_target: !EmbeddingConfig + allow_sparse_grad: false + dropout: 0.0 + factor_configs: + - !FactorConfig + combine: sum + num_embed: 16 + share_embedding: false + vocab_size: 8 + num_embed: 16 + num_factors: 2 + vocab_size: 16 +config_encoder: !TransformerConfig + act_type: relu + attention_heads: 4 + decoder_type: ssru_transformer + depth_key_value: 16 + dropout_act: 0.1 + dropout_attention: 0.1 + dropout_prepost: 0.1 + feed_forward_num_hidden: 16 + max_seq_len_source: 10 + max_seq_len_target: 10 + model_size: 16 + num_layers: 1 + positional_embedding_type: fixed + postprocess_sequence: dr + preprocess_sequence: n + use_glu: false + use_lhuc: false +config_length_task: null +dtype: float32 +lhuc: false +neural_vocab_selection: null +neural_vocab_selection_block_loss: false +vocab_source_size: 16 +vocab_target_size: 16 +weight_tying_type: src_trg_softmax diff --git a/test/data/deepspeed/model/params.00000/global_step4000/mp_rank_00_model_states.pt b/test/data/deepspeed/model/params.00000/global_step4000/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..47e44600544e4e3e2e2fd56d6edfc909c4d7a8af GIT binary patch literal 33699 zcmbq*2Uyg~_ckJ;peUlKC?ePZ5frRI<^&rS1O*!wKzc757O|lqRhk7vrP#ZuSV760 zMNzC+u=n1u_pVY#KL1x~;9^PJF{vp0@;h_pB%+awEhNgX32=eh!AxX{ww*IS}m1yQ#R z3U!x`s5+=Fsq0)zXwy|vPqfuD!pq&Nt3=KIP-Ce?-O*0TPSZ{^K~i7%Q!7EzV78y! zR1K-5p_8PMv+1Ah^>LSrxru&+n|1wRwWgn7l%J|!UB6%C8jEsTj`m8m?CQz&npEjE ztE4X+BGyc7J0p3x5x@sgQK zj1@0gC%a`*iHY1Xb$J9kR|%O`37JVHU7RFcovnWp3UZ&~&B?D&80YO4?j7Xg<}MZb zki3(ciRauxLb>1{7MTCA;MePCA-dU3xPHoZn(>nEcD3RqJ*t#@N+rFVB)y$2|F6m( z{_f!t%bInqL|tn~dreW7k?WEwU64vxCy9-7w?FIt!&@I8x6lab4^R8ltk_pn?B{5& zFDT06+`lT$1FBSQrILY8l0nWE|BI@B$PdxMHH!`rMePL9+IC&#qC=}h?W;r`q>^Dy zlHtx>{wylLxIR8LTu{fF6`e%I5rU$s{DO{@2X<7I+Gwd{jFZG!@gDi!XMgERzp9BB zacs?6<3z3Tj`m$etqF21mntn+sl?4mGSRvHpFIW9S- zKO{9vaH2%Gqy2C}LhgsO%8!UDnMqPfq?2UwzxrWeCi*d@W|66)$TUGj7>eU1)8&3d zRSC?HN@hAqqHDPJKK>!@L2hAOsHb;$cn$q5re>{KqE@V<{X|hKPOcSSrIjF+Bsxiw zoK1ea_J7n8y_#LKQnIL&;%FZxDy7Pm(yElwrII;LlDP_D^YZ@7C_L|nfH(?UNaib! z!Wr^WIP>QyyrAm(WK{_*lu8ykNftZz{^`;Wp`ZQCJwli^y?r@1@rapEsC;&F^L3Yc z|L7XoKjeQMf^$Uwa)p^j@K4?~@~WhlR7o$DN|rfEmOJO-~Agm_u zF018%T~j5xRw`NNBw4R82LGGCLH;4$?%WSoH`J`SQB=%#w4W;|%3a-5+Y?tfGk zQcEbR_V&5k`gD$UT3R6UtGDz-XWfUo!s};EL5QPuO zh0Cgh4@xD6oFs?;)}j9@Ec$z-X4#{n>@i3C5>d8XE_=L6_JmY&(n)ekVV3(pLRv*2k_@QkDVNkLHFt1_T#NCr;&p?Rxj&D)~p9YM2!oxWW2Zk6V}D$V;+$pa_J zLxp7Gx7hwMnW(H;@sX(bSWs*zzhj=rujJDzwP#Yvb0^6Qg$d<9sENJiWzAZzM6K73 z_BTYWH*&4FRa)<)lJ`!M4+@j;Z(6@56CZ1q_#{evcC>#kNXP^Dr7D15t7N`OIVE95 z<;2wz${OaBpF_PuO2H`$ZEIBHY71xNf}Dy_5i}bKA>`@^b-^=ERro8aa&?88l&dF{ zPMn%h{;QWiQVmXBXkD{(ec_BK-9V@axxL`5;4Y^j{1QFp8Va>4`9?xb%4rIv6W3TM z|JCoR6ok_f+SV-IL^vagHx(*^cw@QW+Cp9KcQfH{m2h*RCgoZPr4y$klr{9EKc)*+ zD`~E!0IS)CR>B$423?^dEHXqJS_^f#4SK@gDjV7eH7Tbrlun$1P%5M>KYG|7ZIJh6 zuB`y7*@AY$8PS6FLPf~NMGHCzb-4u{g}+r67z#BhXC#zPoUu?k>--!sge6T_g!Fql z@!jK2LX(>HOoTI{UT2~5W0B063O}l@xS3F^lI$YXq+C~_bmGi~(z$uflA?naLW7zG zy9sAR!R|suSOSO+_7LiF2YU*CtJHc4H7VCyD4jS

%HdQ|cc4&{a)X4@i5o1G5mDR_;g7SLAnz9JEp?Y~KXG=#A^Elp zH&mz?ISc2MID4TQ#W~0u%Xfe#2}^uoLC*~n4%pQScdjEG6LzJ%gSg>hLr0+u_;Iq? zkCWA#D&KdZI48072*uX=e{5Y1=SGUSQHnT)CEY*wx_;SW;YN$tF^brZf7bj3Yu1&M z=jc(~SVeSA{#9-6aO1?*;}u)izwfwTY~)-;teYaXCjV-} z3jR$L(e8?9jq2!BEzCI&v5}`@Bkg}SGV|g>!-SiN^Aa0-D>iQRTjO80r#K%G;;RU0 z`WvKb-sAklw*HE36)wm>b`rP%v2~zgYvVsK*SLkn1&IW~iUi&N7XpQ)EEghjg(`BH z{ZCx-5OZN7i$sy7!=G7x-J{|-5ge`vR@j97NA%2o1)F(=O6B=KClyH|6iIsgR||fS z{Om51M6yUlvd;gRtST?(CW{_j}Mn10umJgfAOy>$K{el(Cq&X5}qD#$s#C45v2F$ zpa@E$L#ZMvO%e5P_E~W0B4~~x=r2-Y?-4>VS47QIL{+CHzs2r5s>;gwA~Zu0`uiaJ zOVY??ijW105QTZ-`<40yqS!;SMD#*MwEl0w{1q*RbCHN!tca_gU%qFzT(;OKN3l`0 zH$PI~U#E;*5t63}G5BfE56F)#8evOOc-ZC^>>eJ-EfIlB6@k@r%@5%Bxn`NzcDZ8P z>bd6Ew!&1pLTtTKv32#p{Lj{AKWwYYthrSpc(o$cRQn$g1|@xD6u5Mn#V5k@ep=IPahcE?;EWq{vV`iT}nRADw=U z54Tw)*rG^KJ&FGZ0>vAlK;+u0$W=Xw{}(Rtw%8`JY*%FYOHjmH{~tc>5E*tVGE`4D zzuO_7ZhmE`8c24DG=+*Z)icNMH2>TMFvK@f21{BiO8{6k>f92tws*B3E@< z^LJl_XIk=yk=%Zf;(#K>Uv8aW%@HUhT$xC4P?6v-9rZVYz)9R8k>Ri+!(U7gyZSFQ z+!2xBs3O7N?l^f*|AnBcqaPDl$`x6v=k?!$S(V~&$3@TyMbN+89YPu*P9EZu0PduS zKc$GT&a3_)T4ip9h&-)`tj?=`Lkh{n84+++5%3rPzUNivM9}&F4w6p_7evrSMNoCJ zQ-f($4qXyamlaX}W}gLjMFd?{1pP(odxpea6H(U{QPp|XZw7uxRatpMgx*wy{+@~a zlIC!?M96JLNOfNIL+uxcVh_0^qVFoAtMjUw(c(CFPsH6<#8ofJf8gXT<)gp@vGGI2 z##%p7T$j zCLiGbu@J5vw|?-~7D%cr7WU3nMD{u&yQ-Z1|Lct$uC5|?J&{{Y3Bpb>{OQ7{%S<=8AkRL_Qsnuce%?ItluXuWBy%5y4i~G;&vUMWWUs zk)E8W?a!P0Peh_MZG^TzyNn!q_ihO2rzUn0HM~ME&C&sFa zB3oCH&0J)&khA^&yzIestEQI6ue(UzLnQAhC-3<)em~}ns->?mYW-uip!g<UgfVNNgz*Tgi#5Z{U9;{^gYo&bpdaZZISA5|J0=yw-m<`1ggM!bJPy%CXh_@tlpw z-bZBbD`&4B(tfl0@9bhk`YCev7r6(B+_rM=|G$?jxPgk~gGBPdBKZ(G`CmGT+;E{E z*8Jv%Xt$jr^-z)8UZi%AQ&$g}zs2NN+FwWFVb!Ga3pre*brflxdQ@a7FoxNtmEXY)ua1w5%}ZiKK>`CFuG3=8C^t1 zS2<(#=>8uV#cSlI$Tv~sa~JtMK@zhWF5zqMxr4tt*l>chY&(S@wnnrF-kVq6P5{1Z#+Ww0*-$(aQq3us=!h|33 zoJ1&{I8G>Qi1i;wE|u~Sb^gbh2v3Nl2BHCMBw2b?(G)DakGUZ1KHso_7yMIe-__- zjCALMBf_|3;e?%%{N__~lX5A-U;E{aoaNt72wSAWnqI)B3I|=dG@<;-;4T7$bLoml zo#lJX!q*WIp2BO(;+F*69P#|)WAcCHV%Wc*4~YmC7VA!KAW-^3Uo^EdPuPmhg zA#QH!;oeepE>k$_!YvTWcrHsQow$WUDcT<%=pPm)D1=MB!@{{m!eK#KzDV*73iS|P zmaRIzSU7IKtdVm=g%j?w=T-^7qqx;_NWiZbUjAl+hwx@YIJZVPJCOZ;!3VO>-W|W# zEYT!^TPx?PEi{uyb)CEs449IBt81#27~HBcz)iTf-KTKt<)?+F;=hOp=f*-s2wte$ z_qC|#4YxsP6~%3o^UCGJrNYNDT)yyA{;EVH@ra-$MkkE(50P@4gmd-}YDYM$3nzsr z|9HJ(^AC!yWeic=L*+>Mf8H8e2lN>5)p$!p~W@EbyU(NGW6d>6XLRD~RmSx8#!)F4}WI4}z)cB1pwA7X}E%_4)Yq>y%F zDwruJ+L09(tw7lj;fRiN(bGkpSv~v?jj~w)1NxScu4g|Hn?s=p zx=P^IVj*Mm#vSuYl9|uEKAgEUjVw(pM}OVfIQYe0Y*_S?8eYo7gL9SH8MkyW?=l?WNPpe5Cjxt^MYZ|5Ggp1<-D@mr&W zj$f7NOzBD3vwt&Tj;piT7uTUhmwIT|FpF2py^J&c4#S$&Te0=>L15p0BPlJ~jDFiY z5|6}&Z0M*t&@Ai#8l()z*>ej?Z)JaGZsV3{VBUk=VLD*Su){E{YctqBGX`hfDP(5! z+>V~tpJ3m?lulk z9>!x5YeZY^xJ`BhBtTC7HKyr@BP88Fl8k#FL<-F>;|xNPbb|~%jqo*8(5)~hl?wQGa>yn(A0huOnlY=uXstZPSYl^ zF)xpFc*o-Fw|TH8Di-Fi%Y;_W-K}@e*#>jHjAh%58{zAUHW+)MF@AP!B)b*ni>(IE z!(o1@aA;(6;#GSyw6RKoMow|KtxFoUI`a|-$D6`fjjfFLrrEf1orCOUb_Q*)DF$xZMwr8T^|Kt_<>7! zTeSdtn;*szQLCY%$2W|9@`)U9?tq#Ni%F%qKJHra4&Q1&0JH5oaOIIms-KhshN};e z2HihUk4Fu_`)E8attf!3nHl)t!d*J-!VX?1+JjNn;CQ2oPIUUsyHqnItJua?mD%=g zkab<(v%K!)ru@pw^Y~jQb+K+?XKKCQ5KqwVxQC6TjgpOsTWLPI=3zjVDC?2DPmSo9 zF?*QBbuvijJq4srW=qnd^Eh&0c>r1THkx)CzMtu}W;+e)@s>7zJ%~)aXHKQ?k}s3m z&<+8SB%1wH95VYn??3uwvAX*vtFv=jQLPr4e7cP}O_jxy!QBH%)|OTzB7Y&@!cL1` ziAyEdo?B2ovnllRwI<|2kD+9~h7&14U&d+9-s0lSbh>hS2*0$EIvMT$if^!@oZnQ? zoLsn8#K-S6v^H6kOc$K#_A%Oh_)+F{+IHM|+@GF1|JeHY>WgG(wiCX%R0mIIdO-`@v2^md zEzG>&uY4^hcf2#E0S#Seg%>l@@Q7zR-Mi#G9d6hJ-Q3sEGpovJjEzjVwrPBmeJvU1 z7yU@HAW!~q_-)4gt_%6%*oMh9X~G*c8bMnIwV_0Ql{JmF)$LAi!vSXFy;iit#u&10 z^+bL_okWIN;Yeb8$1)W$mbBlSc6{wXmhU;Q7EQWu!lcz7NBc*IGWM1IC@pm-E!L^> z75AfP#8YS5K{A6r^eVAll`+jae6}Xd7*?NV_XIkywFx?;4&&EfJk3lx|A;@dyLW8Hb%q^#@4Y3Xl=cY9?g41OEdP+6Qhx7bAN$##epV_&--o6!HzRY z)_HHb!C(+$J3W(WQRfv?I(?tD%c)@Uu4o3CI3R|eh_skEizlrf z_a#FXlrb$6VrYoc1ID3UD)G@?$@lYDrEmR|i)yFN;s?DAW5*U{vcWgTvZK;wu}`^V zcGQVjcH8(6cII_2_Eb_N8(=txjd4$44?YNEeLqfTnP$^jz3@4#UfgWfNk@}bYL=2l zEs^Q@v@zfAtS-6cQx^vvv*c5>oX}+EOX3+6L!pyBZ4-Q-$vCV={g!C+pYOyF$BI_q zH?0LbpFIRxB}<8yy&=uH&<~{#?$Dk?^Kr_m3g}li5lcpNgl@W4eCJ_HVfEyN(DLD8 zSbr1oRL`SSXTU4?uwpJ+rzu&t88Q?(%V09R;RTrHeTt^0@1qA>4a1KWZQzNadDjStLWr$~>4^XDwYD zxEY4*w}Lj8&!EN8>(sY07w+BK33bA(WEXpHAZB4|?9?^qU|)G0-0lbXk)DUUEshW% zi}=o=kAYk;!Rb*+e1b^uZ zFy2(3&dzNmYjhOQZig>0b85>BY_*uKvzOwkb-hqm=P=IdQ$#SM2N@No2Ve6EU41_W z=60A*?q0mi4{aHU4f}S174?Ray3*75bWT_NTwfPe_U1uggZ`L1c{oH58B2Q@meRW33HM5FgDec7q0mDH!WE?#=Q z6wZtdb8RAKt!qnVYc zy6HAHJ${PRp8f@cn@2;N{z>Gs+9f{Hz676+%%xEldm*u$f>@kMdas!KY#^gxC7E~HnOjPb^u(REclc%yS zxW(oJxlxR?-V|XxT@Z`5PK!wDrST|y_)B}{WRi!Zl+gR&ThZB@rUYkAAWf9z5 z(-zJIR8V{ER4CmQiBn9v;^wsZ_|eG|KfScVQ@!`{sm&U~y%smE+TWT3i%MtHG#UeD zwocIOg%SjnSm5$S0o0@O3fktxR%~ZEl%6)(%kOgkMzS6)!zqK(Fz97Ja4)xnAw9Kl zL5h)VV#^rP+v*5&F=Zu_Bb|+wmE;iI766vzkPh@3GG|s!VgDx9v3$vxMI3=_p z-cbpMD5DUV%a6yX+(ME$Y&-ZyM35%kI#5|?0BmMPLg0fO?9k2%R=D`WCD)7m`4dLq zb^dyB{HhhqoJH5jHB%{OTgO1cttjXdoz6EJb_Q$`D*5GK`hw~DH(1`TF)}JIAm-Cs zvZL}g^?%?66F$b0{PAzRufD>wu%uw0V=MjcCiqiEumL zl>d^g1b$tL0Z*FZs66*kZSu>ow2 zY2OyG@!)c*7ve?S>TE%7b}bUKDu_5Ho+ASrIK#yi8W8Yl7Oq>DgiODw@Z@MVSrGjU z`VDOeLq{3m;t3^qyD|yVTcq(Lua^_jjVC(}6cA^Fx%6fe4Q#*IlD{9lnYkI8f=_%8 zVBk;-QoLW2ENb2cPITLVACK#jFK0&M&OJW3&iXwm17-GIm|hcyCo{egtnHJ$H0Fb?S4 z7qD&g0C3TWg$sVi@#(=@pl|z-6m}%=#I%8IwoZR$o#8QX_I^a>dY-36Ejd_McQ$H{ z-AAs~sU$Jh^I_X17L$x!U}WEmG%3x5PVz`X*Ac#Wr=1HvbkZi*=M_<_@r$vqk1y2@ z?1KT97xH*;HoB0Vpzd-HsG2*}xjYsgKCB=UG~;p3MP>HWhxT~RB?e1NUee^J9=KNP^vswwgZxm9gu*Wd1Xg3bg>koC;AT`_`MzR(& z(6nq_ab4F>*xU->dQuXu%yxn)%az%cJ#(q%h`Kn%YZk_tG$naalOUoKLr*R@fLEIq zz@%=UF*b84sJ-)`Hzv)33rg9jR62t8OVEct1el3aCW6C#XXrLL3$G4SMN0`w^OqhW zJ9Sq<)&@0DdN3Cxn+-@az1>8;?HVXOTtW&*9;aP;rlaM{^>q5>miTO2F7~sIhlEpQ zyL+i^!{&|qLjQSVz)O7s)SaG1CONY(sA2$mpf7CN)rY22M>vr;k3V#vtk`6{KCq>) zX>#BtI(Ypu@~wP5EnU}&taVsO+r(QkTkh`Vz4MlmUA=G8kB|5AW;ffy4t^}=+)IVW zqrx#trw-2hI+d83bOhJ@M0)(`Jvw~bbNV&yDf!g@1``r>9PGFIk+M6gIM~o0W{s7= z_y$K&xpsYa-=>bBKB@wm_v?(qCY6%IZ9VwXWe2c}=N!25HW|mX7jmMcNQ&9nbm))< zI4oEjI%Mb4to&?@n3l(MGdc<@&W?cNpoXtnS>yZly>R$jJBTozM{c)a@O|1-<{?`P zE=_nzOdqDgto~PfUtGVCZ*5?MQ8GQqb883AToD|<<_UR&%ur*3Ck&iwi+Q%O7^@Cs z#Qa9E?ut5YE*XSlC+~yYai7Vm*IRM;#{QW3_8NI=KO0AlQ-gNBH$vf{NSbqM2|4)9 zj`Ur#ggDq-fz&3=p|x88D66v6V8Lg)`M?TlzOfO!XgQWBDdo_IDi4S(uPHnHY<);I zY($g!SG1x12>Ph-A!+mc9o4u!4K|NA29;S4AhuyBy&69i-UW9DFSD7LwyZWTEK&z! zFD!f~@fxpfz; zf3l3a7MDP6iySD??TgLs?!kENUc_KS0M3{@l6M_ zsG5y~nXAuWvCCZ&{kb_EGfWqxUiYct{qFpOZ$0R)y9szAM~4{q&?HCh3@6so7qr8G zTSU`}fvnmT68C6e_5*L!NI#6bC%J=v!{_{t)L2+l)D}*^)rCbD9N>9TBRFvzu-B9{ zO82hEtt)!s$RlZ}F+7DlKW<6Shjb(_AFZYf`$uEm_1i=_Y%Zu;E3w=96%e1|cML4c zf~)5ec<ngzfTj%$EOiIOz}Gv?aJ=dbFnJqH(vC*r{*H-^ z|CHTU^J5#rz{Fy-(lMuf7pH^1B$gkT@tmIZTn~4@Ou%<@O!0olM4H>QHJLVt&>2r{ zaLUD*bmOozED1SB*6*S)?nRc+Pd4BLr$*#LP8vBd?jW-oK|c;`ChzB4y#S_Sxe3*Y zw;(klV)5sk)DK4?u%Gf`(r273@L-3}j^l_#U^?Ydw7ml-#U>!rIci%@2k2=C< z4wy_Qv<(BRPcNB*p|8oHzK_Vr#Y54el|ME0?G4Ke=3)t>19!8A;k*bt((-mNqj9$d zhTJbEQ95Ix-yR8VG-s;Sjerz4Y3gYr z+g!1XPp@~9pP8tIKFfvNWM4dLdQ7LQ_GA?2n5#ffn@iTpB@Qs}*g`t2rymSCWPqEe z*Frs;7BK$BTxR_E*VLwuN0g8o=)| zh{m+~$zTVkv99o_{Gu^(o;^g$?iKe*$RTByGw`tG5OVJL zD7Xd*P-(G(q|CU6tMXg$wQ|pcmO(4XKe_|oUlYcYyJZl}4?>@^LwpnO5NO*h2<~XM zV0SEUhW>kRLDIECTsG?jeb8YZIVtqXc^vNi{<EwBwzk62jabrd_))`v`E zdrWDl2iUWcPTU=b1w~u&rk65aDX>5nyW@PFscP)r_By!JW)C=DYbdk0qlB(sG9jOR z1eYvjpncI0jSBT~n(GuCnO|FW`1*7byKyeoj=o0%w;rPdpU%c8Z7*Cgw2XY{`GNPW zxKG|4x_}OjhiRiP9&r1V3hq8s$t>6z4t9|m{Ok=U=^Ukv;J5VviTA(6#7!CmvX1t} zmUnl7YMX^1Ykr<@-C`yE+&3RCUEe{&3hPjvz*O9)r7TOlTA%u4Q@%lZIdK8UK$K6txREFSQ1 zh5DV(;xfYxaNf!v%NIM~^CQm)Q?LRn<{qd1&s>pPwgEa64Mn>T?$F}UY)HN538yW> z$@nv|IPOXe4A8$qv~_(*$BIwzs&*qR>{0|)<&7}8{}N_&ftoDoq(@wQ$yA)v=A@% zZ^z~a93)qtbfNZzN1$zFKA5WUG;CrroZdz0hh!zFxOIdMP3u7q_RJ@z#wL?~ZT6C^ z(XB8pG651w2a6d_<%0`xNux7hJ|Y;# zhmC?hFS^4FQDvDHi2fpf$SbBeTkja*dyIkb zXKk@>r3qOO^_uMb7)jP+a zpH&}|t0$Zpxt(cv zV?K18at{)g+ELEuCE2q19i1}6fbLBl1)nrZaKez@SkXBVo2TdE>R|^leAHAlzGaVF z?tVh2?St^_(T=#2amG9MTHu6k>(G^$;lA75QL@ho55Jg#7HPsIC;k3mr9yyD30Tyn(fI8}dpi>O@PK%WOhGoEEfu;ViWR+<|FeecJiqi-rq zX`Kn-p*^72yC!fYP>DT_4S+eU4DTXiVaS(5SgXQ-kKA($Gp=_9y*X$3mJTss+@l!x zCAERJUiDDvb9d~0{1|>~*bNhB%!P`02Y4Tv0B7blVKmP~;N96Oxa?ktRhO%IaP#py zbTd(g?T@r6b9*eg@jjn?vHV1Wn}tHp)DqsvD3P{khL|*EESe09C98Zhp|`myZ0ytk z2Ao=m_Xp%MtFP_C3nLHEPgnGz$E($JNO&N8T(Jr5g#AZyc{FYPek1rcpAPwZ zGMK{sb!A%*&V@0_`QSb}ntF8lN`2ps!Cw1M(7xvCR2s3Ha{JqYd%;Ms^3lh(>wBWR zs!Q>JW=3GrF%%c8q*BebK2X=m5!W7AjkInsS@FIv(_m)-b(mlXeXMto$|Ax(`S6;2 zJU0tx4H-x#zdt}#UkUg7<+i+GA;qv>&2dKZMEFpo#6G{0gmW@=WFPKM#msq1@qS0)UULrdYZejdnkQGmi!)4P?LDJzwyiq2sG0r_9XU8r_8oB~=2D`%XvZj1^z-nsKdMc`wT>w9q zPQpET5L`bzprs>EgIQlY+NZ$+Jp5S;Th>-0r)`1R>7dLmbPvSGYcG=0snMYOtQl|V z*p}`~9gJbSyTXj9-uOtr6OzhUymr$LhB#$F{UIk|`AJJS*!MGi@q7+u?C6L??LDDX zYbjj37fTb_q`*uK0u+_b?`A+hXYz9d{DH zcQIyK9sqq~U$o!2i%uvVL;T8)QIq~RiR$ufXm}@x_}rTUDt4Llnw3r#T@KAx9;1mF_hH>USG>5z5aVYyVo$BvkDH!cqpNER`}f_Kz>7)s@yNba zxQjc-Z@oW@zA9+I1{LnYrL!l(rT1saf)N8~X=EK57ng~XUiguBl5bElK#6^H^cETX zL6x|~jl$488}hM%8t6AxWs)6AsMU);WVqc(az1z+37A+9435pRs{5&Kaha<`3MuTA~c)Ls@%rDytj(o=$y;8AG(yCue*{B%-P7EJXXd&>oJ#I z!|Y_+zdOM0$Sz=aeO=8Sh{|R4N3UVWvpdc4G-+&GcJ3^wE zX@p3|!0V1B*4n5`J2Gd9hgb-0 z+=eW_XGxQe#=^xD(;)rD708p$#L)wU=cFHsXo8U^JLsOdjMQ69>nS}zsd^kSuh@ut z`j_*0y76R4n**qIX9BdS*g&K=5t2X*vi)SO@zI9TTqsqM%t}!4uco$B)%&&WlqO{PdOQe{>z7AfyoOxXwY_Q zmfsf3I{MJ9hVj^WXX+ju^Fq)pJ4&8*nMWMQHbd@N4Asin4)Z$iMadvz=;pqkEHfKJ z$7a={ZErJp@ZnBs`Q{8vYGj7C4d#Mhr-!tWRu5diC=TEFZvpkGjiKjdJyuIvzGqc0 z753HZ1&mYMjPdO>|RzEz{8n2(4r|G4VB`7#VSmvZ?@R3 zG#<{i&ZJxVeufLLQ=r|kUeKw|3G`7;qg`(fr44)f-3 zcxe(g@i3*XoC^#~NW%B~>ya}DY@t#oo_bcq(Sm{*5T$blv`SOa5g+YQU3(dFr>kS; ziMrTUD;l4UI*J?nOoOBI+Cz8k+jQcbdf27UPW)E#p8B+RrKyp1z?_+lPcMYQxexDg zcosmjrbl?~7eOd(IfO2Z_Q7d0jL^@r9F+VsaAML%9DG{?C2M17yK@z|qy2D<7_GxN zxpu(v_!Z2}yezWaxB+oAE43;b-9n~5cogP;y+^HGukkr$rcCd;?!?)`5-uHW$@*{V z2C+l*Ar3Nllf&iIxtV|Qv?e#m8R|z9m#NT#+^#t8@gr=!cP(978I7GXedy$^ESx($ z65lkxic z9)@ZbLWzkFF5IWYUd@_EhF+h5JJ)SS({(f8 zVdXM(b%?>swklvc)gJE4He!XOp-knnmMlnb5i~vZ&U!%M8w!sv(Pj~v?0H``w(-Ok zvVps<3aNvV?B$fr*z{8=13ec&pus0}JCjB#BUxJas@m?4%T1xdu&%UD$Q;J(^CDdE zS{D%k#vWTQy{tl&o-;)o~<#TVQjx3U<6*ht2A|1jpa*CEL1l zEczzr!pol8(Bj1h#&!FX;^4|B_;r69=rSr7CLWrF0r%>{z2Z02wCo%`I7OJx?k3^o zz&coFkpYQ&Z_%_m-Pnj$eKBFh7ib?}Nk)t>W2BQCfq7Xswle4-tk>Fu?e9#(0%KJ& zIes@9TaSeoiz%+EJPHApIcPTfBKkV-B=rk(>76xgX;x4hXqI0Gl4%z@uHhjDpDU9g zg|p!BD{Wa1ug+jaXW{{~6_6Kt5}xJuB{wzfp+T*1?6Q0_Q~Ipq;t7##?$ zcI*J3qdT#~xLUH#`)pwA(M8O-Hg&*EIR_el)xnRB&+(GcX|QwBC5s}9`Lgg2w3 z$1R7#y(fKS(XWqFrNBP$ro}!8pBsw0{pModqW;A5(FE}DO9Z-YFf@-!#J&Z`VN-EO znlik#Y_o11=(@&R24@?~8rGf_otBKJE7k6k+_+3=X3~-^iOLX$~r8YZnNgGV$14)ad4A}Ez zJwBN99`2mel9^pP$*fOn$LieNL=HCXFKgDMCA?ErqL#b1!=?GA>>1dPr8CZj|5Mo*n;6DV|%hZ~CK;r6a#M1R7H8URqZiSPm zQvPA6ThA7vW*=wH?#Y6(GCx@s`*8E>jck;8 zF8kobV)pH-#jLq0vhD*5*tP|8S-qSj*2Xb|)qAyut-Cy#wLg^3RvH(u4=l6Tk)FBi zfoJ)w#^aGVZ|-`?eIA3UdCzH!rRJz*zW|$jSwK&uUPRXb4H)c}iD`Eu;DRIuyFIF) zvsLv$|5hRRmpI`@4P!P*HJJ7Db731#D_~1pz1X)_v)Qg42C+-DqS!BsJXwS6Bvz-_ zV0M01JX?4;ne|FZU{4f93wbWj>YII~GOhuTav!vfk76v7)4<8q1#Y=K;&-+lj7wkk zrLzhak+xlO$@5#8B>v4ty4|V^i5jrLs=jd}uu3W?% z1M4hgcly<0%dcADv@wyeLZ`88hy6|%{b>^}%3DwG%v_0e<11<9*%Q|9BVRMyk-nfpiJfN(Qharr)xb+X3nCr9|Rb3?x|B1%NrDLDb`|*oq67A6z$pY_Yu>5)g9XZsFUG!}yT$y?Rbn=yDTPw;y)#(ta z9qYmJpR*uMSWCSszlYm~L_({!qh()=_QLcrt+6XvOglXv13FjdP=hULu~=A@=O{B>19y67G$fk9}HdvFFyul5K`B$&*7ZAn9~n>}%1Igr8rAqqQo)WrCaV4nqbW z*fk97)$L@S9S-8PluvkYv@<=}x}L1;YdSr8`!U9!TMYptwPa~?W04tBjEmVpQ2N?T zHfrBgY}dOdCOxrd*DFWE@g|)y`pg|XP>~_Li@g%UW}NQ*{-6=0&QfMy^m-209qX{_ zbDtCkCe4vut!zhPyWb>%Y%pH*dx7}|8svHKbKu*x!HTAH0DTkS5(r~p?*d%q*h;3c zwHNH<8nKZ&n_%4IWBic`Pso9@_1LzjZqSPO?HJrfxZnGvF!yte;iTS5cyXzo>~Y*I zXxqP-)(-ZC4QGt0pZ7!B^~QCWnrMVy^pbIcx({Y4ClQtC2hjbk6|M8FIV{FixMf6bY<1fY zqYGw%-jH43(x;L(Tz-kni0lB{vj@YR%6RN<$J3oEd-0j~5oXa<2hcurnNL`M3{S7V zMQ-TqK@;gFC|OiYM!L_#yPlz#wzCOzNYRGY55M7S=U&W~H6NJePC2;NZW$E&Y~-~L zS%F8ZN}6+5ofPgG2zlNU(I(sgk_%rFt)qirdb&3b^vFfpPYLRKDuYGCX*km{4K{lh z5_AoO%VqHkvXW zE^nlF44v`t)lDRJ?+Th%eu}@Nd>STx)R4J68V^tRZ>FW2-0{()?&Mi{AXa=F&bsXU zKt}PZ7_&JBRDE=1BXXu=ev>2U7*l}lU%!M!2WsO7j}91GO9STXn?lzYw#ZH%V?Dj) zSE_m69@nVN!9CfX=z)sO&>*E9*gq5A!wPA^CPt5dcZN3fMRp$MYsKK{nE7-}peFk$ za|gcazn5n>Rj2yfAq@6;I#JOb)poGYdk|~ z!TgZtvR!WLjSxPY8+ ze?g5@bY$rls@m+XnlrEfE~FykI>!#UlWvm12@&g>I{@$UX`$aWIu)(sHul?03xo;QUZ4M$_+ z44g1GgXCX2M@?PkW1;_Z&_1CGx|MpndrsK|S%w|Zf971!BfasB(k#-iX&j~}eI>@` zapc8|MI$}cApYz?{ z{p{!YY(H!(No$#Tappsk+C$1KdO)_stVyrEki7FOJ z#YBQxhAHAqYq%vpfyhLy#<+|CFiI-H7O7KY`;&!m%s`x*S3U*fCmdwFf8a5pCPiycIdRZs?W&1 zFIH$bbsI?Edq7$|jPWbG4HZMClW9J4q4iw|6uNHa^mXotULAM^TH;kq{poX9HYJ($ zmx$x$&X@#?%?1G9 zwi}*2RwgsQZ)ZRHnPFJG0`B$9Lhn-xz&mCFHp;o9|M_~>dr~rcTBLz@+Y?2rIX$eq zGluWjwGO%ygqk3i3QQN*KrcC2c4l-YgwDudUprNxU4AAw_P+;8gU7*DT+hB96&ii?gEohrd-1qeY`(J(G*XIG4c2O5|K{WB*TT$kr9B6wJMP%-rB#y$2wCA=i zv^jjkL!VCK@THv)cR!n({plPxY~n$feeWZF_i#YDJPG`CXfQZO8=+g+Vdmsl1i_2e zlZ1F_GHh-K)ZV_#X)LwHYV)V$Vc!q9-C_b?%IiU4c{j*q)suXUT9T+T3DeUUbCk;^hfH;AxE~stp+iqXM(A{do;u zu_=IAt%m%-Mj5;pm>!7Lvsi8_liNiDf{73V+7e{^${a*7a7a{Fw=t44I}B;u_IjwtT|si5Ur$F!EciS7|L#29G$fX z#?dRtaz5gF`}vrE+??$8RmDz=RG6C^30}7}Xs3h$m?|oOz32oktLu&CDP!|B4 z%21urz}}R)VaknWvbyyrUj1YUt56)q1#wV5CW*D^{zT)ia`dR`4$%YIR~VfCg5*Cr zhWny^Aj?-$mdf-^*ElLVxC@$h>7_7@C>M zUJW9s+O!_O&K1F7?eA=)X)1W0OT}BQC)xbpZ*akMKILoj*thHJKs-GR>`lqgAaoq;Qz5WmzHm#8B zC~sieR{~kdxB#|6d8kOp{eun5em7JYI3)J7g0EhgO@Mu&47xn%ji#no3bKY0s zK7%^c8Fdz&ls@2Hwet|uF$+hU>OxsB6)X~RQC^5ev!*d|Y;tl5W^T#GNaX?8zch)3 zKGnqBSql7(sdpHgrU$W;^Pp;*1VpV8;pndep~(6ooVL1879N-3VbB{?dvlnLoS}&! zvh`Rkdda#TULz-3!?9Joo_&s&r-eaL81K>yn`V{6^aBAbKV1b4KbCU4_*fG5wh`5* z`GR4#7VL}_;?NPhuz4y^+RCzUk(Me`9rKc#x-u5m=-({%;kUsoFDLMZ1!%505p7GN zU`gJ2@ZED0yOI=X=cN}IvOfgcO0-!^u?L)6F`h1L*We`9q?)&s4rB2tj^x9f-u!~| z!}+{{OQEVk9(4DQ1oQ6+AeDX=bF?M+>qmXidHP;ZEbhc~bt_l|+hN1LO{kvt0HzuY zp$leT0J|VD9Fr;mH>AoTT)vP!l%5IG#;37@Fh4$5I|NrWHjo!WJ+iuwHVn$!$qnh6 z2g`pShg!|cqG@8k;i%am3~;{$_6G~u{-hD)zI;AZb!qYbM~|W5b~ltfq{Oyw{UF5r z($QygM{&KHI?h~k8i#4rk=QV4CYk;YoqVoAfgZ(m?qa;XX$PU?*l+!OGH;__$;NXc#t-qQ}c1?yCxJ%vB-^kE_tVq>yWtlIBlYT_=gBU&FD| zS|Yh_8wuv(aB!;)NDTTymRv5vvcrwcA8)dR-JLM5y@@Ru{gYH(*~^sk!$|ExQ#3lN z4`-$i$8w1}ctAQtq1-iiZlp+Mrw)ZCyRDG*WDT_0PRD@Uwb*}eE=#L8${r=|A#$%7 zPGH$!9Z<+VrBxH9ha2JT^-OqC^Nj0j=FM7WCF9GiEMmohq5Q9?uFK=x z=8af-;R?3s7_h!`6G{B}e9Y3GKn_}VncWETCQFZ9hu3n+xLH#betYbOABxA=`HWEz zsQwGiyYyjFme>0YOKrW99b8NDxH}zBZVoyNc^voyl~5FJi&>3b?#Q`!S1nSlpQ;!~9~_u-^AB6Sl7b-JU%pa{k9q z^@2Antp;X#Egzy%VME+<)~q*O0wGzA(@j)bU&x8CL0&v67 z2@tmt;ow|hZk~Jw_oW>mULF1xlBKa^;k$aSST%t(`d)*qwQFIa>@QYv+Kx%)ykXs? z^7wjoFm`rnV!U25IQm)^FPfQ)=WU0Bi}zkwl!zqJIv>6_Tp-b-K9dP+zcQCST6nhY zF1C-W#upbn@aP6vTx75d5`Uh=l^u~-BR+$hP^^W+zG|YFasnAzK=H7{LCl)vLOXu$ zCWTH?*w43{ed-!c)l~9{R%0fPm3M$7t9jtGz?d%nBn4w7q%n2#ZP=aIkFSoGpr_(m z+2^-AL4Vd5I@%#q6x}2(#1e^U z($!Fj+{9*7xXi2|t~w0j`@4(N7^f)Gkt~Hn3Pl(aUdW2lvdLKoO}an1 z6~bhu5y{@K(7LKgbfh4Om8k`a(rcFpHJ;g+HdPt*I#l?w!Wfu8Z$8{CtH3E@Ss?K{ zg?UIMla!ntFlxY0ZtL9XtoD>Xc@o$g&b(&u^yPUlx+v7Ve7sLe)Zd`qxoGBc@Hl3h z+OV(N8LBM%g^{B(kU7V|R<|Uw$YmJ}9$5;vB&46x{6E43|WN*SVy_f}~yes=aHO!lp)6J28`Fo=k+U{kOsJ;$oZ@D94g7D|35= z8p&NY#-!42a0f!aG(b@SF=NG)gLf(5P$w<*BS$bNmtlkeF9j;~*lhoOLTX8s0 zkwyl2ITGX9aaE~19m$_fm*yQT3bn@jmkn$l|BQZ4?LNVMl2kBy^~jJ z)1)hFl=#Uueh}?l_eY7SWM3UxDa14b`gIcL)_i!oB?D`9EO^Jfy>zK%29^DFljtoN#D6;K#s8vi zm}`+mL_^1dciBT6b&%q8gLin~NEsWZroq2TO2*oScR{OqDf5_fnx&>62aP9g)T~k& zM_)8y!8`5X*o(1zuIm9B7;jJeYMAg7`b9$%iN~!;rzl^)kp*V>@sT+@p>W`4-ZCH# zH&}a8$t5pgU55)##m>Psi{rGidkb%3C4vf1AHHjY7*$EPMQlwwEzh2K~- znXp?}Wd05$oj=r?}{xET;h;C=@8oK#t?P?ogts?ahwQF?$a8=YuYLn!?g z7C{@@B53Zf2>MdMb_-q_sR$*q&UCd?B3%7_l?*TJ$vs{>8BM^s)DK*r6{DPb)@|N^S<{rW|eIS&?OK zI7f%i)lY#Fw|wb}j!krbMFcJR9zjC|Y`fr!VS#*va8EWtouDqwPv-<((09d`ASmD~!Qw0dHqGn))?4(sj+6bmqN< zU<~o(_^f>Hx>timi;|KB`)x@@>(%%RwMJ`L3F8p1eIxypl^jSyd~g` z#qYt#;Rj&o_Ni3ABp!s|2x%Smjx%51pDWs8Zt;Mw;L{t|^AAQ0rzbB2(vmOXbVEY~ zohRT$0=|dey-sTu_vpp-_`l<~)L;2+Y-jxs)cw!y|DC%%y#9aw{@?g(AUwPO=kFi< zwUXf^{+YiDVw3-ezyGDgW@k)Q$sWbjKdeV$ tugCtoH4x+jdrUd%k5<6-99bx475&HOo^EFc=^ph!0V6HebNwIJ{s&RvXN3R& literal 0 HcmV?d00001 diff --git a/test/data/deepspeed/model/params.00000/global_step4000/zero_pp_rank_0_mp_rank_00_optim_states.pt b/test/data/deepspeed/model/params.00000/global_step4000/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..dd3add6050af3f9758f85dd3507481ca640bf90b GIT binary patch literal 15715 zcmdUW2T&Bt`Y#zJh=61ys00a0+@9&#nMF`g5d~DhfMfwBEm2Vv1SBJZf(bz|pdbnc zc4u}PR8&Mvs30a3(PO}bh*=(b?)~3$?yL7+)w{3iU$3i%9;Uze>oBul_vh{^DJ3R0 zW{lYXRK|;`hy}z3g@q@EOb-r-511CSI>Om;oY>O8l~joozL=@cgi-eBnE3F>@Z^wK zzqt5-_z=I~@Su3UcnV*_)JJ}_IU+hP&Mz(~AR;7|FBudZ5)u;^6A}_UEjA%)lrLo3 z${1_Tw7(F^%^PJ{_?upiFEurdFP*}d@fkZRDJWr7#jggwtPAhmpN7W;`lyWF1+R?? zhzt+%`i#9-bE*d#IB5pPU4Tv`BO+lc#L0YY(Owy&gD-LVq(AD2n>h|@%uMl z`0^=y1ydjCQNrK8@fB0}<6Om3_)4ZevZD_}*2MS)B!=?G2gOH^LN?7WAwE1JPR=hj zBrqX7BG@lJBq}aC)-TaUj<0O`CrjKYWk6^MU*%60Ip>JzfcS-fK9u9D`iPH;wc=0k zkr}On5`qJu6<^IQRw|ya?%^sndXvf5a23zwYvu$6MMp%01jUC(N5#o4hz$;j4GEq( z3cVa(YpN$-JD#uO$=CIf`eq>%8?uz0`Vkf4CI{AvH{z0rX@+C3pL{ONA{lzuz(+rT)!l`r4g zRV*MOAi+of7fpCnXhewLXk&c%FWCH|SNg4tv99D9(>T-Y_znrb$LnZoP;}JF@KFAY6uzUWkKU*w@v#BnQGTQ2eN<0iKzvY` zU)<>Y^NSgsf}t@9d?(jbADvOMU!y!cBr4u7ASfsyG9e=1x6{9nisR3875^>%zXm*i zRtn$Q)aMtL@oQs3_%6S1X9NdC@@I!xrSRvNhW)E&!_31h!v57eVgKrkgajYuQ4zl( z_=oJ@YWZ_h_^$t)NTa~aa}|&Kjm&Rzew!W>8$ODRo7;2AgfL#1L&7w7SJ^QwxrQyl zqelI@-2azr);KYPzpYu*t(+Zo|82$k+w!$Et1CdT)^3WhkDn#5m0gIJ`O;0BFHyqx z;#V-Ock{8@#9+ahf$q~eO=Iv2a3yB+;Kk|avS+6Yq}}oOO&C;Vhiy1#HS7J+$fY3Na3M0cHJ;P-pU1P4ze6^gzXmdkpS0iacJZ>yA zG*n4PjoC?*dgzNPLOqBTUx$cD;bMlgp9-v5pLeyEsm?(1Rno9#KI^bXPD1aS_7K z%-Q%V{06aPzzV)mG2?oCIK@cDETsMIwlU8A7DQjT8^bZ}AX3CW5FKt~xb{=+={;T> zMMvtd6I(Amhr?P;%vIxCMB7hekug73RPj=q>t80p*tM==#@w*wZYYjt`Wl}SCwG4! z8a3K*W$_fz@;wGxM)gYesoF%*?{-CF1H8VV`EBGWU0nv5u8F2+_ zGO3Ijz2f6TCetv9sG6QgSX(Cw2_(WI+~Np}wN6Z%aSFb2d?_?;mZC%Zhv3k*Y`XPF zHj`nIK|qz0%$r7Y=9#4@(a|KNW8GiihjOM6JJThYfUOE#cgYisf~-B0!CQid#Y!;z zqyXZXRt&ND_$lIAhAXc7c@}d*?hj)2yjf7~?lb0ovJ~-*YG(%hRuGq^RhjO3C+?Ha z+lkN2V@9V^o2lD3iJRJRkO-fyPfTAI%g~a&g#6*_#MNk;ktXLepJgYBLc5BX$5Gqi z){0VGo#`f+m1DRWHaCgHDfejgZzIC>j_+uxvRxe34Jjt=o>Pw3(UcNzVnazbIhBJ;RKk_m6khH;Nn={1u+2@U(@;G~#2 z&|12lsPIl?ikBvEw*(dmCh5uwrLzVF3$D$?i%QA_SdqN&l-_$xe?bN|&2q8e!)%A9 zxC>oaLtq>xlbHsDb(Vr#-eLIs%WnkuJw`Bpimb5Ti4`Qjo`&DSyqfMS1L6CXI#^Qp zY0!k^ut?t(O{W_#U?bZ*vHWF|n>K99Y4SHn#qa6&2u9ba|F4VrlN*4=nba2 zv4u`+W0;W-67)wxOJtgOlz6z&gNR8J;|>W!pk_%e5e~+Q)|C}8tIAagxuiwR*EO@~ zU6*W`4S^zt_}quLO7}L|?B7ZRr8N@c#5D*zJs(2p%3Q*w>pqbZdzdL&#uKe6aO7TS zc))Cs8zyYmJf`Po01;Db%yhlGE)d>S7Ts~Q6UF3b3O!AW2uF4-lon-jgSBF~e|V1* z-MYS(tNWmqzTXB3Q6MGy{~G-N2gcu*krQ?ENbadTW$ukxfugXgk3?irKG(xaT-38X zjVLjjPegoKCfXX;%6tz>6?PPN8frQuexW-H( zYDpeEWmCZCa;BA=rGZFKRZ@w;xM#!vd;s>TWo5zv40Om6J$MI3u8UKN1l)&vr9m zE~fvn{r{=|_I6{5+OjY7%P=iwslq{M8}9{;$}Z7@6(ew|^j;$NvOcZxy`8ullm@ry zsMGJe>fyDI(}ILRfjJn15gta9nfykw7@lX%?(mK8UAg zI>X>?TJ)oJr|A`P7$F_!PkVIs(&T~~?9miGnA=u}-)l+1dPm)Ssei$D(bqh@YBk1m~2c(hIh4r#)XP(o_2Sgonyfgh$LIn8sZzY5TyfL=9=iT+TNqtg3D4 zy30EFY2&AKNr@aID}PA1rFcI6(?E`{&3+@;@WmIlO25GK-^W0{r#xfeKP9^zx$DLA9pOL$%B32r;P2*%o$;R?GV@j$zW!fTi2|!!&a>6R)BO(8y$kvD0Mv(jo>?(&F&sjY z4ifa`QMlrrjZkBOB60fEV*2CmXE5~f9#|Og1JBqRCTQVA{{_BXjBG?0kB= zSPH%Obqc<2&Mao;xb@7ZA~(9FX(MgiJW;sF@*G^SLXLU%?h;L($f6I{9i?~Oen1nS z?!rky%Y>WW@$k^*v-BYO6JKR#0&6ogh}O-vFtPD19cKBKzF4x4UcE;T7wl}t2WF`f zVnspBmY}b2UuP1NkmyaI$*1Y1bK2ov@vl(pgbxv!eG-CK=ji%CUE=GBaNM+c23=CA zB{;opP+tX=@UH2w_qUabs| z6?M^%jdD?+GO-z}`1a}=sctc9N^g$jKR8`3|1=F(zI4dA9_2(Nx)MVm|t z!;QZ4=(63NxW+Ym_+g_SG^5MtR$?V{w$PI3dbklbuPI|5a9Zhd;T>XlP(b{2+ac^cP{tH#FQZ??+-5>f zcMyZxnuItzSvaxpEz?yQN%Wo1g-5N|K;dE;;YgJ${KLbLI9gJRv-mDZOi(5IU9yO; z%6W|WcYQ+H&xA=lI*VBvTu3MGp@aqt%NRpzeWEO9i!dzqBmQ(C2G0nsrXN_0(2^nc z^t`w>xVl)8aeI=Fzi~2w25+X&cl91XY{^1g)hSSLVDbt$!WPrr&y?g_X3& zsumjcUcf!!BDnh29Nc!sPWs0M4cdNZnea*1PUxh%4Y%WN#`&vK>Bp5X@#8M0F!VA> zCxvz3Gek@1Bdhg<>diRbrfovcD0zYVz9!+}(RV^ax(M%`C_&8KaTU9@-i!X0VgRAn z9yra6<%;k;*Oak-2o^tEjfQ1nC|+B)!P zFOv$q;6gn2xBg1uZP!E?eD#`eGH)|1S9lL=@5IqoF%$9fdG$iQ$u4wrS1nd9 zQ%%cBFo6z zggUcaY2y#a=+u#g!Uv^|!a;{T;ccf(=rZQ9&@C(z>fO)=N86^-$-$D0L|86n_o7m` zk6HmmyLQ4hKR$xuB&zMi}wb3_xT~!u6@T^TZH&%pw zv0jPqTrLOGmD1sJuN^|glk0^s%>jamMel{5)P*qk=v>-fCxNajjTh$bmV&kCo(t~; z+ri?yvPcI3)i7jK7;kK4YkYVa3 zSoG>CF7NURpZTO8d+lC>AC5@FO=qS-?WI&}!a;>CLdePxM;ac&n$Buk6qf{p!uP|zUMb@0 ziFNQ+O(Xq%c^U0Q)WO*ulbFy{RiedW8f+*?hw1qkBV(vW%wA{ClqKiWo!8Rf*hLg= z_11=I9*3N}*k6193d$oyQF(U7}wPtOjQKw(m|1L(v{d+Jkoc z=D)I4tPu#sneNl0+pXoAW7uUAan=UErczyIm3#o;AM$kN$cb4+BzA4Q$32c~;{S zlmqY*l`>q?%oty$G!xG`8jpK!3&)c;=Hnab>3GtPUATwa5Bzc4Bsh2597vzu2+fv< z!YEPRSeZudUil0S_AtNb8d(u;RLTmnnvamb7W!=9&;U~zLduA`Iy zV?32${!~dKdQCO_a!(NsHhsdUd0xhSMwUS)DjY_=*a#gjl;Rqg9DH)S78hQ!hC`cW z2-Qmz{OFPj@0w4Bl&=LW3DAKn^~b^ZP9r!VV+RdB-oUe`j=mqhTR?wMi0dsT;Nv_I zzD3#tP6{xD;VVAi*5MOjX6h8!w`d0ZnBWRSTA$*__GrReZYog5MiZLtbb>RCtYKF1 zO?;lBIlPx)0>OD>Sgbi6e*CTiYYN`tSxHNwxVbK@m7a@VI{pTqcC-mEn=}#DlGPd_n+7^=EbTx-=TUS8mx1IRMcp7(Ibq=3!KNvQ)slnLI#kh036kPRA z1_lJ#5bal%z)7vkphnsd{^PkcG)R+%Utc@G$lhCcVBSmoy!K3p#_opJTWsL|Rd(>y z9A)Tbz5?bhodm!0gW#EA4H&j(4$L(kz%^ZzVcPC&_{!B0_MbI?*qAGLd;K$fu($?4 zw9f<p8imRCp&Q=I}Bhz#7}tPh{z^{BgES6q{>J{t$V`sQnM4rW|HZz; z?}h#Su4%v^^tfF~Xf-FZ+t*%0J3%zMXf}A9e38SGtF@bfi*c?uUb7$44z~YKe}DdW z^IxmwipIl?)M~h>gAZNJ4IxkdGW6Fyg3s)ahxjZ@ z;-;<&bkXgAeiH@IVqGDT;ljYpU;`BFo)0ff-VM)&8pA|z6n4$s09)P7p{49Ccr8^O zI-FkxFQ1w6AKU+*`hRG@H$Kil3QzCAaa^niyLjjo_C>w`JMP0`#+q_?bA=;@&mE5^ zU6_w~rA@*T%i6Jf2L#yHeN*vIwF8*iwO80=9XCM*aCSvx0zf(^@ny`Dc%3lszE6tX$P-*kWz=Nd4(SPR@z zO2Bc610cF+3&>Q@z&dnaf=2u(mTi3>+&W^8WkxoG0ewSk?^ZKxr+G6tK5HsyRJZ{4 zo$CU5nTLUci5K`W90^jlIfMO^yseXN*!46cIYhbOzQNU)%4{P~)6nwaG6U;`@8!LP?4b%SQh)t2`0s}9ifYiM2V9z-zthi(Vuvhnk2Wlt5F6Ic>ai|EZ@cIFa zFSTOfsbbi!Q_Hc%)r#0quN4;EZi-zx-wv97hJd}M%GfC}F^uni4HPAA1tyLAKu$#= z2wSln3m(%8ZUr5}Ry%Zq56(JR*3uVXcAO%XEUtzfTG;^h^&5jtkzL^IvNu4Z@dBs` z&I5W4yTQ~c1)xtR8xy%b2X_3!7(25SnCEF>!^UsGR}(FaK5B*G4lEe&Xa(lhJOTJ5 zd%#+c2MZrO~{&6kToa^Oy&t~>=O zy|3gXwh$nsP!0>7au=k1DgYYVnZRazCJ=j@fHjBT2ldsLu(bY z5qo;w60_-9hFy6!52GKy1)KA;0TFMEy-XK?Le($nG5Q|BY;=r9Bj^;Hju4Q ziP_5^1dlq!FzJt{fjE2zko$A6WN9jB#*{&fHVta~RIt#Udf1^Ge*mA2L*Q-iAb9hB z05hK-gguQ_#<%}DW4h|jm}LDC>{Dwo=5rIr_BRxQ^##!{PhSQZ& z&*@Y2=47ud;BAjfM4k*hVa^`YOvLgynu zOQII6#BxB=(wo3fb=_!xJq2|VQ$e0P4r<#9LA=35&{94I)9(HTLMaN%CU8v2R){&~ z24dztOR?lRve?wXvDnoMZrJzBhS=MLRhYj)2QZvc0)F&!vB1W8*s14L*u)Ksu~Nw# z?2*qZjPv0FmXNp`gO<6Nv)vBth5J0rY{fY&V}&=SE58{NGt|YtO%lT{>$qUAQY*2^ z>)&EgMW->#FC&=Dv`Kit@Kn67)(B6ywGv3N1Qru>6EI#(K99}y2 z2e$iB18#Ut5~DFrTHMN$Tl1+TJ#>v{Ja&5ye*F}S4hPS z_Z-K%?kZy@khH?eGxd0qW<$yixW&k#J16QR7;9F89xM+|BjyFnR^|B(+@v{Pq zkE#aeuWbTzIuCMdKx;ijB!5f_WoI1M3?gA%9*V`)JF0kWrGq`rR5j5wegBVk3 z%=Th0@U^S~>6$MH1dsqw7-QqBd@h8FFoe9|Cc_x@aeIt;&R04j+H3CnS z25@1S7`Br<0M=Y82M?w5K-Iie;Mx9Ga7wTPsQJGHR-@~L!O@9W{^kU%-~@(cZhi>7 zrtSx`A|x;|<)dJ6y(HE!*bc%5C9&rx*Ma@3kARLh6R>5^g<$S(caVJOAyC0%foAec zkQ7h|%J#&9>lg_{)4YJPtS_)!mJc-7Oarv>G~kgT3;LF;fOpkPLB)t4D2tv2g2q&G z=00!%&hJdX43{C!gz+=M_rx~NP{VbOuikggvoLdT#;l5C-Br#x*fIeqUt0<`_h*4G zr!s)2)^X6G=LYb@pMiI85SV=D03Z|@5Ta!QBF~lpWzG%|crXQo=@DRoi3Iq%LkNmy z)PdzsRIu6;R{$rt8sxTlg7-;1plAMJkem?({1Q9C=~>R;T0$tuUbq#M-l+sfqj*3g zP8$TuX94fRa?s}S5fCF{*q6(H0QF*hkZ|@ruv>Kk)Hc)r8v6$7B!fXBlLw@e^MSZ& z8~Cx{F=%<)4P=%IfS`I9#P;iAk4nFR)%A+_*+qI7jJSti&d!R%6>v24J^zMVR>X4>*e5j%9qliK$h+1Ha-&gUg@B{$Iq8 zZ2pt@5xl2_9tBQi-(K9yOnlfaO1k35l96MPp{W&n^X_$#O8F|$kto3Gz}F(f9v-`Y zJ3!3tcs3A{NY!Hjx>M}S9<-Dqucnoenx_=WWA^j8+x^Yh;=^rRcesh1eLs(!?A^&- zdu=jV)dNXsv0V24g=t*TsjKWw4{@~6xCjOL%5e|4_p>Pu+o@TO7f7qy-$;kHXhi9X zQGS|bT!)VVsJtbY`(0!}-MYJutjM~8!gX`docCwYIKd!mak1r&5?_6?q~vlFVu;MlZyxDEZ`SM3wYW+N@UdiW`Y&xacBLT>Xj+J1B-c zeu}Xhwc9B=j8qn4G5P1$bP;Yj-Q>P=xqa3GH%Gk>homyao?o^jjy7wB; zD$xP7#{V{DqVeLrOHw4Xi_=!?YQAdxBM^MYNUF;j*dz9R{cG661fEo@; zLf>9wqHLZit8?-Rd2-ceYIog3?u0MxWMOy`x%6`c*<{^=8t0i);r*FZ*pv$7>Y9gs zimjmH#e}G4@(}gu!XM<)s2I{gjEmf?=1>(ucgQZU3gn*WjjkIhkuFoxsofdn$lzi) z)#P~yY0&3TaMcZ}-+U?>pDforW=bJxCv%1EQ*@xJ#LCgmPIqpUlOnQoTTJ=CRY3QH ztwrQ>5q2)8s+MZr^>8>cFvp-z32en%O7U!m08 z7ZZwfT(+VGVfj?sA`LQFvYyh1d6bf825OvMz`|RDXilajwM4Oj8*(%URoBdK9^6%e zj+{&pZE&w9^CnqR`o87Vr3=H{6R{=aJvlGxY)cz*EVHKEE=e}8SUXHgM%M`~Pg)U*>PQFVI~ zYml8snH*>%73{vSfv*&(kjt5aZSp%^PN?KET8KYq%N zk1AyxoMpI;pJo%Xp|?f4YbJ=Q?`;+JUDrjjO%^QCYJ#q@c8KOBu}Zn7q~EzZ@?oGM zxmQAuEdQp&dV7yX|a^Nvu_o-V>p#Hb8qE>gD2Ps`(ajjXg2Bp z%$kj-21OU*9ay7PNn|ST8xxh&EeiJD>{AGt4j14x5inxt;D zJQ?p_K+=_^+?EbcQcGY??$3{8XMNzZ@7m6x-OOBS0?DVgJkg`(Y?Y>d9xSKimS|I4 zv*lE~zXsa6T$d{qn@C<7`h*S^`m>7t%aMnY5@Ir0)~zm(S~xxyxi+6+r*ApKde%Q; zLyqN9`Fh`3U)$HBR)_UmsY&uE?rRg$m+WBe+%I#73QnL?SHsz^?g8TJfgW<+E>HCS zt~9#2Ett}DUc#G9h=lMlt4~kiDmx&bob4Ik!N5DLXSY zn(Ow`iDl0PkeW3zqRu~3*n~GetPy_$`y#l7*k7`ah|3wzmMoBGcX24zWuiHnpT9tK zxaTG}vHO*%ZSOD<^70CqF!v6-cv3jh6RDwuj5%mx`esBBHSFc`SM0eGn!V?sP4Ea{Xy>s&_kybhZ_cEoYP1pXKXCn=BQn)qdlViONj!dF&Bl+1`iT?vw`f z{eC8z^C_EM@cK5{e`GyMJwTHNK{D+4eKQbpk)bZkh+;Qny%5P7Xp-M1%;qjEHDV=< zY$(0E`RshB2((+Ji5=H&g1YT*kOk*;(PZ~^%#p+X?2TwA@@^xCTwj(enmaU)8@IQX zS!L!c3eWN;Uk9X-uPYXiHmPB3q=N^Uq&1CvygpQ9rZ$Fj==w>3b+g!wt=6pZNVRB4 z%YdwRb7o57e~7lbW|BGsj^y0!7r5G+)7U7n=iK>|^U0NJ)gq^G8Fn~KLLiyHNi=&n zhPOnpjTiZ732$-1Cf=LaT;Aeq>Ac2eQM`@4!Mq#UNxW4iTX<;!n|T+X$M8bGuIF*p z*Yot^w(#^aa(JFvn%fE~{d>2N&E;M;^yMDvS>Gv2ItDiPU`NPl zF=Khj2fL`VS2Cz4P89mec2iF^+ewPtA+m@bpvXJZP+FWg&(mg*{O&u8I)t=P>A*6i z7Bx&o+&IOmJ>7)(=Lgwca#GC&A4dIU-&0eM%Ck9pbefeeQ^@sXD8=0(*=*>n!nMk& zK>KR|($%_zHq8)_sKlOJoS{ektRvY2f22`GM#bdQo_nHs+N;qx$LUnnn0cgZ{7v*` zixv7VuZyJ4lvAq}W}-c7+^N*LOIUl8b8J$=7;aYm0D9RSOQs%uN3Cm}fY^C@RO71} zbW++J8C`u(82<5rz4}4{>6xyh$^;73*I@xQx8pL^8F-qK({G{Vf;ni$KpKT-o7QD)=fb*mEx34QW$F@pGj@AC)sUl zBG9?jyHQZbUFw3V6Iy;G6@~N}G~Y;m$*yE%(cPg6>ejLg=-|#;uGjOo)ZHtllrJkw z3N5x%awr+0lO7b~ZA+O(pQN0xlu{-sAxMhrh+b`ciDZsGK@+atASKuTK#>}$l>W?Y z^847kq9nH#^k&f>HpS))wdYDcq#kdGL|j#)dpo^DS^ z&YnBS{JYCgJl}=o?A}JcAkUHVs_)s_(5cP3cRo`$DQi@4_YB)2XhB=1LZn5UK+@{U zNb!UN5|gN6tCM$8JqyMoi?l7|^VJs3%X>$toXAB~YGDJnOF^IB$@4l+mLSfza%$U#^{m?T_0*GxKgh`5t7z`6bm{;b!4`bEOD3C-pmVNI zsTR7K%=5~md@po}o)S{LKyy{<%>!jBWF(IIw7&#(KWrjnGgeY*9KDLF=PsgFKi`dvMz@Tsd_$?bD|$rT*G#G4?p`Ky ze-(Gjj)&w!%XqYlNTW9QrBE|c3q?u`Zcz@KUyAnqaHK2`e?(U%DaU1gRHr-m8!RHB+EX8py^|CsS5TGGW6O;&4|V++jnb3 za_>gxqnjStSEY(3^eCeh){WGa>LQkFkclQ*ibPFny&~szMrd(?n#ep^iPav+qMp=Q zihk@8qr#>$6uNGS7H4^oauZUi;Po-+5)7tn^t?w;^!u_8c1$N1MK7l+Q>EBcH$U3C zKABWFe8F=3PT@~irPn6#%0iHEuEry)5WQ}4oNiI`={t}Izt_EJwWbO zqA4lgZc+Py4%%ez!%iCP;a*r|NqX1RAnLjeyRq^eS&wJ3#=9O;FQ`|fzD6A-jdr7% zDI}WUs7cjc+{fxg1(SZ#$53pJIGMITg7nDhB3%@GsGcfCYSp(*sHP?xah;N>*OzyZ z+f(0CPV>f5^A?+;oy%L$lb6|4p=N<-QSTL!v=xyj&mSj!42#&us)}gJ=yvfRsrB5) z>3QgN=y|kyo(;*gjwg3$=u_8hkD#wtb;%#MywE9nC8{AllNTrn-lrHv(W{WL%|BK= zVBh4;MjjSz+~MVwD1Y8LYPUrNvV6^_svb|Ix>Yuz341rOGJ4<1ZN^*JLyDGco974C z-AI;w&?R6W1??n@)z46D%S6iLu?M9cybi5$_8`l~U1B>Oa;Y80-`EW*rN~9g3;D}8 zut|G-xcxO&%;AZW&5gZ#SfWRc9D8#-b*@U5a@{M1f}K0iqDQvW)2IQFn*0N5^Of^d zg=#JpyHk}4D`_Y1yxYNAm&qcX*DpkiB>a$5vpSjn*c+|bG!MzOxl)GIE~oU#`Er1TGnp^!&SZ2V}9 zVVyvT=IYI)WDg?Nci3PRHE&7l)kNk8v7}beU!L~x+As=iJe=>I&C(fW{?!u zf2}|DD}J}}Z{^=JeE*&MuV<>p{f+bFXeQoYxu;wGJMUlPX@BRvqyBHa*8k4?S1kX| zt2y=Gcy0cj_pfODN=N124m4YX(UA|g^-d=)!ygq z)j&lNGDhaeOn4N+dw71o-}n1`|JVP0uj_r+b?x@Hi(b)S7Nag>x26B|2L z?Ell0#Z<+7qWpqGHUv!f_lfbbj9eGyFn65Ts(-WuiFm%4srz^ldqiYRNO;JWfGDr% z7@wE`FaHp~7`}KsU&7Q~Uep~H5gqLn?dKB~5XG1D^A8A!jE)Ql@VAVLT`%Gbundf} z;#mG0k(`r=Vc8#gIlh!xB7aOgU)o((B*`yUq~do6UuHh{SUg|J)LlmOVZf$HFP{xTd}Y6w2oYqKUa>JDVbOA4Q31ZO zAz}VrF#+qNBci-E%#`D+nEuTYEu!=Z3gD~$%_8Rz7U2`K?C%fd_-gLrBC#|0UpH7}N3d$}-+HY0djGCR|BoJn|JGypTaQs@@NaRG zV)&C?`BP$}-4#RvBYmQL!bKoP#70K*jo0u^;`yS{Nr-yFqWEU>c}skqdxxq=-IYY2 z_=fmId-+8~Mg46me`>}bIDG?Re4_c&=JVcM`rR(i=e-~9>C;G`mc}=Cm;7zCNVLU# zWWPE0$&76BkJ^m^Awj`0Uj6}oKAZWL|C&9KBa5adAd)}bxl-xRfc_X5&7ZM`Z{;ZF z6YCS}uK$}RWPMOrfS0Hp%E#{u>|g z8y}eQ2OlDaPrCBCK7a5Qg+j#dSP;#!jibaL6#k{tFGA$^d}<9JIZFLa?jP+X^3iCK zL44jHY-U(;ENAm=WB>T8sMjxIePBose@;AquBp48Xdp3BJ|XM9ME)+)h0tB6GWf#APn z|53|d6wi13??4g(6SR`iqSE zeJ%O_U9-lC8T@0-nm)r}uFij~SpQhQR+)5Ga|b6>Fug8*g44(QnAhqN+)|^>WKx+j zv;H{+Bc;p*MyZ`7j~7Db&74TCnJ3BImC6(pd1;aHp_<&9xCU~HR0UJ%eir&mR?#wb zp9sy0>?R9`VyIH6$XxZyI=!~&D7lRDosg?6Vvf+0nChw}f~{{ViFK}p4EK#!`RSxPGtEJl5kJ>TPXGLt3D>D(`Zcr}HKZ;$IzEjXnB~gk zEmb8p>A5j8PrK1xQ9~~%cL`T zn#g^;tciZPi&VU6OblhPjLu5L3=I?!^ZnP5?Zy)slaY_~VC6A_xi^N`!}bzKA~NXu zP08Tm*k^Qv#E+&~pUcP_Tl~p{iGxjpTi?(UC(aRh>}leBoeUWl@|G@~f1N(;T-$WZ zF_7ppup*ao#*^QJ&NR+yFJ$r$mC_r!$1oSf5Ld2pKl9y+6i~>R4!Nt#J#T-PY@BzN zS(6W#a+Jp`%Xvr-uU$;8ptsRy4o#!ApWkEJTGo>0or>I`zGHNj<5=d>xBc|+;w0|i zyq%=vBr!pK^faPt&rI_4t8s!}J6&#(oF3_^UfRUtPZc~FmEz9%Sj^~-d&n3qieh4; zmoch}x1d~E0Cx)KFsaz~jyW~lN)to?*%7KIxDaH{y;w1c-hX=$`744F)cVfhZoMUB zmfkzaG`pW6G>&!BCmviUR0I#018MUZ085yM`x@vO3sw*oQIAN!hi(jeO`37Dr|F-o z4Y{}KOKID|AR?o33i)AYA+7VupDY)5r1?>1f~=1%%$%=R$piM&IhRb_m^&r<+E{Sj`ihAngSehj(+swm)IECYT5U0`tg+C(6j` zcXJr2rfP2IHZ8L2*gajTot21o8LXge8-R*^G=6!E#1L*n2RxSQa_lRB^O9pg*LiN zWe2HemCAJLohH$ZYfV*BsZ7a1cV=DJL^6jTM?UT`7i@n2irhKrA#?nP1nsn`nTX6v zAdfzNM_Y$zk!`#aho%Jr#nw&SOxMYz zc!W;lX7duVL6AwSC+s7CNER{A=Sv7^%Y$@WwWHw7<^)FDHi27JQQUYsLrS1nHi2B- z)X0bjJtU>yg>fg`d&6)lbjYq3n3og5PR(uBI-FuB3eKJhO zoE7JeNVo_NoX;h`oR8slIV>cvh88ozVS6rqXD;}BAe7wdzk}g^up}ECE(knmX{P_{ zizeUj24Y}sIvHJ*PWtkXGqdPf6dqY!hvRn(cD4`9!b6=>GzwN`WXu3 zk(E*c^TB=e-hh1aX^m;X_MjYdT5#ho?y?;+1%lXeC`fk9tc@Aj!Dcj z7OXh9fT`lNH0^$=CiphIh0IFMBRW+k(vo@u-1!$$NcWGGWRl4jn#`ynRBPWcn}Hcu zLq(mt>{%a^AL+yN{A^|%aWG>Tn8SRZEz7K*Ekl-Q+tSA8XOmkm8w=Rk$LT#@StNNn zli6LmfQS)GB#T#R&@XSuaGeTNh&ylXXzh$$q{<@|(#=+$8BNImkyoD(*Rq5}X^9hK zH_nF}Qlu@2iFISv8>SPvKQ+0VOZSi_*?dNR)dD!-)F?CTUrLI(7s9m*a!J2k5rUlY zj|i&yGF|gyLPo5J+GJQ`; zKbsu-Y1>C~tNlazP^LSx6?>9BnfK_~^h;20kwI2?DbRB=8i|bUbC_8*P7I1%KsOqe zGVMQ}G9kN`lPkKs$hK8nxZCmx`Sj64n0)SIQ;K{uG->rFxKA~S?hc!##mB73`+FVe zx52YuX^nH!jH{RFpF(l6cDxC5=;b+LPKYLB#IYw9oDQKQM0M18a z@kR;7x*j^AdzkLtG@h*4D9`9s)Y8~fio9)YLHAo|HM#EH$TX`IkQeKFiH;rzX3mZR zQtX#e?3PkbHJXfYUD`m5Ix5}hkkx) zCvm}DP5@7CXuZ@1SJCix8olIQzJxbR5=#X#kPUpI4 zXb9GK-Dl=}6UAkXBA5T@6FpjfnK{I3rYF6;#~ge&N*e3s5JM|2k;9Y52m+OskQ3K? zHGMY^V`eQ0qQe*VFcy^`$x3>a;Ltv4?&Zs8$!p1{nZ}5PWQ6oj(m}?9d24D<@~b5T z<(K5S%&JbZ@$)*tn?vd3Q`b|(y!KY|z@`c^{&X^99+^XCNo0`Y92N-f%5;!IhYos% zGoRUVe-kNla0@LJnMiuZ>>yrMSqPrZwxVwg^}$CxALjLTQzmq(4rw&LiqzkuM9+T- z=yKD;WY26)LqK8`kvuboZc~VHtfz=*;v#xXm0* zx=wmdL0qr+gUl)FAhT<^0_pX-m~o%-o}^SB(bku`$VW34Gcf--G4WFbW6@Ae?muB6 z7`p1l3~aR|o6?kM_1;c8uEK>AGiPug;j%G~RB#I++m6UK^+Ibxe8`s; zldYrWWloX>BPxvWOd{>MY6TIt?hN7VkV-sl&ZOB>%LvKGN1HCRs*$9IC2eojNfiGW z1M*zAGoY)7v7Nm|ASc&NPu(aZh*NyuH0y&JceUexC`9u?A?bte-lKOZsc_b+40&zO?4oOBr8t$f#~;Dq*4&M9YXJ zk(V2dxxq`zNh*O)zTi%#8_LtkZOy^tHLe$V6ZOEfIvaAC^C8kz%7UBb6-1v7K1^4t zlH}Ef)1>MZSteYif>1l_!#$KyN$eldB%iCtrA?Rg8xJSvHqv~B`< zf?r0?y(&7-D7O&vMh?*9PsA|#t^!h1;SO`|*)irm&ytxeyiU~p^^tUuZf!bx;~*gw zVZw|qHfEA|MP#)W#Wec!n)pru%;Ef4@_lMA-0Jz2Ibc{$>W%x%I8D1oZc0#OE?=7~ zkhC>rZc3BnPMAaAT~`H)))fg%lG5k}W2Cro{oVBSzO<&!za$yF;t1W>s!S&AyIuj%Cbix{&DDrRp4WCe%V;9o=8b(x==KQ+;DPnjbxb-3Oq&Z)|TERe2dd-%rtCmB~30qB?78o!F zgHr{E$5;x2LmzPW>*p|T-w$%H*L86pd_P7+EXons3{K>xgvtq~E-WK2wHcA6;wR}- z9`?-W*`B6@eoN>nM>mkJ+GW%;`Q&y55;T;`JlJgCig>KQz zK-VMMY40y$PskAAmw1->xaU_7c6MUiNjuc}p0{B$d;JuKuLf^&>>jYZW@_ zg%>kqQ`EE|?jkKE??+xd@snOARYgy6J4ifhzeV#C&0$lPKjXXCh1}J(i`eIN75X^( z6E|+2YjPB(&|WfOO}}$QaS67(3!<#k@9Te;BeMBtj>yVkF8tr-i2j)|`saUb7dwf^ z-NzuGy$LTr-wpMv+oAABp=b}33r&=dz;?0J` zsyyH=K|{HS7obLH0NwU90M)9!pvBD->Sf&l*Hjs>zGx+|x;PzhB36Ro<0T-`pc?X} z#=;F^U%<5;-#~ElPWaPb6CUdzVL-A8OsKyI=2;ZN1LN+2!37t9QpZ5; zS4miNhlXW$E#Tt>V>p>A1>l=Clp<+z-plSU@UxeKOU5|%>nkB)8MOqV@Qh`z~@18VM*p7I4UCvXL_ZA z1+`P)LRV){xlahb6<_2CO}oKpG9QRv+Xn3HmxAhH4X{;^4VEO9f@?3jRkcAqn)$n}LiWN!aD40>f3G zgGpa@f}xO$z%Bkb2sB;?qH-vB;lV?oce?`oEOG%?iar2C{u{7p^a{9>m;&zQ%mzwI zR`A=x8W2uff+a@7fZN8H3Agxwu39HJHR&WXTW;d zO`!Q@H>fC=fR>$>uz$vPa8L3kPcl_(K#v?rezhuHfd#e)l`MJQJM%--yVX_=T}17 zcP7+vJ^@RGe$Z=m0Mx!`2%lML!Okzb(EFkrTzmdJ%-ibLfaNj`y^eyG#*1lZe{(KZ9QfENf{_kMmWisdq%mJ0TwLoF%CD0Qq z1iI6Fq1chVaC)CE94=~v+gN;48i_ zv|5u2_oOa^MkWfxt#WU8BU}w$wcH7(b3G|@=_djdOin-E3F|LbONl|B@J^A z_kpUq>9A9g1@MLzT=}UA1gkNixJ3vK9{&ii%NSU7m4a z{FEk7gx*eok=3p6^4khHB&9_JOiF`^%vI<%=nn0>wnERnAu#e;B2?|~fIm-vhwm4S z02AFCpkw(2m^oPn2AtRf&Q)S?sOk*(wMZKdA2|kMCk+F?D+!=VmU2+h8Bm=PmSg!F|$?Q0IO_DdtYa4!n-bu*xQ zWh9*9UJiX*XF%hNw_sbRD`7Mx7tXQFg5oV1kXrc^zK-#RX05qqi$FW!d-0b1{@ONP9 z_bBKtEQYU2r@|Mvr@}390l;M-4PGg!0#j<&g2Mw6(4(&$NXecA&m#?Cpw!SzHt_}i=q)MR~Cl%VypAE;T>A_=<3&6AVYtZ+$E}XdUJ{W%e9r(>}2X}bS!0=HG z*u|d+ojKA_vThnYpw5BQQz0}b#>3DOEBM8&6l!s=f^82Zh;F|ZU|Gv+;LMeRYP<4* zsgf!@>oyBsE{cTvH?ML#sX+%It1Cz@o*HCL7DD*aH_-z=&c)1C{(9IH<|-) z1e}1e13Tc;JtyJFhe-HpcRDy_`V!0+Xj~PjTssxk=^p{x{E~pVRxOybQU-*%#By%%BRIPjd2oza1(4e}hqLL<6VBl9 zUa((!K1WAh4J@3x8b}^^$nl!k&sm3iIc?8^IIGuLf+yjBaeh9T0d%SXDC3!d3-heN zV|E9ac~=MQ(eDMbw@HIXS;^pzTMFO^7XpudBhZ?(0u-zK;;2nn1{Mas1FpZ?KxF+i zXnJc8DA}?Bo}T9iE<6qej?v4ZnB#Jg_t#9=M|OkiPcJwRKAr?;D}>->raG+Xu>;j! zvan**0?rJW1am4cg5j!3aBPh^OffNqyKLg&qY(i}-Rlm+L~BCO?Nfm1^oCWFXF|>P zN1*<`Dr7!ZbB-t0aN>T^oD2I3IJ}dX^ZY)^3I1*k(zi?l?<*E@N>&~Oj~^Cu*6cIk z7&fi}PTm!qjFrxu;N2%UK31&!2!3V|JK)Cfia0plc zwi!*}jNW<(o~-Z!rVY!0*9sYUNPZS1=4QYlt0p*J!2s4rvG7Wt3k;ta01F37VBh-L zu;gn8q+7M&xQsF|Bv%Hnjok*nyefdp?x?}p4$e?8bOahvw?NJmS)esF0{l2X9!8&a zg(rHYfOphE@JdS>p8K#KUU3keS#n4yz4S1MG&={np*P&JVh86y92ZpdC4g7fQb56~ z65Rcy0*HuY5G>luJ<+|%X?)+!8K)S)nU+2VD89+yfQad!HKCldd8-t#nEV-(K5~LP zcIZQaeH)n5G6I-K-@ygZPx;lkH-UIgDHtO&0&H?N0ma4<@Mw1caR2cakP}M=0TSes;l zcb|2^w}cYlzb^)Stfj#U=Q}|6RS%Hf>k9{B%fXykmatJx3!>e1pf=4P_#XVl(X%-U z;#V#KvGXp8;yeK?xmE`zZd(HMy0u{gFoAA%ji8$(VOmKS&^Pmgowrs&twsG{ZhjJY zY-<3X?iPp2K}*5?IUhms*9egB+7CkJ+y-_fc2HvDG&tdn64d@40)6}MfsH1M;TErh zU|VMah#fcw41yMdB>mT5=KExz=y4NB7*&B=f`{O*)HNVOO&l7BPK0scJ>ZdA47|Mm z8%Ql(0J)qapy**8_y7-pxXVgl$>=Kp)m=dwc^7!=^T7+c9rUK`0pm*aq3?brxcA+C z;H&i<)E3e}^|~iamsN*ht6P9)R}xq~I2GKxdkXAybOkHY+CehP1~Zp-fX9tPoOx&V z0FSsj5c`1xLMNnyAIEnAKQ$H{{%anXGQ|e0ysZZ&^6qgG(;?9Mb%}HEb_&=lHNc^b zc))0FHc$^SftA(@@Il87@Va*r3}|lv)7Od4W_k`#rmqNGxLgJPiZudax5Z)XX-{zZ z^Z@v>G6f{we+PnY-UZ+6=fhiyX>h;xZ4mL=QnbIbhLY|@us1^v${nG>79UBtH!l-> z_0)$g*AhXyr9PY~n+yzZNkPd~iXf|~71+ejh62+9;H`TB$l2`&z%+YXL={;~;z4Y;bUPA^3H&87OA%ggdA2g99@rz;L?`II7?VJDV7| zGFcA_+w|d5?@!=>bQaJ%;|)*lJqZqdHh_BX&Vho%*3e=3N6zWb?Ew2`0);yeXg|sY zU*fidRS#96N>&h1d~XhRDp-SshK(Gb^6lU+o$H*kwpfss{eWYBCqfi25UVBx?ka7)ocw7$Fr-@0A^=}LPLU!?**CD(I4Z}J7p5G|Axw@{hkahoz&OweE`D|b)XZ>$Uqh5(PqP3l;)Md|AH$rfhg(4Z4{5M1<~Eq# zuLia&bpm~n-(;9Qf?7Y*V2N8YT=41$Ja7IP_OCxD+9Ni@y|x}OU9{ijj93!4hYrIf zniGgI-%3UM@OT*0^iZ^Csf6!JFqBBqCpG(h(dc_;4feWl>~BeMewuFwQZ~nd zyy#3gds+dQ=u!@5iM|We-6w@EnSO)ACj|bUi3&6@sC#ZIL0X(B?3JzYm z2cEt?1Ionb7+bZSl}ZICQo+e zP%TxQA1_Y=E-Met1_p7mM0>B257O|Sp*e`}{{-4}4+8zQ=K^#f|SLFAC#VfXgi7Ay({INAOp`w9DnkW34p3j-1DsSEFm1&*mBF*_YWwP}u=SS8N z)`)Yp)5KcVF2Gvu%2MkdW}S6T_Izutkwu(GI;xz9sX4QgD*LVTY#vz8wRPvzb{lcx zM*gz?oj-b)m%Z%&B7ZdV|0RF4c+efq+Ej$^yGUb#&tq%%cTs0`4xyrBvPkSjKH}FI zQW|K@v)kXyQhWUpDM+6+*Fiyt%C-#r?95VTr{s>CGwYB zjb!4As15Esw4m1>`{E#UykQQT%PvJ%$_j)R+B*c(JPi=<+-Ejx-CcHJbp`e7ay5JI zs5W(EaWSi(O4yuu)+!7r+fOyX$LzN^twPJkM(B*t1MhpDjoz$?#_ifJ$HtXYikn`zas$)PJFEnvRNgjOUsnCrjB;|Q+$BmKP zie98?XO3MqoTDxp`3lcfwBzZ1d8q$O7G60y5#?uWWN}Fe>$FG#yM(KwDJ7+BQEds1 z-CAZd%d{IExVa4VAX)rLn~O)P0bcUO5yhGnP*3%3@JP-(o0mLsba(Yf%Hm}mA)|HCk=_;f3(^E5+JNiR_HxCr)s>J~H{J{$R4ZpS$llDN1@4w;!9WIbbxaYV2h zjtZE7s`z4P`#x1PV%98dlAnRB%?DWJ@Nu~N^lLN?9A_u2+m6$w-$nZ?YpCceDab}R zjGa6_u)^#-sR~s1@oMVzVr|4; zZleSX3-HpZjjW^BY~hEi9<1rTxm4P&LpWzFg>uPeYLauiP2;LeB%wPU%U|ta*9uyY z=mkZafx=(c{mg4tI3vI?MPsHqBm)I4f)<{hbVZV)) z=$?1GjjnqHx^PYnOTMi{C2n%);_PiU`28#tC_77#5}=7xzWWNvfIDn=k|BP*FPTf6 zD5EC*sHBJ&`>2oxIXwLyg>|m($D4L&;vlLLOHEL~Qq&`xv+-)|n{#1oVs;YRxb`$Y z+_Hn3u!W1~pE6|8Nn>hWU>0)QoQ<0$XAjIT}C+v>p`RutjdhCam?U++Dh#dBJAiY?G>!LDHcD7XW7Y9Y0dcX=>Ij_d1 zYwOULv0=Dna2oaveuQ41AEq(};_+>{v+T%>&o;ANCt;mFdsLKGjjde7@kn?Mm2DDC zS=n}D#pVJwLUFmEBDw_~pbxR6rV;!2Ng+O}W`auBDO08X2ieIg(@=TiW>k269a<8R zBAoAKhl>dd?!`t=B>Q3%uNxAwA7{*FWuF(Z$3pku>thU%!#0c_duLH@<0r5o#XD@~ zT|9!eE^1-dj*LM^?=3}(H##DT>?T$v%^rs2Z zXNx!YZD(s-a&XJ~TU2!ei&hO4iTvakUhS$x-PxBzwXf>5scMq6`7D0LrgPI`TEk#I zc1>MPiO(xwS7#XE@GDQ*rPmX2zw0kbR(}I>++&Z_|2oF5du}7rpMq>Rofm2kA)DUH zY3#(*nbfK`25hGKR%$>t2AA8Ov~hk%um*OzDE#<4_D!KF>-UjBcY1gzjhjft&2FPE zuILgL+HYc48$}}W`$wC4r_a=axv#0^d!6u1?GV-?2%-wZeB5F)5j`t%!3D97l-AR5 z8^vc6@cKak6+h7f*)e=pDQ^q;C^QjIIC%}FWOcJEcZ@^MywB9kIO*v0(T_#U<&thKv(NNxkor`&TRYAPdw}N`%P=E>#Y#xj=d_}p00`mD;<$xTPjxejbjfny9E2Jq!7pUE?1&uF)FxJ%(`%b z(V`26xGqi{>)B2~o0q4kPAxCr_>#F}?pK0^(J@ci=_wITeemRY54^F@v3!Kr$=I4SrY5S;- zd%N&OVi9$#X9aqI($K4!2dEv}9^gZ@6NKWWw~?x$HmdDDgGU~S){|!)C|tMz2XLq$});DFxJ3kxv%*iu1U)lzD{HeoA(oK4yKB z(P95?JVnv~6`DEY9pm&6=DcFP=@eYwbP_-Im%#n?Gw~Y79^sfRvb@&G6Y*JF2Dv{N z*F1AT46prBh-!JS(Ose$srQ*+(?$clb?s)nyjHUL;=?#9xh5Y=COoIsoxH@(f0v2l z)&21Srw-~X=c~}KcaR#sa0f45dXZK7;ftPLlfv|cS2nxPMkB|Kio(oeSJ^zV8Weo8 zok|V4Ym>5J0ctjO77))GkhFd=YSz3h)SYmU9i3Z??mj%jMmCONC$7uJZK@K@>G$Q? zz!Fxdkb4==IS)VD1x?$n>;PEJIzQ3?IsR3s+m!t5y6g+ld9}cl>LG_*s zvAb?3mUa)sPp2m1cHgx~-ux!6Fgb>9lOg!>-o^Ms*L%vQ{s8XH?_opUufD zq#O{5gFkPebX~{sEF&86y?K*(rJy?c{@2wUrQJN`y<^Pnj zIo^PVYYeE$g^jpe=?1b|7LGh4SD-mVv(YvSDZ#z92vs*tMX5#%w%t{Y4zH2K*3UPx zN9H}J4EBVeG69P-^)69D>y;?_y%C=KY8tgW{xfy<+eYe0i3(01Z`E96SAas|mte`t zqiAaReeAGQpVAPsLT~N5P?xzE9+iGWWocEi+^Axkv6T~%$;myCC3jPwz8g|A z52><(_-rhB@E2>ZbW#|yuY$eU`WoN%wM0kD4KQbaA3oXe9W8a*%^H2n#d}(Msh!i! z&~_Uu+_aNGtB%RD`r;gPWBF;DagTN)kp&0c}NE4JX( ziy-bbPsf_MrMSwa6GyMug3X>d;}g%mW7pFQ@Xc;xeAdPt4?LfMSI;_%*HV_a?dfdH zZ*#>Lhc@AXojJH<%M1KAw+Z#P$fMbkq0|nqTuN}`HQtwLkG$1lDLDNi&MPmay1@yy z|DAxf=oVpJ8h}nX+!i{yRbZ8SFR0JOYUm8yjZeo`q8RNtxa0jaw8J(GPmJ(E>ou-m z)#xc`hx%ERxXl=Ue^W>4OixAZ>S8QH9K}6M$Ac=DP?&j{U}H%s)kXHOa$lZMQumLs zA3_ss{5ra@@q0sFOlcyTJD7r(24$nox`ilObkG6AYN&sm81Fh(KsFa8(D24&wCKkL zEZ%D<+{iq^yB^L!dU-d6T8k5rnY947W#}U#|FKwXbT)=Pm+-G~vvB&heAJt|7>z`v zp&R*XHp(|*@v}@RT=9H8IsJYadi-V>drgx-r(dhHHcvgMM~*A%a#GsAj4HvcFc02klF3svb1pJImYI@BIeI`coCVD0&_G zcAyw}T%hq*e-5VZxv{zpS%)8}wx*2IsG^ zbFJjqm{^*PI&Xx0>X#!j&;T1%bFh!}8iBorDVk;+f%i&fv&u&Tk&Nq7e5Ab!vohh- zfswg33TNxt#j8!w9PSzFRTIT~`}H&R?N$;_S~Q>9G}6vWe-g#}JtLt>Ba0(}Cf=6i zjlMRC@jmot;Jm_#&0n8w!G#54)XSM7IM*s~KWJInGUq>|7c zm+M%W%*F+Fvb>4MJE&c64`GY7lkg1rw`g&pLbJd-1y{sfLm3y-@uLpW8ROPl>ZV5} zW+M*ZeGAv3o{sUt=+G+GR(A`Q?YM)2*Gv_~WG7nt^#yxw`E_JD*O8s0up3_-RmED8 zV$^j9WOH`01h3d<9e#79k2<#{0qMNg5E4s`*tYD2IFg=$w#7sI+F&ZCUM1rPj~&q> z*Ih_{(N$D=l|Y?yN7epiV!4_;!^cMYPW1#5BN z2@{;UU5R(?@Ogav?E|(-vJ^d^U5fAmuqi2pc*9UIHO&8oTI|Jmue+a63%^QJYf@I=h%#I1n}RGdP?5IDTHL~t zLvyGlj?1ar;YX=Z@3F}6QW7ciT}IHc)|q`dbpsXHpMx(iebdAp@Wr=dr%?wYN_a)D zPVqiIKg6qXFXc^k+RxjmbC9=gUk&f-r4HVE>wMl}o3p&h!|l8?CH1_9pH;l}_)?yM z+hLw3?<~*h%whJ<&N7=5pI%VP%Ar(Mts(0a*+s?MY^5mvO7z*dg{zKrSYw-;)NKbz zs#L5_*wS}^9Y5G1d}w-_RejpX&i?BHdwgadu3DUoB2@IL%I5@|(VdL?u53lQL;a{M zW;=GX_eIOTHnC}@e!KgPD1kg=IRa)FS;!Jk%AqDzYodj_&r(4Ll$%{IA>odzkvOE%3D-{>Viznp&05wP;SS?K_M}NFHb0xq zOtfl5${pR*yXghgQV$Ir^*)hR-FF%ln73m70y8wr=R8$mxsvrL5@(H`+ThNYXIbK} z8)$=)C3aBAN5NBHvP!DfxOz_t{ww4JlG~z!IQR5;zw<|Q-IHDa=lqe9825kPKl-nF zKZm*C|Mmr>S#rNMUn>t<@w@Eu_qD`P{P*pDYI^^@a!&bo$3NPCR`~uW_rJeWHSQmr zXGE2F|IR&q#((nudp_--ymvJIgV*XmdH)^DfAUT+`w!lk|H=FBX#6fc{a176{)_j& zuwUXREp7f^6nc^t|3mw`*7tY&@B8%M1-YWuD-!X$iTHk3{)&l<(m8+bqCZO<$NVnW O74eJ_`%V8h?f(LA(Cvf( literal 0 HcmV?d00001 diff --git a/test/data/deepspeed/model/params.00000/global_step4000/zero_pp_rank_2_mp_rank_00_optim_states.pt b/test/data/deepspeed/model/params.00000/global_step4000/zero_pp_rank_2_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..d991de1c77684f8c110c5465676c2118c813bfda GIT binary patch literal 15715 zcmbum2{={J_cv~yQX~mQNGNSCRQ*{7Aq!-l{{7ZBn8IxsrWM?R-P*A`G?^#0iIfYPeCgq{Fj9X`u$Um zSS8c*&*~zV2S(%_yNm$;m|(yESc|w{qFCL` zQ?ZZm@7l!u6U75uWfH|2W}eD@FN0S^`S~vo6>A2@M)pBw15284w)nA0r++Rn&Fm&$v2O^!7{BJ)QKiX<`dc`9DVYiME`I zr>xxg%z6X;S8aT7cxYIxUr=zM|4OmdKfTu%$bIe!juMY^+o|z)L4OU55nKC+ZCqvi z(&O zNG$q`jkT4)YP@(t+~0WZn+=SN2ni1rPfQd$nt2-bEhIMDKRm*(FW&p~1o+1WhWW+x z<)2?vUkZjs#fc}mCV39+Bl{EO;lUBHe*S@ham(VC`u|<@AEaW$lU-&1ivOR07f(qP zJDGX@!7_GbRIqsJ-^YnT{>#M9Vb+P_X=Y*n^laGJFpIE%dME6k&WMZi9M~t~F9iQ3 z`>$HDOQP8IzZ0nsnCY&vF@KTyYtCQOqoTw6ka2T+B_9_i3Y!pTvpo9EbhKj`eTbmsgVHv}C%^t-4E(-O=cohIP@29vlkbVGfquPFD&x9<=PtOsoZT4K&^Wt_;i%(@` zmOPNiZrnt>FV?4xmSb9E;7tcDYmkhH@S+_O!C0n& zu;0E#Nfyd5bj%W8dQN^V{iI-Uom6E%Yjm{)8I3wc`^{X<7W)G`rp;9HCX1%8%Z}vI zD-7s2`z$zT8)eQk;uk$9DT#hqcpDAdmRkE^OB+%@-^3Qz*fRw|8<799(`;>sB3+;g ztYhFFR3vJmo5Ea}_>}Q<_fY|NDkYb8dy~gB2Y1t@GjiF4E%x*jc`r`aMuxfca6Nr` zXCC@inaqw_UBFuExUfERI@zBuoR|xH4zrs#Vy>UwF8aIE5GK`p4So4rZSBbIMd;Nv zH+F@cFCA#2!j|bIu(EfmB;VH9ur?~^(4qVf=!L%n`=;7kvhA5WdPi-MSkz6G4EA1w zhMlUF{8$)?LRA8h!N-YkyTE|0Jp2t!+c6!*`L&>>Yhsc1+^1-Whkq>!WNHuZp1_Wr zP$3x}Kb;*^Bga|~zlTO%x`?bNe@Az|cd>W$Mm?RYII(|bCUOk)e`Z5 zeER1tXw|`>?m$kxhi9kIGMY7{};Q%D21-;AIetGT|(DN9plPE5B}bmt(~~Ubv-R@%toH7q?C9iLdee`02rKroO(z}LZ?2wf z!OZ7sIsb_n`9`zAu=BofngiQ(SmZ~F(%oxO~=DlMJ4Uedz!^ z9-L$Ms?^ZudcMK4AGhg}SH95jfM!3I7a}*KaC)Q5Sae4wn${}3LtiQ%!o@C^Wq&K( zN6!x3s5@abo89^(f&F%K3@anoCTU{q*x{iAB+`BMsLJjU?eR;2o$^VRb(!Ns+s`Yd z_eI{57^z3FUBjH&A6ACUt&20+-CCpQ6pK`Pe_0AUsdX6bbGiYAo0iidAD*)5wu;RB z4+AA{isrN9A@&wx4P=l2|-UOdJC6xBOVmhh(Pdyr0MR@4CSr)iI&p z+Ga}1_S1Al>rKgNqtUdxtSoosrwR3Cw-N09adj*toVO^ozYm(A=Wo zlxw#UH@WybdNKVOy>q~Ewto99q`1LJqOpPE4)^;`hvr^EqxzNB`IL@j{e>>{8hnhd zyq(5gcR4PZd~_;GD6wZIb%jb+IaRX_!7X&GOMm9TiyiEv6BC%LNmBMzYcjjxMjoAN zvWM2_2<5`9TG^W;ccAxHh&wXSi~bcbhYMb8M4M@MGNYv}tR!X~m%s!_CXTb=mQ9|) z1ekKnP=mwVZ}OFUb}^4>jgO?eQuOGBYsb-auMcgl{F}~wZ~#`-PGF=j=5brbYq57d zhDgAq5?MBNqL>s6?WaAs$hTYQwsc3jTR)kFd9~=tUI&R|RG+_P^OyvMF&tSujrO6& zF&lQguG<@wMYo>uWPhdzS=umyeK3-vBfqL~Yfa~~ax?2B`ty95mAe%gO7#KT^Sp>2 zt9PDOdTGjS*(N~^i)$En%QV_em`?v1v7UA?+(XZ8-@ts@;?Hyy8ZsArMsS`Un%t8? zds#2hHAZ-JJ2UO}SI*o`m)ong3SHQzML)d2vJnwk%uSvNklt$)zLTiT%%(eN3lhCJ&b$d$X1_KWBBGMrX}5< zwYc~c?4IS*mnvkrD)TSs^e_uu)VT*i`xha@-3N7A)|TkO=D z{j9~lAbRt>&Fra^eROa~1s#}W#BIh-}Jr^};{vXG|53|LkTtw?{ z?Pdp6?_u>`xw1mlGm_Vi2Wx-ms4;h=UZ6NhI&Ckz0_AmFqT7P{aXHglk*8k<%a2f^ zZC}<%%vc(=UA{^iZ|p*Yl;pUi+-!D==|y$`<;IF@kFv+UZ)Ocf1k<^` z2j~UN4!XVl4g2)OZ1yX*XB(1SQN~nnCivXfI0R}-sI#%Mz{GQpCaK17Xasz{=HSbcixaE|SFFNC(ccmPq?UenHu zJomhj;_fDPpi<2-bpLaQY5LJs)<>;DB7aAbiz~!zk%^FVWqWCvheMf3W$Wl!iIc!W z`zNjQ+m@4hZ>8rQUWHzItcQ))OWEks$H+D!l|8mNl)IxcmKj*yz=ljTrqeo#fGrg> zdX0r_daE_^_(HKdi_Mwd^}B2H%+}L}8$(&UkEU$ruCcVeTM1hA)r!5IWyUmg=yMI@ z53uE%w$r~Cs?jgb8ga|jH?ZgZZMjdoM9k)tTGrZO3j3g8LS6LmLfT`|^*ZMlrp)e} zk?gxY$LY_?8jQ`HCv;ZhMONFsm+n-sWviYQudO}1Utj!6>FVvTb%*+;UvOxHqL zN!5p1dgjJ&wT}&7vp%r_%njjx{}{^7V~r z?>&QPqo+$5tzBx&Vk&!EtGgZR1+a9Oqem%iWI5U!YdohRG zNVl;bk)vtVH8YsGZ@#mZhgWhh=3(adm{Nw*oXS=X&t{J@W{l5gC(fgH8vP_>Jre9) z!fsEU&Dpy~a~;YKOwtQE_C{to>w0DkYH@U-=Z0I-vc_|`GQ}9C@|BO|^yRgT;ARG0 zGI}g0uLF|XD$4YozyXYN;S{dw%1Cbdi*dABLqF!?#7fBI^4JyD_RM)5o;muwhP}oL z>MR4i*|}GwY{8uY^yApm^!K+m-0Sn@T=oqmE{gktmTt;pz0;o3uEJ50BH24AOaC4n zP9s*$Mu41|@si!tF-g+2-)PH1OEx$ufht`P#o3OoMhkZ+b2H|IF#JK9ts_zYSbwU#sZ1YE7cN}74Ngf`ZmMK@-Y(Hjk8*u@cV zSu4XL_M@93eP&)volaXnB$atb@9T@HQw>v2By8zq`H(Wn-2u+bh2mwj@{DEdq6uec zv$S@QeKC>Kowx!hoaCx9W001p$8Rw92 zzizjr*WBr##zyHf$F7f{^4zmI!vR0w!2K_%@P#h7 zM>Bx+PJY7OYUyGfx+B?_+h#L%m(DYqLA$vo(+l*eAXy0>5XxyO-Js2DT$sc8Zd})Z z5e^f3zV-bb-^&&j@NBr_XFAIRbI)u)v6kb%(=*C-nR=P#I=72<4E^XUJMhN=PD!l9 zRqUD1tXQ>$og+4(Q$i}4$(`!lG_`~Dk!+Uvc|@ z6-skP0_a;hQbIjt7_Y2Lwe&Zp_LW#t7v+XiX1eO z^9GR$!pi_%l=hz^)a8(Cp6k3Dw84t?j=31&S&5cSOodmAvIHw6nfuA)2($59)P*HXEi`>4o>BZAGtoGA?}MR*hGLw(s=Lb;qarzVWrEO>5u zK(G%73wn;a31mOq6L{Db3J%7c6KM1#34UyK5!i8of`%SFYT!WwbVpx?i`OvU{6B>FnGo=*;4C1L<^QA7*L0^iUgW9mDIFV9|cdzaq6|5I`w|$ zT*_(TLF#RX8c|5<(A?WYMenzVs`6lPG8_k&pR`fKXnD9P^MM-b91S;5 z3_;$%UsE5}q(b`i@08DBAtk&17`5kNCZ&M1srxY*)YOR!sJtVW1-1DHsZYv0rMk9| z+S6Ym*f;V#74$uoYTR~Fuv=ve7*yGS!i?io0c#95v@TM;^NhfD!USkx(}gdswq(x17}fsZcS9ZZUA-shaUCP&W!50c7bX!Zx%$XtfJVx z&nXK{NAQ{YmU^Xc2=P`sD5auAN>5=Tb=tX}x-yEQYHkey*LHvE?i&qC+WLwz|2!X7 z7fz#GX7z*NgP&7+SCpXU$O}r{@)Tum-Atuks-RLGM16&icB;Ah8FgpzGpc9jHfr(J z)6_zjLU^}l4LJThNYxa_!tu`0pjVnq#fS&Mjx|+O_V@$TRLMx_eF3n<&I*2fdrZaN z{YmvWDZ}Z&NQ~3gF^D$S-ZjzKz zS-Azi2AM+q`r%L$Yd|H$Xo3GwBPt=JF9%;3aMfNA>oWp6 zf<9Ae_l4km$pl5%S3|PT9oRM96q<(~gYiH5gRhJRMBVj)&vKctY|BUZ^wk9->Mv6! z@vo`r4P#+0l|*?rzM+mk9t7`j1Ldl$5A-%&Xj=4@`Xb&#>DAUzzM0G6)^jVEY1Fre zi{l_>cL#-rx=>zmUDV1(Yih>s3d(qk9b7O|fR8do)RgEU;PrMS+~|LR!pk%u*XSqZ zB20s#=u_0rW+7-^-48my6~JCq4K)`%rS>W`z#*4`Xv?EL(7)RY@_NUhF)qI$S(FCL zPE>>XuC?fqXaQQ^x*m-7^}_SpR-lqxPEk44;EEGKc~2G`S)d1RkO?TLjex9*BzRao z0}4)6QB@HD-`vN+g49J|Rr7+{b$TY$JWGYaB>>Yl)I!1ZL1;FwjBXBeg6#TTEmi$nc#6H8A_*w!-y?UfGfTLg;`(V$CDjU_G>S^ zS2aXl2?J2Z&;e+`{%2r2axi)$p9PWWItXc(L*O_mG+J4pq=30#)f9m~u24W%SIf|H z(j$<#Uj{{V8=z_Z+)!Q+jjlhjMuu*R$mz~iI9-{Fj&3wZy1DP5es&FNwparnIw!#7 zr3tY9&TT5tIvM(BY=cRu2;#m?hOohMP-wXc8hb+F*t2yo>8md|uh;@By|+Mj=KzSD z)dtIk?Sjd*51`RiAFgD5gf>5Uq^*|@Q)k4%Z_hGNtnG!{`qyB|`13IC^c5Iw`Wn6& zcY>R)3xo`u59OB)AzTsz3tl%+L*9xZcl{-3+^hvTi^fBYz8M%=t%ja*OU%5x>c}Y+;_^|x(Zy|K12PgTu5;}3>%MyLfeU5Fxp-M z{xv&b?do_qw&px6)_nx2Gbk8KGGUIjGbFn0hLk%`sMUVwVdq)`WjfKYYC#&znh+1q zM^!-T=&RuVH4DrxZvfxfC!wY3F+AxO!dDGX=*`js{Z0&&=Ps&6o&oEERG8Z%4?4F- z!k&=xRR0TpFmzvkaJ=sU0Z11D{ieVod3$&ut)i5zi{Ztq#qfGaDSR^Np`26>!|fw` zp$Edi=k#n?H*6R{R5e_&$%IM$Qo+YH865j<1<48t$gb6f^(XE^QD7ykT<{o74Msv| zNEf(l&_*?C`Eb)?Gh~jQ52H%&K;EGL;HYa<|CvL< z_HaK)Q=}o`(|T&i1$&rN=L_{-I^gyCHPyLtMBo1uqd_gk6kN--pxmjOQusO@qInNU zj_0VxsFM&N8v|xW5?F2cnVOW+055OfgGUizaJ%1J*k{Fn0V9Dxb`!YG*a4kX7MQ#_ z01^Gpz^8|1aOh_+969L&3u02>=P?=Zb58?P?Rb=T3HO$4Mng(SODs~Q^8W~53VovK(%2BtiEm!h5IMK(NUv79mYU&(N5}8t3Iq7 zu@0u_U!Y27S-{&p`>11Q*TL3VnlNnOB1qRzfgY3D5P01Z{M!#xPhNJw6`lowYaTSO zv_l;x*FiFE1?YO8gmXK-0P&GRzRU^W7EDFQ90#I{`_qvB2_M8Y9EW$c$*4BsDXh)Q zhVVt>k;j=DwEOZ8_}t7xW6V&r$=Dp7c`zJRa49G&Xg>OABtnN{>`@NA9+}>=Kzr{+ zq4F4g@L3=S+}SKRbHNinKnk$2479DB3QLzn!D=4_mBX)qcW(eVG(U!E3w&VDsM#<* zY8V=eUO-9qVTiMu5B1B=Lf1(zv`1$Lm<^8y#mzgR*sB<(4B7~hKEn{Vg+^DV3`8$o zegXRJj=U;0k)5^xRc$B)(?=pm&(lE@Hs1$0xe*14MMvQOc%w;G1+D_UeCv7h(}gHBdn2 zq3cj5DupS*AD}NYK>xB5@<>-lJDYFAYTsD+ZLSA38>;%^*AgYT9YMC-ewcM*2C`Us z969(8LOJIwk*!=1npL|U#h;ms+RqdstIIlQlACa37N(*7F!-(rBw}|sG_Md;YNJ8(Z8{8dNQRB~&qC{JJ(PWZ4_x;; z5BE-Lpj6`wNEopk5*=QyizQU;aX-32dBK8KQ{ zH$isOCpfsNAJPtzMd-XLY;U*%n|67_-If68h%SWL-<%;T#0bRkZ((1p1q_YZ3r-gt zpvpT5PF&gy+sD0zaVJ_KrDiw^Q5g+#Wsku;)C(qc$H5HgCfE(d@Dxpf-HTMw+IBiyvN z+Zv?|%LebynzTmiKDeUoJ){_pEzLu{g5S^*a1woc z7mt#31W2|Ap|TNMQTiYm+WGS>*pb=~(Z4NFNz`(r_|=wHY_dX^ANE5p8h-+1*^8QQ zZAX)pnxL^|5%N2*5cOUvL_5M+5PwKVF(*W5-D!2?FfIvcwyGeF4?B?UmPmB$h92C$ z(@T}$+pt*y!P>MhRL@#F(8~TmebmZ^&57{P*B(rb-uRE7L#Z7<9ir z3I(}p)0d__LoWoU(U5y{QPDgDdRFsJwD|j7&@6cg_Qh!^z3c(Bd^bVIO}9hC-DqUy z;fkJ5aYw;oK+U!}=!}yOa<1jjT^k)VEM^%xxy=(@Z!||sE~}%E{ex)L;p@;w^|>f= znL2vl+lcHhR-rSQnW%hGF)Hyqk2-wMAlG+=NWEz|`W@VY^nM>gk!CKC=Cu}deR^SU zOg^|i%7)yC_?#6UT$qYKbXbtJjg9=p zq=6(_^B9hHA1&eopYnGc^RT2`8RGy?hjY5xQEVn5ayh(p>vUc{WPoUP#b8pFehJ(3PQv$ACE>UO^0=uXgA59=;+xSJfDrDTm8uI(pS3YEvFTWvP5!*Oqk~epw@a=D3iF*;oI@*`& z2ftfNV)b44;-nCgu*Qspp{rOXynrmqI6$Vg^(W%WWZw983pp{$oy5)4bMWvrAy-p( zIkavs5*{~K7tXS3pzG%8O4X*#B?o(7@Qm-{`fZJt4xsEW^t46z{$+hp_#r#I&cz6? z#|8CcFJ9$6b;7xXK~IG@`7l1cT!BAYYE8VlI?2G(mHdv5Br>KjgkMoD!nZHZA)kgk zC--}=V?XsLSpQQ8ekoRyu6E8AHeMZyPkGJ~sWn~^=6^Zn(6iKpKYefndHi?-nIyiB zv)j9{Y_T7o-#9^(`Z!#ub<2#m`+Nsa*Q~{-#)gpMGcuw_MVo}vpDZLb`wkP!eQWUR zj&kBVBbhvKQoxq0-0(~30n#EKAXRvvC0%M*g0-)IVW!l6;&In~UMEgdbSG3*G;oow zbn3}R{K8@x>HC$3iS~~M2Q1i(myGy9{BGt5JK{yW@*`Er_nnqleWo?97?JB>^|OR* z{-}mMW*^|kZ-0OX2WUzc?4E+<-&#ml)yYeJOwW;PCKdQhx+c-zo{7IXIY%QxOcnwck4D|Eu(?L6Kken#heY~ zYPp_p`HI`Z0b3jJ*vtkTq;5twZ4V@^)%~USWoVK@HxXgj5qx;70vX+=C@QijBi?P4 zw7O;i3C$|S?*)Ug?z?XezQ?<{WgXq5_v~;yZf+@FbRmT-exZzCaG!X~#@qb)l|VY5 zXOIU=6iK6P0Zu>tjL%WD6U7a6Bx$>T;Zdm_!r9)94zVjVuuY?#sAK7QT&2a5QSAxj zsJW7GMXH3DGYjy$Ej-!XaTza0+lW=>Jre3!BkWgO%D3+w$rmjhj&%+y;w*j~?=|3p z1L=_yy40rNiysC{?Sm{Z&95crtajn;k=O9sQb%EHe>bcy8$-tJti%g#2k}0O%kg;m zOQbky2_Aal1P-}eLrlD6r50x<;OfgI4qn3*v6b95Jg`@fe4o`r?wj4ft{%gLCGlK+ zW6U>U>v~h-7SZAGJ^O~xs{R?d?q*0T$0!pAw<$y=upYaG$zlVkyr^GS7jbo#!#8GU zkrhUd$s)&U+@(Jqzv!MQP5Rix%Pg6QKk1&qF?o?>*rYt-ST|W1*trk~gr#Ht*fcyO zF`YObZNi7R(R{YMf%Nb&MQpt{Sc-2Al&ZHyIR3#}8r2CvU#|@LPQ3%mSzo$JH<^qn^X4m5UkXdt7h;{U`l3^*rDVfyby4!2QZnyRrx2}_7i}sV zPSWd_2z4_GFxy>4UM>HM+iz=0t?plQsLC-G>9rmbp4Xl%)fuFZzbMJ@=;U#HKi^Vx z6Q3mwYi<)wRZiHGKVQ^w{Q}PG=n|ext`pvH#(2RjYbZYYvnH=j}hM0 zP!wJ6QYBa?oyd7Kk>AO$9X7o3BAUIXc=^U=;-Xbew3o|D9a84-1r-N`8`rnv&rkDs zE8a;OcJ3ZlOqs`|jXbhu*@Qm--_yo5Gb6zuvXYe=tdz(4$Fy;!5>hQsy%U0lbYh{E> zA56jD8>aFBqBG?1o&%ypn^MuM<}ISnySIpJEQ!c}>QT|iqj@63Z5g5ovkF9pANGor zcV>y)E^HKam>(6r>N`0)C$Lm>?(IQQ|E@V?Q(hG=?MWsZw)gP5J8X!S+h#K8*Ji$X z!#(1=xIdoeS4eW6$Kku;WMbFZ!e=TOVWVfYIQ*0cIndu+l%cdt6c*+q8klfYbjmkK z^qI~SS(`YE%Cr(ize)l{BerIU^r&f~{GwD*?ZqrnP|ZZ3vg+lf(dt-u zHH0{&COV+39PHuigP-|y*4G$JBRk$Z@+n74gd@k53VWUv3R6EF;E&Vego#r&)BVgf zFul7@5e`xwFiXtTdqAIitf=wcAu@ZeUhx`>LZ#Y%#X*r^ae^#xYgi!KMs+S z?NxmH+6tna+QE0+YG%I1e{?v${xjaSy_<-WC-asWA93NaRk(MHk|?yGOPJQxO0q^o z(#1(9u-{iHGlPE3ca&}*BfQjwWg`ZPs;%^-XNQ?dWoOjkQKx>9mD8RIWI^QD!V^`zPOJW?{Y!(v1#Q>bx>PVQZF1+^aY%(Rr zitKvuQYan6kgeU9>Tl&m;Qcq8$f}FByh4vD>Gu7AfBsCVmtIW8*CJjM-|v2;Th2-N z5ERH6*NycXGD7jDlZ_&L*EMJ2y$@WWN-u--vAX?05@R`R$YlwTJhs{dJpb830=;pz)=%q1QV9XU_> z%j`5>J>P&>3%BrNdgf!jN4flnBRTlw_n%lj`Uc;b^?=X2a-#l6)FG@h<1Lmktidl* z2jRl#EDVyX4h>r;;Jrh|{6vc)($uk%s6{1_x#R2c>4!!2d)w^!-|7bVbc-^ce@#<5 z$6^&pGd+v-SWD@?LQib`N{LjKq?7BvYq9C75TbWq)F2E^_J4|Bd(-!&3NoN9ZLfe3&?wB7cLq(goO8fCMRnH%4%(>pV@wnC$6jTjLGM)%Fs{bNVk~0ySkj5PRu70 z^|eH|tJVpRnZ6fxU(m%FH@M@|WAXNb2o9c<~%9X-?ic;^4xOEh1;!@X<;-_sl9{3<8qT?Ix;{ zOTtZqEJ)JLc5<$zKnQXbIBLxe_pQHQZ`-S}4F z645epFYFyTIHJF>XIT%fHy%z}v~w{DO~dzb-yR4@$?jQ0rTwcZTobJ!iq|`Yy}GW{ zUs~8LJa(#^;(!z{8;Erb*vvgBupJ?f9OF6Vt$%@y`Pp(*N)EOhwz!nD4W4Zv;2 zYM7s$jjw8_;<-E5VY+rce)<}dj}iNMN&EyHn?Z@%e8b7pB{Oi}nM1$-XJ6B;%Gf`Bsfc?&#~pbao{_V`~UL zd#)aD9it}vW$!{(oV&=+&B!1z+oEuS);vDuTp0-|86urwdr4^5F^aT(%;oo`IEcpU zo}?F655O^EwtmK&aoA<64WHJuT&R%r3XlIx^NPQR;4Ne~IWk+G3~dV|Nk>z#q033^ zGqHmou=BofO}q&{zI7VT?MNl#UF-Q8h12A1@FjFu?3lHkks;-b5d!7lO>ak>O z>>+%rgcHv3-$b4VMv|PGLD(dFFgAGon|$=79FFY$=CIRa8>w_H$6UyPdaVmI4j9_O zZ+osLtUWmuZx3EXCd7=uS+(zlT9=*i>W#rr&}PR^%IxNobX%dkdQ{wR8>Srn0I8O(I7laMN_&hOOsziBk55iO@6>0H{52u2X9S& zkEcbP^^b7uyvc6)M}rX@UEsR#+)jCE341H|5%LK1C|aZ z4-W9{rk>>DqeH@Vr+4w&uU@Zjm%D)%eeW;z>GZ~L&K~9)4*8SL&hf&xSC^2M@9rX> zns36n^-3i9a57d3877^*Z8bSK=n|Qge3Xp(_#T&>lPBK-Oh}|`f1Gb*iLKu`5z&hI z%xe8!Uh}0J*{hmM*sWvub1jFldbTlkds|LcM(B#tlV;;DrW5#gTep*gTFK-_ay~zQ ziKeKt@C5lV`E>oDC6v%m!IGc!QHdY6YCA4c8b)T_G{o=XHxs>wC+oilZ6(WcM&Y35 z2He#DEs2evM{W(=jGwpCxOKG&zebc!s(0)s6TF6ybsC|hYBG(V)NB`)8SEC$ecvsV zIZl(azjxNZF*u2#wnr%4$n1(d}#u`$g^qkDPc8;I7=m`7d8tm88 zK`iq}@Jsg(t`CS*lv?&*j$_YM@YJ|85*JF5>{equea~xRb@+O{q3KA%H5m)pX^zCJ zRu`L2J1T7Uf5)3C=t(!;OC|R#Ey>keyKu+)LgDlC--L~_*4U*xglMMRz@1)-LVUH9 z2$RMV&)?(g1Luf@rcXnK?D|svbHNb@uh+*&uH6vPfnh?jc47oc^$*7voR*VP<;8uu zl1A1I6_KdRF~r=wfGo@_5FWgLo454IC$-@{cyO~49@b$f5v)9fi%d;O_}V;dCj0jYkk$?~B!uJE_I~*MN1YbIw%+EQcMl7z2N$Bi1 z_4krIag9?6UOf0b8EchV|M_ty=F2KbVDC#|+nP=Fr4Md$qv@r**$S5IzHtP%us0kA zzgpT`}-lf1~x@CU-fcOpm^JqOo4okK>iydm5+FpOMzl7i2l(&E39 zqWbI4nn;;gh4_Cj!5<i7M^F;tf;qp%kuOAw3!Q9jzc_mLtBTZZA|xSw?ayqOeWL8PcV& zUzl-k5#BjY)?q>8RFZ37#7~=&D{M+C&mIntvw#ZTx44 z?|*Xt>zS$n|HgTuuM_WIxyM=mC-1-3)Bc_JuFik(+WaT)zhe3Cyt?N9!E5`Uy#I>E zpVrfV4o~Kvy#Iy0yQ`9t#eY#4%3J;q?Vn!XKjXje-=EH38Cl_< YpMCef%iUG+PrGg(kD|;U`hU>=FDYi=m;e9( literal 0 HcmV?d00001 diff --git a/test/data/deepspeed/model/params.00000/global_step4000/zero_pp_rank_3_mp_rank_00_optim_states.pt b/test/data/deepspeed/model/params.00000/global_step4000/zero_pp_rank_3_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..d5fff6fd518e13a0a20974d584b66bbbde0d9ddb GIT binary patch literal 15715 zcmbum2T&A2*Dg%X2qGXT0xH1(NZjt3-RTuk1VIcK5J4n^g0KV?jGz)!L_oU{(GzL?W)*h%<}7rPmXVN9P>}fl zG<69L3EwFHkkIvkGXi{Ld@Ul^hB?j}CE@vRZM)PKfy5LK%@Ov9$e7UZ(D=Y8pXeCh zm_VO^Q2!W#|?afwX@>U|?i)WME)`MO5s%5xzi+phzo@ z#lH|KxsEWn|D{(F$V^QT$ZiqHc_@xZ@{b)+@n=9FKbQaMABIQ!d1#EB1;nlM4G;DA z`KKI#LaN6f?;_R*Mg@gMY!E1>{zFJuXrxbYly87QY3@HnNJ#v3!_PN5(C0sV5h!mF zs7&#Y9U=Vdn?QAoV3do*7J=Fn5BZUcfg2-zeAfpH)cs>3Mj*5BiH!*ji&pZ93iOK& z4GZvz30xN)5#_VqT1hZ^%0F15M<{)R0|grYU{P`mi|~zc|Hnlo!59z85wVs6O%J(| z-aj_Lm$VUR&5x3a5oj-Rkr+AIB^c`>xl5pv>hB*B78d9q6B@BDT4_O4Kwwm0!0Zv| zl?1v|7YoM42*xiK=y@pp(@LLcfnY+&h#essA^*^0CD8v@JqCaE82-1OiGTD=N)7oV zZgPykXtBUJHrhjFL|~+ElyCS5h!L@o(E<~1f$0{(h<8#WBVkd3sdM=Y{pR2NT=COG zZRCnysBg57e?(N&KTH)&OZf|@Uto-Hv|##N{=1Gp?aEyK$FEl(XeUok5}0{N|1o+* zwE0{#CvM*}%e~}ZwHpFMgF|9`0s{Sg;{+D}oV}4i9+{rNNWqNxN7Vj0^w+>>fu*;= z%0ZwQU8c_L7~9{`z?Z5Q#|xX9EpkY4PECm67M5={Cs2l zLwuq~)}K%0$Px^Wj1@SzZ1)&HLiQ)hLj%{v`1tz!$A-s-`TlkK4^q*B*)EcQ#s5#h z3+8MQI8O2SgJn!yWT0T~-&6Yl-*ACbh~*Z6^OTT(&TPo^5VMef&Q8cbCnGl2WAuoK zzYzSF?7wOS^R@_F{(B*f0ORT+8T}WTzvlckJu)hE1ey8sd!=JT_#rc6Ef%`SE1(pU z`2izF{W%u?|L$3%BnjCSyI?$)Ybq_(LUy(216n6O5lD>a zPoL~OcTT7rrYq$5O%Xo-AtBs-U_JNx<~ByJvYM!Qc8Y8}>%bkXks*98rjQxQ6Ua&S zrNq8432u+s6y}+91=DDx#kJMf7k=*d5++0%lR;@a$lfYU#TNf#}i z|DBJ-%NU3&ji21j-Pf3zs)pR22?0#~&@|D#rgPkk59ZvX$BxrZHJi9Y8TW+m zqxzV;XPn6np0DuhxODFR8^*#3G2?_CZ8Lb^lTw6!yEbxLMN^r5PhSXE*z>q^TT{6o zZoF*H;K_64o}~%*ughjC?RN|J?2m`;4ufRAX$UE4<3>h%ULu#T9xc>kO~?zoUeLv6 zE10lnwOl#>L^`J|Uiez=J{gcw!Pw{aH%||FNP4GC7n)ycWo~pgGlO5|k!Gv!Hm<+2 zi|l;f+DtwSWa9F7kw!;TnCvbWX0Q7@lF47p9BDsBmL8qWoGemeJ~=(6ryAQ5@=@c2 zH`1NSn(vO}x_jE(q3R^!{csvN#;8CTEd)YSk2EsU>j61@mSPUP1I&e;rghI_bn@xX?{K2J5?AdMhmfpQr2SGqlMyM~$z|jEo6FCMh+Lp8bltF?$+~4E zl%H{s41JnOyam?@tF$~~v8a)h-tdAxp;X8iKI|j-p3jJMgFDPen?_;9-XgMp`!vSL zZ3iLoc{EpiE`!}+;PoJL& z8M0h~n>xXktWu9=QgY6bbeO(S!&%IXe!7tR!(W%%sr!z3|ICva`}GzP7Bz$UaLR_+ z@YR8Ay4pd8jG80-U2~i))qO%5Uy)*z_0*X7-Z4VDw4f<)JdgRHVZ~K(iX=Ot>gi1X z9QrQj61g&evT)s&M}$k39=T(Vn-IAgAJaqdc7&Z-Py9ea(Ob+VK;>g-}Z6HCdh-hRx> zNJHTT(@W$e<6^G=m3C&KqmOVmCzgEe1xZWA7~0-LgPZr;h>U!aAXHjr#>`yslC$18Bz#y0YdsfF;{lss}t;3j%jMi#j&as@Nx7R}rciizHt zdZdK43bUzag;3?@A@2UWW~AYjBH<8Mk*k{FLwXEJ30>Fno4;BdBex}N6n?Kh!91Jc zL0)<}p3DCt$9++uEUbGmhk56kK~D%b;1W`==oxP~!dML@ZoG{=H~IQ0a@G1Zr1Zp5 zq;{ybFhl(@S!3a}AThe?y+${+UTFk>ehY8prJT zK3}-KYX|w1mgBD1$zkkQDiJGHiik7W`)FCFh&$u^IN><=<)nCzrZ8^cEOE_8oqOlI z3S;3-GhHn^NDI#?gjsh4b0augSiD$~j9juxDDmWCb5OA+ca)%jo0u}4Xb5>q&)X)? zyitfJlTYN3%#~%_u2HpQv$q3RX^k>-=7R}obGVRRvT0;5&N@jJ461Quy0r6(#UpyH14{!3W3EThB%#WEKjH7kS&T^Ef*evD*OWTOpCN2A$W}=) zqtd9$H7+S29=G(8mw$~D`p8Ig?>7{a-<6Zd#NjK2%$&gs1jqY?KWj5N3)oGnVmygcRDf4A9Kis^HrEMxklzx`wNmI zzn3^&U&@?JHRnt{r%#>`Rl=_HBqC_-St77+E~CbaB8z&nzSvy+umo?%NWi7uSrrI_yNo2k6rl-#NsyM`MNUktyWK>UQG(7JpLV zz2&!{Mm!t>Uo<8-#&@?`C$#&`AR~lC*qKXx8und zpEr>e>P3vnCv95pTM9|g+4P#QeEN*MEK?a{P800L=0>9(^vR_Xbkq8sgq^IAq#Zn& zkBdA>jkG7k;3g3q9<}&l} zbH+Dm0>K$wLI*#RB#$RgrLShVGpwo%qboUsnXzsfF|9kB>AgROyz*I?qiHest&X2 z#vm#4H5c4oy`4c}67;;w2F%)dr$|@ldrj@DrG?g(#$;k&DznD%4$(gn1B=`DldDCF zT#<4aX|p$t{B3G1Jj>g~l;+rTIs0caDha2V)>STKt$7raZ7CuKUuu%ys+5FhkOldo zc`W(0r;ez2&`98-F=Q`$ndn@QKx>NHn~bjnli!kV&|@hTGAy{AF+G$|Kf35nuB+M2 zBwW%a&hw`>&k7o#C2dcT?vILTeZ3{*$0fhuvm!6@*R^frvJhRSx8f=}zh90!EnJ10 zTcSiSz8+7jB-Rm)<2}izyNXG(yWPYdZU8xW;v8|<(Sls#V<-$T_9Sl?z9beaUTS`S zQ-a$u%q8c@q!6*MONkYel9}*O1?Ite6DF*87u_>ef+^_S!pO|_AurCKPk-?{O~zzN zGW)Wk2+4Ss*yWPad~8-FoxH1&8N8Ln81-f|&jO>!r(3JZt65D9NAezX=3y06@L(K4 ze&{E(td+>pCHACxaXcwx`w4zk9XWr-O7fPO8yS$dwE0Vv6ce1hfhf!T4PP4*^o--W z#Fsk?Om^>~roru1n=0l|sr1#GZW@kzZ(TC)lxuceih);%@+2Q1Uol^|j zVN7m(A|!S8sF1HJT9~Z|?1`zTl*zYBp^WN-n?!?lK9t;%OV?$uBB#44lV6{tlSwJ7;n8pDy6O)}w&lyZ56;H0B<;)71)OXhj#lWYGTgPrj z@5CLt_c%v*(N9VEqjI!xtld}!RevTp?~RFZS~+C*>^8iZbNG6$TIi6_L4&0F|xk=F1h#75V2g@ns#a2PQbfdCb2VsWWhsXV*4j@ z9=Jgt3SI;AAH1UP;^oZAEpcSJBDJ`-87^rx{j2eV?a*4)=Y0~ZKJ>0E0Q05xXh(h z1+??keWc&EUqo+9599Q$p5P_UC5@yU$jxhwiNk@(Wc@D(W{nMs2T+3v|eEizCQnlBM4@Pm@OqvMvoG9YtEIpZUUdD?S>l*R?~&@dEDVERzg3EwZi21G_J$uXN)I%g*j7ul)Q0H)O7wz zGIQ-mBhk1>k*J$t2@FEM1K-UPnoNgF8sB;lu zhA?(t`{=HlBpDl4Pj-K)W_oY@q|by(346otHC5YIHqMMW0+vROYL0#R8syFLZr*2K z0S{f%Y7RVN)bz~J7|ObI!Sjv_noFK!H@QX31v;BX!Pm+Y8m}Uv`QrxZ=9@~tpjq+l zrs&0qFzBAR+03#7){fY?#dZk15lV1gX=lI{{h4> zqJPgA{rkVT&0~ngr}RkX%M>E(<`|;<<$7YQbf`3$kJnl1>xE-HOOi0ZlI{kJK23NNe zcMnPsaXNuS)6XX2vf&AMAKr!zdJ2SrCWd)$>fwI66h5*20o9EyhzR{SsNK>E^XB}3 z{Y()wTYVV%4?}pz=`<{es)tdg=V9SfFE~zCmb@@akMOcegMl$S;naP~#GCFpgqL+1 ze9aw07}_fmCYL|M%|j8;L#F55r88#RNA;gLu70od`0hgz}5B;qc*XVyx2=;-<-KICpj?46-RDhVFGh zi6t3u^@RQKowO9uztxg(4@!XhQ$^6od<+qps6<@Vh=81*d>9cVN1QWABPJcFgr6$| zprv~yabe3d2p(;O`xJLUvx`H}=>1B<=Ldl5ZIcOS_y%?rb-*77%AnShczFEHV&Z(> z4|spVH2Bz21-|(0Kuqc^frVw7aP*vIFleVIOqwM{6u;7f%Awm~)xpIue~&G6Sj|AE zqZaV{?YHpswkY^;9E3MbWnsV%YogdA3@%ok0S{`~!8g1JsB>iq?kk=NyB?*&gmrGv z>?jvHnXpg|>p}Nb$KhJv?NG&h8|;=}4Bg!s;?v+AC^5VPR%Io@^LELw_Tx4pOcV`u zd}<+N_Caf-t#Fm^FSuu7CS2OPf>6JA1Wvy<0Vb~I!fk3}i7iwT6yC6bH_JD`_0vOO z!Sy$AyGbPMbl3!KFC@V}t1PH<=>x1yGlrTDrEqLj5={AO4nKG*!IB#b2(D{8tg!Pj!Ja@! ze66T~7721N;P3>f(3$}+I}Sqc2P!ae=?R$r!x8QmjD@W`%OTS`40adiK<1q{toddF zZEmSUU9mAy;mpGMloq^-N%*CBB?QSAphWBxn5JI<2V*kf?B6@#n77B^_K`;^l}{UB zYDy~1E1nAv%9_F5+Ovt(a=Jv?0t;BZ&kwR6D&cP>d1AMtDU_PcKwXDixGy6P-p#}? zYO@)%Yrg<}CRaiQlO6C~4TJ}~W)nj?*CCD_;XSbmibeVG*n?kiXI}ukmf8$=g=Rwf z*?j1sBttZ>T>%}|bijgT#~^q^Lh~pmSl{JO$h~duDim+;^rwa@USy#T6S`~(~G>tO4P74Y+F9yGi}5rncHp%lo0 z7q;$&&zUr6)!zlplHDPi@&ImkKM3b8oei&ytP#`ZnQ&ImOX$^e2Wlt31CJIx1#+r# z#M&uNke7P}n2>7laez9UTapB?_NqaSV;n4eI|ja2xB(u>#lp}M2^iEe2OhBA1kXN& zaEG=w{M?~Vw5&Y?-J`9bM>8M(>Rt}ZobE%prOr@Q$U>v%32?S54@UgXgNE8u;rP-V zI6$StcZ&wV%a45kPSGLE-|U7H%_ZO=ng=VMEr&(rRj|}xC9M8X0EaVepyB~T_^iAJ zMjV?CUGId!M*?p+gCU?@>U0>si$lz+N`?2XFNa&6`9i9E8~kZ}2jA+2JK&=eYCJ`VmLk`ZrZw6;S9R&G=HRLR40@Zh=VUTMvG@N@5 z1p1GH*}SDtea>ZoSF1v9@FaLXPY05fSHMxU6S#|3!$t3&fLU3(@D`B|PiW+VBP9re zs=M%Q+bQsHMkm;LQ4=b>(1qDs{orN36qGOg0ch_MaB_kxm|P_VGg~#`82$TTbLkSe zOnx6|(r5#@cEcbk|(_#1bM5qw>10+r=g)@rN;Fyk3 zP!4Et(RT{`cD@pHI1T`(g_WSKNE$wAz69zr+rhn<(ePst3l^Ww0xF3&zzTyyK=@V> zu87crd(0xh*>7Dy)AJ{=kJ|-w7Ffd*y}95V=m0BCUxP~{m<35)0te2W0=C2LV8`TJ zz-`4YFzewr@J3!2#?3m%nP#~MsN_q6m0IzjWcehp^oRtwHh&bDvC0W7qqV?5bsI;= zP8;0&fB<`u16pf)Ieib+f!eA#P+qeQe0KN-mh>-x)tUWZj@LbKF}wrxdGvzp7heG{ z`V_eRQ3gRM}b$+DEJ`bBv?010all*!7iIJAilH=9{Fkr8#DVr|KcPdk(mWL zyS3opOf{%}pdQG-e+o|PYr>t94WO}o0lY773fCF;fNklD@F6E2m``$rN_O_pRdYO? zb7BabHi`gjL@)SGp9I>keu0v=YVhWg3*dpwNw6nL9fFk&V6J{QNEw!ZarQ<)K8OO& zZCk*wi3gbFxeTn6BEW)WabTRoI$*xu3MhQb2iB?xoG1(j{+BlcvHfH)FzzwvK@Gsd zX(BKJ&ftwl3uy04294Giz|M{ZAo$@KKzH8)LoWH?LCFaKE9Zc!-A91h{w-j2Itkph zWZ>-$jUb<&0yNIN2kdTHSX0yjCTX}pv)5uUKimMaouiso+4^0$?|w2ss-~;foeu zP^E>z>B9i5G?@Xauk8loDRZz}I~+LePXZR(xgfSP7PMQ*Ke4_q_6i6+QV-lxYQX5oui)yzEHJ1d1Y;+b0SA@$phb}fR~DTg znQLR1zo7*rRn`Ko)Hv{J%YCrJ|2>$)y*=W$GjQLr4@6}t!D-9YAdvcP3+DGB8!QY&Q=2&nveaG76S>_j4xt z0l1)!0rR@#;rH>EfU3D6CqbnTtom{YG#*xiXjl`xrssi`z6U@+?s-t#R0!T~*a23m zW&k~BcTg%90iP-t0JGi&z{+p+$o^Qw`Cf4uh`$Ac&7x+o-s2gEj8*`m3}0|$%_T5x zvoy#vAVBMvBCuas6S%1S1}EkF!1ua5V3oKR==AU5*gtOu8*H<|>K$)^!4^xPv%wQs z)ZPW2KA~XAfz`kyB>|NAD#0tOm4H&|3-Br^1tc!Y0IumvfyH@}qc*=9Bn5bb(AgZA zGgcc2{nJ5p=ONHiJ`)&OodJ;#;sEFJ2XNL*0_>H|g>6E0_^9j(P}07L4) zsvWO6aRzn3`d$ure*7c&Xjlh+x+H;l)?L8maRDe;dJwF1P6H36Z-I$QRv_7_9{f_# z22QW{0Pi6^P*Hu*&iUOYuwCjmhhsXjKUN&)jAPz%_`-Rh@5Lw1s1Q4FOHLm!6@I{5 zIfAprVIjQgo(0stq> znwAY-J(Po1#&sa*kP>w2HU?Wc>99=mI%sR$49K8sV3zVp&Yn{jK*zmpU@3hJY+CS< zvuc7m`1Gt6{9O4Q>`s*fWk?&aOJt#{swTL@Q-GEgOJQIH0aM=T!!D~fup&! z!OY(*ET_WYjKIlL49<}lM_Nd3$Hs$`Xd-&<2?Zw_`TdxEe4&R5hNq${w zx2;me?!a{kThATlwi4Avw(T7nwxX9~?9RXKv|A85)9!H3MY~XyyLJzE3psfYrR_q$ zBD;aoe2(EcQ(J@ViJX`53YtNGzd191dKlEDc36Do}36ZKEq#i`!x^FNvVT3(&0#6;0&m_ z9|tRsZ-L{b&%%;HAE?=L8|JINg0J0E;JuQwQ1X@rp-S(7AL+m)o_nEw*V_$Jy<*%CVCQ@3uW5AH)fHm|&NCtIjS% zc8T4j;omkbbMABOEDm#ci$2-vpA*~EJD#$Quk5$cSUQ)pGxMF@hOXt{;H69sSf6gE z*S-_H-T1}MJbS=y@4f>Zo%k}l2E&~k6W5=dluju-!v3t?qL(u{nV%Y0 zGhvk{=pBC$`1WmwS_KJEAGmS6S55|h@<#?AH^l#6-Vjkjz8f&AfEBu=Yf!}I3j(NpsAyQU1Z zgFJ(a);r+F4KG>iMrX88cQ5`DEX4J$*D zD2iWD$_@=`p?;M`$aTyuGuCcut64Eow<&>4dU^NpV{oHzGakW zmK<{arHX!i*2Km2Q*o7w4>rqNLp^Xc!;dm*(42;?s6mN9scbNM{M`-77@VL&OT5?} zaZRXTVGeu1qnzq~;EiV+S>j>$8LZ}=Ks*>{gGC20{wZC82OHO-lohVHeC!!&a(p1_ zIsJv&376odUSs(=uV0Cd7F40W7m1X@l?uug<>CDuQ&FGGH(csye{=(%ncavMf3g&7_bA|7u}9I= zBVE{8?g8F9TOHpj?m%CQ5g(iF!L~Ne_|)tT#O~~?3rJXc2omqKV)NH z=P=RJ$2X~+Qnzu}@_1_f+W|DjCmEZ~J%H}@Y@_yAx#G`wh#i2HB5IK>QuFpLpxsy+2R(+ZcJ0=(%4qr@^c&3Ngv zBUsA70IQn`@e5~F$~H)WuhHx&E-F#rr}t$djRqf4&low%xlod?TkD0@mD=&N*;$Bh zcbA%{^NTWgeI89#k>h{+beqjs{sxO680 zSgeQ@CDl=1QXbk8+k(8C-lMhEI^tIO2+`IzZ$)t9L{#>%8=no?g>4o}h#Lj~UZ$pu zk1Q|57O{%rje)LcruK3q4Fb{6t4iWC5zgq|_uVKwKLzK!-h$VUi^6k^jA&A=<}ewc-ffM*j_da{kCaDdTa;x@Z$Z*oPP$h zRUhyh=asnO7)kA0tAhu*g=kq#0`j?R$R8YeK;WpWA!pHX9Mmn3$^2=!Yn!9kJ1r0O zBqp=c`HS%u!zkPuu>{?pw-%e3AIBPtY^dqQv$4(dCs?Mr8_NnZs2xF{ad>$PYAe=2 z*GpfrFV1=3{EKg?@J~PS*^ibe3`-*O4LsB~Eth>a^(VF-QsN)hNfq5ue2OPW$?mZpIMM{(Zs)j!{0&Y;eERNMeq~7xKTPcef4TB}zE67NO}O&gv; zc|I>-C(NUH18px^lQXewi+LowUPDhLAyI(Vhos}*zt->^)!w1!x9dpSk4T(72{Vm$ExDCxe7|Ax2 zXkp6_=XiU>d#TN@da(YywaBvE813GB5Ut&|2fuX^Q+=&v*w0vp)tUB$=d@-o7Fb*o z28)x>Vjnl;i&kOM2;ey_i724G`hI`=#>YRi_Qawc3H`+DjEQ^}r-Ve$PNMB`0x}iIn)l;b6RQekM{E z4&oAhd*p{w(E6+`SiklTa+)!oziLSxa@-(+rxi+}i!yC!s|wBDm0g79*%q^no7ahJ zOyY28&r9m@GhM77Y=Xwt?&gjEx&np#K88Atnnm*@ext*%07rT@A?IwGJ&-$*x~wLCr1&!#-*?Dm>r56!Vi-7puBQPan<&P=tFG-_PS$&bXw}L+U0fFbXOuWnyJS=IQ~-0bJ2`TCcI=*U#vtvM^`SW;AL+&;d6Co@H^8Ptm?8HD*bc=E;F&C zvS)sU7dQA)fk!T&XG(c^yO9q1?X?g6P(Q++E}e{G^?#w$?#isp%nR6K+dIU2p@PYb zGNIe$YLWV4cM7_fAiDjzNZm3YNnk7SDT_78b-oe4K^&oim6MUFWvZxE=@O35Jk1X8 z4n$+SkE2q^ux{lFe9LLcIMX>1>GOB9Wm@Mb!>mbYs@V-RSg~DnX6DE}e&Sc$_PPow zjcXJcU67-OlhQE%Ukd;tZm94hc8$dJ^}YD*q&|FsdmioH@|LPd&101<#v{c=e5S!#fFA}B zSduRx)-!xd?b263j(RR=QDqus`Ne~(Kh=pI&vT=84Z2~@+(^7$&km(+!YF&$$i3O2 z8fO(BqrK?`B!{=1+vn3nx;8!D@ zZkC5!*Ul9BEvmt%=NKS&-%R9}gQ=WZ4e0yJI%>z{&y-!%5bM6*7@zFBiTlm3^EZ%%K2G=*9;`!8l8wt2yYmNr?xSco89Zw3 zF#GYV0bf_Efimtnil?i%pj`VE$ZeH1-~Xcwnl2@a3%7TmeL17V*VCo=C(=H#pI_}p z7K^9xr??yy?dX-|E1PHFvQ1ZUT22Ozx%G&J7fWy|wMM<6Hht0(tBOr{*Iab?U9Y_G?~8{}jClh}>`$duIGhz3ub!&j@&3k6AW%IGrmrp^`l?d_+g7|G(5Yt?skI%?l_*0} zzw=oysXQvbDiuxE{K?z3Y(Cq0!h(7bE03z5v8e0GX=GKOiq${drdsr0V6#&@SoiE> zxSTVa9ZF%bR?sh;Fr^f;3zN|!+L#DB?1$0rtx= zfaSAaP+#xKpi5u3qb3n0Hf=IxH^{n+uVqHDYFm2PPH>buT9Jdk9_T=p_5OHXtP-2o ztij8?l8yGvwx)XTDInn@8{`st3@==Gg12sCDp$pKH&vQ%L6w;qv#rn9;d@2bS&*yG z-ku|g4z?FlV|`pH>m^EP%`Fi>czid%$EArMaZq;V#w`r4 zn(%!JJ<)m1VXWn2iCY>|kp4GW?6Cd;8}HGAsy-dUO>4st!F$U_RlR2KtbRpZ5k#_6 z_er5^doH8Mr5o_lhZ<kN6kg1J`iND3RWDVQgNIt#3A%>jr@48! zbM;MRd_9o$T3X2#mK;Nd4?Otrss^4?Z^Lff?TC&%o-VFgevluV?#x#;v=Q5l+JSnh zbi6C~1Yi7o6B}E)PMlD+8`0X^#WN$*@cLQaeCdEED0;wM%$GQY&U26PuMY1N+t`ax zyLYho>v{>kR>>UkP4AaDe^xJ=x&dPjBO#tA-o`)K6eMQ%%)&XpjQOq0r}0~EJVI47 zn%J;~Z+SCkP7v>%oQ8KiTp_-!P>=G3uHwh*rehnseIqZKKdCTXZOSWF6WP4qiY5(N ziQ^(3i(D=?Q4ek1cq=+jqUhmD-b|h4DE+X8*hLtQe(EcWLk!e#uzw*64?fCn3%-CA z&)r1fM#lWHzn`O*kqjOG2|=AH}908t|X| zk#$VG;QyRIQj_5R&-X|FRqy9G3;f@HfMlceNArzC!At*?UH&;1x=8-%{->t*Un}R- z{|x+F`_BsB|K$GHZ>mQ98)yAUCEmYs&#?SY-ha)f{X6e{?f>Al`cK|}#q!^Ib*BCY zul0ZO{wo@PN>BgUJh=bl{V(hny2#0y{TGG4wE6$g{;Bo-)Bd}k@uwhnq}?15@u!XO g{i*zwkfi?nj`SnPLKoRT<+>w0vJ!vj|3Uk|04-SO(*OVf literal 0 HcmV?d00001 diff --git a/test/data/deepspeed/model/params.00000/latest b/test/data/deepspeed/model/params.00000/latest new file mode 100644 index 000000000..641c40ce8 --- /dev/null +++ b/test/data/deepspeed/model/params.00000/latest @@ -0,0 +1 @@ +global_step4000 \ No newline at end of file diff --git a/test/unit/test_arguments.py b/test/unit/test_arguments.py index 101190fae..80cbb1016 100644 --- a/test/unit/test_arguments.py +++ b/test/unit/test_arguments.py @@ -232,6 +232,9 @@ def test_inference_args(test_params, expected_params): no_reload_on_learning_rate_reduce=False, fixed_param_names=[], fixed_param_strategy=None, + local_rank=None, + deepspeed_fp16=False, + deepspeed_bf16=False, decode_and_evaluate=500, stop_training_on_decoder_failure=False, seed=1, diff --git a/test/unit/test_deepspeed.py b/test/unit/test_deepspeed.py new file mode 100644 index 000000000..74a3efc7a --- /dev/null +++ b/test/unit/test_deepspeed.py @@ -0,0 +1,46 @@ +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import os +import shutil +import tempfile + +import pytest + +import torch + +import sockeye.constants as C +from sockeye.convert_deepspeed import convert_model_checkpoints + +# Only run tests in this file if DeepSpeed is installed +try: + import deepspeed + deepspeed_installed = True +except: + deepspeed_installed = False + + +@pytest.mark.skipif(not deepspeed_installed, reason='DeepSpeed is not installed') +def test_convert_model_checkpoints(): + with tempfile.TemporaryDirectory() as work_dir: + model_dir = os.path.join(work_dir, 'model') + shutil.copytree(os.path.join('test', 'data', 'deepspeed', 'model'), model_dir, symlinks=True) + # Convert + convert_model_checkpoints(model_dirname=model_dir, keep_deepspeed=False) + # Check + for fname in os.listdir(model_dir): + if fname.startswith(C.PARAMS_PREFIX) and fname[len(C.PARAMS_PREFIX):].isdigit(): + converted_params = torch.load(os.path.join(model_dir, fname)) + reference_params = torch.load(os.path.join('test', 'data', 'deepspeed', 'converted', fname)) + for key in converted_params.keys() | reference_params.keys(): + assert torch.allclose(converted_params[key], reference_params[key])