In [None]:

import os
os.environ["TRANSFORMERS_CACHE"] = "./cache"

import torch
from langchain.llms.base import LLM
from llama_index import (
    ListIndex,
    LLMPredictor,
    PromptHelper,
    ServiceContext,
    SimpleDirectoryReader,
    StorageContext,
    load_index_from_storage
)
from transformers import pipeline

# number of output tokens
max_token = 256 
use_gpu = torch.cuda.is_available()


prompt_helper = PromptHelper(
    max_input_size=1024,
    num_output=max_token,
    max_chunk_overlap=20,
)


class LocalOPT(LLM):
    # model_name = "facebook/opt-iml-max-30b" (this is a 60gb model)
    model_name = "facebook/opt-iml-1.3b"  # ~2.63gb model

    if use_gpu:        
        pipeline = pipeline(
            "text-generation",
            model=model_name,
            device="cuda:0",
            model_kwargs={"torch_dtype": torch.bfloat16},
        )
    else:       
        pipeline = pipeline(
            "text-generation",
            model=model_name,
            model_kwargs={"torch_dtype": torch.bfloat16},
        )

    def _call(self, prompt: str, stop=None) -> str:
        response = self.pipeline(prompt, max_new_tokens=max_token)[0]["generated_text"]
        # only return newly generated tokens
        return response[len(prompt) :]

    @property
    def _identifying_params(self):
        return {"name_of_model": self.model_name}

    @property
    def _llm_type(self):
        return "custom"


def create_index():
    print("Creating index")
    # Wrapper around an LLMChain from Langchaim
    llm = LLMPredictor(llm=LocalOPT())
    # Service Context: a container for your llamaindex index and query

    service_context = ServiceContext.from_defaults(
        llm_predictor=llm, prompt_helper=prompt_helper
    )
    docs = SimpleDirectoryReader("news").load_data()
    index = ListIndex.from_documents(docs, service_context=service_context)
    print("Done creating index", index)
    return index


def execute_query():
    query_engine = index.as_query_engine()

    response = query_engine.query(
        "Who does China export its coal to in 2023?"
    )
    return response


if __name__ == "__main__":
    """
    Check if a local cache of the model exists,
    if not, it will download the model from huggingface
    """
    if not os.path.exists("data"):
        print("No local cache of model found, downloading from huggingface")
        index = create_index()
        index.storage_context.persist(persist_dir="data")
    else:
        print("Loading local cache of model")
        llm = LLMPredictor(llm=LocalOPT())
        service_context = ServiceContext.from_defaults(
            llm_predictor=llm, prompt_helper=prompt_helper
        )

        storage_context = StorageContext.from_defaults(persist_dir="data")
        index = load_index_from_storage(storage_context, service_context=service_context)


    response = execute_query()
    print(response)
    print(response.source_nodes)
