In [1]:
# # Code to for 4-bit model generation

# from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# import torch
# import time

# import bitsandbytes as bnb
# print(bnb.__version__)

# model_id = "mistralai/Mistral-7B-v0.1"

# # 4-bit quantization config
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_compute_dtype=torch.float16,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_use_double_quant=True
# )

# # Load tokenizer + model from Hugging Face in 4-bit
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModelForCausalLM.from_pretrained(
#     model_id,
#     quantization_config=bnb_config,
#     device_map="cuda",
#     torch_dtype=torch.float16
# )

# # Test prompt
# prompt = "Explain quantum computing in simple terms."
# inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# # Generate output
# with torch.no_grad():
#     output_ids = model.generate(
#         **inputs,
#         max_new_tokens=100,
#         do_sample=True,
#         temperature=0.7,
#         top_p=0.9
#     )

# print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

In [2]:
# Saving to the model

# save_path = "./mistral_7b_4bit_local"
# model.save_pretrained(save_path)
# tokenizer.save_pretrained(save_path)

In [3]:
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)

# Load model from local directory
local_model_path = "./mistral_7b_4bit_local"

tokenizer = AutoTokenizer.from_pretrained(local_model_path)
model = AutoModelForCausalLM.from_pretrained(
    local_model_path,
    quantization_config=bnb_config,
    device_map="cuda",
    torch_dtype=torch.float16
)

model.eval()

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )

In [4]:
# Prompt
prompt = "What is reinforcement learning?"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# ------------------------------
# Measure Prefill + Decode
# ------------------------------
max_new_tokens = 100

# 1️⃣ Prefill: forward pass on input prompt (no generation)
torch.cuda.synchronize()
start_prefill = time.time()

with torch.no_grad():
    _ = model(**inputs, use_cache=True)

torch.cuda.synchronize()
end_prefill = time.time()
prefill_time = end_prefill - start_prefill

# 2️⃣ Decode: generate tokens and measure TTFT and TBT
generated = inputs["input_ids"]
past_key_values = None
new_tokens = []

torch.cuda.synchronize()
start_decode = time.time()
first_token_time = None

with torch.no_grad():
    for i in range(max_new_tokens):
        # Run one forward step
        outputs = model(input_ids=generated[:, -1:], past_key_values=past_key_values, use_cache=True)
        logits = outputs.logits[:, -1, :]
        past_key_values = outputs.past_key_values

        # Sample next token
        next_token = torch.argmax(logits, dim=-1, keepdim=True)
        generated = torch.cat([generated, next_token], dim=-1)
        new_tokens.append(next_token.item())

        # Measure TTFT and per-token latency
        torch.cuda.synchronize()
        now = time.time()

        if i == 0:
            ttft = now - start_decode  # Time to first token
            last_token_time = now
        else:
            tbt = now - last_token_time
            last_token_time = now
            print(f"Token {i+1}: {tokenizer.decode(next_token[0])} | Δt = {tbt:.4f}s")

torch.cuda.synchronize()
end_decode = time.time()

decode_time = end_decode - start_decode
avg_tbt = (decode_time - ttft) / (len(new_tokens) - 1) if len(new_tokens) > 1 else 0

# ------------------------------
# Results
# ------------------------------
print("\n===== PERFORMANCE METRICS =====")
print(f"Prefill time: {prefill_time:.4f} s")
print(f"Total decode time: {decode_time:.4f} s")
print(f"Time to first token (TTFT): {ttft:.4f} s")
print(f"Average time between tokens (TBT): {avg_tbt:.4f} s")
print(f"Generated text:\n{tokenizer.decode(generated[0], skip_special_tokens=True)}")

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Token 2: HECK | Δt = 0.0625s
Token 3: Y | Δt = 0.0620s
Token 4: ES | Δt = 0.0608s
Token 5: ! | Δt = 0.0603s
Token 6: 
 | Δt = 0.0615s
Token 7: 
 | Δt = 0.0604s
Token 8: I | Δt = 0.0581s
Token 9: ’ | Δt = 0.0582s
Token 10: m | Δt = 0.0573s
Token 11: not | Δt = 0.0568s
Token 12: sure | Δt = 0.0566s
Token 13: if | Δt = 0.0580s
Token 14: I | Δt = 0.0586s
Token 15: ’ | Δt = 0.0582s
Token 16: m | Δt = 0.0581s
Token 17: more | Δt = 0.0581s
Token 18: excited | Δt = 0.0585s
Token 19: about | Δt = 0.0583s
Token 20: the | Δt = 0.0585s
Token 21: fact | Δt = 0.0567s
Token 22: that | Δt = 0.0573s
Token 23: I | Δt = 0.0571s
Token 24: ’ | Δt = 0.0566s
Token 25: m | Δt = 0.0584s
Token 26: going | Δt = 0.0583s
Token 27: to | Δt = 0.0576s
Token 28: be | Δt = 0.0565s
Token 29: able | Δt = 0.0640s
Token 30: to | Δt = 0.0607s
Token 31: see | Δt = 0.0569s
Token 32: the | Δt = 0.0571s
Token 33: new | Δt = 0.0564s
Token 34: Star | Δt = 0.0570s
Token 35: Wars | Δt = 0.0567s
Token 36: movie | Δt = 0.0565s
Token 