# This is only for mac (ARM64)

## Import llama3-8b

In [1]:
from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Meta-Llama-3-8B-Instruct-4bit")

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

## Example

In [2]:
SYSTEM_MSG = ("You are an assistant that detects entities and their relationships in questions, for example:"
              "user question: where is china?"
              "your answer: [(china)(located in)(?)]"
              "user question: where is USA?"
              "Your answer: [(USA)(located in)(?)]"
              "user question: where is UK's capital?"
              "your answer: [(UK's capital)(located in)(?)]")

def generate_entity_response(promptStr, maxTokens=100):
    messages = [
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "user", "content": promptStr},
    ]
    # print(messages)
    # 将消息应用于聊天模板并生成输入ID
    input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
    prompt = tokenizer.decode(input_ids)

    # 生成响应
    response = generate(model, tokenizer, prompt=prompt, max_tokens=maxTokens)

    return response


# example
user_question = "Where is China?"
response = generate_entity_response(user_question)
print(response)

# other question
user_question = "where is japan's capital"
response = generate_entity_response(user_question)
print(response)

[(China)(located in)(?)]
[(Japan's capital)(located in)(?)]


## Train a llama which can learn rdf
#### first, initial database and get all rdf data

In [3]:
from jena.fuseki_client import JenaClient
from mongoDB.mongoDB_client import init_db, MongoDBInterface
db, fs = init_db(
    "mongodb://localhost:27017")
db_interface = MongoDBInterface(db, fs)

jena_client = JenaClient(jena_url='http://127.0.0.1:3030', dataset='test')  

In [4]:
import json
code,text=jena_client.execute_sparql_query_global("SELECT * WHERE { ?sub ?pred ?obj .}")
# print("text: ",text)

def rdf_to_natural_language(rdf_data):
    descriptions = []
    for triple in rdf_data:
        subj = triple['sub']['value'].split('/')[-1]
        pred = triple['pred']['value'].split('/')[-1].replace('_', ' ')
        obj = triple['obj']['value'].split('/')[-1]
        description = f"{subj} {pred} {obj}."
        descriptions.append(description)
    return "\n".join(descriptions)

rdf_to_nl=""
if code == 200:
    # print(text)
    json_object=json.loads(text)
    result=json_object['results']['bindings']
    rdf_to_nl=rdf_to_natural_language(result)
    print(rdf_to_nl)
    

SELECT *  FROM <http://server/unset-base/66aaabab526e92da62d5188d> FROM <http://server/unset-base/66b2adf4d89c350c0c3b15e3> FROM <http://server/unset-base/66b2ad67d89c350c0c3b15d8> WHERE { ?sub ?pred ?obj .}
country1 has border with country2.
country1 located in part1.
country2 located in part1.
country4 located in part2.
country3 has border with country4.
country3 located in part2.
1162_nièng rdf-schema#label 1162 nièng.
Buchanan_Gông_(Missouri) rdf-schema#label Buchanan Gông (Missouri).
Woodruff_Gông_(Arkansas) rdf-schema#label Woodruff Gông (Arkansas).
Huà-să rdf-schema#label Huà-să.
2月16號 rdf-schema#label 2月16號.
Sék-sáuk rdf-schema#label Sék-sáuk.
1479_nièng rdf-schema#label 1479 nièng.
林暾谷 rdf-schema#label 林暾谷.
Kóng-cié-gāng rdf-schema#label Kóng-cié-gāng.
1299_nièng rdf-schema#label 1299 nièng.
Chautauqua_Gông_(Kansas) rdf-schema#label Chautauqua Gông (Kansas).
171_nièng rdf-schema#label 171 nièng.
257_nièng rdf-schema#label 257 nièng.
Páuk_Céng-hĭ rdf-schema#label Páuk Céng-hĭ.


In [None]:
model, tokenizer = load("mlx-community/Meta-Llama-3-8B-Instruct-4bit")

# 准备初始上下文
SYSTEM_MSG = (f"You are a knowledgeable assistant who answers questions based on the provided data, "
              f"If the user's question is out of scope for this dataset, you should only answer: Sorry, this question is out of scope."
              f"\n\nHere is the data:\n{rdf_to_nl}")


# 生成回答
def generate_response(question, initial_context, max_tokens=100):
    messages = [
        {"role": "system", "content": initial_context},
        {"role": "user", "content": question},
    ]

    # 将消息应用于聊天模板并生成输入ID
    input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
    prompt = tokenizer.decode(input_ids)

    # 生成响应
    response = generate(model, tokenizer, prompt=prompt, max_tokens=max_tokens)
    

    return response

# example
user_question = "Where is country1?"
response = generate_response(user_question, SYSTEM_MSG)
print(f"question: {user_question}\nLlama response: {response}\n")

# question which out of scope
user_question = "Where is USA"
response = generate_response(user_question, SYSTEM_MSG)
print(f"question: {user_question}\nLlama response: {response}")

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

In [None]:
# 将消息应用于聊天模板并生成输入ID
question = "i want some information of country 1"
messages=[{'role':'user','content':question}]
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
prompt = tokenizer.decode(input_ids)

# 生成响应
response = generate(model, tokenizer, prompt=prompt, max_tokens=100)
print(response)

In [None]:
response = generate_response(question, SYSTEM_MSG)
print(response)