In [47]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers.utils import is_flash_attn_2_available 

# quantization_config = BitsAndBytesConfig(load_in_4bit=True,
#                                          bnb_4bit_compute_dtype=torch.float16)
model_id = 'google/gemma-2b-it' 
print(f"model_id: {model_id}")

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_id)

llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id, 
                                                 torch_dtype=torch.bfloat16) # which attention version to use

model_id: google/gemma-2b-it


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [48]:
llm_model

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaRMSNo

In [67]:
def get_model_num_params(model: torch.nn.Module):
    return sum([param.numel() for param in model.parameters()])

get_model_num_params(llm_model)

2506172416

In [68]:
def get_model_mem_size(model: torch.nn.Module):
    # Get model parameters and buffer sizes
    mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
    mem_buffers = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])

    # Calculate various model sizes
    model_mem_bytes = mem_params + mem_buffers # in bytes
    model_mem_mb = model_mem_bytes / (1024**2) 
    model_mem_gb = model_mem_bytes / (1024**3) 

    return {"model_mem_bytes": model_mem_bytes,
            "model_mem_mb": round(model_mem_mb, 2),
            "model_mem_gb": round(model_mem_gb, 2)}

get_model_mem_size(llm_model)

{'model_mem_bytes': 5012354048, 'model_mem_mb': 4780.15, 'model_mem_gb': 4.67}

In [69]:
input_text = "explique moi le schéma relationnel?"
print(f"Input text:\n{input_text}")

#craete prompt template for instruction-tuned model
dialogue_template = [
    {"role": "user",
     "content": input_text}
]

#apply the chat template
prompt = tokenizer.apply_chat_template(conversation=dialogue_template,
                                       tokenize=False, # keep it not tokenized
                                       add_generation_prompt=True)
print(f"\nPrompt (formatted):\n{prompt}")

Input text:
explique moi le schéma relationnel?

Prompt (formatted):
<bos><start_of_turn>user
explique moi le schéma relationnel?<end_of_turn>
<start_of_turn>model



In [70]:
%%time

inputs = tokenizer.encode(prompt, 
                      add_special_tokens= True,
                      return_tensors="pt" )

outputs = llm_model.generate(input_ids=inputs.to(llm_model.device), max_new_tokens=150)

CPU times: total: 2min 21s
Wall time: 1min 23s


In [71]:
outputs[0]

tensor([     2,      2,    106,   1645,    108, 214032,  19653,    709, 190251,
         10189,   2408, 235336,    107,    108,    106,   2516,    108,    688,
          7758,  36511,  10189,   2408,    688,    109,   2331, 190251,  10189,
          2408,   1455,   2360, 142772, 108088,   2459,  52324,   1437,   4106,
          3692,   1437,  48559,    499, 235303,    549,  26889, 235265,   4626,
         40357,   1437,  48559,    581,    533, 235303,  34853,   2499,    848,
        205633,   1008,   1437,   4106,   3692,  34092,   2499,    848,  70244,
        235265,    109,    688,  16053,    499, 235303,    549, 190251,  10189,
          2408,    865,    688,    109, 235287,   5231,   1836,  22181,    865,
           688,   5221,  48559,    581,    533, 235303,  34853, 235265,    108,
        235287,   5231,   5479,  41955,    865,    688,   2481,   4106,   3692,
          1437,  66262, 235265,    108, 235287,   5231,   4965, 123843,    865,
           688,   2481, 225592,   2459, 

In [72]:
#convert the output tokens into readable text

output_decoded = tokenizer.decode(outputs[0])
output_decoded

"<bos><bos><start_of_turn>user\nexplique moi le schéma relationnel?<end_of_turn>\n<start_of_turn>model\n**Schéma relationnel**\n\nUn schéma relationnel est une représentation graphique qui montre les relations entre les éléments d'un ensemble. Il présente les éléments de l'ensemble dans des boîtes et les relations entre eux dans des lignes.\n\n**Components d'un schéma relationnel :**\n\n* **Objets :** Les éléments de l'ensemble.\n* **Relationen :** Des relations entre les objets.\n* **Connecteurs :** Des symboles qui représentent les relations.\n* **Boîtes :** Des boîtes qui représentent les objets.\n\n**Types de relations :**\n\n* **Relation d'inclusion :** Si A est inclus dans B, cela signifie que A fait partie de B.\n* **Relation"

In [57]:
model_path = "model" 
llm_model.save_pretrained(model_path)