In [None]:
from typing import Optional, List, Mapping, Any
import torch
from transformers import pipeline
from transformers import T5Tokenizer, T5ForConditionalGeneration
from langchain.llms.base import LLM
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings
from llama_index.node_parser import SimpleNodeParser
from llama_index import SimpleDirectoryReader, GPTSimpleVectorIndex, LLMPredictor, LangchainEmbedding, ServiceContext, PromptHelper
from llama_index.logger import LlamaLogger

In [None]:
documents = SimpleDirectoryReader('./datasets/huggingface_docs/').load_data()

In [None]:
class CustomLLM(LLM):
    model_name: str = 't5-small'
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)

    def __init__(self):
        super().__init__()
        
    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
        outputs = self.model.generate(input_ids)
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return {"name_of_model": self.model_name}

    @property
    def _llm_type(self) -> str:
        return self.model_name

In [None]:
llm = CustomLLM()
llm('')

In [None]:
model_name = "hkunlp/instructor-large"
embed_instruction = "Represent the Hugging Face library documentation"
query_instruction = "Query the most relevant piece of information from the Hugging Face documentation"

# embedding_model = HuggingFaceEmbeddings()
embedding_model = HuggingFaceInstructEmbeddings(
    model_name=model_name,
    embed_instruction=embed_instruction,
    query_instruction=query_instruction
)

max_input_size = 4096
num_output = 256
max_chunk_overlap = 20
service_context = ServiceContext(
    llm_predictor=LLMPredictor(llm=CustomLLM()),
    embed_model=LangchainEmbedding(embedding_model),
    prompt_helper=PromptHelper(max_input_size, num_output, max_chunk_overlap),
    node_parser=SimpleNodeParser(),
    llama_logger=LlamaLogger()
)

In [None]:
parser = SimpleNodeParser()
nodes = parser.get_nodes_from_documents(documents)
nodes

In [None]:
index = GPTSimpleVectorIndex.from_documents(documents, service_context=service_context)
index.save_to_disk('index_v2.json')

In [None]:
index.query('how to create pipeline object?')#.response