In [None]:
import textwrap
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from langchain import HuggingFaceHub
from langchain.embeddings import HuggingFaceEmbeddings

def print_response(response: str):
    print(textwrap.fill(response, width=100))

# Check if CUDA is available, otherwise use CPU
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Special token IDs
BOS_TOKEN_ID = 1  # Beginning of sentence token ID
EOS_TOKEN_ID = 2  # End of sentence token ID

MAX_TOKENS = 1024  # Maximum number of tokens for generation

# Path to the pre-trained model
MODEL_NAME = "/content/open_llama_7b_preview_300bt/open_llama_7b_preview_300bt_transformers_weights"

# Initialize the tokenizer
tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME, add_prefix_space=True, add_special_tokens=True)

# Initialize the model
model = LlamaForCausalLM.from_pretrained(
    MODEL_NAME, local_files_only=True, torch_dtype=torch.float32, device=DEVICE
)

# Set the BOS token ID and convert EOS token to its ID
tokenizer.bos_token = tokenizer.eos_token
tokenizer.eos_token = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)

# Initialize HuggingFaceHub and HuggingFaceEmbeddings
hub = HuggingFaceHub()
embeddings = HuggingFaceEmbeddings(hub, model_name="openllama")

def generate_response(query):
    # Encode the user query using the tokenizer and move to the appropriate device
    input_ids = tokenizer.encode(query, add_special_tokens=False, return_tensors="pt").to(DEVICE)
    
    # Predict the language of the input query using HuggingFaceEmbeddings
    lang_id = embeddings.predict_language(query)

    with torch.no_grad():
        # Generate the response using the model
        output_token_ids = model.generate(
            input_ids=input_ids,
            max_length=MAX_TOKENS,
            pad_token_id=tokenizer.eos_token_id,
            num_return_sequences=1,
        )

    # Decode the generated output into text, skipping special tokens and cleaning up tokenization spaces
    response = tokenizer.decode(output_token_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return response

# Main chat loop
while True:
    user_input = input("User: ")
    if user_input.lower() in ["quit", "exit"]:
        break

    # Generate response based on user input
    response = generate_response(user_input)
    print_response("ChatBot: " + response)
