# Get Sentences with SPIKE's query API

The following script takes a spike query and lists of words, runs the query on SPIKE and downloads a list of sentences that match the query. 

In [43]:
import csv
import json
import requests
from tqdm import tqdm
from collections import defaultdict

### Materials to prepare
This script assumes the following materials:
1. json file with the desired patterns, located at `./src`, in the following format:
```
{
    "0": {
        "query": "$[w={roles}]guitarist <E>musician:[e=PERSON]John plays the piano",
        "type": "syntactic", # other options are boolean or token 
        "case_strategy": "ignore", # other options are exact or smart 
        "label": "positive",
        "lists": ["roles"] # should match the name within brackets in the query. Leave empty list if irrelevant.
    },
    "1": {
        ...
    }...
}
```
2. Lists of words stored in text files under `../data/lists`. The name of the file should match the name in the patterns file. Note you can download the list straight from spike, or create one yourself, with a single item per line. 

In [77]:
SPIKE_MATCHES_DIR = '../data/spike_matches'
PATTERNS_FILE = 'patterns.json'
LISTS_FILE = '../data/lists'
LIMIT = 1000
ENTITY_TYPE='PERSON' # the type of entity you are looking for. If your desired capture is not an entity, leave an empty string.


In [80]:
def read_patterns_from_file(path):
    with open(path, "r") as f:
        return json.load(f)


def write_pattern_matches(pattern):
    pattern_matches = search_single_query(pattern)
    return pattern_matches

        
def search_single_query(pattern):
    spike_url = "https://spike.staging.apps.allenai.org"
    stream_location = get_stream_location(spike_url, pattern)
    if not stream_location: return None
    matches = list(search_stream(spike_url, stream_location, pattern))
    return matches


def get_lists(pattern):
    lists = defaultdict(list)
    list_names = pattern["lists"]
    if list_names:
        for name in list_names:
            with open(f"../data/lists/{name}.txt", "r") as f:
                for item in f.readlines():
                    lists[name].append(item.strip())
    return lists 
    

def get_stream_location(spike_url, pattern):
    dataset = "wikipedia"
    url = spike_url + "/api/3/multi-search/query"
    query_type = pattern["type"]
    query = pattern["query"]
    case_strategy = pattern["case_strategy"]
    lists = get_lists(pattern)
    data= {
        "queries": {
            "main": {
                query_type: query
            }
        },
        "data_set_name": dataset,
        "context": {
            "lists": lists,
            "tables": {
            },
            "case_strategy": case_strategy,
            "attempt_fuzzy": False
        }
    }
    response = requests.post(url, json=data)
    if 'stream-location' not in response.headers:
        return None
    return response.headers['stream-location']


def search_stream(spike_url, stream_location, pattern):
    limit = LIMIT*5 if pattern["label"] == "negative" else LIMIT
    stream_url = spike_url + stream_location + f"?include_sentence=true&limit={limit}&include_sentence=true"
    response = requests.get(stream_url)
    results = [json.loads(jl) for jl in response.text.splitlines()]
    if len(results) == 0:
        print(f"Couldn't find any results for pattern {pattern}")
    for result in results:
        if result['kind'] in ['continuation_url', 'tip']: continue
        data = result['value']['sub_matches']['main']
        entities = [ent for ent in data['sentence']['entities'] if ent["label"].startswith(ENTITY_TYPE)] if ENTITY_TYPE else []
        yield {
            'words': data['words'],
            'captures': data['captures'],
            'sentence_index': data['sentence_index'],
            'highlights': data['highlights'],
            'entities': entities
            }

        
def main():
    patterns = read_patterns_from_file(f'{PATTERNS_FILE}')
    for idx, pattern in tqdm(patterns.items()):
        label = pattern["label"]
        with open(f'{SPIKE_MATCHES_DIR}/{label}/{idx}.json', 'w') as f:
            matches = write_pattern_matches(pattern)
            if matches:
                f.write(f"{json.dumps(matches, indent=4)}")
            else:
                print("no matches")

In [81]:
main()

100%|███████████████████████████████████████████████████████████| 6/6 [00:21<00:00,  3.62s/it]
