# LangChain integration 🦜🔗

This notebook shows how to integrate Lit-GPT with LangChain!

In [None]:
# clone Lit-GPT
!git clone https://github.com/Lightning-AI/lit-gpt
%cd lit-gpt/

In [None]:
# for CUDA
!pip install --index-url https://download.pytorch.org/whl/nightly/cu118 --pre 'torch>=2.1.0dev' -q

# install the dependencies
!pip install .
!pip install langchain

In [None]:
from typing import Any, List, Mapping, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM

In [None]:
from lit_gpt.generate.base import build_llm, generate
from lit_gpt import GPT, Tokenizer

checkpoint_dir = "checkpoints/tiiuae/falcon-7b"
devices = 1
quantize = "bnb.int8"
max_new_tokens = 50
top_k = 200
temperature = 0.8

In [None]:
model, tokenizer, fabric = build_llm(checkpoint_dir=checkpoint_dir, devices=devices, quantize=quantize)

We will create a (CustomLLM)[https://python.langchain.com/docs/modules/model_io/models/llms/how_to/custom_llm], which is a callable class and it will be responsible for interacting with our LLM.

In [None]:
class LitGPTLLM(LLM):
    model: Any
    tokenizer: Tokenizer

    @property
    def _llm_type(self) -> str:
        return "lit-gpt"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
    ) -> str:
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        
        encoded = self.tokenizer.encode(prompt, device=self.model.device)
        prompt_length = encoded.size(0)
        max_returned_tokens = prompt_length + max_new_tokens
        assert max_returned_tokens <= self.model.config.block_size, (
            max_returned_tokens,
            self.model.config.block_size,
        )  # maximum rope cache length
        y = generate(
            self.model,
            encoded,
            max_returned_tokens,
            max_seq_length=max_returned_tokens,
            temperature=temperature,
            top_k=top_k,
        ) 
        model.reset_cache()
        return self.tokenizer.decode(y)

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {"name": self.model.config.name}

In [None]:
llm = LitGPTLLM(model=model, tokenizer=tokenizer)

In [None]:
print(llm("hello"))