In [1]:
# Import the necessary libraries
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [2]:
# Check CUDA is working
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

2.9.0+cu130
13.0
True


In [3]:
# Config for loading the model in 4 bits
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, # original is 32 bit
    bnb_4bit_quant_type="nf4", # gaussian distribution
    bnb_4bit_use_double_quant=True, # 32 -> 8 -> 4 bits
    bnb_4bit_compute_dtype=torch.float16 # compute in float16
)

In [4]:
# Load the model with our config
model_id = "mistralai/Mistral-7B-v0.3"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto", # use CUDA if available
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
# Check the model is loaded in 4bit
for name, module in model.named_modules():
    if "Linear" in str(type(module)) or "4bit" in str(type(module)):
        print(f"{name} -> {type(module)}")

model.layers.0.self_attn.q_proj -> <class 'bitsandbytes.nn.modules.Linear4bit'>
model.layers.0.self_attn.k_proj -> <class 'bitsandbytes.nn.modules.Linear4bit'>
model.layers.0.self_attn.v_proj -> <class 'bitsandbytes.nn.modules.Linear4bit'>
model.layers.0.self_attn.o_proj -> <class 'bitsandbytes.nn.modules.Linear4bit'>
model.layers.0.mlp.gate_proj -> <class 'bitsandbytes.nn.modules.Linear4bit'>
model.layers.0.mlp.up_proj -> <class 'bitsandbytes.nn.modules.Linear4bit'>
model.layers.0.mlp.down_proj -> <class 'bitsandbytes.nn.modules.Linear4bit'>
model.layers.1.self_attn.q_proj -> <class 'bitsandbytes.nn.modules.Linear4bit'>
model.layers.1.self_attn.k_proj -> <class 'bitsandbytes.nn.modules.Linear4bit'>
model.layers.1.self_attn.v_proj -> <class 'bitsandbytes.nn.modules.Linear4bit'>
model.layers.1.self_attn.o_proj -> <class 'bitsandbytes.nn.modules.Linear4bit'>
model.layers.1.mlp.gate_proj -> <class 'bitsandbytes.nn.modules.Linear4bit'>
model.layers.1.mlp.up_proj -> <class 'bitsandbytes.nn.

In [None]:
# Setup for the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token # set padding as EOS
print(f"Vocabulary Size: {len(tokenizer)}")

Vocabulary Size: 32768


In [7]:
# Experimenting with tokenizer
sentence = "What's the craic?"
tokens = tokenizer(sentence)
print(f"Input IDs: {tokens["input_ids"]}")
print(f"Tokens (Encoded): {tokenizer.convert_ids_to_tokens(tokens["input_ids"])}")
print(f"Original (Decoded): {tokenizer.decode(tokens["input_ids"])}")

Input IDs: [1, 2592, 29510, 29481, 1040, 1045, 1288, 1062, 29572]
Tokens (Encoded): ['<s>', '▁What', "'", 's', '▁the', '▁c', 'ra', 'ic', '?']
Original (Decoded): <s> What's the craic?


In [25]:
# Experimenting with batching
sentences = ["Sound lad", "That's grand", "Ye eejit"]
batch = tokenizer(
    sentences,
    padding=True,
    truncation=True,
    return_tensors="pt"
)
print(f"Input IDs: {batch["input_ids"]}")
print(f"Tokens (Encoded):")
for ids in batch["input_ids"]:
    print(f"{tokenizer.convert_ids_to_tokens(tokens["input_ids"])}")
    print(f"Original (Decoded): {tokenizer.decode(tokens["input_ids"])}")

Input IDs: tensor([[    2,     2,     1, 15579, 10544],
        [    1,  2493, 29510, 29481,  4255],
        [    1, 23470,  1085, 16576,  1047]])
Tokens (Encoded):
['<s>', '▁What', "'", 's', '▁the', '▁c', 'ra', 'ic', '?']
Original (Decoded): <s> What's the craic?
['<s>', '▁What', "'", 's', '▁the', '▁c', 'ra', 'ic', '?']
Original (Decoded): <s> What's the craic?
['<s>', '▁What', "'", 's', '▁the', '▁c', 'ra', 'ic', '?']
Original (Decoded): <s> What's the craic?


In [34]:
# Test the quantized model before fine-tuning
prompt_1 = "What's the craic?"
prompt_2 = "What's the story lad?"
prompts = [prompt_1, prompt_2]
inputs = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to("cuda")
with torch.no_grad():
    outputs = model.generate(
        **inputs, 
        max_new_tokens=128, 
        do_sample=True,
    )
for i, output in enumerate(outputs):
    response = tokenizer.decode(output, skip_special_tokens=True)
    print(response)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


What's the craic?�����0�''½�6‚5´850��573�677141�650´5���5�6¡�6�3��'♡�5���'�7�����6‒�76���‰6�6��´��∞���5�86₂��1������1���3��´‚�′�7��‒2´4‍�ʹ4�24′�25
What's the story lad?�6�34�6�5����42��94�―8��4�6�8���6����ʼ��2����´1′�7���5��5�4¡�3�087′�¾��¾��47�‏�½´2��81�7���´�6‒��´´����8����0�‍�′½8��54¸�400̀7¾3
