We construct ConnectHub, a synthetic dataset where for each knowledge graph $G = (V(G),E(G),R(G))$ in the dataset, we can partition the relation that appears into two classes $R(G) = R_H(G) \bigcup R_N(G) \bigcup \{r_0\}$. Then, each of these graphs consists of the following components:

- A "hub" graph $H$ where we have the center node $u$, and for all relation $r \in R_H(G)$, $\exists v \in V(G)$ such that $r(v,u)$ holds.
- Multiple "pair" graphs with "shared hub relations" $S_H$ where for each pair of $r_1, r_2 \in R_H(G)$, we have a center node $x$, and $\exists y,z \in V(G)$ such that $r_1(y,x) \land r_2(z,x)$ hold.
- Multiple "pair" graphs with ``different hub relations'' $S_N$ where for each pair of $r_1, r_2 \in R_N(G)$, we have a center node $x'$, and $\exists y',z' \in V(G)$ such that $r_1(y',x') \land r_2(z',x')$ hold.

Finally, the graph $G$ is a disjoint union of all these components. The objective of the task is to predict the link $r_0(u,x)$ from the center node  $u$ in the hub graph to the center node $x$ of the pair graph in $S_H$ as true, and the link $r_0(u,x')$ from the center node $u$ in the hub graph to the center node $x'$ of pair graph in $S_N$ false. 

In [3]:
import torch
from torch import nn
from torch.nn import functional as F   
from torch_geometric.data import Data
import random
from torch_geometric.loader import DataLoader
from motif.tasks import build_relation_graph, build_relation_hypergraph_synth
from motif.models import Ultra, MOTIF
import tqdm

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
# Constructing a subgraph dataset

def concat_two_graph(index_1, type_1, index_2, type_2):
    node_1 = torch.max(index_1.view(-1)) + 1
    index = torch.cat([index_1, index_2 + node_1], dim=-1)
    edge_type = torch.cat([type_1, type_2])
    return index, edge_type, node_1

def generate_hub_graph(relation_list):
    num_nodes = len(relation_list) + 1
    num_edges = len(relation_list)
    index = torch.tensor([[i, 0] for i in range(1, num_nodes)], dtype = torch.int64).t().contiguous()
    edge_type = torch.tensor(relation_list, dtype = torch.int64)
    return index, edge_type


In [6]:
import itertools

def generate_permutations(lst, k):
    # Generate permutations using itertools

    all_permutations = []
    for i in range(2,k+2):
        all_permutations += list(itertools.combinations(lst, i))


    return all_permutations

In [7]:
def create_data_instances(num_instances, k):
    # Here, we will generate num_instances times num_instances data
    graph_list = []
    for R_H_count in tqdm.tqdm(range(k+1, k+num_instances+1)):
        R_N_count = R_H_count
        R_H = [i for i in range(1, R_H_count+1)]
        R_N = [i + R_H_count for i in range(1, R_N_count+1)]
        target_edge_index = []
        not_target_edge_index = []
        # generate hub graph
        index, edge_type = generate_hub_graph(R_H)
        # generate pair graph
        R_H_permutations = generate_permutations(R_H, k-1)
        for rel_list in R_H_permutations:
            index_pair, edge_type_pair = generate_hub_graph(rel_list)
            index, edge_type, node_1  = concat_two_graph(index, edge_type, index_pair, edge_type_pair)
            target_edge_index.append([0, node_1])
        R_N_permutations = generate_permutations(R_N, k-1)
        for rel_list in R_N_permutations:
            index_pair, edge_type_pair = generate_hub_graph(rel_list)
            index, edge_type, node_1  = concat_two_graph(index, edge_type, index_pair, edge_type_pair)
            not_target_edge_index.append([0, node_1])
        
        graph = Data(
            edge_index=index,
            edge_type=edge_type,
            num_relations=torch.max(edge_type).item() + 1,
            num_nodes=torch.max(index.view(-1)).item() + 1,
            target_edge_index= torch.tensor(target_edge_index).T,
            target_edge_type=torch.zeros(len(target_edge_index)),
            not_target_edge_index= torch.tensor(not_target_edge_index).T,
            not_target_edge_type=torch.zeros(len(not_target_edge_index)),
            device = device,
            relation_hypergraph=[],
        )
        graph = build_relation_graph(graph)
        for current_arity in range(2, k+2):
            graph = build_relation_hypergraph_synth(graph, max_arity=current_arity)
        graph_list.append(graph)
    return graph_list


In [8]:
def generate_dataloader(data_instances, k, seed=0):
    graph_list = create_data_instances(data_instances, k)
    random.Random(seed).shuffle(graph_list)
    split = int(len(graph_list) * 0.7) # Don't do shuffling to see if this can generalize well to unseen large graphs
    train_list = graph_list[:split]
    test_list = graph_list[split:] 
    train_dataloader = DataLoader(train_list)
    test_dataloader = DataLoader(test_list)
    return train_dataloader, test_dataloader

In [9]:
def get_target_sampling(graph_data):
    target_edge_index = graph_data.target_edge_index
    target_edge_type = graph_data.target_edge_type
    not_target_edge_index = graph_data.not_target_edge_index
    not_target_edge_type = graph_data.not_target_edge_type

    h_index = torch.cat([target_edge_index[0], not_target_edge_index[0]]).unsqueeze(-1)
    t_index = torch.cat([target_edge_index[1], not_target_edge_index[1]]).unsqueeze(-1)
    r_index = torch.cat([target_edge_type, not_target_edge_type]).unsqueeze(-1)
    target = torch.zeros(h_index.size(0))
    target[: target_edge_index.size(-1)] = 1

    return torch.stack([h_index.to(torch.int64), t_index.to(torch.int64), r_index.to(torch.int64)], dim=-1), target.unsqueeze(-1)

In [10]:
def train_and_validate(
    model,
    train_dataloader,
    num_epoch,
    lr,
):
    if num_epoch == 0:
        return

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    batch_id = 0
    for epoch in range(0, num_epoch):
        model.train()
        print(f"Epoch {epoch} begin")
        losses = []
        for graph_data in train_dataloader:
            batch, target = get_target_sampling(graph_data)
            batch = batch.to(device)
            target = target.to(device)
            graph_data = graph_data.to(device)
            pred = model(graph_data, batch)
            
            loss = F.binary_cross_entropy_with_logits(pred, target).mean()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
            batch_id += 1
        avg_loss = sum(losses) / len(losses)
        print("Epoch %d end" % epoch)
        print("average binary cross entropy: %g" % avg_loss)

In [11]:
@torch.no_grad()
def test(model, test_dataloader):
    model.eval()
    corrects = []
    totals = []
    for graph_data in test_dataloader:
        batch, target = get_target_sampling(graph_data)
        batch = batch.to(device)
        target = target.to(device)
        graph_data = graph_data.to(device)
        pred = model(graph_data, batch)
        output = (F.sigmoid(pred) > 0.5).float()
        correct = (output == target).float().sum()
        corrects.append(correct)
        totals.append(target.size(0))
    accuracy = 100 * sum(corrects) / sum(totals)
    print("Accuracy = {}".format(accuracy))
    return accuracy

In [None]:
train_dataloader, test_dataloader = generate_dataloader(
    data_instances = 2,
    k = 3,
)
# generate the hub with k+1

In [13]:

Ultra_sum = Ultra(
    rel_model_cfg=dict(
        input_dim = 32,
        hidden_dims = [32, 32],
        message_func = "distmult",
        aggregate_func = "sum",
        short_cut = True,
    ),
    entity_model_cfg=dict(
        input_dim = 32,
        hidden_dims = [32, 32],
        message_func = "transe",
        aggregate_func = "sum",
        layer_norm = True,
    ),
).to(device)

In [14]:
MOTIF_sum = MOTIF(
    rel_model_cfg=dict(
        input_dim = 32,
        hidden_dims = [32, 32],
        aggregate_func = "sum",
        max_considered_arity = 3, # to be expressive it has to be k+1
        short_cut = True,
        layer_norm = True, 
        use_triton = True,
        synthetic = True,
    ),
    entity_model_cfg=dict(
        input_dim = 32,
        hidden_dims = [32, 32],
        message_func = "transe",
        aggregate_func = "sum",
        layer_norm = True,
        synthetic = True,
    ),
).to(device)

In [None]:
train_and_validate(MOTIF_sum, train_dataloader, 500, 0.001)
test(MOTIF_sum, test_dataloader)

In [22]:
def experiments(end_k = 6):
    result = []
    for k in tqdm.tqdm(range(2, end_k+1)):
        if k ==2 or k == 3:
            data_instance = 10
        elif k == 4:
            data_instance = 6
        else:
            data_instance = 2
        
        train_dataloader, test_dataloader = generate_dataloader(
            data_instances = data_instance,
            k = k,
        )
        temp_result = []

        for test_val in range(2, k+2):
            print("k = {}, test = {}".format(k, test_val))
            MOTIF_sum = MOTIF(
            rel_model_cfg=dict(
                input_dim = 32,
                hidden_dims = [32, 32],
                aggregate_func = "sum",
                max_considered_arity = test_val, # to be expressive it has to be k+1
                short_cut = True,
                layer_norm = True, 
                use_triton = True,
                synthetic = True,
            ),
            entity_model_cfg=dict(
                input_dim = 32,
                hidden_dims = [32, 32],
                message_func = "transe",
                aggregate_func = "sum",
                layer_norm = True,
                synthetic = True,
            ),
            ).to(device)
            train_and_validate(MOTIF_sum, train_dataloader, 1000, 0.001)
            accuracy = test(MOTIF_sum, test_dataloader)
            print("k = {}, test = {}, accuracy = {}".format(k, test_val, accuracy))
            temp_result.append(accuracy)
        result.append(temp_result)
    return result

In [None]:
result = experiments(7)
print(result)