In [5]:
import os
# Set environment variables
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["NCCL_IB_DISABLE"] = "1"
os.environ["NCCL_P2P_DISABLE"] = "1"

In [6]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import bitsandbytes as bnb

In [7]:
# Set up the 4-bit quantization configuration
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
#set group size
group_size = 128

#load tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", token=True)

# Load the model with quantization configuration
model_q = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b", torch_dtype=torch.bfloat16, token=True,
    quantization_config=quantization_config,
    device_map="auto"  # enables automatic device mapping
)

# Example of modifying model layers to include group size setting
for module in model_q.modules():
    if isinstance(module, bnb.nn.Linear4bit):
        module.group_size = group_size

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [8]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
# Test the model with a simple input
def generate_text(prompt, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model_q.generate(inputs.input_ids, max_length=max_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Example usage
prompt = "Who are you?"
generated_text = generate_text(prompt)
print(generated_text)

Who are you?

I am a 20 year old student at the University of California, Berkeley. I am currently studying Computer Science and Economics. I am also a member of the Berkeley College Republicans.

What are your goals for the


In [10]:
# Save the quantized model
model_q.save_pretrained("gemma-2-2b-nf4")
tokenizer.save_pretrained("gemma-2-2b-nf4")

('gemma-2-2b-nf4/tokenizer_config.json',
 'gemma-2-2b-nf4/special_tokens_map.json',
 'gemma-2-2b-nf4/tokenizer.model',
 'gemma-2-2b-nf4/added_tokens.json',
 'gemma-2-2b-nf4/tokenizer.json')