In [None]:
import torch
from typing import Optional, List, Any
from pydantic import PrivateAttr


from langchain.llms.base import LLM
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import tool


from unsloth import FastLanguageModel

from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver




In [None]:

class UnslothLLM(LLM):
    model_name: str
    tools: Optional[List[Any]] = None
    _model: Any = PrivateAttr()
    _tokenizer: Any = PrivateAttr()
    _device: str = PrivateAttr()

    def __init__(
        self,
        model_name: str,
        max_seq_length: int,
        dtype,
        load_in_4bit: bool,
        device: str
    ):
        super().__init__(model_name=model_name)
        self.model_name = model_name

        print(f"[INIT] Loading model: {self.model_name}")
        self._model, self._tokenizer = FastLanguageModel.from_pretrained(
            model_name=model_name,
            max_seq_length=max_seq_length,
            dtype=dtype,
            load_in_4bit=load_in_4bit
        )
        FastLanguageModel.for_inference(self._model)
        self._model.to(device)
        self._device = device

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        """
        Generates a response from the model, encouraging an instruction style
        to reduce echoing. Uses basic generation params to reduce repetition.
        """
        print(f"[CALL] Model '{self.model_name}' => prompt: {prompt!r}")

        # System-level directive to reduce echoing
        system_msg = "You are a helpful assistant. Avoid echoing user queries.\n\n"

        instruction_prompt = (
            f"{system_msg}"
            f"### Instruction:\n{prompt}\n\n"
            f"### Response:\n"
        )

        inputs = self._tokenizer(instruction_prompt, return_tensors="pt").to(self._device)
        outputs = self._model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            top_k=50,
            repetition_penalty=1.2,
            no_repeat_ngram_size=3
        )
        response = self._tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Handle stop tokens if needed
        if stop is not None:
            for token in stop:
                idx = response.find(token)
                if idx != -1:
                    response = response[:idx]
                    break

        return response

    @property
    def _identifying_params(self) -> dict:
        return {"model_name": self.model_name}

    @property
    def _llm_type(self) -> str:
        return "unsloth"

    def bind_tools(self, tools: List[Any]) -> "UnslothLLM":
        self.tools = tools
        return self


class ChatUnslothLLM(UnslothLLM):
    def invoke(self, input: Any, config: Optional[dict] = None, **kwargs) -> AIMessage:
        if isinstance(input, dict) and "messages" in input and input["messages"]:
            prompt = input["messages"][-1].content
        else:
            prompt = str(input)

        response_text = self._call(prompt, **kwargs)
        return AIMessage(content=response_text, name="ai")




In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
max_seq_length = 2048
dtype = None
load_in_4bit = True

# "Hub" or "controller" model
base_model_name = "unsloth/llama-3-8b-bnb-4bit"
hub_llm = ChatUnslothLLM(
    model_name=base_model_name,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
    device=device
)

# LoRA-based psychology model
psychology_model_name = "lora_model_osloth_psychology"
psychology_hub_llm = ChatUnslothLLM(
    model_name=psychology_model_name,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
    device=device
)

# LoRA-based logic model
logical_model_name = "lora_model_osloth_commonsense_qa"
logical_hub_llm = ChatUnslothLLM(
    model_name=logical_model_name,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
    device=device
)




In [None]:

@tool
def psychological_understanding(query: str) -> str:
    """
    Use the LoRA psychology model to provide psychological insights
    about the supplied query.
    """
    print("[TOOL] psychological_understanding was called with:", query)
    return psychology_hub_llm._call(query)

@tool
def logical_reasoning(query: str) -> str:
    """
    Use the LoRA commonsense QA model to offer logical reasoning
    about the supplied query.
    """
    print("[TOOL] logical_reasoning was called with:", query)
    return logical_hub_llm._call(query)

def classify_and_toolflow(user_query: str) -> str:
    """
    1) Ask the base model to classify the query as 'psychology', 'logic', 'both', or 'none'.
    2) If 'psychology', call psychological_understanding.
       If 'logic', call logical_reasoning.
       If 'both', call both in sequence.
    3) Combine the results with the original query, then ask the base model for a final answer.
    """

    # 1) Classification prompt for the base model
    classification_prompt = (
        "The user query is: "
        f"\"{user_query}\"\n\n"
        "Decide if this is about psychology, logic, both, or none.\n"
        "Reply with exactly one word: 'psychology', 'logic', 'both', or 'none'."
    )

    classification = hub_llm._call(classification_prompt).lower()
    print("[DEBUG] Classification result:", classification)

    # 2) Based on classification, call the specialized tool(s)
    tool_responses = []
    if "psychology" in classification:
        tool_output = psychological_understanding(user_query)
        tool_responses.append(f"Psychology Tool Output:\n{tool_output}")
    if "logic" in classification:
        tool_output = logical_reasoning(user_query)
        tool_responses.append(f"Logic Tool Output:\n{tool_output}")

    # 3) Combine the user query + tool outputs into a final prompt for the base LLM
    if len(tool_responses) == 0:
        # If 'none', no tool used, just respond directly
        final_prompt = (
            f"User asked: \"{user_query}\"\n"
            "No specialized tools were used because classification was 'none'.\n\n"
            "Now please give a final answer to the user."
        )
    else:
        combined_tool_text = "\n\n".join(tool_responses)
        final_prompt = (
            f"User asked: \"{user_query}\"\n\n"
            f"You used the following tool(s) output:\n{combined_tool_text}\n\n"
            "Now, combine these results into a helpful final answer."
        )

    # 4) The base model returns a final, polished answer
    print("[Final Answer]: ")
    final_answer = hub_llm._call(final_prompt)
    return final_answer




In [None]:
if __name__ == "__main__":
    user_query = "I have a psychological issue. Give me advice."
    print("[USER QUERY]", user_query)

    # Instead of agent.invoke(...), we do our manual classification-based approach
    answer = classify_and_toolflow(user_query)

    print("\n===== Final Answer =====")
    print(answer)