Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]CUDA OUT OF MEMORY #179

Closed
henan991201 opened this issue Jun 28, 2023 · 8 comments
Closed

[BUG]CUDA OUT OF MEMORY #179

henan991201 opened this issue Jun 28, 2023 · 8 comments
Labels
bug Something isn't working

Comments

@henan991201
Copy link

henan991201 commented Jun 28, 2023

I use 32G V100 to try to quantize llama-13B ckpt, but it will report CUDA OUT OF MEMORY. The number of samples is 128.
Thanks for any reply.

quantize_config = BaseQuantizeConfig(
    bits=4,  # quantize model to 4-bit
    group_size=128,  # it is recommended to set the value to 128
    desc_act=False,  # set to False can significantly speed up inference but the perplexity may slightly bad 
)
max_memory = dict()
if torch.cuda.is_available():
    max_memory.update({i: f"32GIB" for i in range(torch.cuda.device_count())})

model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, max_memory=max_memory)

model.quantize(examples,batch_size=1)
@henan991201 henan991201 added the bug Something isn't working label Jun 28, 2023
@TheBloke
Copy link
Contributor

The issue is that by using max_memory you are telling it to load the model into VRAM, as well as doing the quantisation in VRAM.

If you leave out max_memory, then it will load the model into RAM and then quantise on the GPU. This uses much less VRAM. For example, with a 24GB card I can make a GPTQ for a 33B or even 65B model.

I don't know if this is a bug or not - I guess maybe it is, because the quantization code should take account of max_memory and not cause the out of memory error.

But it's easy to solve the problem just by not using max_memory.

Here is the function I use for quantising:

def quantize(model_dir, output_dir, traindataset, bits, group_size, desc_act, damp, batch_size = 1, use_triton=False, trust_remote_code=False, dtype='float16'):
    quantize_config = BaseQuantizeConfig(
        bits=bits,
        group_size=group_size,
        desc_act=desc_act,
        damp_percent=damp
    )

    if dtype == 'float16':
        torch_dtype  = torch.float16
    elif dtype == 'float32':
        torch_dtype  = torch.float32
    elif dtype == 'bfloat16':
        torch_dtype  = torch.bfloat16
    else:
        raise ValueError(f"Unsupported dtype: {dtype}")

    logger.info(f"Loading model from {model_dir} with trust_remote_code={trust_remote_code} and dtype={torch_dtype}")
    model = AutoGPTQForCausalLM.from_pretrained(model_dir, quantize_config=quantize_config, low_cpu_mem_usage=True, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)

    logger.info(f"Starting quantization to {output_dir} with use_triton={use_triton}")
    start_time = time.time()
    model.quantize(traindataset, use_triton=use_triton, batch_size=batch_size)

    logger.info(f"Time to quantize model at {output_dir} with use_triton={use_triton}: {time.time() - start_time:.2f}")

    logger.info(f"Saving quantized model to {output_dir}")
    model.save_quantized(output_dir, use_safetensors=True)
    logger.info("Done.")

That function is part of a script I wrote to easily quantise a model from the command line. The full script is below. You could save this and run it immediately to quantise with C4 or Wikitext2. it's able to make multiple GPTQs of a single base model, with varying parameters for testing. But you can use it to just make just one by using the command line parameters.

It always uses 128 samples and I didn't add a parameter for that, but it would be easy to add.

Example execution of the script:

python3 quant_autogptq.py /workspace/llama-30b /workspace/llama-30b-gptq wikitext --bits 4 --group_size 128 --desc_act 0  --dtype float16

Script:

import time
import os
import logging

from transformers import AutoTokenizer, TextGenerationPipeline
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import numpy as np
import torch
import torch.nn as nn
import argparse

def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    from datasets import load_dataset
    logger = logging.getLogger(__name__)

    wikidata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
    wikilist = [' \n' if s == '' else s for s in wikidata['text'] ]

    text = ''.join(wikilist)
    logger.info("Tokenising wikitext2")
    trainenc = tokenizer(text, return_tensors='pt')

    import random
    random.seed(seed)
    np.random.seed(0)
    torch.random.manual_seed(0)

    traindataset = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        attention_mask = torch.ones_like(inp)
        traindataset.append({'input_ids':inp,'attention_mask': attention_mask})
    return traindataset

def get_c4(nsamples, seed, seqlen, tokenizer):
    from datasets import load_dataset
    traindata = load_dataset(
        'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False
    )

    import random
    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
            if trainenc.input_ids.shape[1] >= seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        attention_mask = torch.ones_like(inp)
        trainloader.append({'input_ids':inp,'attention_mask': attention_mask})

    return trainloader

def quantize(model_dir, output_dir, traindataset, bits, group_size, desc_act, damp, batch_size = 1, use_triton=False, trust_remote_code=False, dtype='float16'):
    quantize_config = BaseQuantizeConfig(
        bits=bits,
        group_size=group_size,
        desc_act=desc_act,
        damp_percent=damp
    )

    if dtype == 'float16':
        torch_dtype  = torch.float16
    elif dtype == 'float32':
        torch_dtype  = torch.float32
    elif dtype == 'bfloat16':
        torch_dtype  = torch.bfloat16
    else:
        raise ValueError(f"Unsupported dtype: {dtype}")

    logger.info(f"Loading model from {model_dir} with trust_remote_code={trust_remote_code} and dtype={torch_dtype}")
    model = AutoGPTQForCausalLM.from_pretrained(model_dir, quantize_config=quantize_config, low_cpu_mem_usage=True, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)

    logger.info(f"Starting quantization to {output_dir} with use_triton={use_triton}")
    start_time = time.time()
    model.quantize(traindataset, use_triton=use_triton, batch_size=batch_size)

    logger.info(f"Time to quantize model at {output_dir} with use_triton={use_triton}: {time.time() - start_time:.2f}")

    logger.info(f"Saving quantized model to {output_dir}")
    model.save_quantized(output_dir, use_safetensors=True)
    logger.info("Done.")

if __name__ == "__main__":
    logger = logging.getLogger()

    logging.basicConfig(
        format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
    )

    parser = argparse.ArgumentParser(description='quantise')
    parser.add_argument('pretrained_model_dir', type=str, help='Repo name')
    parser.add_argument('output_dir_base', type=str, help='Output base folder')
    parser.add_argument('dataset', type=str, help='Output base folder')
    parser.add_argument('--trust_remote_code', action="store_true", help='Trust remote code')
    parser.add_argument('--use_triton', action="store_true", help='Use Triton for quantization')
    parser.add_argument('--bits', type=int, nargs='+', default=[4], help='Quantize bit(s)')
    parser.add_argument('--group_size', type=int, nargs='+', default=[32, 128, 1024, -1], help='Quantize group size(s)')
    parser.add_argument('--damp', type=float, nargs='+', default=[0.01], help='Quantize damp_percent(s)')
    parser.add_argument('--desc_act', type=int, nargs='+', default=[0, 1], help='Quantize desc_act(s) - 1 = True, 0 = False')
    parser.add_argument('--dtype', type=str, choices=['float16', 'float32', 'bfloat16'], help='Quantize desc_act(s) - 1 = True, 0 = False')
    parser.add_argument('--seqlen', type=int, default=2048, help='Model sequence length')
    parser.add_argument('--batch_size', type=int, default=1, help='Quantize batch size for processing dataset samples')
    parser.add_argument('--stop_file', type=str, help='Filename to look for to stop inference, specific to this instance')

    args = parser.parse_args()

    stop_file = args.stop_file or ""

    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_dir, use_fast=True, trust_remote_code=args.trust_remote_code)

    if args.dataset == 'wikitext':
        traindataset = get_wikitext2(128, 0, args.seqlen, tokenizer)
    elif args.dataset == 'c4':
        traindataset = get_c4(128, 0, args.seqlen, tokenizer)
    else:
        logger.error(f"Unsupported dataset: {args.dataset}")
        raise ValueError(f"Unsupported dataset: {args.dataset}")

    abort = False

    iterations=[]
    for bits in args.bits:
        for group_size in args.group_size:
            for desc_act in args.desc_act:
                for damp in args.damp:
                    desc_act = desc_act == 1 and True or False
                    iterations.append({"bits": bits, "group_size": group_size, "desc_act": desc_act, "damp": damp})

    num_iters = len(iterations)
    logger.info(f"Starting {num_iters} quantizations.")
    count=1
    for iter in iterations:
        if not os.path.isfile("/workspace/gptq-ppl-test/STOP") and not os.path.isfile(stop_file) and not abort:
            bits = iter['bits']
            group_size = iter['group_size']
            desc_act = iter['desc_act']
            damp = iter['damp']

            output_dir = args.output_dir_base
            try:
                os.makedirs(output_dir, exist_ok=True)

                # Log file has same name as directory + .quantize.log, and is placed alongside model directory, not inside it
                # This ensures that we can delete the output_dir in case of error or abort, without losing the logfile.
                # Therefore the existence of the output_dir is a reliable indicator of whether a model has started or not.
                logger.info(f"[{count} / {num_iters}] Quantizing: bits = {bits} - group_size = {group_size} - desc_act = {desc_act} - damp_percent = {damp} to {output_dir}")
                try:
                    quantize(args.pretrained_model_dir, output_dir, traindataset, bits, group_size, desc_act, damp, args.batch_size, args.use_triton, trust_remote_code=args.trust_remote_code, dtype=args.dtype)
                except KeyboardInterrupt:
                    logger.error(f"Aborted. Will delete {output_dir}")
                    os.rmdir(output_dir)
                    abort = True
                except:
                    raise

            finally:
                count += 1
        else:
                logger.error(f"Aborting - told to stop!")
                break

@RonanKMcGovern
Copy link

RonanKMcGovern commented Aug 3, 2023

This uses much less VRAM. For example, with a 24GB card I can make a GPTQ for a 33B or even 65B model.

Nice script @TheBloke ! How long does it take - roughly to get a single quantised model for a 13B model? Thanks

Also, I suppose - to keep perplexity across all tasks - ideally one should use mixed datasets? Or would wikitexts still keep good perplexity on coding, for example?

@RonanKMcGovern
Copy link

Nice script @TheBloke ! How long does it take - roughly to get a single quantised model for a 13B model? Thanks

Also, I suppose - to keep perplexity across all tasks - ideally one should use mixed datasets? Or would wikitexts still keep good perplexity on coding, for example?

Took me about 1 hour on a T4 to quantize a Llama-2-7B model. So I imagine maybe double for 13B.

FYI - The c4 dataset doesn't load correctly due to that dataset repo not being set up in HF. It's not an issue though because you've selected wikitext only in this script above, so loading c4 is redundant.

No further Qs from me, probably @henan991201 can close out this issue.

@SuperBruceJia
Copy link

This uses much less VRAM. For example, with a 24GB card I can make a GPTQ for a 33B or even 65B model.

Nice script @TheBloke ! How long does it take - roughly to get a single quantised model for a 13B model? Thanks

Also, I suppose - to keep perplexity across all tasks - ideally one should use mixed datasets? Or would wikitexts still keep good perplexity on coding, for example?

@TheBloke @henan991201 @RonanKMcGovern @Qubitium @Sciumo
It occurred to me that you are using the test split of the wikitext-2-raw-v1. Why don't use the train split?
Which dataset should we use for GPTQ quantization?

Thank you very much in advance!

Best regards,

Shuyue
June 26th, 2024

@Qubitium
Copy link
Collaborator

@SuperBruceJia The examples in AutoGPTQ repo are for testing only, not applicable for high quality quants. As far as what dataset to use for calibration, this is an open-ended question. You probably want to mix dataset that the model has and has not seen to maximize quality and prevent memory loss. Good quantization is half magic since many models do not reveal their training data/src.

@SuperBruceJia
Copy link

@Qubitium @TheBloke
During the packing process, there is an error:
RuntimeError: [enforce fail at alloc_cpu.cpp:117] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate

Are there any solutions to this issue?

Thank you very much in advance!

Best regards,

Shuyue
July 20th, 2024

@Qubitium
Copy link
Collaborator

Qubitium commented Jul 21, 2024

@SuperBruceJia packing happens on cpu so you need enough physical ram. Watch your cpu memory usage when it happens and use vmstat to make sure you are not using swap as that destroys performance.

Please note just because you have 2GB of free memory doesn't mean 100% you have enough memory. Memory allocation blocks are continuous so there is no single block of memory found, os will crash the app your allocator will report memory errors.

@SuperBruceJia
Copy link

SuperBruceJia commented Jul 21, 2024

@Qubitium @TheBloke
Thank you very much for your answer! I really appreciate it.

The model I am trying to quantize is the CohereForAI/c4ai-command-r-plus (104B).
Loading the model and conducting quantization consume so many memories, which results in so few free memory.
Are there any other solutions?

image

Thank you very much again, and have a nice day!

Best regards,

Shuyue
July 21st, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants