In [None]:
# Install llama.cpp here!

# You must be running this notebook as a job (this is the default case, so you're probably fine)
# Only run this cell once. You can comment these lines out after installation.

# %env CMAKE_ARGS=-DLLAMA_CUBLAS=on
# %env FORCE_CMAKE=1
# %pip install llama-cpp-python --force-reinstall --upgrade --no-cache-dir --no-clean

In [None]:
from llama_cpp import Llama

In [None]:
path = '/data/ai_club/llms/llama-2-7b-chat.Q5_K_M.gguf'
llm = Llama(path, n_ctx = 512, n_gpu_layers=-1, verbose=False)

In [None]:
bool_llm = Llama(path, n_ctx = 4000, n_gpu_layers=-1, verbose=False)

In [None]:
#response_llm = Llama(path, n_gpu_layers=-1, verbose=False)

In [None]:
def load_context(file_path):
    try:
        with open(file_path, 'r') as file:
            file_contents = file.read()
        return '```\n' + file_contents + '\n```'
    except FileNotFoundError:
        return f'File not found: {file_path}'
    except Exception as e:
        return f'An error occurred: {e}'

In [None]:
class ContextFactory():
    _contexts = {}
    @staticmethod
    def get_context(file_name):
        if ContextFactory._contexts.get(file_name, None) == None:
            print('loading from file')
            ContextFactory._contexts[file_name] = load_context(f'contexts/{file_name}.txt')
        return ContextFactory._contexts.get(file_name)

In [None]:
class LLMResponseToContextFactory():
    _response_contexts = {}
    @staticmethod
    def get_response(context_name):
        if LLMResponseToContextFactory._response_contexts.get(context_name, None) == None:
            print('generating response')
            context = context_builder(context_name, output='string')
            prompt = llama_v2_context_prompt(context)
            #print(f'Prompt: {prompt}')
            response = ''
            while response == '':
                response = bool_llm.create_completion(prompt, repeat_penalty=1.2, temperature=0.2)['choices'][0]['text'] 
            LLMResponseToContextFactory._response_contexts[context_name] = response
        return LLMResponseToContextFactory._response_contexts.get(context_name)

In [None]:
def prompt_model(llm_, prompt) -> str:
    """
    """
    #print(f'len of hist (prompt): {len(history)}')
    msg_result = ''
    while msg_result == '':
        result = llm_.create_completion(prompt, repeat_penalty=1.2, temperature=0.2)
        msg_result = result['choices'][0]['text']
    return result['choices'][0]['text']

In [None]:
def context_builder(context_name: str, output: str = 'dict') -> list[dict]:
    context = (
            f'[context]\n'
            f'{ContextFactory.get_context(context_name)}\n'
            f'[/context]\n'
         )
    if output == 'dict': 
        response = LLMResponseToContextFactory.get_response(context_name)
        history = [
            {'role': 'user', 
             'content': context
            },
            {'role': 'assistant',
             'content': response
            }
        ]
        return history
    elif output in ['string', 'str'] :
        return context

In [None]:
def question_llm(context_name: str, dev_prompt: str, user_prompt: str, history=[]) -> bool:
    """
    - Question model with a Yes/No form of question and return boolean of response
    """
    for cntx in context_builder(context_name):
        history.append(cntx)
    dev_prompt = dev_prompt + '. Answer with a Yes/No'
    prompt = llama_v2_prompt(user_prompt, dev_prompt=dev_prompt, messages=history)
    print(prompt)
    result = prompt_model(bool_llm, prompt)
    print(result)
    print('\n')
    #print(f'len of hist (question out): {len(history)}')
    #print(history)
    #print(result)
    return 'yes' in result.lower()

In [None]:
def is_faq(user_prompt: str) -> bool:
    #history = [{'role':'system', 'content':'You are a ASD question classifier'}]
    #print(f'len of hist (is_faq in): {len(history)}')
    context_name = 'ASD_general'
    dev_prompt = 'Keeping in mind the context, is the following question a generic question about ASD?'
    boolean = question_llm(context_name, dev_prompt, user_prompt)
    #print(f'len of hist (is_faq out): {len(history)}')
    return boolean

In [None]:
def is_screening(user_prompt: str) -> bool:
    #history = [{'role':'system', 'content':'You are a ASD question classifier'}]
    #print(f'len of hist (is_screening in): {len(history)}')
    context_name = 'ASD_screen'
    dev_prompt = 'Keeping in mind the context, is the following question seeking to getting a child screened for ASD?'
    response = question_llm(context_name, dev_prompt, user_prompt)
    #print(f'len of hist (is_screening out): {len(history)}')
    return response

In [None]:
def determine_request_type(prompt: str) -> str:
    result = ''
    if is_screening(prompt): return 'screen'
    if is_faq(prompt): return 'faq'

In [None]:
def bucket_faq(prompt: str, history = []):
    #history = [{'role':'system', 'content':'You are a ASD question classifier'}]
    """
    Not complete, will return a Literal[str] of the type of faq being asked in order to give the correct
    context to the model
    """
    if True: faq_type = 'general'
    if False: faq_type = 'symptoms'
    if False: faq_type = 'screening_diagnosis'
    if False: faq_type = 'treatment'
    return faq_type

In [None]:
def answer_faq(prompt: str, history = []):
    #history = [{'role':'system', 'content':'You are a helpful assistant that gives simple and concise answers'}]
    #print(f'len of hist (faq in): {len(history)}')
    faq_type = bucket_faq(prompt)
    context_name = f'ASD_{faq_type}'
    
    history.append(context_builder(context_name))
    prompt = llama_v2_prompt(prompt, history, dev_prompt = 'Remember the context when answering questions')
    result = prompt_model(llm, prompt)
    #print(f'len of hist (faq out): {len(history)}')
    return result

In [None]:
def llm_response(prompt: str):
    request_type = determine_request_type(prompt)
    if request_type == 'screen': return 'begin screening process'
    if request_type == 'faq': return answer_faq(prompt)
    return 'I cannot help with that as it is outside the bounds of my expertise'

In [None]:
def llama_v2_context_prompt(context: str, sys_prompt: dict = ''):
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
    BOS, EOS = "<s>", "</s>"
    DEFAULT_SYSTEM_PROMPT = f"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""

    if sys_prompt == '': sys_prompt = DEFAULT_SYSTEM_PROMPT

    messages = [
        {
            "role": 'system',
            "content": B_SYS + sys_prompt + E_SYS
        },
        {
            "role": 'user',
            "CONTEXT": (
                f'{B_INST}Keep in mind the following context{E_INST}\n'
                f'{context}\n'
            ),
            "content": 'respond if you understand'
        }
    ]
    
    system_msg = messages[0]['content']
    context_msg = messages[1]['CONTEXT']
    messages_list = [system_msg, context_msg]
    
    str1 = f"{B_INST}Remember the context when answering questions{E_INST}"
    str2 = 'User: "respond if you understand"\n'
    messages_list.append(str1 + str2)
    
    return "".join(messages_list)

In [None]:
def llama_v2_prompt(prompt: str, 
                    messages: list[dict],
                    dev_prompt: str = 'Remember the context when answering questions'
                   ):
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
    BOS, EOS = "<s>", "</s>"
    DEFAULT_SYSTEM_PROMPT = f"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""

    #print(messages)
    
    if messages[0]["role"] != "system":
        messages = [
            {
                "role": "system",
                "content": DEFAULT_SYSTEM_PROMPT,
            }
        ] + messages
        
    messages = [
        {
            "role": messages[0]["role"],
            "content": B_SYS + messages[0]["content"] + E_SYS
        },
        {
            "role": messages[1]["role"],
            "CONTEXT": (
                f'{B_INST}Keep in mind the following context{E_INST}\n'
                f'{messages[1]["content"]}\n'
            ),
            "content": 'respond if you understand'
        }
    ] + messages[2:] + [{
        'role':'user',
        'content': prompt
    }
    ]
    
    system_msg = messages[0]['content']
    context_msg = messages[1]['CONTEXT']
    messages_list = [system_msg, context_msg]
    for prompt in messages[1:-1:]:
        #f"{BOS}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {EOS}"
        #str1 = f"{B_INST}Remember the context when answering questions{E_INST}" if prompt['role'] == 'user' else ''
        str2 = f'{prompt["role"]}: "{(prompt["content"]).strip()}"\n'
        messages_list.append(str2)
    #messages_list.append(f"{BOS}{B_INST} {(messages[-1]['content']).strip()} {E_INST}")
    
    str1 = f"{B_INST}{dev_prompt}{E_INST}"
    str2 = f'{messages[-1]["role"]}: "{(messages[-1]["content"]).strip()}"\n'
    messages_list.append(str1 + str2)
    
    #print(messages_list)
    
    return "".join(messages_list)

In [None]:
#print(llama_v2_prompt('test', context_builder('ASD_general')))

In [None]:
llm_response('What is Autism?')

In [None]:
# def llama_v2_prompt(prompt: str, messages: list[dict]):
#     B_INST, E_INST = "[INST]", "[/INST]"
#     B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
#     BOS, EOS = "<s>", "</s>"
#     DEFAULT_SYSTEM_PROMPT = f"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""

#     if messages[0]["role"] != "system":
#         messages = [
#             {
#                 "role": "system",
#                 "content": DEFAULT_SYSTEM_PROMPT,
#             }
#         ] + messages
        
#     messages = [
#         {
#             "role": messages[0]["role"],
#             "content": B_SYS + messages[0]["content"] + E_SYS
#         }
#     ]
    
#     messages = [
#         {
#             "role": messages[0]["role"],
#             "content": B_SYS + messages[0]["content"] + E_SYS
#         },
#         {
#             "role": messages[1]["role"],
#             "CONTEXT": (
#                 f'{B_INST}Keep in mind the following context{E_INST}\n'
#                 f'{messages[1]["content"]}\n'
#             ),
#             "content": 'respond if you understand'
#         }
#     ] + messages[2:] + [{
#         'role':'user',
#         'content': prompt
#     }
#     ]
    
#     system_msg = messages[0]['content']
#     context_msg = messages[1]['CONTEXT']
#     messages_list = [system_msg, context_msg]
#     for prompt in messages[1:-1:]:
#         #f"{BOS}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {EOS}"
#         #str1 = f"{B_INST}Remember the context when answering questions{E_INST}" if prompt['role'] == 'user' else ''
#         str2 = f'{prompt["role"]}: "{(prompt["content"]).strip()}"\n'
#         messages_list.append(str2)
#     #messages_list.append(f"{BOS}{B_INST} {(messages[-1]['content']).strip()} {E_INST}")
    
#     str1 = f"{B_INST}Remember the context when answering questions{E_INST}"
#     str2 = f'{messages[-1]["role"]}: "{(messages[-1]["content"]).strip()}"\n'
#     messages_list.append(str1 + str2)
    
#     return "".join(messages_list)