Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

max_memory and offload_folder options not working for big models #78

Closed
abhinavkulkarni opened this issue May 14, 2023 · 8 comments
Closed

Comments

@abhinavkulkarni
Copy link

abhinavkulkarni commented May 14, 2023

Hi,

I have a GeForce RTX 3060 GPU with 12GB VRAM. I am able to load models up to 3B parameters and quantize them, however am running into trouble when I try to load 6B parameter or bigger models.

Here are the GPU details:

$ nvidia-smi 
Sun May 14 15:33:10 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3060         On | 00000000:01:00.0 Off |                  N/A |
|  0%   40C    P8               14W / 170W|      1MiB / 12288MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

I try to load them using

pretrained_model_dir = "EleutherAI/gpt-j-6b"
quantized_model_dir = "EleutherAI/gpt-j-6b-4bit-128g"

quantize_config = BaseQuantizeConfig(
    bits=4,  # quantize model to 4-bit
    group_size=128,  # it is recommended to set the value to 128
)

max_memory={0: "8GiB"}

# load un-quantized model, by default, the model will always be loaded into CPU memory
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, max_memory=max_memory, offload_folder="offload")

# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
model.quantize(examples, use_triton=False)

However, I get the error:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[6], line 11
      8 model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, max_memory=max_memory, offload_folder="offload")
     10 # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
---> 11 model.quantize(examples, use_triton=False)
     13 # save quantized model
     14 model.save_quantized(quantized_model_dir)

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/utils/_contextlib.py:115](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/utils/_contextlib.py:115), in context_decorator..decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/auto_gptq/modeling/_base.py:220](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/auto_gptq/modeling/_base.py:220), in BaseGPTQForCausalLM.quantize(self, examples, batch_size, use_triton, use_cuda_fp16, autotune_warmup_after_quantized, cache_examples_on_gpu)
    218     ori_outside_layer_module_devices[module_name] = get_device(module)
    219     if module is not None:
--> 220         move_to_device(module, cur_layer_device)
    222 # get inputs for first layer
    223 layers[0] = LayerHijacker(layers[0], cur_layer_device)

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/auto_gptq/modeling/_utils.py:24](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/auto_gptq/modeling/_utils.py:24), in move_to_device(obj, device)
     22 def move_to_device(obj: Union[torch.Tensor, nn.Module], device: torch.device):
     23     if get_device(obj) != device:
---> 24         obj = obj.to(device)
     25     return obj

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:1145](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:1145), in Module.to(self, *args, **kwargs)
   1141         return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
   1142                     non_blocking, memory_format=convert_to_format)
   1143     return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
-> 1145 return self._apply(convert)

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:820](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:820), in Module._apply(self, fn)
    816 # Tensors stored in modules are graph leaves, and we don't want to
    817 # track autograd history of `param_applied`, so we have to use
    818 # `with torch.no_grad():`
    819 with torch.no_grad():
--> 820     param_applied = fn(param)
    821 should_use_set_data = compute_should_use_set_data(param, param_applied)
    822 if should_use_set_data:

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:1143](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:1143), in Module.to..convert(t)
   1140 if convert_to_format is not None and t.dim() in (4, 5):
   1141     return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
   1142                 non_blocking, memory_format=convert_to_format)
-> 1143 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)

NotImplementedError: Cannot copy out of meta tensor; no data!

No matter what memory I specify in max_memory, I get the same error.

What am I missing?

@PanQiWei
Copy link
Collaborator

offload_folder argument is not supported, you should set like max_memory={"cpu": "30GIB", 0: "3GIB"} to use cpu offload. For more details you can also turn to this tutorial

@oobabooga
Copy link
Contributor

I find that max_memory gets properly fed into accelerate, but the model loads entirely into VRAM nevertheless. For instance:

from auto_gptq import AutoGPTQForCausalLM

path_to_model = 'models/TheBloke_stable-vicuna-13B-GPTQ'
params = {
    'model_basename': 'stable-vicuna-13B-GPTQ-4bit.compat.no-act-order',
    'use_triton': False,
    'use_safetensors': True,
    'max_memory': {0: '2GiB', 'cpu': '99GiB'}
}

model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params)

input()

nvidia-smi reports a 7533MiB allocation instead of something close to 2000MiB.

@Dessix
Copy link

Dessix commented May 15, 2023

I ended up having to massively reduce my example dataset in order to get it to load, because of the vram constraints involved. Disabling the cache-on-GPU flag didn't appear to affect this result. Also, it appears to be impossible to do a truly cpu-only run, as a lack of GPUs leads to a division by zero.

@abhinavkulkarni
Copy link
Author

abhinavkulkarni commented May 16, 2023

Hi @PanQiWei,

So, I am still running into an error while trying to quantize large models (that don't fit in the 12GB of VRAM).

The script ran for 8 mins and then failed:

pretrained_model_dir = "EleutherAI/gpt-j-6b"
quantized_model_dir = "EleutherAI/gpt-j-6b-4bit-128g"

quantize_config = BaseQuantizeConfig(
    bits=4,  # quantize model to 4-bit
    group_size=128,  # it is recommended to set the value to 128
)

max_memory={0: "6GiB", 'cpu': '80GiB'}

# load un-quantized model, by default, the model will always be loaded into CPU memory
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, max_memory=max_memory)

# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
model.quantize(examples, use_triton=False)

# save quantized model
model.save_quantized(quantized_model_dir)

I run into the following problem:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[6], line 16
     13 model.quantize(examples, use_triton=False)
     15 # save quantized model
---> 16 model.save_quantized(quantized_model_dir)
     18 # save quantized model using safetensors
     19 model.save_quantized(quantized_model_dir, use_safetensors=True)

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/auto_gptq/modeling/_base.py:392](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/auto_gptq/modeling/_base.py:392), in BaseGPTQForCausalLM.save_quantized(self, save_dir, use_safetensors)
    389 if not self.quantized:
    390     raise EnvironmentError("can only save quantized model, please execute .quantize first.")
--> 392 self.model.to(CPU)
    394 model_save_name = f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
    395 if use_safetensors:

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/transformers/modeling_utils.py:1878](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/transformers/modeling_utils.py:1878), in PreTrainedModel.to(self, *args, **kwargs)
   1873     raise ValueError(
   1874         "`.to` is not supported for `8-bit` models. Please use the model as it is, since the"
   1875         " model has already been set to the correct devices and casted to the correct `dtype`."
   1876     )
   1877 else:
-> 1878     return super().to(*args, **kwargs)

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:1145](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:1145), in Module.to(self, *args, **kwargs)
   1141         return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
   1142                     non_blocking, memory_format=convert_to_format)
   1143     return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
-> 1145 return self._apply(convert)

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:797](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:797), in Module._apply(self, fn)
    795 def _apply(self, fn):
    796     for module in self.children():
--> 797         module._apply(fn)
    799     def compute_should_use_set_data(tensor, tensor_applied):
    800         if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    801             # If the new tensor has compatible tensor type as the existing tensor,
    802             # the current behavior is to change the tensor in-place using `.data =`,
    (...)
    807             # global flag to let the user control whether they want the future
    808             # behavior of overwriting the existing tensor or not.

File /opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:797, in Module._apply(self, fn)
    795 def _apply(self, fn):
    796     for module in self.children():
--> 797         module._apply(fn)
    799     def compute_should_use_set_data(tensor, tensor_applied):
    800         if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    801             # If the new tensor has compatible tensor type as the existing tensor,
    802             # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    807             # global flag to let the user control whether they want the future
    808             # behavior of overwriting the existing tensor or not.

    [... skipping similar frames: Module._apply at line 797 (1 times)]

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:797](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:797), in Module._apply(self, fn)
    795 def _apply(self, fn):
    796     for module in self.children():
--> 797         module._apply(fn)
    799     def compute_should_use_set_data(tensor, tensor_applied):
    800         if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    801             # If the new tensor has compatible tensor type as the existing tensor,
    802             # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    807             # global flag to let the user control whether they want the future
    808             # behavior of overwriting the existing tensor or not.

File /opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:820, in Module._apply(self, fn)
    816 # Tensors stored in modules are graph leaves, and we don't want to
    817 # track autograd history of `param_applied`, so we have to use
    818 # `with torch.no_grad():`
    819 with torch.no_grad():
--> 820     param_applied = fn(param)
    821 should_use_set_data = compute_should_use_set_data(param, param_applied)
    822 if should_use_set_data:

File /opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:1143, in Module.to..convert(t)
   1140 if convert_to_format is not None and t.dim() in (4, 5):
   1141     return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
   1142                 non_blocking, memory_format=convert_to_format)
-> 1143 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)

NotImplementedError: Cannot copy out of meta tensor; no data!

Also, I see that the VRAM usage goes upto 11GB even when I have specified 6GB in the quantized_config. The examples array has only 10 sentences.

@z80maniac
Copy link
Contributor

but the model loads entirely into VRAM nevertheless

I may be mistaken, but there might be a bug in the accelerate library.

In #47 (comment) I've shown what device map is generated when you specify max_memory=max_memory={0: "2GIB", "cpu": "30GIB"}:

{'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 'cpu', 'model.layers.9': 'cpu', 'model.layers.10': 'cpu', 'model.layers.11': 'cpu', 'model.layers.12': 'cpu', 'model.layers.13': 'cpu', 'model.layers.14': 'cpu', 'model.layers.15': 'cpu', 'model.layers.16': 'cpu', 'model.layers.17': 'cpu', 'model.layers.18': 'cpu', 'model.layers.19': 'cpu', 'model.layers.20': 'cpu', 'model.layers.21': 'cpu', 'model.layers.22': 'cpu', 'model.layers.23': 'cpu', 'model.layers.24': 'cpu', 'model.layers.25': 'cpu', 'model.layers.26': 'cpu', 'model.layers.27': 'cpu', 'model.layers.28': 'cpu', 'model.layers.29': 'cpu', 'model.layers.30': 'cpu', 'model.layers.31': 'cpu', 'model.layers.32': 'cpu', 'model.layers.33': 'cpu', 'model.layers.34': 'cpu', 'model.layers.35': 'cpu', 'model.layers.36': 'cpu', 'model.layers.37': 'cpu', 'model.layers.38': 'cpu', 'model.layers.39': 'cpu', 'model.norm': 'cpu', 'lm_head': 'cpu'}

Let's take this part:

'model.layers.1': 0

The accelerate library in the modelling.py (load_state_dict function) has the following code:

# For each device, get the weights that go there
device_weights = {device: [] for device in devices}
for module_name, device in device_map.items():
    if device in devices:
        device_weights[device].extend([k for k in weight_names if k.startswith(module_name)])

This code is supposed to distribute all modules to their respective devices, so that 'model.layers.1': 0 would mean that all modules that start with model.layers.1 (e.g. model.layers.1.input_layernorm.weight, etc) will go the device 0.

And here's the problem: model.layers.11 also starts with model.layers.1... As well as model.layers.12, etc. So in the original example:

'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0

would mean that not only layers 1-3 go to the GPU, but also layers 10-39 as well. In fact, it seems like layers 10-39 go to both CPU and GPU.

Disclaimer: I have no idea how accelerate library works, I just tried to debug the code and stumbled upon this weird logic. So, I may be digging in the wrong direction.

Also, as I've shown in #47 (comment) even if you put only model.layers.0 on the GPU (so the bug above is not activated), the model is still fully loaded into VRAM. So, there may be an additional problem somewhere.

@Ph0rk0z
Copy link
Contributor

Ph0rk0z commented May 17, 2023

For accelerate I always have to lower the actual memory fed into it to get usable results. I will tell it to feed 16gb and it will load 18gb.

@PanQiWei
Copy link
Collaborator

Hi! This pr #100 fixed the bug that can't save quantized model when load pretrained model using CPU offload.

@PanQiWei
Copy link
Collaborator

Close this issue for the problem mentioned here has been fixed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants