In [1]:
from typing import Optional, List, Mapping, Any
from mlx_lm import load, generate
from llama_index.core import SimpleDirectoryReader, SummaryIndex, Settings
from llama_index.core.callbacks import CallbackManager
from llama_index.core.llms import (
    CustomLLM,
    CompletionResponse,
    CompletionResponseGen,
    LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback 
from pydantic import BaseModel, validator

class OurLLM(CustomLLM, BaseModel):
    model: Optional[Any] = None
    tokenizer: Optional[Any] = None

    def __init__(self, **data):
        super().__init__(**data)  # Initialize BaseModel part with data
        # Directly load the model and tokenizer
        self.model, self.tokenizer = load("mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
    context_window: int = 4096
    max_tokens : int = 500
    model_name: str = "custom"
    
    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window = self.context_window,
            model_name=self.model_name,
            max_tokens=self.max_tokens
        )

    def process_generated_text(self, text: str) -> str:
        token_pos = text.find("\n\n")
        if token_pos != -1:
            # Truncate text at the first occurrence of two new lines
            return text[:token_pos]
        return text

    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
       # Remove 'formatted' argument if present
        kwargs.pop('formatted', None)
    
        generated_text = generate(self.model, self.tokenizer, prompt=prompt, verbose=False, **kwargs)
        processed_text = self.process_generated_text(generated_text)
        return CompletionResponse(text=processed_text)


    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        generated_text = generate(self.model, self.tokenizer, prompt=prompt, verbose=False, **kwargs)
        processed_text = self.process_generated_text(generated_text)
        for char in processed_text:  
            yield CompletionResponse(text=char, delta=char)

In [2]:
# Define our LLM
Settings.llm = OurLLM()

# Define embed model
Settings.embed_model = "local:BAAI/bge-base-en-v1.5"

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

In [3]:
from llama_index.core import VectorStoreIndex
documents = SimpleDirectoryReader("data").load_data()
index = VectorStoreIndex.from_documents(documents, show_progress=True)

Parsing nodes:   0%|          | 0/19 [00:00<?, ?it/s]

Generating embeddings:   0%|          | 0/20 [00:00<?, ?it/s]

In [4]:
query_engine = index.as_query_engine()
response = query_engine.query("What is the first sentence of the constitution")
print (response)

We the People of the United States, in Order to form a  more perfect Union, establish Justice, insure domestic  Tranquility, provide for the common defence, promote the general  Welfare, and secure the Bless Liberty to ourselves  and our Posterity,  do ordain  and establish this Constitution for the United States of America
