In [1]:
import os
import warnings
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.llms import CustomLLM, CompletionResponse, CompletionResponseGen, LLMMetadata, LLM
from llama_index.core.llms.callbacks import llm_completion_callback
from huggingface_hub import InferenceClient
from snowflake.snowpark import Session
from snowflake.cortex import Complete
from dotenv import load_dotenv
from typing import Any

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
warnings.filterwarnings("ignore")
load_dotenv()

connection_params = {
    "account": os.environ["SNOWFLAKE_ACCOUNT"],
    "user": os.environ["SNOWFLAKE_USER"],
    "password": os.environ["SNOWFLAKE_USER_PASSWORD"],
}

In [3]:
snowflake_session = Session.builder.configs(connection_params).create()

In [4]:
client = InferenceClient(api_key=os.environ["HF_TOKEN"])
def complete(user_text):
    # completion = Complete(
    #     model="snowflake-arctic",
    #     prompt=user_text,
    #     session=snowflake_session,
    # )
    # return completion

    messages = [
        {
            "role": "user",
            "content": user_text
        }
    ]

    completion = client.chat.completions.create(
        model="meta-llama/Llama-3.2-3B-Instruct", 
        messages=messages, 
        max_tokens=500
    )

    return completion.choices[0].message.content

In [5]:
class RagoonBot(CustomLLM):
    context_window: int = 3900
    num_output: int = 256
    model_name: str = "mistral-large2"

    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window=self.context_window,
            num_output=self.num_output,
            model_name=self.model_name,
        )

    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        response = complete(prompt)
        return CompletionResponse(text=response)

    @llm_completion_callback()
    def stream_complete(
        self, prompt: str, **kwargs: Any
    ) -> CompletionResponseGen:
        # In streaming mode, we'll still receive the full response at the end of generate.
        # To truly stream token by token, you'd need to yield from within the generate function itself.
        # Here we simulate token-level streaming by splitting the final response.
        full_response = complete(prompt)

        accumulated_text = ""
        for char in full_response:
            accumulated_text += char
            yield CompletionResponse(text=accumulated_text, delta=char)

In [6]:
llm = RagoonBot()

In [7]:
llm.complete("Hello")

CompletionResponse(text="Hello! It's nice to meet you. Is there something I can help you with or would you like to chat?", additional_kwargs={}, raw=None, logprobs=None, delta=None)

In [8]:
class HyDETransformer(HyDEQueryTransform):
    def __init__(self, 
                 llm: LLM,
                 hyde_prompt: str = None,
                 include_original: bool = True):
        """
        Initializes the Hypothetical Document Embeddings 

        :param llm: str, default None. The LLM model to use.
        :param hyde_prompt: str, default None. The prompt to use for the HyDE model.
        :param include_original: bool, default True. Whether to include the original text in the output.
        """
        super().__init__(
            llm=llm,
            hyde_prompt=hyde_prompt,
            include_original=include_original
        )
        

    def transform(
        self,
        text: str = None
    ):
        """
        Transforms the input text into hypothetical document embeddings.

        :param text: str. The text to transform.
        :return: str. The transformed text.
        """
        if text is None:
            return "Please provide a text to transform."
        
        response = self.run(text)
        return response.custom_embedding_strs

In [10]:
text = "What are the effects of schizophrenia on memory?"
transformer = HyDETransformer(
    llm=llm
)
queries = transformer.transform(text)
for query in queries:
    print(query)

Schizophrenia is a chronic and severe mental disorder that affects not only an individual's thought processes and behaviors but also their cognitive abilities, particularly their memory. Research has consistently shown that schizophrenia has a significant impact on memory, leading to various cognitive deficits and impairments. One of the primary effects of schizophrenia on memory is the disruption of the hippocampus, a crucial brain region responsible for the formation and consolidation of new memories.

Individuals with schizophrenia often experience deficits in short-term memory, known as working memory, which is the ability to hold and manipulate information in working memory. Studies have shown that people with schizophrenia tend to have reduced working memory capacity, making it difficult for them to perform cognitive tasks, such as learning new information or following conversations. Additionally, individuals with schizophrenia are at a higher risk of developing state-dependent m