In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import textwrap
import time

In [None]:
model_name = "meta-llama/Llama-3.1-8B-Instruct"

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [None]:
device_map = {"": 0}

quantized_model = AutoModelForCausalLM.from_pretrained(
    model_name, quantization_config=bnb_config, device_map=device_map
)
print(quantized_model)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
#this function returns the outputs from the model received, and inputs.
def get_outputs(model, inputs, max_new_tokens=200):
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=max_new_tokens,
        repetition_penalty=1.1,
        early_stopping=False, #Can stop before reach the max_length
        eos_token_id=tokenizer.eos_token_id,
    )
    return outputs

In [None]:
# Example Input text
input_text = "Tell a short history of humanity with happy ending."
# Example Output function
def example_output_tokens(model, tokenizer, input_text):
    input_sentences = tokenizer(input_text, return_tensors="pt").to('cuda')
    foundational_outputs_sentence = get_outputs(model, input_sentences, max_new_tokens=100)
    return foundational_outputs_sentence

def example_output_text(tokenizer, tokens):
    return tokenizer.batch_decode(tokens, skip_special_tokens=True)

start = time.time()
tokens = example_output_tokens(quantized_model, tokenizer, input_text)
print(f"Time taken to generate tokens: {time.time() - start}")
text = example_output_text(tokenizer, tokens)
print(tokens)

In [None]:
def beautify_text(text):
    print("Generated Output:\n")
    for i, sentence in enumerate(text, 1):
        wrapped_sentence = textwrap.fill(sentence, width=80)
        print(f"Output {i}:\n{wrapped_sentence}\n")

print(beautify_text(text))

In [None]:
def get_model_memory_usage(model):
  """Calculates the memory usage of a PyTorch model."""
  total_memory = 0
  for param in model.parameters():
    total_memory += param.element_size() * param.numel()
  return total_memory

# Calculate memory usage for the quantized model
quantized_model_memory = get_model_memory_usage(quantized_model)
print(f"Quantized model memory usage: {quantized_model_memory / (1024**2):.2f} MB")
