In [73]:
# Import required libraries
from deep_translator import GoogleTranslator
from transformers import TextIteratorStreamer
from threading import Thread
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import gradio as gr
from datetime import datetime, timedelta

In [95]:
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">DTT Chatbot Piloto - Gemma 2B-it</h1>
<p> Este espacio es un piloto sobre el uso de un "instruction-tuned model", en este caso <a href="https://huggingface.co/google/gemma-1.1-2b-it"><b>Gemma 2B Chat</b></a>. Gemma es un LLM de código abierto, el cual viene en dos tamaños: 2B y 8B.</p>
<p>🔎 Para más detalles sobre el modelo y sus capacidades de Gemma con <code>transformers</code>, puedes visitar <a href="https://huggingface.co/blog/gemma">este blog</a>.</p>
<p>🦕 Este chatbot piloto tiene como finalidad responder preguntas orientadas a los practicantes de la ECYS. Este chatbot fue alimentado con información oficial provista por el DTT.</p>
<p>❗¿El chatbot está tardando mucho en responder o no responde?❗ ¡Presiona el botón de "Aviso de pánico" para enviar un aviso y que el chatbot sea reseteado manualmente!</p>
</div>
'''

LICENSE = """
<p/>
---
Built with Gemma 2B-it
"""

PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
   <img src="https://www.lavanguardia.com/andro4all/hero/2024/04/google-gemma.png?width=1200&aspect_ratio=16:9" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;"> 
   <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Gemma 2B-it</h1>
   <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Pregúntame cualquier duda...</p>
</div>
"""


css = """
h1 {
  text-align: center;
  display: block;
}
#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
"""

In [8]:
#Set the model
model_id = "google/gemma-1.1-2b-it"

#bits and bytes config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, #Load the model in 4 bits
    bnb_4bit_quant_type="nf4", #Quantization type 4 bits
    bnb_4bit_compute_dtype=torch.bfloat16 #Data type for calculation
)

# Load the pre-trained tokenizer associated with the model specified by `model_id`
# `padding_side="right"` indicates that padding should be applied to the right side of the sequences
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right")

# Load the pre-trained causal language model associated with `model_id`
# `quantization_config=bnb_config` specifies the quantization configuration to load the model in 4 bits
# `device_map={"":0}` assigns the model to the first available CUDA device, typically the first GPU (index 0)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})

# Definition of terminators
terminators = [
    tokenizer.eos_token_id, # ID of the end-of-sequence token
    tokenizer.convert_tokens_to_ids("<|eot_id|>") # ID of the blank token (in this case, assumed to represent the end of the sequence)
]



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [9]:
# Funtion to translate from Spanish to English and vice versa
def translate_text(text,mode):
        try:
            # print("Traduciendo:",text,"--",type(text))
            # Define the source and destination of the translation
            src = 'es' if mode else 'en'
            dest = 'en' if mode else 'es'
            #print(src,"--",dest)

            # Call the Google API to do the translation
            translated_text = GoogleTranslator(source=src, target=dest).translate(text) 
            # print("Resultado:",translated_text,type(translated_text))
            
            return translated_text
        except Exception as e:
            # print(f"Error al traducir:{text} --->{e}")
            return "Tuve un problema de traducción ¿puedes volver a preguntar?"  # In case of error, returns the original text

In [10]:
def chat_gemma_2b(message: str, 
              history: list, 
              temperature: float, 
              max_new_tokens: int
             ) -> str:
    """
    Generate a streaming response.
    Args:
        message (str): The input message.
        history (list): The conversation history used by ChatInterface.
        temperature (float): The temperature for generating the response.
        max_new_tokens (int): The maximum number of new tokens to generate.
    Returns:
        str: The generated response.
    """
    message = translate_text(message,True) # Translate es to en

    # Concat all the conversation to use as a context
    # The convesarion has format with specifict format
    conversation = []
    for user, assistant in history:
        conversation.extend([
                             {"role": "user", "content": user}, # Add the user message
                             {"role": "assistant", "content": assistant}# Add the assistant response
                            ])
    conversation.append({"role": "user", "content": message}) # add user current message

    # Convert the conversation history into input IDs using the tokenizer
    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)

    # Initialize the text streamer to handle the output text generation
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids= input_ids, # Input IDs generated from the conversation history
        streamer=streamer, # Text streamer for handling the output text generation
        max_new_tokens=max_new_tokens, # Maximum number of new tokens to generate
        do_sample=True, # Enable sampling for generating the output
        temperature=temperature, # Temperature setting to control randomness
        eos_token_id=terminators, # End-of-sequence token IDs
    )
    
    # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.             
    if temperature == 0:
        generate_kwargs['do_sample'] = False

    # Start a new thread to run the model's generate method with the specified arguments
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    #outputs = []
    #for text in streamer:
        #outputs.append(text)
        #print(outputs)
        #yield "".join(outputs)
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
    partial_message = translate_text(partial_message,False)
    yield partial_message

In [99]:
# Define the function to be called when the button is clicked
momento_inicial = None
def on_button_click():
    global momento_inicial
    if momento_inicial is None:
        momento_inicial = datetime.now()
    else:
        diferencia = datetime.now() - momento_inicial
        intervalo_deseado = timedelta(minutes=5)
        if diferencia >= intervalo_deseado:
            momento_inicial = None
            return on_button_click()
        else:
            tiempo_restante = intervalo_deseado - (datetime.now() - momento_inicial)
            tiempo_restante_str = str(tiempo_restante).split('.')[0]
            return f"Ya se tiene en cola un reporte hace menos de 5 minutos. Podrás enviar uno en {tiempo_restante_str}."
    return "¡Aviso enviado, gracias por el reporte!"

In [100]:
# Gradio block
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')

with gr.Blocks(fill_height=True, css=css) as demo:
    """
    Create a Gradio interface for the GEMMA 2B chat model.

    Parameters:
    fill_height (bool): Specifies if the height of the interface should be filled.
    css (str): Custom CSS for styling the interface.
    """
    
    gr.Markdown(DESCRIPTION) # Add a Markdown section with the description of the interface

    # Add a custom button and connect it to the on_button_click function
    custom_button = gr.Button("Aviso de pánico")
    output_text = gr.Textbox(label="Mensaje aviso de pánico")

    custom_button.click(on_button_click, inputs=[], outputs=output_text)
    
    # Create a chat interface for interacting with the GEMMA 2B model
    gr.ChatInterface(
        fn=chat_gemma_2b, # Function to call for generating responses
        chatbot=chatbot,  # Chatbot object to handle the conversation
        fill_height=True, # Specifies if the height of the chat interface should be filled
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), # Accordion for additional input parameters
        additional_inputs=[
            # Slider for adjusting the temperature parameter
            gr.Slider(minimum=0,
                      maximum=1, 
                      step=0.1,
                      value=0.95, 
                      label="Temperature", 
                      render=False),
            # Slider for adjusting the maximum number of new tokens
            gr.Slider(minimum=128, 
                      maximum=4096,
                      step=1,
                      value=512, 
                      label="Max new tokens", 
                      render=False ),
            ],
        examples=[
            ['¿Cómo establecer una base humana en Marte? Da una respuesta breve.'],
            ['Explícame la teoría de la relatividad como si tuviera 8 años.'],
            ['¿Cuánto es 9,000 * 9,000?'],
            ['Escribe un mensaje de cumpleaños lleno de juegos de palabras para mi amigo Alex.'],
            ['Justifica por qué un pingüino podría ser un buen rey de la jungla.']
            ],
        cache_examples=False, # Do not cache the examples
                     )


    # Add a Markdown section with the license information
    gr.Markdown(LICENSE)

In [101]:
 demo.launch()

Running on local URL:  http://127.0.0.1:7889

To create a public link, set `share=True` in `launch()`.




In [102]:
demo.launch(share=True)

Rerunning server... use `close()` to stop if you need to change `launch()` parameters.
----
Running on public URL: https://341926d3912da37424.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


