diff --git a/train.py b/train.py index 8fa1698..f44946b 100644 --- a/train.py +++ b/train.py @@ -612,16 +612,16 @@ def load_and_quantize_parallel(name_param, model, **kwargs): param_count = sum((p.numel() for n,p in model.named_parameters())) if rank == 0 and args['verbose']: print_func(f"Total model params: {param_count}") + quant_method = "hqq" if args["train_type"] in ["hqq_lora"] else "bnb" + devprops = torch.cuda.get_device_properties(torch.cuda.current_device()) + left = int(os.cpu_count()/torch.cuda.device_count()) + right = int((4 if quant_method == "hqq" else 8) * (devprops.total_memory/1e9/40) * (70/(param_count/1e9))) + n_workers = min(left, right) + if rank == 0 and args['verbose']: + print_func(f"Using n_workers: {n_workers} for loading") start = time.time() for filename in files: weights = safetensors.torch.load_file(filename) - quant_method = "hqq" if args["train_type"] in ["hqq_lora"] else "bnb" - devprops = torch.cuda.get_device_properties(torch.cuda.current_device()) - left = int(os.cpu_count()/torch.cuda.device_count()) - right = int(8 * (devprops.total_memory/1e9/40) * (70/(param_count/1e9))) - n_workers = min(left, right) - if rank == 0 and args['verbose']: - print_func(f"Using n_workers: {n_workers} for loading") parallel(load_and_quantize_parallel, iter(weights.items()), n_workers=n_workers, threadpool=True, model=model, dtype=torch_dtype, device=local_rank, skip_names=load_param_skip_names, is_meta_rank=(args["low_memory"] and rank!=0), verbose=args["verbose"], quant_method=quant_method)