In [1]:
import os #Operações com o SO (arquivos)
import json #Leitura/escrita de arquivos JSON
import time #Sleep
import threading #Multithreading
import unicodedata #Normalização de string
import collections #Estrutura de contador e fila
import string #Operações com strings
import re #Expressões regulares
import abc #Classes abstratas
import warnings #Lançamento de warnings
from typing import Optional, Dict, Tuple, Any #Type hints

import torch #spacy não carrega sem importar antes (??)
import spacy #Separador em sentenças
import tqdm #Barra de progresso
import groq #API para o Llama 3 70B
from pyserini.search import SimpleSearcher #Busca nos documentos
import sentence_transformers #Rerankeamento
import bs4 #Remoção de tags HTML
import numpy as np #Operações com arrays
import matplotlib.pyplot as plt #Plots

  hasattr(torch, "has_mps")
  and torch.has_mps  # type: ignore[attr-defined]


In [2]:
if not os.path.isdir("data"):
    os.mkdir("data")

if not os.path.isfile("data\\context_articles.json"):
    !curl -LO https://iirc-dataset.s3.us-west-2.amazonaws.com/context_articles.tar.gz
    !move context_articles.tar.gz data
    !tar -xf data/context_articles.tar.gz
    !move context_articles.json data

if not os.path.isfile("data\\iirc_test.json"):
    !curl -LO https://iirc-dataset.s3.us-west-2.amazonaws.com/iirc_test.json
    !move iirc_test.json data

In [3]:
file = open("data\\context_articles.json", "r")
articles = json.load(file)
file.close()

In [4]:
file = open("data\\iirc_test.json", "r")
test_data = json.load(file)
file.close()

In [5]:
n_question = 50

In [7]:
test_data[0]["questions"][0]

{'answer': {'type': 'span',
  'answer_spans': [{'text': 'sky and thunder god',
    'passage': 'zeus',
    'type': 'answer',
    'start': 83,
    'end': 102}]},
 'question': 'What is Zeus know for in Greek mythology?',
 'context': [{'text': 'he Palici the sons of Zeus',
   'passage': 'main',
   'indices': [684, 710]},
  {'text': 'in Greek mythology', 'passage': 'main', 'indices': [137, 155]},
  {'text': 'Zeus (British English , North American English ; , Zeús ) is the sky and thunder god in ancient Greek religion',
   'passage': 'Zeus',
   'indices': [0, 110]}],
 'question_links': ['Greek mythology', 'Zeus']}

In [30]:
dataset = []

for i in range(len(test_data)):
    item = test_data[i]

    main_passage_name = item["title"]

    #Get the questions
    for q in item["questions"]:
        data = {}
        
        #Get and format the answer
        if q["answer"]["type"] == "span":
            data["answer"] = q["answer"]["answer_spans"][0]["text"]
        elif q["answer"]["type"] == "value":
            data["answer"] = q["answer"]["answer_value"]+" "+q["answer"]["answer_unit"]
        elif  q["answer"]["type"] == "none":
            continue
        elif q["answer"]["type"] == "binary":
            data["answer"] = q["answer"]["answer_value"]
        else:
            raise ValueError
        
        data["question"] = q["question"]

        context = ""
        for context_item in q["context"]:
            passage = context_item["passage"]
            if passage == "main":
                passage = main_passage_name

            context += f"{passage}: {context_item['text']}"
            context += "\n"

        data["context"] = context

        dataset.append(data)

        if len(dataset) == n_question:
            break
    if len(dataset) == n_question:
            break


In [35]:
del test_data, articles

In [36]:
class GroqInterface:
    '''
    Interface for using the Groq API

    Implements a rate limit control for multi-threading use. 
    '''

    _client :groq.Groq = None 

    LLAMA3_70B = "llama3-70b-8192"

    inference_lock = threading.Lock()
    time_waiter_lock = threading.Lock()
    SINGLE_THREAD = True

    def __init__(self, model:Optional[str]=None, api_key:Optional[str]=None, json_mode:bool=False, system_message:Optional[str]=None, n_retry:int=5):
        '''
        GroqInterface constructor.

        Args:
            model (str, optional): model to use. Llama3 70B is used if None. Default is None
            api_key (str, optional): Groq API key to use, if None will check the environment 'GROQ_API_KEY' variable. Default is None.
            json_mode (bool): if the model need to output in JSON. Default is False.
            system_message (str): the system message to send to the model, if needed. Default is None.
            n_retyr (int): number of times to retry if the model fails (not considering RateLimitError). Default is 5.
        '''
        
        if GroqInterface._client is None:

            if api_key is None:
                api_key = os.environ.get("GROQ_API_KEY")

            if api_key is None:
                raise RuntimeError("API key is not in the environment variables ('GROQ_API_KEY' variable is not set).")

            GroqInterface._client = groq.Groq(api_key=api_key)

        if model is None:
            model = GroqInterface.LLAMA3_70B
        self._model = model

        self._system_message = system_message


        if json_mode:
            self._response_format = {"type": "json_object"}
        else:
            self._response_format = None
        self._json_mode = json_mode

        self._n_retry = n_retry

    def __call__(self, prompt:str) -> str:
        '''
        Generates the model response

        Args:
            prompt (str): prompt to send to the model.

        Returns:
            str: model response. 
        '''
        done = False
        retry_count = 0
        while not done:
            try:
                if not GroqInterface.SINGLE_THREAD:
                    GroqInterface.inference_lock.acquire()
                    GroqInterface.inference_lock.release()

                messages = []
                if self._system_message is not None:
                    messages.append({"role":"system", "content":self._system_message})
                
                messages.append({"role":"user", "content":prompt})

                chat_completion = GroqInterface._client.chat.completions.create(
                        messages=messages,
                        model=self._model,
                        response_format=self._response_format
                    )
                
                done = True
            except groq.RateLimitError as exception: #Wait
                print("ERROR")
                print(exception)
                
                GroqInterface.error = exception
                if not GroqInterface.SINGLE_THREAD:
                    if not GroqInterface.time_waiter_lock.locked():
                        GroqInterface.time_waiter_lock.acquire()
                        GroqInterface.inference_lock.acquire()
                        time.sleep(2)
                        GroqInterface.time_waiter_lock.release()
                        GroqInterface.inference_lock.release()
                else:
                    time.sleep(2)

            except KeyboardInterrupt as e: #Stop the code
                raise e
            except Exception as e: #Retry
                retry_count += 1
                if retry_count >= self._n_retry:
                    raise e

        return chat_completion.choices[0].message.content

In [37]:
class Tool(abc.ABC):
    '''
    Base class for creating LLM agent tools.
    '''

    @abc.abstractmethod
    def __call__(self, query:str, context:str) -> Dict[str, str]:
        '''
        Execute the tool.

        Args:
            query (str): query for the tool execution.
            context (str): agent context in the tool execution moment.

        Returns:
            Dict[str, str]: tool results.
        '''
        ...

In [38]:
class QuestionGenerator(Tool, GroqInterface):



    _system_message = '''You are a question generator that outputs in JSON. 
The JSON object must use the schema: {'questions':['str', 'str', ...]}

Please use a valid JSON format.'''

    _base_prompt = '''Generate questions for the given answer:

Answer: {answer}
'''

    def __init__(self, model: Optional[str] = None, api_key: Optional[str] = None):

        super().__init__(model, api_key, True, QuestionGenerator._system_message)

    def __call__(self, query:Optional[Any]=None, context:str=None) -> Dict[str, str]:

        
        prompt = QuestionGenerator._base_prompt.format(answer=query)

        return json.loads(GroqInterface.__call__(self, prompt=prompt))

In [41]:
question_generator = QuestionGenerator()

In [48]:
questions = question_generator(dataset[0]["answer"])

In [49]:
questions

{'questions': ['Who is the Greek god of the sky and thunder?',
  'Which Greek god is often depicted holding a lightning bolt?',
  'What was the name of the Greek god of the sky and thunder in ancient Greek mythology?',
  'Who is the Greek god often associated with the sky and thunderbolt?',
  'What is the name of the Greek god of the sky and thunder?']}

In [50]:
embedder_model:str="all-MiniLM-L6-v2"
embedder = sentence_transformers.SentenceTransformer(embedder_model)

In [51]:
q_embedding = embedder.encode(dataset[0]["question"], convert_to_tensor=True)
qi_embeddings = embedder.encode(questions["questions"], convert_to_tensor=True)

In [57]:
cosine_scores = sentence_transformers.util.cos_sim(q_embedding, qi_embeddings)

score = cosine_scores.sum().item()
score /= len(qi_embeddings)

score

0.575706958770752