Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support user customized device_map #47

Closed
z80maniac opened this issue May 2, 2023 · 3 comments
Closed

support user customized device_map #47

z80maniac opened this issue May 2, 2023 · 3 comments
Labels
enhancement New feature or request

Comments

@z80maniac
Copy link
Contributor

It seems like device_map does not offload anything to CPU if it's constructed manually.

Here's an example:

from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import torch

model_path = "/opt/models/vicuna-13B-1.1-GPTQ-4bit-128g"

device_map = {'model.embed_tokens': 0, 'model.norm': 0, 'lm_head': 0, 'model.layers.0': 0, 'model.layers.1': 'cpu', 'model.layers.2': 'cpu', 'model.layers.3': 'cpu', 'model.layers.4': 'cpu', 'model.layers.5': 'cpu', 'model.layers.6': 'cpu', 'model.layers.7': 'cpu', 'model.layers.8': 'cpu', 'model.layers.9': 'cpu', 'model.layers.10': 'cpu', 'model.layers.11': 'cpu', 'model.layers.12': 'cpu', 'model.layers.13': 'cpu', 'model.layers.14': 'cpu', 'model.layers.15': 'cpu', 'model.layers.16': 'cpu', 'model.layers.17': 'cpu', 'model.layers.18': 'cpu', 'model.layers.19': 'cpu', 'model.layers.20': 'cpu', 'model.layers.21': 'cpu', 'model.layers.22': 'cpu', 'model.layers.23': 'cpu', 'model.layers.24': 'cpu', 'model.layers.25': 'cpu', 'model.layers.26': 'cpu', 'model.layers.27': 'cpu', 'model.layers.28': 'cpu', 'model.layers.29': 'cpu', 'model.layers.30': 'cpu', 'model.layers.31': 'cpu', 'model.layers.32': 'cpu', 'model.layers.33': 'cpu', 'model.layers.34': 'cpu', 'model.layers.35': 'cpu', 'model.layers.36': 'cpu', 'model.layers.37': 'cpu', 'model.layers.38': 'cpu', 'model.layers.39': 'cpu'}

quantize_config = BaseQuantizeConfig(
    bits=4,
    group_size=128,
)

full_gpu = True

if full_gpu:
    model = AutoGPTQForCausalLM.from_quantized(
        model_path,
        device="cuda:0",
        use_safetensors=True,
        quantize_config=quantize_config,
        model_basename="vicuna-13B-1.1-GPTQ-4bit-128g.latest"
    )
else:
    model = AutoGPTQForCausalLM.from_quantized(
        model_path,
        device="cpu",
        use_safetensors=True,
        quantize_config=quantize_config,
        model_basename="vicuna-13B-1.1-GPTQ-4bit-128g.latest",
        device_map=device_map
    )

mem_gb = round(torch.cuda.memory_allocated(0) / 1000 / 1000 / 1000)
print(f"USED VRAM: {mem_gb}GB")

The model is https://huggingface.co/TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g. The device_map is constructed so that only the first model layer is on GPU, and the rest is supposed to be on CPU. The last two lines measure the used VRAM.

When full_gpu = True, everything is on GPU, and I get this:

USED VRAM: 7GB

which is expected.

But now, I set full_gpu = False and the device map is used. However, at the end I get the same result:

USED VRAM: 7GB

I double-checked that the device_map is actually used, but it seems like it doesn't offload anything. Am I missing something?

@PanQiWei
Copy link
Collaborator

PanQiWei commented May 4, 2023

currently device_map doesn't support manually construct, only support those pre-defined strategies: "auto", "balanced", "balanced_low_0", "sequential", or you can set max_memory to specify the memory each device used at most, which means if you want to use CPU offload, you can set something like this: max_memory={0: "1GIB", "cpu": "16GIB"}, for more details you can also reference to this tutorial

@z80maniac
Copy link
Contributor Author

OK.

It's just that initially I found this:

https://github.com/PanQiWei/AutoGPTQ/blob/d79aec7bd08a789f6b7384df4cc91b57226a2cc1/auto_gptq/modeling/_base.py#L541-L542

So, I assumed that if {"": device} is accepted into load_checkpoint_and_dispatch, then any custom device map should also be accepted. Alas, it's not the case.

@z80maniac
Copy link
Contributor Author

I've inspected the code and found out that, in theory, custom device maps should work, because when max_memory is passed to the accelerate.load_checkpoint_and_dispatch, it will be used to construct the corresponding device map (via infer_auto_device_map) that looks similar to what I provided in the example code.

So I decided to just test the max_memory. And it turns out that it also doesn't work.

from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import torch

model_path = "/opt/models/vicuna-13B-1.1-GPTQ-4bit-128g"

quantize_config = BaseQuantizeConfig(
    bits=4,
    group_size=128,
)

model = AutoGPTQForCausalLM.from_quantized(
    model_path,
    device="cpu",
    use_safetensors=True,
    use_triton=False,
    quantize_config=quantize_config,
    model_basename="vicuna-13B-1.1-GPTQ-4bit-128g.latest",
    max_memory={0: "2GIB", "cpu": "30GIB"}
)

mem_gb = round(torch.cuda.memory_allocated(0) / 1000 / 1000 / 1000)
print(f"USED VRAM: {mem_gb}GB")

This internally constructs the following device map (via infer_auto_device_map):

{'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 'cpu', 'model.layers.9': 'cpu', 'model.layers.10': 'cpu', 'model.layers.11': 'cpu', 'model.layers.12': 'cpu', 'model.layers.13': 'cpu', 'model.layers.14': 'cpu', 'model.layers.15': 'cpu', 'model.layers.16': 'cpu', 'model.layers.17': 'cpu', 'model.layers.18': 'cpu', 'model.layers.19': 'cpu', 'model.layers.20': 'cpu', 'model.layers.21': 'cpu', 'model.layers.22': 'cpu', 'model.layers.23': 'cpu', 'model.layers.24': 'cpu', 'model.layers.25': 'cpu', 'model.layers.26': 'cpu', 'model.layers.27': 'cpu', 'model.layers.28': 'cpu', 'model.layers.29': 'cpu', 'model.layers.30': 'cpu', 'model.layers.31': 'cpu', 'model.layers.32': 'cpu', 'model.layers.33': 'cpu', 'model.layers.34': 'cpu', 'model.layers.35': 'cpu', 'model.layers.36': 'cpu', 'model.layers.37': 'cpu', 'model.layers.38': 'cpu', 'model.layers.39': 'cpu', 'model.norm': 'cpu', 'lm_head': 'cpu'}

The limit for the GPU was set to 2GiB, but in reality it still uses the same amount of VRAM:

USED VRAM: 7GB

I can make the following conclusions:

  • It is perfectly OK to pass custom device maps to from_quantized (the type hints need to be changed though).
  • No matter how the device map is constructed (manually or via max_memory), it does not work.

Seems like the problem is somewhere deeper.

@PanQiWei PanQiWei added the enhancement New feature or request label May 14, 2023
@PanQiWei PanQiWei changed the title custom device_map does not offload anything to CPU support use customized device_map May 14, 2023
@PanQiWei PanQiWei changed the title support use customized device_map support user customized device_map May 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants