Skip to content

Commit

Permalink
reduce n_workers for hqq
Browse files Browse the repository at this point in the history
  • Loading branch information
KeremTurgutlu committed Mar 12, 2024
1 parent b3768f4 commit cf61426
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,16 +612,16 @@ def load_and_quantize_parallel(name_param, model, **kwargs):
param_count = sum((p.numel() for n,p in model.named_parameters()))
if rank == 0 and args['verbose']:
print_func(f"Total model params: {param_count}")
quant_method = "hqq" if args["train_type"] in ["hqq_lora"] else "bnb"
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count()/torch.cuda.device_count())
right = int((4 if quant_method == "hqq" else 8) * (devprops.total_memory/1e9/40) * (70/(param_count/1e9)))
n_workers = min(left, right)
if rank == 0 and args['verbose']:
print_func(f"Using n_workers: {n_workers} for loading")
start = time.time()
for filename in files:
weights = safetensors.torch.load_file(filename)
quant_method = "hqq" if args["train_type"] in ["hqq_lora"] else "bnb"
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count()/torch.cuda.device_count())
right = int(8 * (devprops.total_memory/1e9/40) * (70/(param_count/1e9)))
n_workers = min(left, right)
if rank == 0 and args['verbose']:
print_func(f"Using n_workers: {n_workers} for loading")
parallel(load_and_quantize_parallel, iter(weights.items()), n_workers=n_workers, threadpool=True,
model=model, dtype=torch_dtype, device=local_rank, skip_names=load_param_skip_names,
is_meta_rank=(args["low_memory"] and rank!=0), verbose=args["verbose"], quant_method=quant_method)
Expand Down

0 comments on commit cf61426

Please sign in to comment.