In [None]:
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_huggingface import ChatHuggingFace
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

import bs4
from langchain_chroma import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
import dotenv

dotenv.load_dotenv()

model_id = 'microsoft/Phi-3-mini-128k-instruct'

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    quantization_config=bnb_config,
)

pipe = HuggingFacePipeline(
    pipeline=pipeline(
        task='text-generation',
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=256,
    )
)

embedding = HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2', model_kwargs={'device': 'cuda'})
llm = ChatHuggingFace(
    model_id=model_id,
    llm=pipe,
    tokenizer=tokenizer
)

prompt_chat = ChatPromptTemplate.from_messages(
    [
        ('system', 'You are an helpful assistant and help with question answering using the given context.'),
        ('human', "context: {context}\nquestion: {question}")
    ]
)

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, chunk_overlap=200
)

loader = WebBaseLoader(
    web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
    bs_kwargs=dict(
        parse_only=bs4.SoupStrainer(
            class_=("post-content", "post-title", "post-header")
        )
    )
)

docs = loader.load()

splits = text_splitter.split_documents(docs)

vector_store = Chroma.from_documents(documents=splits, embedding=embedding)
retriever = vector_store.as_retriever()

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain = (
        {"context": retriever | format_docs, "question": RunnablePassthrough()}
        | prompt_chat
        | llm
        | StrOutputParser() | (lambda x: (x.split("<|assistant|>")[-1]).strip())
)

for chunk in rag_chain.stream("What is Task Decomposition?"):
    print(chunk, end="", flush=True)