diff --git a/train.py b/train.py index 5e3155d..8fa1698 100644 --- a/train.py +++ b/train.py @@ -622,7 +622,7 @@ def load_and_quantize_parallel(name_param, model, **kwargs): n_workers = min(left, right) if rank == 0 and args['verbose']: print_func(f"Using n_workers: {n_workers} for loading") - parallel(load_and_quantize_parallel, weights.items(), n_workers=n_workers, threadpool=True, + parallel(load_and_quantize_parallel, iter(weights.items()), n_workers=n_workers, threadpool=True, model=model, dtype=torch_dtype, device=local_rank, skip_names=load_param_skip_names, is_meta_rank=(args["low_memory"] and rank!=0), verbose=args["verbose"], quant_method=quant_method) if rank == 0 and args["verbose"]: