In [None]:
# import all necessary libraries

from transformers import AutoTokenizer, AutoConfig
from transformers import AutoModelForCausalLM
from safetensors.torch import load_file
import torch
import os

In [2]:
MODEL_DIR_PREFIX = "gemma-2-2b-it"
# MODEL_SIZE = "2b"
# MODEL_DIR_SUFFIX = "-it"
MODEL_SIZE = ""
MODEL_DIR_SUFFIX = ""

In [None]:
# check what device is avaliable to run torch
message:str = "[+] Running Torch over "
if torch.cuda.is_available():
    message += "CUDA"
    device = torch.device("cuda")
    #TODO: set torch over CUDA
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    message += "MPS"
else:
    device = torch.device("cpu")
    message += "CPU"
print(message)

In [4]:
# Load essentials for the model
model_path = "./" + MODEL_DIR_PREFIX + MODEL_SIZE + MODEL_DIR_SUFFIX

config = AutoConfig.from_pretrained(model_path + "/config.json")

tokenizer = AutoTokenizer.from_pretrained(model_path)

file_paths:list = [x for x in os.listdir(model_path) if ('.safetensor' in x) and ('model-' in x)]

# Load each file and update the state_dict
state_dict = {}
for file_path in file_paths:
    part_state_dict = load_file(model_path + '/' + file_path)
    state_dict.update(part_state_dict)

In [None]:
# Instantiate the model
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=None,  # Since we're loading from state_dict, this is None
    config=config,
    state_dict=state_dict,
    torch_dtype=torch.bfloat16
)
model.to(device)

In [None]:
# Tokenize the input text
input_text = """
\nRequirment:\n
 - Use **Chatting Tone** to chat.\n
 - Provide **Brief** response.\n
 
\nUser: What can you do?.
\nChatbot: 
"""

def generate_text(text:str):
    input_ids = tokenizer(text, return_tensors="pt").to(device)
    outputs = model.generate(
        **input_ids,
        max_new_tokens=150,  # You can adjust this based on your needs
        eos_token_id=tokenizer.eos_token_id,  # Stops when EOS token is generated
        pad_token_id=tokenizer.eos_token_id,  # Optionally set padding token to EOS
        # add randomnes to the output
        do_sample=True,
        temperature=0.7,
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

print(generate_text(input_text))