Scratch space for TKG patterns synthetic data creation

In [22]:
import config

In [339]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scipy

from collections import defaultdict
from itertools import combinations, product

import os
import random
import re

Steps:
1. Start with random graph. Parameters: # of entities, # of relations, # of time windows, avg density of entities 
2. Choose set of patterns. Start w/ 3-hop, 2-hop, 1-hop. Pick time-stamp diffs. Don't allow subset patterns (e.g. chosen 2-hop patterns can't be included in 3-hop)
3. Apply patterns iteratively. Parameters: Probability of random wiring, probability of no time pattern application (or interrupted time pattern application)

# Utilities

In [201]:
class TemporalPattern():
    def __init__(
        self,
        antecedent: 'List[Tuple[int,int,int]]' = [],
        consequence: 'Tuple[int,int,int]' = None,
        time_lags: 'List[Tuple[float,float]]' = [],
        n_hops: int = None,
    ):
        """ Defines a temporal pattern over our TKG
        Args:
            antecedent (List[Tuple[int,int,int]]): Antecedent(s) for the pattern
                in the form of subject, relation, object ID triples
            consequence (Tuple[int,int,int]): Consequence for the pattern in the form
                of a subject, relation, object ID triple
            time_lags (List[Tuple[float,float]]): Time lags with which antecedents and 
                consequences can occur validly in the pattern. Must be an iterable of
                length equal to the antecedent. The i-th element of time_lags is a
                tuple of the form (minimum # of time windows since i-th antecedent,
                maximum # of time windows since i-th antecedent) for the i+1-th
                antecedent (or consequence if we have iterated over all antecedents)
        """
        self.antecedent = antecedent
        self.consequence = consequence
        self.time_lags = time_lags
        self.n_hops = n_hops

        # Regex pattern used to parametrize a pattern from a text label
        self.pat_time_lag = re.compile('t\d+=t\d+\+\((\d+),\s(\d+)\)')
    
    def __label__(self) -> str:
        """ Return a string label for the pattern like <ANTECEDENT> -> <CONSEQUENCE>
        """
        label = ''
        for idx in range(len(self.antecedent)):
            if idx > 0:
                label += ' & '
            label += f'({self.antecedent[idx][0]}, '
            label += f'{self.antecedent[idx][1]}, '
            label += f'{self.antecedent[idx][2]}, '
            label += f't{idx+1}{f"=t{idx}+"+str(self.time_lags[idx-1]) if idx > 0 else ""})'
        label += ' -> '
        label += f'({self.consequence[0]}, '
        label += f'{self.consequence[1]}, '
        label += f'{self.consequence[2]}, '
        label += f't{len(self.antecedent)+1}{f"=t{len(self.antecedent)}+"+str(self.time_lags[-1])})'
        return label
    
    def __triples__(self) -> 'List[Tuple[int,int,int]]':
        """ Return antecedent and consequence as list of triples (excluding time lag information)
        """
        return self.antecedent + [self.consequence]
    
    def __quadruples__(self) -> 'List[Tuple[int,int,int,Tuple[float,float]]]':
        """ Return antecedent and consequence as list of quadruples (including time lag information)
        """
        return [
            triple + (time_lag,) for triple, time_lag in zip(self.__triples__(), [()]+self.time_lags)
        ]
    
    def from_label(self, label: str):
        """ Parametrize a pattern from a string label (as generated by __label__ method)
        """
        antecedent, consequence = label.split('->')
        antecedent = antecedent.split('&')
        # Get time lag information
        time_lags = [
            (
                int(self.pat_time_lag.search(text).groups()[0]),
                int(self.pat_time_lag.search(text).groups()[1])
            )
            for text in antecedent+[consequence]
            if self.pat_time_lag.search(text)
        ]
        # Format antecedent and consequence
        antecedent = [
            eval('('+','.join(ant.strip(' ()').split(',')[:3])+')')
            for ant in antecedent
        ]
        consequence = eval('('+','.join(consequence.strip(' ()').split(',')[:3])+')')
        
        self.antecedent = antecedent
        self.consequence = consequence
        self.time_lags = time_lags
        self.n_hops = len(antecedent)

In [160]:
def is_subpattern(subpattern: 'List[Tuple]', patterns: 'List[Tuple]') -> bool:
    """ Test whether subpattern is a subpattern of any member of patterns
    """
    n = len(subpattern)
    pattern_subsets = set([
        tuple(pattern[idx:idx+n]) for pattern in patterns
        for idx in range(len(pattern)-n+1)
    ])
    if tuple(subpattern) in pattern_subsets:
        return True
    return False

In [161]:
def entities_intersect(entities1: 'List[int]', entities2: 'List[int]') -> bool:
    """ Indicate whether entities1 and entities2 have any intersection
    """
    if len(set(entities1).intersection(entities2)) > 0:
        return True
    return False

In [162]:
def entities_connect_triples(
    e1: int, e2: int, triple1: 'Tuple[int,int,int]', triple2: 'Tuple[int,int,int]'
) -> bool:
    """ Indicate whether entities e1 and e2 connect triples triple1 and triple2
    """
    if (e1 in {triple1[0], triple1[2]}) & (e2 in {triple2[0], triple2[2]}):
        return True
    elif (e2 in {triple1[0], triple1[2]}) & (e1 in {triple2[0], triple2[2]}):
        return True
    return False

In [163]:
def entities_connect_components(
    e1: int, e2: int, comp1: 'List[Tuple[int,int,int]]', comp2: 'List[Tuple[int,int,int]]'
) -> bool:
    """ Indicate whether entities e1 and e2 connect components comp1 and comp2
    """
    for triple1, triple2 in product(comp1, comp2):
        if entities_connect_triples(e1, e2, triple1, triple2):
            return True
    return False

In [164]:
def combinations_of_increasing_size(iterable, a, b):
    """ Return combinations of iterable of size from a to b, inclusive
    """
    all_combinations = []
    for size in range(a, b+1):
        combs = combinations(iterable, size)
        for comb in combs:
            yield comb

In [165]:
def force_swap_to_entities(
    idxs_to_force: 'List[int]',
    sampled_entities: 'List[int]',
    swap_to_entities: 'List[int]',
    seed: int = None,
) -> None:
    """ Force at least one of idxs_to_swap to be switching in sampled_entities
    to one of swap_to_entities. Note, alters swap_to_entities in place.
    """
    random.seed(seed)
    to_swap = random.choice(
        list(combinations_of_increasing_size(idxs_to_force, 1, len(idxs_to_force)))
    )
    for idx in to_swap:
        swap_to = random.choice(swap_to_entities)
        sampled_entities[idx] = swap_to

In [166]:
def force_connect_components(
    idxs_to_force: 'List[int,int]',
    sampled_entities: 'List[int]',
    comp1: 'List[int]',
    comp2: 'List[int]',
    seed: int = None,
) -> None:
    """ Force at least one of idxs_to_force to be switched so as to connect comp1
    and comp2. Note, alters sampled_entities in place.
    """
    if len(idxs_to_force) != 2:
        raise ValueError(
            'force_connecte_components only implemented for idxs_to_force of length 2'
        )
    random.seed(seed)
    comps = [comp1, comp2]
    if sampled_entities[idxs_to_force[0]] in comps[0]:
        sampled_entities[idxs_to_force[1]] = random.choice(comps[1])
    elif sampled_entities[idxs_to_force[0]] in comps[1]:
        sampled_entities[idxs_to_force[1]] = random.choice(comps[0])
    elif sampled_entities[idxs_to_force[1]] in comps[0]:
        sampled_entities[idxs_to_force[0]] = random.choice(comps[1])
    elif sampled_entities[idxs_to_force[1]] in comps[1]:
        sampled_entities[idxs_to_force[0]] = random.choice(comps[0])
    else:
        random.shuffle(comps)
        sampled_entities[idxs_to_force[0]] = random.choice(comps[0])
        sampled_entities[idxs_to_force[1]] = random.choice(comps[1])

In [167]:
def create_time_lag_tuples(
    time_lags: 'List[Tuple(float,float)]',
    antecedent: 'List[Tuple[int,int,int]]',
) -> 'List[Tuple[float,float]]':
    """ Create time_lag_tuples used to instantiate patterns. Contains logic to prohibit: identical
    antecedents from having 0 time lag between them, the consequence from having 0 lag from the
    last antecedent.
    """
    time_lag_tuples = []
    for idx, time_lag in enumerate(time_lags):
        lag_min, lag_max = time_lag[0], time_lag[1]
        time_lag_tuple = []
        if type(lag_min) in [float, int]:
            pass
        else:
            lag_min = lag_min()
        if type(lag_max) in [float, int]:
            pass
        else:
            lag_max = lag_max()
        # Prohibit identical antecedents from having 0 lag_min between them
        if (idx < len(antecedent)-1) and (antecedent[idx] == antecedent[idx+1]):
            lag_min = max(1, lag_min)
        # Prohibit consequence from having 0 lag_min from the last antecedent
        if idx == len(antecedent)-1:
            lag_min = max(1, lag_min)
        # Enforce that lag_max is no smaller than lag_min
        time_lag_tuples.append((lag_min, max(lag_min, lag_max)))
    return time_lag_tuples

# Define basic information

In [168]:
n_ents = 10
n_rels = 5
n_tws = 10
entity2id = pd.DataFrame({
    'name': range(n_ents),
    'id': range(n_ents),
})
relation2id = pd.DataFrame({
    'name': range(n_rels),
    'id': range(n_rels),
})
time2id = pd.DataFrame({
    'name': range(n_tws),
    'id': range(n_tws),
})

# Define patterns

## 1-hop

In [169]:
def create_1_hop_pattern(
    entity2id: pd.DataFrame,
    relation2id: pd.DataFrame,
    time2id: pd.DataFrame,
    time_lags: 'List[Tuple(float,float)]',
    seed: int = None,
) -> TemporalPattern:
    """ Create a 1-hop temporal pattern
    (e1, r1, e2, t1) → (e3, r2, e4, t2)
    Such that:
        - e3 or e4 \in {e1, e2}
        - t2 > t1
    Args:
        entity2id (pd.DataFrame): Dataframe with entity ids in column 'id'
        relation2id (pd.DataFrame): Dataframe with relation ids in column 'id'
        time2id (pd.DataFrame): Dataframe with time window ids in column 'id'
        time_lags (List[float]): Time lags with which antecedents and consequences
            can occur validly. Either a list of float tuples or of functions that
            can create such floats when called.
        seed (int): Random seed, default None
    """
    random.seed(seed)
    # Randomly select all initial entities and relations to be used
    sampled_entities = entity2id.sample(4, replace=True, random_state=seed)\
        ['id'].tolist()
    sampled_relations = relation2id.sample(2, replace=True, random_state=seed)\
        ['id'].tolist()
    # Define antecedent
    antecedent = [
        (
            sampled_entities[0],  # e1
            sampled_relations[0],  # r1
            sampled_entities[1],  # e2
        )
    ]
    # Consquence must satisfy the constraint that at least one of its entities is in
    # some antecedent
    if ~entities_intersect(sampled_entities[2:4], sampled_entities[:2]):
        # Force at least of the consequent entities to switch to an antecedent entity
        force_swap_to_entities([2,3], sampled_entities, sampled_entities[:2], seed)
    consequence = (
        sampled_entities[2],  # e3
        sampled_relations[1],  # r2
        sampled_entities[3],  # e4
    )

    time_lag_tuples = create_time_lag_tuples(time_lags, antecedent)

    return TemporalPattern(
        antecedent=antecedent,
        consequence=consequence,
        time_lags=time_lag_tuples,
        n_hops=1,
    )

In [170]:
pat = create_1_hop_pattern(
    entity2id,
    relation2id,
    time2id,
    [(0,4)],
    # [(lambda: random.randint(1, 4), lambda: random.randint(5, 10))],
    0
)
pat.__label__()

'(5, 4, 0, t1) -> (3, 0, 0, t2=t1+(1, 4))'

## 2-hop

In [171]:
def create_2_hop_pattern(
    entity2id: pd.DataFrame,
    relation2id: pd.DataFrame,
    time2id: pd.DataFrame,
    time_lags: 'List[Tuple(float,float)]',
    seed: int = None,
) -> TemporalPattern:
    """ Create a 2-hop temporal pattern
    (e1, r1, e2, t1) & (e3, r2, e4, t2) → (e5, r3, e6, t3)
    Such that:
        - (e3 or e4 \in {e1, e2}) and (e5 or e6 \in {e1, e2, e3, e4}) OR
        - ~(e3 and e4 \in {e1, e2}) and ((e5 \in {e1, e2} and e6 \in {e3, e4}) or
            (e5 \in {e3, e4} and e6 \in {e1, e2}))
        - t3 > t2 >= t1
        - (e1, r1, e2, t1) != (e3, r2, e4, t2)

    Args:
        entity2id (pd.DataFrame): Dataframe with entity ids in column 'id'
        relation2id (pd.DataFrame): Dataframe with relation ids in column 'id'
        time2id (pd.DataFrame): Dataframe with time window ids in column 'id'
        time_lags (List[float]): Time lags with which antecedents and consequences
            can occur validly. Either a list of float tuples or of functions that
            can create such floats when called.
        seed (int): Random seed, default None
    """
    random.seed(seed)
    # Randomly select all initial entities and relations to be used
    sampled_entities = entity2id.sample(6, replace=True, random_state=seed)\
        ['id'].tolist()
    sampled_relations = relation2id.sample(3, replace=True, random_state=seed)\
        ['id'].tolist()
    # Define antecedent
    antecedent = [
        (
            sampled_entities[0],  # e1
            sampled_relations[0],  # r1
            sampled_entities[1],  # e2
        ),
        (
            sampled_entities[2],  # e3
            sampled_relations[1],  # r2
            sampled_entities[3],  # e4
        )
    ]
    # Second antecedent determines how we define the consequence
    if entities_intersect(sampled_entities[2:4], sampled_entities[:2]):
        # At least one entity intersects with the first antecedent
        # Consquence must have at least one of its entities in some antecedent
        if ~entities_intersect(sampled_entities[4:6], sampled_entities[:4]):
            # Enforce that one or more consequence entity be chosen from the antecedents
            force_swap_to_entities([4,5], sampled_entities, sampled_entities[:4], seed)
    else:
        # Non-intersecting first and second antecedents
        # Consequence must connect them
        if not entities_connect_triples(
            sampled_entities[4], sampled_entities[5], antecedent[0], antecedent[1]
        ):
            # Enforce that the consequence connects the two antecedents
            force_connect_components(
                [4,5], sampled_entities, sampled_entities[:2], sampled_entities[2:4], seed
            )
    # Define consequence
    consequence = (
        sampled_entities[4],  # e5
        sampled_relations[2],  # r3
        sampled_entities[5],  # e6
    )

    time_lag_tuples = create_time_lag_tuples(time_lags, antecedent)
    
    return TemporalPattern(
        antecedent=antecedent,
        consequence=consequence,
        time_lags=time_lag_tuples,
        n_hops=2,
    )

In [172]:
pat = create_2_hop_pattern(
    entity2id,
    relation2id,
    time2id,
    [(0,4), (0,5)],
    # [(lambda: random.randint(1, 4), lambda: random.randint(5, 10))],
    0
)
pat.__label__()

'(5, 4, 0, t1) & (3, 0, 3, t2=t1+(0, 4)) -> (0, 3, 3, t3=t2+(1, 5))'

## 3-hop

In [173]:
def create_3_hop_pattern(
    entity2id: pd.DataFrame,
    relation2id: pd.DataFrame,
    time2id: pd.DataFrame,
    time_lags: 'List[Tuple(float,float)]',
    seed: int = None,
) -> TemporalPattern:
    """ Create a 3-hop temporal pattern
    (e1, r1, e2, t1) & (e3, r2, e4, t2) & (e5, r3, e6, t3) → (e7, r4, e8, t4)
    Such that:
        - All antecedents intersect: (e3 or e4 \in {e1, e2}) and (e5 or e6 \in {e1, e2, e3, e4})
            and (e7 or e8 \in {e1, e2, e3, e4, e5, e6})
        - Second antecedent does not intersect first: ~(e3 and e4 \in {e1, e2}) and
            ((e5 \in {e1, e2} and e6 \in {e3, e4}) or (e5 \in {e3, e4} and e6 \in {e1, e2})) and
            (e7 or e8 \in {e1, e2, e3, e4, e5, e6})
        - Third antecedent does not intersect first two: (e3 or e4 \in {e1, e2}) and
            ~(e5 and e6 \in {e1, e2, e3, e4}) and ((e7 \in {e1, e2, e3, e4} and e8 \in {e5, e6})
            or (e7 \in {e5, e6} and e8 \in {e1, e2, e3, e4}))
        - t4 > t3 >= t2 >= t1
        - (e1, r1, e2, t1) != (e3, r2, e4, t2) and (e1, r1, e2, t1) != (e5, r3, e6, t3) and
            (e1, r1, e2, t1) != (e3, r2, e4, t2)

    Args:
        entity2id (pd.DataFrame): Dataframe with entity ids in column 'id'
        relation2id (pd.DataFrame): Dataframe with relation ids in column 'id'
        time2id (pd.DataFrame): Dataframe with time window ids in column 'id'
        time_lags (List[float]): Time lags with which antecedents and consequences
            can occur validly. Either a list of float tuples or of functions that
            can create such floats when called.
        seed (int): Random seed, default None
    """
    random.seed(seed)
    # Randomly select all initial entities and relations to be used
    sampled_entities = entity2id.sample(8, replace=True, random_state=seed)\
        ['id'].tolist()
    sampled_relations = relation2id.sample(4, replace=True, random_state=seed)\
        ['id'].tolist()
    # Third antecedent must intersect at least one prior antecedent
    if ~entities_intersect(sampled_entities[:4], sampled_entities[4:6]):
        # Enforce that the third antecedent include entities from at least one prior
        # antecedent
        force_swap_to_entities([4,5], sampled_entities, sampled_entities[:4], seed)
    # Define antecedent
    antecedent = [
        (
            sampled_entities[0],  # e1
            sampled_relations[0],  # r1
            sampled_entities[1],  # e2
        ),
        (
            sampled_entities[2],  # e3
            sampled_relations[1],  # r2
            sampled_entities[3],  # e4
        ),
        (
            sampled_entities[4],  # e5
            sampled_relations[2],  # r3
            sampled_entities[5],  # e6
        ),
    ]
    # Second and third antecedents determines how we define the consequence
    if entities_intersect(sampled_entities[2:4], sampled_entities[:2]) & \
        entities_intersect(sampled_entities[4:6], sampled_entities[:4]):
        # Antecedents are connected
        # Consquence must have at least one of its entities in some antecedent
        if (sampled_entities[6] not in sampled_entities[:6]) & \
            (sampled_entities[7] not in sampled_entities[:6]):
            # Enforce that one or more consequence entity be chosen from the antecedents
            force_swap_to_entities([6,7], sampled_entities, sampled_entities[:6], seed)
    elif ~entities_intersect(sampled_entities[2:4], sampled_entities[:2]):
        # Second antecedent does not intersect first antecedent
        # Determine conditions for consequence based on how the third antecedent intersects
        # the prior antecedents
        if entities_connect_triples(
            sampled_entities[4], sampled_entities[5], antecedent[0], antecedent[1]
        ):
            # Third antecedent connects prior antecedents
            # Consquence must have at least one of its entities in some antecedent
            if ~entities_intersect(sampled_entities[6:8], sampled_entities[:6]):
                # Enforce that one or more consequence entity be chosen from the antecedents
                force_swap_to_entities([6,7], sampled_entities, sampled_entities[:6], seed)
        elif entities_intersect(sampled_entities[4:6], sampled_entities[:2]):
            # Third antecedent connects with first antecedent
            # Consequence must connect antecedents
            if not entities_connect_components(
                sampled_entities[6], sampled_entities[7],
                [antecedent[0], antecedent[2]], [antecedent[1]]
            ):
                # Enforce that the consequence connects the disconnected antecedents
                force_connect_components(
                    [6,7], sampled_entities,
                    sampled_entities[:2]+sampled_entities[4:6],
                    sampled_entities[2:4],
                    seed,
                )
        else:
            # Third antecedent connects with second antecedent
            # Consequence must connect antecedents
            if not entities_connect_components(
                sampled_entities[6], sampled_entities[7],
                [antecedent[0]], [antecedent[1], antecedent[2]]
            ):
                # Enforce that the consequence connects the disconnected antecedents
                force_connect_components(
                    [6,7], sampled_entities,
                    sampled_entities[:2],
                    sampled_entities[2:6],
                    seed,
                )
    else:
        # Third antecedent does not intersect prior intersecting antecedents
        # Consequence must connect them
        if not entities_connect_components(
            sampled_entities[6], sampled_entities[7],
            [antecedent[0], antecedent[1]], [antecedent[2]]
        ):
            # Enforce that the consequence connects the disconnected antecedents
            force_connect_components(
                [6,7], sampled_entities, sampled_entities[:4], sampled_entities[4:6], seed,
            )

    # Define consequence
    consequence = (
        sampled_entities[6],  # e7
        sampled_relations[3],  # r4
        sampled_entities[7],  # e8
    )
    
    time_lag_tuples = create_time_lag_tuples(time_lags, antecedent)

    return TemporalPattern(
        antecedent=antecedent,
        consequence=consequence,
        time_lags=time_lag_tuples,
        n_hops=3,
    )

In [174]:
pat = create_3_hop_pattern(
    entity2id,
    relation2id,
    time2id,
    [(0,4), (0,5), (0,6)],
    # [(lambda: random.randint(0, 2), lambda: random.randint(0, 2)), (lambda: random.randint(0, 4), lambda: random.randint(0, 4)), (lambda: random.randint(0, 4), lambda: random.randint(0, 4))],
    0
)
pat.__label__()

'(5, 4, 0, t1) & (3, 0, 3, t2=t1+(0, 4)) & (7, 3, 3, t3=t2+(0, 5)) -> (3, 3, 5, t4=t3+(1, 6))'

## Instantiate patterns

In [175]:
# Choose set of patterns
n_3_hop = 10
n_2_hop = 10
n_1_hop = 10
time_lag_3_hop = [
    (0, lambda: scipy.stats.poisson(3).rvs(1)[0]),
    (0, lambda: scipy.stats.poisson(3).rvs(1)[0]),
    (1, lambda: scipy.stats.poisson(3).rvs(1)[0]),
]
time_lag_2_hop = [
    (0, lambda: scipy.stats.poisson(3).rvs(1)[0]),
    (1, lambda: scipy.stats.poisson(3).rvs(1)[0]),
]
time_lag_1_hop = [
    (1, lambda: scipy.stats.poisson(3).rvs(1)[0]),
]
max_retries = 10
# Start from 3-hop patterns, then 2-hop, then 1-hop
# Prohibit any new patterns from being contained (antecedent and consequence) in the antecedent of
# an existing larger pattern or being identical to an already chosen same-sized pattern
patterns = []
pattern_quadruples = []
for _ in range(n_3_hop):
    new_pat = False
    retry = 0
    while (not new_pat) | (retry > max_retries):
        pat = create_3_hop_pattern(entity2id, relation2id, time2id, time_lag_3_hop)
        # Find out if quadruple is 
        quad = pat.__quadruples__()
        if ~is_subpattern(quad, pattern_quadruples):
            patterns.append(pat)
            pattern_quadruples.append(quad)
            new_pat = True
        retry += 1
for _ in range(n_2_hop):
    new_pat = False
    retry = 0
    while (not new_pat) | (retry > max_retries):
        pat = create_2_hop_pattern(entity2id, relation2id, time2id, time_lag_2_hop)
        # Find out if quadruple is 
        quad = pat.__quadruples__()
        if ~is_subpattern(quad, pattern_quadruples):
            patterns.append(pat)
            pattern_quadruples.append(quad)
            new_pat = True
        retry += 1
for _ in range(n_1_hop):
    new_pat = False
    retry = 0
    while (not new_pat) | (retry > max_retries):
        pat = create_1_hop_pattern(entity2id, relation2id, time2id, time_lag_1_hop)
        # Find out if quadruple is 
        quad = pat.__quadruples__()
        if ~is_subpattern(quad, pattern_quadruples):
            patterns.append(pat)
            pattern_quadruples.append(quad)
            new_pat = True
        retry += 1

In [177]:
# Create dataframe of pattern ids
pattern2id = pd.DataFrame({
    'pattern': [pat.__label__() for pat in patterns],
    'n_hops': [pat.n_hops for pat in patterns],
    'id': range(len(patterns)),
})
pattern2id[:3]

Unnamed: 0,pattern,n_hops,id
0,"(2, 3, 3, t1) & (9, 2, 4, t2=t1+(0, 3)) & (8, ...",3,0
1,"(8, 1, 6, t1) & (2, 0, 4, t2=t1+(0, 9)) & (4, ...",3,1
2,"(2, 1, 2, t1) & (9, 1, 5, t2=t1+(0, 6)) & (2, ...",3,2


# Apply patterns

Steps:
1. Start with random graph. Parameters: # of entities, # of relations, # of time windows, avg density of entities 
2. Choose set of patterns. Start w/ 3-hop, 2-hop, 1-hop. Pick time-stamp diffs. Don't allow subset patterns (e.g. chosen 2-hop patterns can't be included in 3-hop)
3. Apply patterns iteratively. Parameters: Probability of random wiring, probability of no time pattern application (or interrupted time pattern application)

In [355]:
# Density with which we randomly wire entities in this time window
rnd_avg_density = 3
# Function that returns an integer to be used for average density per entity
rnd_avg_density_distr = None #lambda: scipy.stats.poisson.rvs(3, size=1)[0]
# Probability that we do not apply a given pattern, per valid pattern (with all antecedents
# satisfied in previous time windows)
p_skip_consequence = .1
# Probability that we create artificially create edges that validate a given pattern
# (create edges that satisfay all antecedents), per pattern
n_hops2p_force = {
    1: .1,
    2: .1,
    3: .1,
}
# Note: A more sophisticated approach could also have a parameter to interrupt, or only
# partially create edges that satisfy antecedents
df_edgelist = pd.DataFrame({
    'head': [],
    'rel': [],
    'tail': [],
    't': [],
    'wt': [],
    'pattern': [],
})
for t in range(n_tws):
    # First randomly wire entities
    for ent_id in entity2id['id']:
        # Sample entities to use as tails
        if rnd_avg_density_distr:
            dens = rnd_avg_density_distr()
        else:
            dens = int(rnd_avg_density)
        tails = entity2id['id'].sample(dens, replace=True)
        # Sample relations to connect them
        rels = relation2id['id'].sample(dens, replace=True)
        df_i = pd.DataFrame({
            'head': [ent_id]*dens,
            'rel': rels.values,
            'tail': tails.values,
            't': [t]*dens,
            'wt': [1]*dens,
            'pattern': [[]]*dens,
        })
        df_edgelist = pd.concat([
            df_edgelist,
            df_i,
        ]).reset_index(drop=True)
    
    # Iterate over patterns
    heads, rels, tails, pats = [], [], [], []
    for label, pattern_id in zip(pattern2id['pattern'], pattern2id['id']):
        # Instantiate pattern from label
        pattern = TemporalPattern()
        pattern.from_label(label)

        # Artificially create valid patterns 
        rnd = random.random()
        if rnd < n_hops2p_force[pattern.n_hops]:
            # Create the antecedent in this and subsequent windows
            # Track time window of current antecedent as we create them
            t_i = int(t)
            heads_pat, rels_pat, tails_pat, ts_pat = [], [], [], []
            for antecedent, time_lag in zip(pattern.antecedent, pattern.time_lags):
                heads_pat.append(antecedent[0])
                rels_pat.append(antecedent[1])
                tails_pat.append(antecedent[2])
                ts_pat.append(t_i)
                # Increment t_i according to time_lag min and max
                t_i += random.randint(time_lag[0], time_lag[1])
            df_pat = pd.DataFrame({
                'head': heads_pat,
                'rel': rels_pat,
                'tail': tails_pat,
                't': ts_pat,
                'wt': [1]*len(heads_pat),
                'pattern': [[pattern_id]]*len(heads_pat),
            })
            df_edgelist = pd.concat([
                df_edgelist,
                df_pat,
            ]).reset_index(drop=True)
        
        # Apply valid patterns
        rnd = random.random()
        if rnd < p_skip_consequence:
            # Skip the consequence even though antecedents may be satisfied
            continue
        # Test whether antecedents are satisfied in prior windows
        antecedents_satisfied = False
        # Iterate over antecedents in reverse order (most recent to least recent)
        t_i = [t]  # Track current time window(s) for antecedent validation
        for antecedent, time_lag in zip(pattern.antecedent[::-1], pattern.time_lags[::-1]):
            # Check whether the antecedent exists in the edgelist
            df_ants = [
                df_edgelist[
                    (df_edgelist['head'] == antecedent[0]) &
                    (df_edgelist['rel'] == antecedent[1]) &
                    (df_edgelist['tail'] == antecedent[2]) &
                    (df_edgelist['t'] <= t_-time_lag[0]) & 
                    (df_edgelist['t'] >= t_-time_lag[1])
                ] for t_ in t_i
            ]
            df_ants = [df for df in df_ants if df.shape[0] > 0]
            if len(df_ants) == 0:
                # No satisfied antecedent
                antecedents_satisfied = False
                continue
            t_i = pd.concat(df_ants)['t'].unique().tolist()
            antecedents_satisfied = True
        # If all antecedents are satisfied, create consequence in current time window
        if antecedents_satisfied:
            heads.append(pattern.consequence[0])
            rels.append(pattern.consequence[1])
            tails.append(pattern.consequence[2])
            pats.append([pattern_id])
            # Note: Could try to label antecedent edges which satisfied this pattern with
            # relevant pattern id. But I found this really difficult to track, since the
            # way I track antecedent validity is like a branching tree.
    # Add all new consequences to edgelist
    df_pat = pd.DataFrame({
        'head': heads,
        'rel': rels,
        'tail': tails,
        't': [t]*len(heads),
        'wt': [1]*len(heads),
        'pattern': pats,
    })
    df_edgelist = pd.concat([
        df_edgelist,
        df_pat,
    ]).reset_index(drop=True)

# Cut off df_edgelist at n_tws
df_edgelist = df_edgelist[df_edgelist['t'] < n_tws]
# Aggregate duplicate edges
df_edgelist = df_edgelist.groupby(['head', 'rel', 'tail', 't']).agg({
    'wt': 'sum',
    'pattern': lambda x: sorted(list(set([el for ids in x for el in ids]))),
}).reset_index().sort_values(['t', 'head', 'tail', 'rel']).reset_index(drop=True)
df_edgelist['head'] = df_edgelist['head'].astype(int)
df_edgelist['rel'] = df_edgelist['rel'].astype(int)
df_edgelist['tail'] = df_edgelist['tail'].astype(int)
df_edgelist['t'] = df_edgelist['t'].astype(int)
len(df_edgelist)

454

In [356]:
# Export all relevant files
wd = '/nas/ckgfs/users/eboxer/tkg_patterns'
export_dir = os.path.join(wd, 'scratch_output')
entity2id.to_csv(
    os.path.join(export_dir, 'entity2id.txt'), sep='\t', index=False, header=False)
relation2id.to_csv(
    os.path.join(export_dir, 'relation2id.txt'), sep='\t', index=False, header=False)
with open(os.path.join(export_dir, 'stat.txt'), 'w') as f:
    f.writelines(f'{entity2id.id.nunique()}\t{relation2id.id.nunique()}\t0')
time2id.to_csv(
    os.path.join(export_dir, 'timestamp2id.txt'), sep='\t', index=False, header=False)
pattern2id.to_csv(
    os.path.join(export_dir, 'pattern2id.txt'), sep='\t', index=False, header=False)
df_edgelist.to_csv(
    os.path.join(export_dir, 'edgelist.txt'), sep='\t', index=False, header=False)