# Run Instruct Lab trained model in Notebook - Langchain Implementation

Run `ilab serve` in your terminal to serve your model

## LangChain Custom LLM Wrapper

In [1]:
from typing import Any, Dict, Iterator, List, Optional
import requests
import json
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk

class InstructLabLLM(LLM):
    """A custom chat model that communicates with an instructlab server.

    Example:

        .. code-block:: python

            model = InstructLabLLM(
                url="http://localhost:5000/your-endpoint",
                model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
                system_message="You are a helpful assistant."
            )
            result = model.invoke([HumanMessage(content="hello")])
            result = model.batch([[HumanMessage(content="hello")],
                                 [HumanMessage(content="world")]])
    """

    url: str
    model_name: str
    system_message: str
    """The URL of the instructlab server, the model name, and the system message."""

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Run the LLM on the given input.

        Args:
            prompt: The prompt to generate from.
            stop: Stop words to use when generating. Model output is cut off at the
                first occurrence of any of the stop substrings.
            run_manager: Callback manager for the run.
            **kwargs: Arbitrary additional keyword arguments. These are usually passed
                to the model provider API call.

        Returns:
            The model output as a string.
        """
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")

        payload = {
            "model": self.model_name,
            "messages": [
                {"role": "system", "content": self.system_message},
                {"role": "user", "content": prompt}
            ]
        }
        headers = {
            "Content-Type": "application/json"
        }

        response = requests.post(self.url, data=json.dumps(payload), headers=headers)

        if response.status_code == 200:
            result = response.json()
            return result['choices'][0]['message']['content']
        else:
            raise Exception(f"Request failed with status code {response.status_code}")

    def _stream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        """Stream the LLM on the given prompt.

        This method should be overridden by subclasses that support streaming.

        Args:
            prompt: The prompt to generate from.
            stop: Stop words to use when generating. Model output is cut off at the
                first occurrence of any of these substrings.
            run_manager: Callback manager for the run.
            **kwargs: Arbitrary additional keyword arguments. These are usually passed
                to the model provider API call.

        Returns:
            An iterator of GenerationChunks.
        """
        payload = {
            "model": self.model_name,
            "messages": [
                {"role": "system", "content": self.system_message},
                {"role": "user", "content": prompt}
            ]
        }
        headers = {
        'accept': 'application/json',
        'Authorization': 'Bearer test',
        'Content-Type': 'application/json'
        }

        response = requests.post(self.url, data=json.dumps(payload), headers=headers, stream=True)

        if response.status_code == 200:
            for line in response.iter_lines():
                if line:
                    result = json.loads(line.decode('utf-8'))
                    chunk = GenerationChunk(text=result.get("output", ""))
                    if run_manager:
                        run_manager.on_llm_new_token(chunk.text, chunk=chunk)
                    yield chunk
        else:
            raise Exception(f"Request failed with status code {response.status_code}")

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Return a dictionary of identifying parameters."""
        return {"url": self.url, "model_name": self.model_name, "system_message": self.system_message}

    @property
    def _llm_type(self) -> str:
        """Get the type of language model used by this chat model. Used for logging purposes only."""
        return "instructlab_llm"


## Import Libraries

In [18]:
from langchain_core.prompts import PromptTemplate
from langchain.chains.llm import LLMChain

In [7]:
# Replace with your actual server URL and model details
llm_config = {
    "base_llm_config": {
        "url": "http://127.0.0.1:8000/v1/chat/completions",
        "model_name": "models/granite-7b-lab-Q4_K_M.gguf",
        "system_message": "You are a helpful assistant."
    },
    "trained_llm_config": {
        "url": "http://127.0.0.1:1234/v1/chat/completions",
        "model_name": "instructlab-merlinite-7b-lab-trained/instructlab-merlinite-7b-lab-Q4_K_M.gguf",
        "system_message": "You are a helpful assistant."
    }
}

## Initialize the custom Instruct Lab LLMs

In [10]:
base_llm = InstructLabLLM(
    url=llm_config['base_llm_config']['url'], 
    model_name=llm_config['base_llm_config']['model_name'], 
    system_message=llm_config['base_llm_config']['system_message']
)

In [11]:
trained_llm = InstructLabLLM(
    url=llm_config['trained_llm_config']['url'], 
    model_name=llm_config['trained_llm_config']['model_name'], 
    system_message=llm_config['trained_llm_config']['system_message']
)

## Create a Prompt Template

In [14]:
template = """<|INSTRUCTION|>
Act as a coding assistant. You are good in generating test cases for a given code snippet. Analyse the given code snipet and generate test cases for it.

<|USER|>
Code Snippet:
{snippet}

<|ASSISTANT|>
Test Cases: """

prompt = PromptTemplate(input_variables=["snippet"], template=template)

## Create Langchain wrapper

In [21]:
# Create a LangChain with the InstructLab LLM and the prompt template
base_llm_chain = LLMChain(prompt=prompt, llm=base_llm)

In [22]:
trained_llm_chain = LLMChain(prompt=prompt, llm=trained_llm)

## Query

In [23]:
snippet = """
def mult(x, y):
    return x * y
"""

In [24]:
base_response = base_llm_chain.invoke(snippet)

In [33]:
trained_response = trained_llm_chain.invoke(snippet)

In [31]:
print("Base model response: ", "----", base_response['text'], "----", sep="\n")

Base model response: 
----
To thoroughly test the `mult` function, consider creating test cases that cover various scenarios, including positive and negative integers, zero, and floating-point numbers. Here are some test cases for your code snippet:

1. Positive Integer Multiplication:
   - Test with an integer value of 1: `mult(1, 2)` should return `3`.
   - Test with an integer value of 2: `mult(2, 3)` should return `9`.
   - Test with an integer value of 3: `mult(3, 4)` should return `12`.

2. Negative Integer Multiplication:
   - Test with a negative integer value of 1: `mult(-1, 2)` should return `-3`.
   - Test with a negative integer value of 2: `mult(-2, 3)` should return `-7`.
   - Test with a negative integer value of 3: `mult(-3, 4)` should return `-10`.

3. Zero Multiplication:
   - Test with zero: `mult(0, 2)` should return `0`.
   - Test with zero: `mult(0, 3)` should return `0`.

4. Floating-Point Number Multiplication:
   - Test with a floating-point number value of 1.1

In [35]:
print("InstructLab trained model response: ", "----", trained_response['text'], "----", sep="\n")

InstructLab trained model response: 
----
import unittest

class TestMult(unittest.TestCase):

    def test_mult_positive_numbers(self):
        self.assertEqual(mult(2, 3), 6)

    def test_mult_negative_numbers(self):
        self.assertEqual(mult(-2, -3), 6)

    def test_mult_mixed_numbers(self):
        self.assertEqual(mult(2, -3), -6)

if __name__ == "__main__":
    unittest.main()
----
