In [None]:
## ref : https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/optimize-llama-2-gptq.ipynb 

In [None]:
!sudo pip install -q transformers --upgrade

In [None]:
!sudo -H pip install auto-gptq --no-cache-dir

In [None]:
!sudo -H pip install --upgrade optimum

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
import gc
gc.collect()

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from optimum.gptq import GPTQQuantizer, load_quantized_model
import torch

In [None]:
model_name = "mistralai/Mistral-7B-v0.1"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16,device_map='auto')

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

In [None]:
quantizer = GPTQQuantizer(bits=4, dataset="wikitext2")
quantizer.quant_method = "gptq"

In [None]:
quantized_model = quantizer.quantize_model(model, tokenizer)

In [None]:
print(os.environ["PYTORCH_CUDA_ALLOC_CONF"])

In [None]:
quant_path = "/data/quantization-trials/GPTQ-quantized/ravi"

In [None]:
# save the quantize model to disk

quantized_model.save_pretrained(quant_path, safe_serialization=True)

### Inference on quantized model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

In [None]:
gptq_config = GPTQConfig(bits=4, use_exllama=True)

model_id = "/data/quantization-trials/GPTQ-quantized"
quant_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.float16)

In [None]:
def predict_from_quant(user_query):
    _inputs = tokenizer.encode(user_query, return_tensors="pt").to('cuda')
    outputs = quant_model.generate(input_ids=_inputs, max_length= 1000, pad_token_id=tokenizer.eos_token_id)
    output = tokenizer.decode(outputs[0])
    return output

import time

In [None]:
# Using quant model
start = time.time()
output1 = predict_from_quant("what is science")
print("time taken is :", time.time()-start)