In [None]:
# Required packages, install if not installed (assume PyTorch* and Intel® Extension for PyTorch* is already present)
#import sys
#!conda install -y --quiet --prefix {sys.prefix}  -c conda-forge \
#    accelerate==0.23.0 \
#    validators==0.22.0 \
#    transformers==4.32.1 \
#    sentencepiece \
#    pillow \
#    ipywidgets \
#    ipython > /dev/null && echo "Installation successful" || echo "Installation failed"

import sys
import site
from pathlib import Path

!echo "Installation in progress..."
!{sys.executable} -m pip install -U transformers==4.35.2  --no-warn-script-location > /dev/null && echo "Installation successful" || echo "Installation failed"

def get_python_version():
    return "python" + ".".join(map(str, sys.version_info[:2]))

def set_local_bin_path():
    local_bin = str(Path.home() / ".local" / "bin") 
    local_site_packages = str(
        Path.home() / ".local" / "lib" / get_python_version() / "site-packages"
    )
    sys.path.append(local_bin)
    sys.path.insert(0, site.getusersitepackages())
    sys.path.insert(0, sys.path.pop(sys.path.index(local_site_packages)))

set_local_bin_path()

In [None]:
import logging
import os
import random
import re

os.environ["SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS"] = "1"
os.environ["ENABLE_SDP_FUSION"] = "1"
import warnings

# Suppress warnings for a cleaner output
warnings.filterwarnings("ignore")

import torch
import intel_extension_for_pytorch as ipex

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import LlamaTokenizer, LlamaForCausalLM
from transformers import BertTokenizer, BertForSequenceClassification

from ipywidgets import VBox, HBox, Button, Dropdown, IntSlider, FloatSlider, Text, Output, Label, Layout
import ipywidgets as widgets
from ipywidgets import HTML


# random seed
if torch.xpu.is_available():
    seed = 88
    random.seed(seed)
    torch.xpu.manual_seed(seed)
    torch.xpu.manual_seed_all(seed)

def select_device(preferred_device=None):
    """
    Selects the best available XPU device or the preferred device if specified.

    Args:
        preferred_device (str, optional): Preferred device string (e.g., "cpu", "xpu", "xpu:0", "xpu:1", etc.). If None, a random available XPU device will be selected or CPU if no XPU devices are available.

    Returns:
        torch.device: The selected device object.
    """
    try:
        if preferred_device and preferred_device.startswith("cpu"):
            print("Using CPU.")
            return torch.device("cpu")
        if preferred_device and preferred_device.startswith("xpu"):
            if preferred_device == "xpu" or (
                ":" in preferred_device
                and int(preferred_device.split(":")[1]) >= torch.xpu.device_count()
            ):
                preferred_device = (
                    None  # Handle as if no preferred device was specified
                )
            else:
                device = torch.device(preferred_device)
                if device.type == "xpu" and device.index < torch.xpu.device_count():
                    vram_used = torch.xpu.memory_allocated(device) / (
                        1024**2
                    )  # In MB
                    print(
                        f"Using preferred device: {device}, VRAM used: {vram_used:.2f} MB"
                    )
                    return device

        if torch.xpu.is_available():
            device_id = random.choice(
                range(torch.xpu.device_count())
            )  # Select a random available XPU device
            device = torch.device(f"xpu:{device_id}")
            vram_used = torch.xpu.memory_allocated(device) / (1024**2)  # In MB
            print(f"Selected device: {device}, VRAM used: {vram_used:.2f} MB")
            return device
    except Exception as e:
        print(f"An error occurred while selecting the device: {e}")
    print("No XPU devices available or preferred device not found. Using CPU.")
    return torch.device("cpu")


In [None]:
MODEL_CACHE_PATH = "/home/common/data/Big_Data/GenAI/llm_models"
class ChatBotModel:
    """
    ChatBotModel is a class for generating responses based on text prompts using a pretrained model.

    Attributes:
    - device: The device to run the model on. Default is "xpu" if available, otherwise "cpu".
    - model: The loaded model for text generation.
    - tokenizer: The loaded tokenizer for the model.
    - torch_dtype: The data type to use in the model.
    """

    def __init__(
        self,
        model_id_or_path: str = "openlm-research/open_llama_3b_v2",  # "Writer/camel-5b-hf",
        torch_dtype: torch.dtype = torch.bfloat16,
        optimize: bool = True,
    ) -> None:
        """
        The initializer for ChatBotModel class.

        Parameters:
        - model_id_or_path: The identifier or path of the pretrained model.
        - torch_dtype: The data type to use in the model. Default is torch.bfloat16.
        - optimize: If True, ipex is used to optimized the model
        """
        self.torch_dtype = torch_dtype
        self.device = select_device("xpu")
        self.model_id_or_path = model_id_or_path
        local_model_id = self.model_id_or_path.replace("/", "--")
        local_model_path = os.path.join(MODEL_CACHE_PATH, local_model_id)

        if (
            self.device == self.device.startswith("xpu")
            if isinstance(self.device, str)
            else self.device.type == "xpu"
        ):

            self.autocast = torch.xpu.amp.autocast
        else:
            self.autocast = torch.cpu.amp.autocast
        self.torch_dtype = torch_dtype
        try:
            if "llama" in model_id_or_path:
                self.tokenizer = LlamaTokenizer.from_pretrained(local_model_path)
                self.model = (
                    LlamaForCausalLM.from_pretrained(
                        local_model_path,
                        low_cpu_mem_usage=True,
                        torch_dtype=self.torch_dtype,
                    )
                    .to(self.device)
                    .eval()
                )
            else:
                self.tokenizer = AutoTokenizer.from_pretrained(
                    local_model_path, trust_remote_code=True
                )
                self.model = (
                    AutoModelForCausalLM.from_pretrained(
                        local_model_path,
                        low_cpu_mem_usage=True,
                        trust_remote_code=True,
                        torch_dtype=self.torch_dtype,
                    )
                    .to(self.device)
                    .eval()
                )
        except (OSError, ValueError, EnvironmentError) as e:
            logging.info(
                f"Tokenizer / model not found locally. Downloading tokenizer / model for {self.model_id_or_path} to cache...: {e}"
            )
            if "llama" in model_id_or_path:
                self.tokenizer = LlamaTokenizer.from_pretrained(self.model_id_or_path)
                self.model = (
                    LlamaForCausalLM.from_pretrained(
                        self.model_id_or_path,
                        low_cpu_mem_usage=True,
                        torch_dtype=self.torch_dtype,
                    )
                    .to(self.device)
                    .eval()
                )
            else:
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.model_id_or_path, trust_remote_code=True
                )
                self.model = (
                    AutoModelForCausalLM.from_pretrained(
                        self.model_id_or_path,
                        low_cpu_mem_usage=True,
                        trust_remote_code=True,
                        torch_dtype=self.torch_dtype,
                    )
                    .to(self.device)
                    .eval()
                )
            
        self.max_length = 256

        if optimize:
            if hasattr(ipex, "optimize_transformers"):
                try:
                    ipex.optimize_transformers(self.model, dtype=self.torch_dtype)
                except:
                    ipex.optimize(self.model, dtype=self.torch_dtype)
            else:
                ipex.optimize(self.model, dtype=self.torch_dtype)

    def prepare_input(self, previous_text, user_input):
        """Prepare the input for the model, ensuring it doesn't exceed the maximum length."""
        response_buffer = 100
        user_input = (
             "Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\n\n"
            f"### Instruction:\n{user_input}\n\n### Response:")
        combined_text = previous_text + "\nUser: " + user_input + "\nBot: "
        input_ids = self.tokenizer.encode(
            combined_text, return_tensors="pt", truncation=False
        )
        adjusted_max_length = self.max_length - response_buffer
        if input_ids.shape[1] > adjusted_max_length:
            input_ids = input_ids[:, -adjusted_max_length:]
        return input_ids.to(device=self.device)

    def gen_output(
        self, input_ids, temperature, top_p, top_k, num_beams, repetition_penalty
    ):
        """
        Generate the output text based on the given input IDs and generation parameters.

        Args:
            input_ids (torch.Tensor): The input tensor containing token IDs.
            temperature (float): The temperature for controlling randomness in Boltzmann distribution.
                                Higher values increase randomness, lower values make the generation more deterministic.
            top_p (float): The cumulative distribution function (CDF) threshold for Nucleus Sampling.
                           Helps in controlling the trade-off between randomness and diversity.
            top_k (int): The number of highest probability vocabulary tokens to keep for top-k-filtering.
            num_beams (int): The number of beams for beam search. Controls the breadth of the search.
            repetition_penalty (float): The penalty applied for repeating tokens.

        Returns:
            torch.Tensor: The generated output tensor.
        """
        print(f"Using max length: {self.max_length}")
        with self.autocast(
            enabled=True if self.torch_dtype != torch.float32 else False,
            dtype=self.torch_dtype,
        ):
            with torch.no_grad():
                output = self.model.generate(
                    input_ids,
                    pad_token_id=self.tokenizer.eos_token_id,
                    max_length=self.max_length,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                    num_beams=num_beams,
                    repetition_penalty=repetition_penalty,
                )
                return output

    def strip_response(self, generated_text):
        """Remove ### Response: from string if exists."""
        match = re.search(r'### Response:(.*)', generated_text, re.S)
        if match:
            return match.group(1).strip()
    
        else:
            return generated_text
        
    def unique_sentences(self, text: str) -> str:
        sentences = text.split(". ")
        if sentences[-1] and sentences[-1][-1] != ".":
            sentences = sentences[:-1]
        sentences = set(sentences)
        return ". ".join(sentences) + "." if sentences else ""

    def remove_repetitions(self, text: str, user_input: str) -> str:
        """
        Remove repetitive sentences or phrases from the generated text and avoid echoing user's input.

        Args:
            text (str): The input text with potential repetitions.
            user_input (str): The user's original input to check against echoing.

        Returns:
            str: The processed text with repetitions and echoes removed.
        """
        text = re.sub(re.escape(user_input), "", text, count=1).strip()
        text = self.unique_sentences(text)
        return text

    def extract_bot_response(self, generated_text: str) -> str:
        """
        Extract the first response starting with "Bot:" from the generated text.

        Args:
            generated_text (str): The full generated text from the model.

        Returns:
            str: The extracted response starting with "Bot:".
        """
        prefix = "Bot:"
        generated_text = generated_text.replace("\n", ". ")
        bot_response_start = generated_text.find(prefix)
        if bot_response_start != -1:
            response_start = bot_response_start + len(prefix)
            end_of_response = generated_text.find("\n", response_start)
            if end_of_response != -1:
                return generated_text[response_start:end_of_response].strip()
            else:
                return generated_text[response_start:].strip()
        return re.sub(r'^[^a-zA-Z0-9]+', '', generated_text)

    def interact(
        self,
        user_input, # Input text
        out: Output,  # Output widget to display the conversation
        with_context: bool = True,
        temperature: float = 0.10,
        top_p: float = 0.95,
        top_k: int = 40,
        num_beams: int = 3,
        repetition_penalty: float = 1.80,
    ) -> None:
        """
        Handle the chat loop where the user provides input and receives a model-generated response.

        Args:
            with_context (bool): Whether to consider previous interactions in the session. Default is True.
            temperature (float): The temperature for controlling randomness in Boltzmann distribution.
                                 Higher values increase randomness, lower values make the generation more deterministic.
            top_p (float): The cumulative distribution function (CDF) threshold for Nucleus Sampling.
                           Helps in controlling the trade-off between randomness and diversity.
            top_k (int): The number of highest probability vocabulary tokens to keep for top-k-filtering.
            num_beams (int): The number of beams for beam search. Controls the breadth of the search.
            repetition_penalty (float): The penalty applied for repeating tokens.
            """
            
        send_button.button_style = "warning"
        chat_spin.value = spin_style
        orig_input = ""
        with out:
            print(f" {user_color}{user_icon}You: {user_input}{default_color}")
        if user_input.lower() == "exit":
            return
        if "camel" in self.model_id_or_path:
                orig_input = user_input
                user_input = (
                    "Below is an instruction that describes a task. "
                    "Write a response that appropriately completes the request.\n\n"
                    f"### Instruction:\n{user_input}\n\n### Response:")
        self.max_length = 96
        input_ids = self.tokenizer.encode(user_input, return_tensors="pt").to(self.device)

        output_ids = self.gen_output(
            input_ids,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            num_beams=num_beams,
            repetition_penalty=repetition_penalty,
        )
        generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        generated_text = self.strip_response(generated_text)
        generated_text = self.extract_bot_response(generated_text)
        generated_text = self.remove_repetitions(generated_text, user_input)
        print(generated_text)