In [None]:
import gradio as gr
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationChain
from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent
from langchain.prompts import StringPromptTemplate
from langchain.schema import AgentAction, AgentFinish
from langchain.chains import LLMChain
from langchain.agents import AgentOutputParser 
from typing import List, Union
import requests
import json
import logging
import re
from rouge import Rouge
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

# Set up logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)



## Get the API key

In [None]:
# Function to register and get API key
def register_user():
    register_url = "http://127.0.0.1:8899/v1/register"
    try:
        response = requests.post(register_url)
        response.raise_for_status()
        api_key = response.json()["api_key"]
        logger.info("Successfully registered and received API key")
        return api_key
    except requests.exceptions.RequestException as e:
        logger.error(f"Failed to register user: {str(e)}")
        raise

# Get or create API key
try:
    with open("api_key.txt", "r") as f:
        api_key = f.read().strip()
    logger.info("Loaded existing API key")
except FileNotFoundError:
    api_key = register_user()
    with open("api_key.txt", "w") as f:
        f.write(api_key)
    logger.info("Registered new user and saved API key")

## Custom LLM
- Same as Project 1
- Only difference is the API key

In [None]:
class CustomLLM(LLM):
    api_url: str = "http://127.0.0.1:8899/v1/completions"
    api_key: str = None

    def __init__(self, api_key: str):
        super().__init__()
        self.api_key = api_key

    def _call(
        self,
        prompt: str,
        stop = None,
        run_manager = None,
    ):
        headers = {
            "Content-Type": "application/json",
            "X-API-Key": self.api_key
        }
        data = {
            "prompt": prompt + "\nAnswer:",
            "max_tokens": 500,
            "temperature": 0.7,
            "top_p": 1.0,
            "n": 1,
            "stop": stop or ["Human:", "\n\n"]
        }
        try:
            logger.info(f"Sending prompt to API: {prompt}")
            response = requests.post(self.api_url, headers=headers, json=data)
            response.raise_for_status()
            result = response.json()['choices'][0]['text']
            logger.info(f"Received response from API: {result}")
            return result.strip()
        except requests.exceptions.RequestException as e:
            logger.error(f"API request failed: {str(e)}")
            return f"Sorry, I encountered an error: {str(e)}"
        except KeyError as e:
            logger.error(f"Unexpected API response format: {str(e)}")
            return f"Sorry, I received an unexpected response format: {str(e)}"

    @property
    def _llm_type(self):
        return "custom"

# Initialize the custom LLM
llm = CustomLLM(api_key=api_key)

logger.info("Custom LLM initialized")


## FOL translator function
- Uses the fine-tuned GPT-2

In [None]:
# Load the fine-tuned GPT-2 model and tokenizer
model_path = "./gpt-2-nl-to-fol-merged"
logger.info(f"Loading GPT-2 model from path: {model_path}")
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path)
logger.info("GPT-2 model and tokenizer loaded successfully")

# Set the padding token
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
logger.info("Padding token set for GPT-2 model")

# GPT-2 Tool
def gpt2_fol_translation(question: str) -> str:
    logger.info(f"Translating to FOL using GPT-2: {question}")
    SYSTEM_PROMPT = "Translate the following natural language question to First Order Logic (FOL). Please respond with only the FOL statement. Don't include additional text.\nQuestion: "
    full_input = SYSTEM_PROMPT + question + "\nFOL Query:"
    input_ids = tokenizer.encode(full_input, return_tensors="pt")
    
    with torch.no_grad():
        output = model.generate(input_ids, max_length=150, num_return_sequences=1, temperature=0.7)
    
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    fol_output = generated_text.split("FOL Query:")[-1].strip()
    logger.info(f"GPT-2 FOL translation: {fol_output}")
    return fol_output


## Judge

In [None]:
# Judge Tool

def call_judge(question: str) -> str:
    logger.info(f"Translating to FOL using Judge: {question}")
    api_url = "http://127.0.0.1:8899/v1/judge"
    headers = {
        "Content-Type": "application/json",
        "X-API-Key": api_key
    }
    data = {
        "prompt": f"{question}",
        "max_tokens": 512,
        "temperature": 0.7
    }
    try:
        response = requests.post(api_url, headers=headers, data=json.dumps(data))
        response.raise_for_status()
        result = response.json()["choices"][0]["text"].strip()
        logger.info(f"Judge FOL translation: {result}")
        return result
    except requests.exceptions.RequestException as e:
        logger.error(f"Judge API request failed: {str(e)}")
        return f"Error: Unable to get response from judge API"

## Define the tool list according to your requirements

In [None]:
tools = [
    # Populate with your tool list
]
logger.info("Tools defined")


## Update the `CustomPromptTemplate` class and `prompt_template` for using proper tools

In [None]:
# Set up the prompt template
class CustomPromptTemplate(StringPromptTemplate):
    template: str
    tools: List[Tool]
    
    def format(self, **kwargs):
        # Populate with your custom formatting code
        return self.template.format(**kwargs)

logger.info("CustomPromptTemplate set up")

prompt_template = """
Your instruction prompt
"""

logger.info("Prompt template defined")



## Custom Output parser

In [None]:
# Output parser
class CustomOutputParser(AgentOutputParser):
    def parse(self, llm_output: str):
        logger.debug(f"+++++ Parsing LLM output: {llm_output}")
        if "[END_OF_RESPONSE]" in llm_output:
            response = llm_output.split("[END_OF_RESPONSE]")[0].strip()

            if response == "":
                response = llm_output.split("[END_OF_RESPONSE]")[1].strip()
            
            # Extract only the Final Answer
            if "Final Answer:" in response:
                final_answer = response.split("Final Answer:")[-1].strip()
                return AgentFinish(
                    return_values={"output": final_answer},
                    log=llm_output,
                )
            else:
                return AgentFinish(
                    return_values={"output": response},
                    log=llm_output,
                )

        # Sanity Check: Check if this is the final answer
        if "Final Answer:" in llm_output:
            final_answer = llm_output.split("Final Answer:")[-1].strip()
            if "[END_OF_RESPONSE]" in final_answer:
                final_answer = final_answer.split("[END_OF_RESPONSE]")[0].strip()
            return AgentFinish(
                return_values={"output": final_answer},
                log=llm_output,
            )
            
        pattern = r"Action: (.*?)\nAction Input: (.*?)(?=\n|$)"
        match = re.search(pattern, llm_output, re.DOTALL)

        if not match:
            # return llm_output
            return AgentFinish(
                return_values={"output": f"I apologize, but I encountered an error. Please try again or rephrase your message."},
                log=llm_output,
            )
            # raise ValueError(f"Could not parse LLM output: `{llm_output}`")
            
        action = match.group(1).strip()
        action_input = match.group(2).strip()
        
        logger.debug(f"------ Parsed action: {action}, action input: {action_input}")
        return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)

logger.info("CustomOutputParser defined")



## Initialize prompt template and output parser

In [None]:
prompt = CustomPromptTemplate(
    template=prompt_template,
    tools=tools,
    input_variables=["input", "intermediate_steps"]
)

output_parser = CustomOutputParser()
llm_chain = LLMChain(llm=llm, prompt=prompt)
tool_names = [tool.name for tool in tools]



## Creating Agent that can use tools

In [None]:
# Initialize the agent
agent = LLMSingleActionAgent(
    llm_chain=llm_chain,
    output_parser=output_parser,
    stop=["\nObservation:"],
    allowed_tools=tool_names,
)
logger.info("Agent initialized")

# Set up the agent executor
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, max_iterations=10)
logger.info("Agent executor set up")

# Initialize conversation memory
memory = ConversationBufferMemory(return_messages=True)
logger.info("Conversation memory initialized")



## define you chat function

In [None]:
# Gradio chat interface backend
def chat(message, history):
    # Add necessary code for the chat
    response = message
    return response



## Modify the the gradio interface as required
- The version below adds a `Verify` button, which, when pressed, sends a `Verify` message in the chat.
- You can modify it as needed.

In [None]:
# Custom CSS for full height
custom_css = """
#chatbot-container {
    height: calc(100vh - 230px) !important;
    overflow-y: auto;
}
#input-container {
    position: fixed;
    bottom: 0;
    left: 0;
    right: 0;
    padding: 20px;
    background-color: white;
    border-top: 1px solid #ccc;
}
"""

logger.info("Custom CSS defined")

# Create the Gradio interface
with gr.Blocks(css=custom_css) as gr_iface:
    with gr.Column():
        chatbot = gr.Chatbot(elem_id="chatbot-container")
        with gr.Row(elem_id="input-container"):
            msg = gr.Textbox(show_label=False, placeholder="Type your message here...")
            send = gr.Button("Send")
            with gr.Row():
                verify = gr.Button("Verify")
                clear = gr.Button("Clear")
            
    def user(user_message, history):
        logger.debug(f"User message: {user_message}")
        return "", history + [[user_message, None]]

    def bot(history):
        user_message = history[-1][0]
        logger.debug(f"Processing bot response for: {user_message}")
        bot_message = chat(user_message, history[:-1])
        history[-1][1] = bot_message
        return history

    def clear_chat():
        return None

    def verify_click(history):
        return user("Verify", history)[1]

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    send.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    clear.click(clear_chat, None, chatbot, queue=False)
    verify.click(verify_click, chatbot, chatbot).then(
        bot, chatbot, chatbot
    )

logger.info("Gradio interface created")

# Launch the Gradio interface
logger.info("Launching Gradio interface")
gr_iface.launch()