# Install and Import Dependencies

## Install Dependencies

In [None]:
!pip install transformers

## Import Dependencies

In [None]:
import json
import random
import uuid
import operator
import os

# tqdm is used to visualize the progress while processing input files
from tqdm import tqdm

# for embedding current time into log file name 
from datetime import datetime

# We will use a pre-trained tokenizer to determine the length of strings
from transformers import AutoTokenizer

# Settings

In [None]:
# if set, only the first 'n' API tree models will be loaded
api_limit = None

remove_uris = True  # remove URIs from description
sort_by_name = True # sort context by name

max_depth = 8 # max. depth of XPath in both context and answer
min_question_length = 3 # min. number of tokens that must be in a question
max_question_length = 96 # max. number of tokens that may be in a question

max_questions_per_sample = 32 # max. number of Question-Answer pairs per sample. If number is exceeded, additional sample is created

number_of_chunks = 10 # number of containers where samples are distributed to

# Variable that specifies how many times the generated sample set is repeated. A value of '1' means that each sample is only created once.
original_retakes = 1
shuffled_retakes = 0

random.seed(42)

# Specify the base model that is used for training later. We will use its pre-trained tokenizer in this notebook.
base_model = "microsoft/codebert-base"

# List of special tokens that should be removed from XPaths while creating context string
to_be_removed = ["<?>","<str>","<num>","<int>","<bool>","{_}","$."]

input_path = "/home/user/input_directory/"
output_path = "/home/user/output_directory/"

# APIs (identified by their keys) that should be excluded from processing after loading and parsing them (e.g. due to too many payloads)
excluded_api_keys = []

In [None]:
cnt_no_context = 0
cnt_no_property_description = 0
cnt_too_short_property_description = 0
cnt_too_long_property_description_but_truncated = 0
cnt_too_long_property_description = 0
cnt_too_deep_property_xpath = 0
cnt_split_samples = 0
cnt_answer_not_in_context = 0

# Load and Parse API Tree Models

In [None]:
def remove_data_types_from_xpath(xpath: str):
    """
    Removes special tokens defined in 'to_be_removed' from the passed XPath (input string) and returns the modified string
    
    Parameters
    ----------
    xpath : str
        input string
        
    Returns
    -------
    Modified input string
    """
    for data_type in to_be_removed:
        xpath = xpath.replace(data_type,"")
    return xpath


class ApiInterfaceNode:
    """
    Represents a generic node in an API tree.
    ...
    Attributes
    ----------
    key : str
        key that uniquely identifies the node among all children of the parent node
    value
        optional value of the node
    node_type: str
        type (i.e. role) of the node, allowed values are 'api', 'path', 'method', 'response', 'payload', and 'property'
    id : str
        unique identifier of the node among all nodes of the API tree
    elements : ApiInterfaceNode
        contains all children of the node
    raw_node
        contains the original JSON structure of the node and the sub tree as Python dictionary
    api_key : str
        contains the API key, this attribute is only present if node is type of 'api'
    api_name: str
        contains the API name, this attribute is only present if node is type of 'api'
    api_version_key : str
        contains the API version key, this attribute is only present if node is type of 'api'
    api_version_name : str
        contains the API version name, this attribute is only present if node is type of 'api'
    method_summary : str
        contains the summary of the method, this attribute is only present if node is type of 'method'
    method_description : str
        contains the description of the method, this attribute is only present if node is type of 'method'
    response_description : str
        contains the description of the response, this attribute is only present if node is type of 'response'
    property_name : str
        contains the name of the property, this attribute is only present if node is type of 'property'
    property_data_type : str
        contains the data type of the property, this attribute is only present if node is type of 'property'
    property_xpath : str
        contains the XPath of the property, this attribute is only present if node is type of 'property'
    property_format : str
        contains the format of the property, this attribute is only present if node is type of 'property'
    property_pattern : str
        contains the pattern of the property, this attribute is only present if node is type of 'property'
    property_description : str
        contains the description of the property, this attribute is only present if node is type of 'property'

    Methods
    -------
    is_type(node_type : str):
        Returns true if the node's type is equal the passed type
    __str__(): 
        Returns a JSON object as string containing all attributes of the node
    """
    def __init__(self, api_documentation_raw_node):
        """
        Constructs the sub tree consisting of ApiInterfaceNodes based on the passed raw API tree model (parsed JSON structrure as Python dictionary).

        Parameters
        ----------
        api_documentation_raw_node
            parsed JSON structure of the API tree model as Python dictionary
        """
        self.raw_node = api_documentation_raw_node

        # Generic attributes
        self.key = api_documentation_raw_node["key"]
        self.value = api_documentation_raw_node["value"]
        self.node_type = api_documentation_raw_node["type"]
        self.id = api_documentation_raw_node["id"].replace("-",".")
        #self.id = parent_id+"."+self.key

        self.elements = [ApiInterfaceNode(api_documentation_raw_node["elements"][i]) for i in range(len(api_documentation_raw_node["elements"]))]
    
        if self.node_type == "api":
            self.api_key = api_documentation_raw_node["apiKey"]
            self.api_name = api_documentation_raw_node["apiName"]
            self.api_version_key = api_documentation_raw_node["versionKey"]
            self.api_version_name = api_documentation_raw_node["versionName"]
    
        if self.node_type == "path":
            # no individual attributes for path type
            pass

        if self.node_type == "method":
            self.method_summary = api_documentation_raw_node["summary"]
            self.method_description = api_documentation_raw_node["description"]

        if self.node_type == "response":
            self.response_description = api_documentation_raw_node["description"]
    
        if self.node_type == "payload":
            # no individual attributes for payload type
            pass

        if self.node_type == "property":
            self.property_name = api_documentation_raw_node["name"]
            self.property_data_type = api_documentation_raw_node["dataType"]
            self.property_xpath = remove_data_types_from_xpath(api_documentation_raw_node["xpath"].replace(' ','').replace('\t','').replace('\n',''))
            self.property_format = api_documentation_raw_node["format"]
            self.property_pattern = api_documentation_raw_node["pattern"]
            self.property_description = api_documentation_raw_node["description"]
        
    def is_type(self, node_type: str):
        """
        Returns true, if the node's type is equal the passed type, else false.

        Parameters
        ----------
        node_type : str
            type that should be compared with the type of the node
    
        Returns
        -------
        True or False
        """
        return self.node_type == node_type
  
    def __str__(self):
        """
        Returns a JSON object as string containing all attributes of the node
    
        Returns
        -------
        JSON object as string containing all attributes of the node
        """
        json_dict = {}
        json_dict["key"] = self.key
        json_dict["value"] = self.value
        json_dict["id"] = self.id
        json_dict["type"] = self.node_type
        json_dict["number_of_elements"] = len(self.elements)

        if self.node_type == "api":
            json_dict["apiKey"] = self.api_key
            json_dict["apiName"] = self.api_name
            json_dict["versionKey"] = self.api_version_key
            json_dict["versionName"] = self.api_version_name
    
        if self.node_type == "path":
            # no individual attributes for path type
            pass

        if self.node_type == "method":
            json_dict["summary"] = self.method_summary
            json_dict["description"] = self.method_description

        if self.node_type == "response":
            json_dict["description"] = self.response_description
    
        if self.node_type == "payload":
            # no individual attributes for payload type
            pass

        if self.node_type == "property":
            json_dict["name"] = self.property_name 
            json_dict["dataType"] = self.property_data_type 
            json_dict["xpath"] = self.property_xpath 
            json_dict["format"] = self.property_format
            json_dict["pattern"] = self.property_pattern
            json_dict["description"] = self.property_description
    
        return json.dumps(json_dict)

def load_and_parse_api(path: str):
    """
    Loads and parses an API tree model file located under the passed path and converts it structure into a structure of ApiInterfaceNodes.
  
    Parameters
    ----------
    path : str
        Path of the API tree model file
  
    Returns
    -------
    ApiInterfaceNode representing the root of the loaded and parsed API tree model
    """
    with open(path,"r",encoding="utf-8") as json_file:
        json_api = json.load(json_file)
    return ApiInterfaceNode(json_api)

def load_and_parse_apis_from_directory(path: str, limit: int = None):
    """
    Loads and parses multiple API tree model files located in the specified directory.
  
    Parameters
    ----------
    path : str
        Path to directory where the API tree model files are located
      
    limit : int
        Optional limit. If specified, only the first 'n' API tree model files are loaded and parsed
  
    Returns
    -------
    List of ApiInterfaceNodes where each node represents the root of a loaded and parsed API tree model
    """
    apis = []
    if limit:
        filesnames = os.listdir(path)[:limit]
    else:
        filesnames = os.listdir(path)

    for filename in tqdm(filesnames):
        if filename.endswith(".json"):
            apis.append(load_and_parse_api(os.path.join(path,filename)))
    return apis

def extract_nodes(node: ApiInterfaceNode, node_type: str):
    """
    Extracts all nodes matching the passed node type from the passed API tree model.
  
    Parameters
    ----------
    node : ApiInterfaceNode
      API tree model (input)
    node_type : str
      The type of the node that should be extracted

    Returns
    -------
    List of ApiInterfaceNodes matching the passed node type
    """
    nodes = []
    if node.node_type == node_type:
        nodes.append(node)
    for element in node.elements:
        nodes += extract_nodes(element, node_type)
    return nodes

def extract_nodes_in_apis(nodes: [ApiInterfaceNode], node_type: str):
    """
    Extracts all nodes matching the passed node type from the passed list of API tree models.
  
    Parameters
    ----------
    node : [ApiInterfaceNodes]
      List of API tree models (input)
    node_type : str
      The type of the node that should be extracted

    Returns
    -------
    List of ApiInterfaceNodes matching the passed node type
    """
    extracted_nodes = []
    for node in nodes:
        extracted_nodes += extract_nodes(node,node_type)
    return extracted_nodes


In [None]:
# Load and parse API tree models
apis = load_and_parse_apis_from_directory(input_path,limit=api_limit)

In [None]:
# extract payload nodes
payload_nodes = extract_nodes_in_apis(apis,node_type="payload")

In [None]:
print("Number of payload nodes: ",len(payload_nodes))

# Create QA-Samples

In [None]:
class QuestionAnswer:
    """
    Represents a question-answer pair consisting of a unique identifier, a question, its length (number of tokens), the answer of the question, and the character based index of the start of the answer within the context.
    
    Attributes
    ----------
    id : str
        unique identifier of a question-answer pair
    question : str
        question text
    question_length : int
        number of tokens of question text
    answer : str
        answer text
    answer_start : int
        position (index) of the first character of the answer text within original context
    
    Methods
    -------
    as_dict():
        Converts the question-answer pair into a Python dictionary following the structure for QA pairs recommended in https://huggingface.co/course/chapter7/7?fw=pt  
    """
    
    def __init__(self, question, question_length, answer, answer_position_in_index):
        """
        Constructs a question-answer pair having the passed parameters
        
        Parameters
        ----------
        question : str
            question text
        question_length : int
            number of tokens of question text
        answer : str
            answer text
        answer_position_in_index : int
            position (index) of the first character of the answer text within original context
        """
        self.id = uuid.uuid4().hex;
        self.question = question
        self.question_length = question_length
        self.answer = answer
        self.answer_start = answer_position_in_index
  
    def as_dict(self):
        """
        Converts this question-answer pair into a Python dictionary following the structure for QA pairs recommended in https://huggingface.co/course/chapter7/7?fw=pt
        {
            'id': ....,
            'question': ....,
            'question_length': .....,
            'answers':{
                'text': [....],
                'answer_start : [....]
            }
        }
        Note that 'answers.text' and 'answers.answer_start' both are lists as in classical NL QA a question might have multiple answers. In our case, every question has exactly one answer, thus, one item per list.
        
        Returns
        -------
        Python dictionary
        """
        a = {}
        a["text"] = [self.answer]
        a["answer_start"] = [self.answer_start]
        q = {}
        q["id"] = self.id
        q["question"] = self.question
        q["question_length"] = self.question_length
        q["answers"] = a
        return q

class QuestionAnswerSample:
    """
    Represents a question-answer sample consisting of a unique identifier, a title, a context, and multiple question-answer pairs extracted from the context 
    
    Attributes
    ----------
    id : str
        unique identifier of the sample
    title : str
        title of the sample
    context : str
        context (text)
    questionAnswers : [QuestionAnswer]
        question-answer pairs extracted from the context 
    
    Methods
    -------
    __str__():
        Converts the sample including its question-answer pairs into a JSON structure following the structure for QA pairs recommended in https://huggingface.co/course/chapter7/7?fw=pt
    """
    
    def __init__(self, context, questionAnswers: list, title = None):
        """
        Constructs a question-answer sample having the passed parameters
        
        Parameters
        ----------
        context : str
            context (text)
        questionAnswers : [QuestionAnswer]
            list of question-answer pairs
        title : str
            optional title of the sample
        """
        self.id = uuid.uuid4().hex;
        for qa in questionAnswers:
            qa.id = self.id +"_"+qa.id

        self.title = title;
        self.context = context
        self.questionAnswers = questionAnswers

    def __str__(self):
        """
        Converts the sample including its question-answer pairs into a JSON structure following the structure for QA pairs recommended in https://huggingface.co/course/chapter7/7?fw=pt
        {
            "id": ....,
            "title": ....,
            "context": ....,
            "questions": [ //see QuestionAnswer.as_dict()]
        }
        """
        json_dict = {}
        json_dict["id"] = self.id
        json_dict["title"] = self.title
        json_dict["context"] = self.context
        json_dict["questions"] = [x.as_dict() for x in self.questionAnswers]
        return json.dumps(json_dict)


## Prepare Pre-Trained Tokenizer for Length Calculation

In [None]:
tokenizer = AutoTokenizer.from_pretrained(base_model)

In [None]:
def get_length(input: str):
    """
    Calculates and returns the length (i.e. number of tokens) of the passed input string
    
    Parameters
    ----------
    input : str
        Input string whose length should be calculated
    
    Returns
    -------
    Number of tokens
    """
    return len(tokenizer.encode(input, add_special_tokens=False))

In [None]:
# Test tokenizer
test_string = "account.user.name"
get_length(test_string)

In [None]:
# Testing length limitations
s = " ".join([str(x) for x in range(99999)])
get_length(s)

## Methods for Creating Context

In [None]:
def extract_xpaths(node: ApiInterfaceNode):
    """
    Extracts all XPaths from the API sub tree where the passed ApiInterfaceNode is the root of this tree. Note that the root node will not be processed (only its children) and all nodes, except of the root node, must be type of 'property'.
    The method returns the list of extracted XPaths.
    
    Parameters
    ----------
    node : ApiInterfaceNode
        root node of the API sub tree whose XPaths should be extracted
        
    Returns
    -------
    List of extracted XPaths
    """
    xpaths = []
    for element in node.elements:
        # if node is type of array, object or unknown with childs
        if element.property_data_type == "array" or element.property_data_type == "object" or ((element.property_data_type == "unknown" or element.property_data_type == None) and len(element.elements) > 0):
            xpaths += extract_xpaths(element)
            
        # if node is type of string, num, int, bool, or unknown without childs 
        else:
            xpaths.append(element.property_xpath)
    return xpaths


def filter_xpaths(xpaths: list, max_depth: int):
    """
    Removes all XPaths from the passed list that exceed the specified maximum depth and returns the modified list.
    Example: 'users.address.street' has a depth of 3.
    If 'max_depth' is None, the passed list will be returned without any removal.
    
    Parameters
    ----------
    xpaths : [str]
        List of XPaths (strings)
    max_depth : int
        maximum depth (can be None)
        
    Returns
    -------
    Modified list of XPaths
    """
    filtered_xpaths = []
    for xpath in xpaths:
        if max_depth == None or len(xpath.split(".")) <= max_depth:
            filtered_xpaths.append(xpath)
    return filtered_xpaths


def build_context_string(xpaths: list, sort_by_name: bool = False, shuffle: bool = False):
    """
    Removes duplicate XPaths from the passed XPath list, optionally sort (ascending order) or shuffle the list and finally concatenates the remaining XPath items to a string with spaces as speparator between items.
    The method returns this resulting string.
    
    Parameters
    ----------
    xpaths : [string]
        list of XPaths
    sort_by_name : bool
        if set to True (default value is False), the method will sort the list of XPath items in ascending order
    shuffle : bool
        if set to True (default value is False), the method will shuffle the list of XPath items
        
    Returns
    -------
    String containing the concatenated XPath items
    """
    
    assert not shuffle and sort_by_name, "Cannot shuffle and sort xpaths at the same time"
    
    xpaths_without_duplicate_items = []
    deduplication_set = set()
    
    for xpath in xpaths:
        if not xpath in deduplication_set:
            xpaths_without_duplicate_items.append(xpath)
            deduplication_set.add(xpath)
            
    # shuffle (if enabled)
    if shuffle:
        random.shuffle(xpaths_without_duplicate_items)
        
    # sort by name (if enabled)
    if sort_by_name:
        xpaths_without_duplicate_items.sort()
    
    return " ".join(xpaths_without_duplicate_items)

## Methods for Creating Questions

In [None]:
def remove_inline_uris(text: str):
    """
    Removes URIs from the passed string and returns the modified string.
    
    Parameters
    ----------
    text : str
        input string that should be scanned for URIs
        
    Returns
    -------
    Modified string
    """
    tokens = text.split(' ')
    for token in tokens:
        if "http:" in token.lower() or "https:" in token.lower():
            text = text.replace(token," ")
    return text

def truncate_question(question: str, max_question_length: int):
    """
    Truncates the passed question (string) if the number of tokens exceeds the passed maximum question length. If the passed question is too long, the method tries to split it into sentences and concatenates these sentences
    until maximum question length is exceeded. The method returns the truncated question, its length (number of tokens), and the whether it has been  truncated (True) or not (False)
    
    Parameters
    ----------
    question : str
        Original question that should be truncated
    max_question_length : int
        Maximum number of tokens
    
    Returns
    -------
    The truncated question, its length, and whether it has been truncated (True) or not (False)
    
    """
    # determine length of original question
    length = get_length(question)
    
    # if question length (number of tokens determined by tokenizer) does not exceed max. question length
    if length <= max_question_length:
        # return question without any modifications
        return question, length, False
    
    # if question length exceeds max. question length
    else:
        # (Try to) split question into sentences
        sentences = question.split(".")
        
        # calculate length for each sentende 
        sentences_and_lengths = [(x,get_length(x)) for x in sentences] # generates a tuple of (sentence,length) for each sentence
        
        # reset question and length
        question = ""
        length = 0
        
        # concatenate sentences until max. question length is reached
        for sentence in sentences_and_lengths:
            sentence_text = sentence[0]
            sentence_length = sentence[1]
            
            # if sentence is not empty and new total length does not exceed max. question length
            if sentence_text is not None and length + sentence_length < max_question_length:
                question = question + sentence_text+"."
                length = length + sentence_length + 1
            else:
                break
                
        return question, get_length(question), True

In [None]:
## Test truncate_question
question = "This is a short example. We want to test, whether truncate_question works as expected. Hello World"
truncate_question(question,25)

In [None]:
get_length("This is a short")

## Methods for Creating Question-Answer Pairs and Samples

In [None]:
def get_answer_start(context: str,answer_text: str):
    """
    Returns the position of the first character of the passed answer text in the specified context or None if the answer is not in the context.
    
    Parameters
    ----------
    context : str
        context
    answer_text:
        answer text
    
    Returns
    -------
    Position of the first character of the answer text or None if the answer is not in the context
    """
    
    position = 0
    for property in context.split():
        if answer_text == property:
            return position
        else:
            position += len(property)
        position+=1 # for each space
    return None

In [None]:
## Test get_answer_start
context = "user.address user.address.street user.address.city user"
get_answer_start(context,"user")

In [None]:
for i,char in enumerate(context):
    print(i," ",char)

In [None]:
def create_question_answer_pairs(node: ApiInterfaceNode, context: str, min_question_length: int = None , max_question_length: int = None, remove_uris: bool = False, max_depth: int = None):
    """
    Create and returns a list of Question-Answer pairs deduced from the passed API sub tree.
    
    Parameters
    ----------
    context : str
        context, required for calculating the position (character-based index) of the answer
    min_question_length : int
        Optional parameter (default is None) that specifies the minimum length (number of tokens) that a question must have. If the parameter is None, the minimum length is one token
    max_question_length : int
        Optional parameter (default is None) that specifies the maximum length (number of tokens) that a question may have. If the parameter is None, the maximum length is unlimited
    remove_uris : bool
        Optional parameter (default is False) indicating whether URIs should be removed from questions
    max_depth : int
        Optional parameter (default is None) that specifies the maximum depth that an answer (i.e. the XPath without root element) may have. If the parameter is None, there is no depth limitation
        
    Results
    -------
    List of created Question-Answer pairs
    """
    question_answer_pairs = []
    
    valid = True
    
    # check whether property has a description
    if not node.property_description:
        global cnt_no_property_description
        cnt_no_property_description += 1
        valid = False
    
    # check whether XPath does not exceed max. depth
    if max_depth is not None and len(node.property_xpath.split(".")) > max_depth:
        global cnt_too_deep_property_xpath 
        cnt_too_deep_property_xpath += 1
        valid = False
        
    # check whether all constraints so far are satisfied and property is primitive type (string, number, integer, boolean or unknown without children)
    if valid and (node.property_data_type == "string" 
        or node.property_data_type == "number" 
        or node.property_data_type == "integer" 
        or node.property_data_type == "boolean"
        or (node.property_data_type == "unknown" and len(node.elements) == 0)):
        
        description = node.property_description

        
        # Step 1 (optional): Remove inline URIs:
        if remove_uris:
            description = remove_inline_uris(description)
        
        # Step 2: Remove unecessary whitespaces
        while '  ' in description:
            description = description.replace('  ',' ')
        
        # Step 3 (optional): Truncate description (Note: truncate_question might return an empty question!)
        if max_question_length is not None:
            description, length, truncated = truncate_question(description, max_question_length)
            if truncated:
                global cnt_too_long_property_description_but_truncated
                cnt_too_long_property_description_but_truncated += 1
        else:
            length = get_length(question)
        
        # Finally check whether description is either empty (if there is NO min_question_length)
        if min_question_length is None and not description:
            global cnt_too_long_property_description
            cnt_too_long_property_description += 1
            valid = False
        
        # ... or (if there is a min_question_length) whether the description is shorten than the minimum required number of tokens
        if min_question_length is not None and length < min_question_length:
            global cnt_too_short_property_description 
            cnt_too_short_property_description += 1
            valid = False
        
        # if all constraints are satisfied 
        if valid:
            
            # Build answer:
            answer = node.property_xpath
            
            # Calculate answer position in context
            answer_start = get_answer_start(context,answer)
            # if answer is in the context
            if answer_start != None:
                # Create Question-Answer pair and add it to list
                question_answer = QuestionAnswer(description,length,answer,answer_start)
                question_answer_pairs.append(question_answer)
            else:
                global cnt_answer_not_in_context
                cnt_answer_not_in_context += 1
                print("Error: Answer not in context")
                print("ID: ",node.id)
                print("Context: ",context)
                print("Answer: ", answer)
    
    
    # recursive call if node is type of array, object, or unknown with childs
    if node.property_data_type == "array" or node.property_data_type == "object" or ((node.property_data_type == "unknown" or node.property_data_type == None) and len(node.elements) > 0):
        for element in node.elements:
            question_answer_pairs += create_question_answer_pairs(element,context,min_question_length,max_question_length, remove_uris, max_depth)
    
    return question_answer_pairs

In [None]:
def create_question_answer_samples_for_payload(schema_root_node: ApiInterfaceNode, min_question_length: int = None , max_question_length: int = None, remove_uris: bool = False, max_depth: int = None, max_questions_per_sample: int = None, sort_by_name: bool = False, shuffle_context: bool = False):
    """
    Creates and returns one or multiple Question-answer samples for the passed schema root node. The decision whether one or multiple samples are created depends on the 'max_question_per_sample' threshold as well as the size (number of properties) of the payload.
    
    Parameters
    ----------
    schema_root_node : ApiInterfaceNode
        Root node of the schema ($)
    min_question_length : int
        Optional parameter (default is None) that specifies the minimum length (number of tokens) that a question must have. If the parameter is None, the minimum length is one token
    max_question_length : int
        Optional parameter (default is None) that specifies the maximum length (number of tokens) that a question may have. If the parameter is None, the maximum length is unlimited
    remove_uris : bool
        Optional parameter (default is False) indicating whether URIs should be removed from questions
    max_depth : int
        Optional parameter (default is None) that specifies the maximum depth that an XPath (as answer as well as in context) may have. If the parameter is None, there is no depth limitation
    max_questions_per_sample : int
        Optional parameter (default is None) that specifies the maximum number of Question-Answer pairs per sample. The method distributes the Question-Answer pairs to mulitple samples if this limit is exceeded.
    sort_by_name : bool
        if set to True (default value is False), the method will sort XPaths in context in ascending order
    shuffle : bool
        if set to True (default value is False), the method will shuffle XPaths in context
        
    Returns
    -------
    List of created Question-Answer samples (even if one sample is created, it is a list)
    """
    # Build context
    xpaths = extract_xpaths(schema_root_node)
    if max_depth:
        xpaths = filter_xpaths(xpaths,max_depth)
    context = build_context_string(xpaths,sort_by_name,shuffle_context)
    
    # Check whether context contains XPaths
    if context:
        question_answer_pairs = create_question_answer_pairs(schema_root_node, context, min_question_length, max_question_length, remove_uris, max_depth)
        
        # Check if at least one Question-Answer pair has been created:
        if question_answer_pairs:
            if max_questions_per_sample is not None:
                samples = []
                
                while question_answer_pairs:
                    counter = 0
                    partial_question_answer_pairs = []
                    while len(question_answer_pairs) > 0 and counter < max_questions_per_sample:
                        partial_question_answer_pairs.append(question_answer_pairs.pop())
                        counter += 1
                    sample = QuestionAnswerSample(context, partial_question_answer_pairs, schema_root_node.id)
                    samples.append(sample)
                
                if len(samples) > 1:
                    global cnt_split_samples
                    cnt_split_samples += 1
                return samples
            else:
                sample = QuestionAnswerSample(context, question_answer_pairs, schema_root_node.id)
                return [sample]
    else:
        global cnt_no_context
        cnt_no_context += 1
        
    return None
        

In [None]:
question_cnt = 0

results = []



for api in tqdm(apis):
    if int(api.api_key) in excluded_api_keys:
        print("Skip ",api.api_name," (",api.api_key,")")
        continue

    question_cnt_per_api = 0
    payload_cnt_per_api = 0
    samples = []

    for i in range(original_retakes):
        for payload_node in extract_nodes(api,node_type="payload"):
            payload_cnt_per_api+=1
            if len(payload_node.elements) == 1:
                root_node = payload_node.elements[0]
                s = create_question_answer_samples_for_payload(root_node,min_question_length,max_question_length,remove_uris,max_depth,max_questions_per_sample, sort_by_name, False)
                if s:
                    samples+= s

    for i in range(shuffled_retakes):
        for payload_node in extract_nodes(api,node_type="payload"):
            payload_cnt_per_api+=1
            if len(payload_node.elements) == 1:
                root_node = payload_node.elements[0]
                s = create_question_answer_samples_for_payload(root_node,min_question_length,max_question_length,remove_uris,max_depth,max_questions_per_sample, False, True)
                if s:
                    samples+= s

    if samples:
        for sample in samples:
            question_cnt_per_api += len(sample.questionAnswers)
            question_cnt += len(sample.questionAnswers)
        results.append({
            "samples":samples,
            "api_key":api.api_key,
            "api_name":api.api_name,
            "api_version_key":api.api_version_key,
            "api_version_name":api.api_version_name,
            "payload_cnt_per_api":payload_cnt_per_api,
            "question_cnt_per_api":question_cnt_per_api
        })
        
        
            
            
       
                

In [None]:
sorted_results = sorted(results, key=lambda item: item["question_cnt_per_api"],reverse=True) 

In [None]:
chunks = [[] for i in range(number_of_chunks)]

with open(output_path+datetime.now().strftime("%Y-%m-%dT%H-%M-%S")+".log.csv","w") as log_file:
    log_file.write("API Key;API Name;API Version Key; API Version;#Payloads;#Samples;#Questions;Out File\n")
    for result in sorted_results:
        smallest_chunk_index = 0
        smallest_chunk_size = None
        for i in range(number_of_chunks):
            num_questions = 0
            for sample in chunks[i]:
                num_questions+=result["question_cnt_per_api"]
            if smallest_chunk_size == None or num_questions < smallest_chunk_size:
                smallest_chunk_size = num_questions
                smallest_chunk_index = i
        chunks[smallest_chunk_index]+= result["samples"]
        filename = str(smallest_chunk_index)+".json"

        log_file.write(str(result["api_key"])+";"+str(result["api_name"])+";"+str(result["api_version_key"])+";"+str(result["api_version_name"])+";"+str(result["payload_cnt_per_api"])+";"+str(len(result["samples"]))+";"+str(result["question_cnt_per_api"])+";"+filename+"\n")

In [None]:
print("# Questions: ",question_cnt)
print("# Payloads without context: ", cnt_no_context)
print("# Properties without description: ", cnt_no_property_description)
print("# Properties with too long descriptions (but could be truncated): ", cnt_too_long_property_description_but_truncated)
print("# Properties with too long descriptions (could not be truncated): ", cnt_too_long_property_description)
print("# Properties with too short descriptions: ", cnt_too_short_property_description)
print("# Properties with too deep XPaths: ", cnt_too_deep_property_xpath)
print("# Original samples that have been split into multiple samples: ", cnt_split_samples)
print("# Errors: Answer not in context: ", cnt_answer_not_in_context)

In [None]:
# shuffle samples
for i in range(number_of_chunks):
    random.shuffle(chunks[i])

In [None]:
# print chunks
for i in range(number_of_chunks):
    q_cnt = 0
    for sample in chunks[i]:
            q_cnt += len(sample.questionAnswers)
    print(i,": ",len(chunks[i])," samples / ",q_cnt," questions (", (q_cnt/question_cnt)*100,"%)")

In [None]:
# write chunks
for i in range(number_of_chunks):
  with open(output_path+str(i)+".json","w") as file:
    for sample in chunks[i]:
      file.write(str(sample))
      file.write("\n")

In [None]:
#write validation set
validation_index = 2
with open(output_path+"validation.json","w") as file:
    for sample in chunks[validation_index]:
        file.write(str(sample))
        file.write("\n")

In [None]:
#write test set
test_index = 9
with open(output_path+"test.json","w") as file:
    for sample in chunks[test_index]:
        file.write(str(sample))
        file.write("\n")

In [None]:
train_indices = [0,1,3,4,5,6,7,8]
train_samples = []
for i in train_indices:
    train_samples += chunks[i]
random.shuffle(train_samples)

with open(output_path+"train.json","w") as file:
    for sample in train_samples:
        file.write(str(sample))
        file.write("\n")