In [None]:
!pip install transformers
!pip install datasets
# !pip install sentence_transformers

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd drive/MyDrive/en2sparql

In [None]:
import json
import torch
import random
from datasets import load_dataset
from transformers import pipeline, AutoTokenizer
raw_datasets = load_dataset("orkg/SciQA")
print(raw_datasets)

In [None]:
gpt2 = pipeline(model="gpt2-large", max_new_tokens=384, device='cuda' if torch.cuda.is_available() else "cpu", return_full_text=False)
tokenizer = AutoTokenizer.from_pretrained("gpt2-large")

In [None]:
def save_json(filename,data):
    with open(filename, "w", encoding="utf-8") as json_file:
        print(json.dumps(data), file=json_file)


def load_json(file__name):
    try:
        data_file = open(file__name, "r", encoding='utf-8')
        file_data = json.loads(data_file.read())
        data_file.close()
        return file_data
    except FileNotFoundError:
        return None


def clean(st):
    st = st.replace("\n", " ")
    st = st.replace("?", " ?")
    st = st.replace("{", " { ")
    st = st.replace("}", " } ")
    st = st.replace("\\'", "'")

    while "  " in st:
        st = st.replace("  ", " ")
    return st


def get_key(q):
    t0 = q.get('template_id')
    if t0 is None:
        t0 = "None"
    t = str(q.get("number_of_patterns")) + "-" + t0
    return t


def get_random(n_):
    train = raw_datasets.get("train")
    sample = random.sample(list(train), n_)
    sample_list = []
    for q in sample:
        t = get_key(q)
        query = clean(q["query"]["sparql"])
        question = q["question"]["string"]
        sample_list.append([query, question, t])
    return sample_list


def prepare_queries(n_):
    data = raw_datasets.get("test")
    queries = []
    suggestions = []
    for q in data:
        t = get_key(q)
        question = q["question"]["string"]

        if n_ == 0:
            queries.append(question)
            suggestions.append(t)
        else:

            suggestion = get_random(n_)
            suggestions.append([[x[2] for x in suggestion], t])
            # print(suggestion, t)
            # break

            if suggestion is None or len(suggestion) == 0:
                print("Error with key", t)
                queries.append("translate the following English text '" + question + "' to a sparql query")
            else:
                final_q = ""
                for i_, k in enumerate(suggestion):

                    # works better with fine-tuned gpt2?
                    # final_q += "<|endoftext|>" + k[1] + " "
                    # final_q += k[0]

                    # works better with dolly
                    final_q += "\n input (English text): " + k[1]
                    final_q += "\n output (Sparql query): " + k[0]

                # works better with gpt
                # final_q += "\n with this example what is the sparql query for:  " + question

                # works better with gpt2?
                # final_q += "<|endoftext|>" + question

                # works better with dolly
                final_q += "\n input (English text): " + question
                final_q += "\n output (Sparql query): "

                queries.append(final_q)

    return queries, suggestions


def main(shots=3, attempts=1):
    data = load_json("random_gpt2_large_" + str(shots) + "_shots.json")
    print(data)
    if data is None:
        query_list, suggestions = prepare_queries(shots)
        gs = []
        lens = []
    else:
        query_list = data["questions"]
        suggestions = data["suggestions"]
        gs = data["generated_sparql"]
        lens = data["prompt_len"]

    print(len(query_list))

    q_list = query_list
    sparql = [clean(x["query"]["sparql"]) for x in raw_datasets.get("test")]

    for question in q_list[len(gs):]:
        print(question)
        res_ = tokenizer.encode(question)
        len_ = len(res_)
        lens.append(len_)
        print(len_)
        if len_ > 600:
            print(type(res_))
            question = tokenizer.decode(res_[-600:])
            print (type(question))
            len_ = 600

        if len_ <= 600:
            res = gpt2(question)
            # if "SELECT" not in res[0]["generated_text"]:
            #     for i in range(attempts-1):
            #         res = gpt2(question)
            #         if "SELECT" in res[0]["generated_text"]:
            #             break
            gs.append(res[0]["generated_text"])
            result = {"questions": query_list, "sparql": sparql, "generated_sparql": gs, "prompt_len": lens,
                      "suggestions": suggestions}
            save_json("random_gpt2_large_" + str(shots) + "_shots.json", result)

main()