In [75]:
import sys
sys.path.append("..")

from tinylang.language import PCFG, QueryType
import pandas as pd
import plotnine as p9
from tqdm import tqdm
from copy import deepcopy
from collections import defaultdict

In [32]:
pcfg = PCFG(num_terminals=20,
            num_nonterminals=10,
            max_rhs_len=10,
            max_rules_per_nt=5,
            max_depth=10,
            head_position="right",
            mask_nonquery=False,
            no_parent_queries=True,
            no_sibling_queries=False,
            no_child_queries=True,
            unambiguous_queries=True)

In [51]:
sentence = pcfg._sample()

In [54]:
sentence

[Node(label=0, id=6, head_id=1, depth=1),
 Node(label=8, id=16, head_id=7, depth=2),
 Node(label=18, id=15, head_id=7, depth=2),
 Node(label=10, id=14, head_id=7, depth=2),
 Node(label=10, id=13, head_id=7, depth=2),
 Node(label=7, id=12, head_id=7, depth=2),
 Node(label=12, id=11, head_id=7, depth=2),
 Node(label=11, id=10, head_id=7, depth=2),
 Node(label=12, id=9, head_id=7, depth=2),
 Node(label=9, id=8, head_id=7, depth=2),
 Node(label=17, id=7, head_id=1, depth=2),
 Node(label=12, id=4, head_id=1, depth=1),
 Node(label=4, id=3, head_id=1, depth=1),
 Node(label=2, id=2, head_id=1, depth=1),
 Node(label=16, id=1, head_id=None, depth=1)]

we want:
- (0, 13)
- (11, 13)
- (12, 13)
- (1, 10)
- (2, 10)
- (..., 10)

In [86]:
eligible_pairs = {}
# default rightmost sibling is self
rightmost_siblings = {q : q for q in range(len(sentence))}

# we will construct a list of eligible (query, target) pairs for each query type
for query_type in pcfg.acceptable_query_types:
    eligible_pairs[query_type] = []

    # traverse all pairs
    for target in range(len(sentence)):
        for query in range(len(sentence)):
            # first check train/test split
            if len(pcfg.prohibited_pairs) > 0:
                if split in ["train", "dev"]:
                    if (sentence[target].label, sentence[query].label) in pcfg.prohibited_pairs:
                        continue
                elif split == "test":
                    if (sentence[target].label, sentence[query].label) not in pcfg.prohibited_pairs:
                        continue
            
            # second check if it satisfies the relation
            if pcfg.is_relation(sentence, query_type, query, target):
                
                if query_type == QueryType.SIBLING:
                    rightmost_siblings[query] = max(rightmost_siblings[query], target)

                else:
                    eligible_pairs[query_type].append((query, target))

3

In [87]:
eligible_pairs

{<QueryType.SIBLING: 2>: [(0, 13),
  (1, 9),
  (2, 9),
  (3, 9),
  (4, 9),
  (5, 9),
  (6, 9),
  (7, 9),
  (8, 9),
  (9, 9),
  (10, 13),
  (11, 13),
  (12, 13),
  (13, 13),
  (14, 14)]}

In [83]:
sentence

[Node(label=0, id=6, head_id=1, depth=1),
 Node(label=8, id=16, head_id=7, depth=2),
 Node(label=18, id=15, head_id=7, depth=2),
 Node(label=10, id=14, head_id=7, depth=2),
 Node(label=10, id=13, head_id=7, depth=2),
 Node(label=7, id=12, head_id=7, depth=2),
 Node(label=12, id=11, head_id=7, depth=2),
 Node(label=11, id=10, head_id=7, depth=2),
 Node(label=12, id=9, head_id=7, depth=2),
 Node(label=9, id=8, head_id=7, depth=2),
 Node(label=17, id=7, head_id=1, depth=2),
 Node(label=12, id=4, head_id=1, depth=1),
 Node(label=4, id=3, head_id=1, depth=1),
 Node(label=2, id=2, head_id=1, depth=1),
 Node(label=16, id=1, head_id=None, depth=1)]

In [88]:
# now filter for rightmost instance of each child, for the queries
    # get rightmost instance of each terminal
rightmost_types = dict()
for i, node in enumerate(sentence):
    rightmost_types[node.label] = max(rightmost_types.get(node.label, 0), i)

# query must be the rightmost instance of its type
for query_type in pcfg.acceptable_query_types:
    if query_type == QueryType.SIBLING:
        # for sibling queries, both query and target must be rightmost instances
        eligible_pairs[query_type] = [
            (query, target) for query, target in eligible_pairs[query_type] if rightmost_types[sentence[query].label] == query and rightmost_types[sentence[target].label] == target
        ]
    else:
        eligible_pairs[query_type] = [
            (query, target) for query, target in eligible_pairs[query_type] if rightmost_types[sentence[query].label] == query
        ]

In [64]:
sentence

[Node(label=0, id=6, head_id=1, depth=1),
 Node(label=8, id=16, head_id=7, depth=2),
 Node(label=18, id=15, head_id=7, depth=2),
 Node(label=10, id=14, head_id=7, depth=2),
 Node(label=10, id=13, head_id=7, depth=2),
 Node(label=7, id=12, head_id=7, depth=2),
 Node(label=12, id=11, head_id=7, depth=2),
 Node(label=11, id=10, head_id=7, depth=2),
 Node(label=12, id=9, head_id=7, depth=2),
 Node(label=9, id=8, head_id=7, depth=2),
 Node(label=17, id=7, head_id=1, depth=2),
 Node(label=12, id=4, head_id=1, depth=1),
 Node(label=4, id=3, head_id=1, depth=1),
 Node(label=2, id=2, head_id=1, depth=1),
 Node(label=16, id=1, head_id=None, depth=1)]

In [89]:
eligible_pairs

{<QueryType.SIBLING: 2>: [(0, 13),
  (1, 9),
  (2, 9),
  (4, 9),
  (5, 9),
  (7, 9),
  (9, 9),
  (10, 13),
  (11, 13),
  (12, 13),
  (13, 13),
  (14, 14)]}

In [90]:
# now we have a list of eligible (query, target) pairs for each query type
# let's also pass the eligible queries/targets for sampling
eligible_queries = {}
eligible_targets = {}
for query_type in pcfg.acceptable_query_types:
    if len(eligible_pairs[query_type]) == 0:
        del eligible_pairs[query_type]
        continue
    eligible_queries[query_type] = list(set([query for query, _ in eligible_pairs[query_type]]))
    eligible_targets[query_type] = list(set([target for _, target in eligible_pairs[query_type]]))