In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os
# MODIFY SYSTEM PATH
os.chdir('/Users/igorgarbuz/Dev/catlearn')
os.getcwd()

In [None]:
from typing import Callable, Iterable, Any
import os
import argparse
import warnings
import logging
import sys
import torch
import random
sys.path.append('../catlearn')
from catlearn.data.dataset import Dataset
from catlearn.tensor_utils import (Tsor, DEFAULT_EPSILON)
from catlearn.graph_utils import (DirectedGraph, uniform_sample, random_walk_edge_sample,
                                    random_walk_vertex_sample)
from catlearn.algebra_models import (Algebra, VectAlgebra, VectMultAlgebra)
from catlearn.composition_graph import CompositeArrow
from catlearn.categorical_model import (TrainableDecisionCatModel, RelationModel,
                                        ScoringModel)

In [None]:
# DEBUG: Tools to export networkx graph
from networkx import to_dict_of_lists, to_dict_of_dicts

In [None]:
# Select default type: https://pytorch.org/docs/stable/tensors.html
default_tensor = torch.FloatTensor
# default_tensor = torch.cuda.FloatTensor
torch.set_default_tensor_type(default_tensor)
# Check default tensor datatype
default_tensor.dtype

In [None]:
class CustomRelation(RelationModel):
    """ Fake relations
    two nodes and rel
    """
    def __init__(self, nb_features: int, nb_labels: int, algebra: Algebra) -> None:
        self.linear = torch.nn.Linear(2 * nb_features + nb_labels, algebra.flatdim)

    @property
    def parameters(self) -> Callable[[], Iterable[Any]]:
        return self.linear.parameters

    def __call__(self, x: Tsor, y: Tsor, l: Tsor) -> Tsor:
        """ Compute x R y """
        return self.linear(torch.cat((x, y, l), -1))

In [None]:
class CustomScore(ScoringModel):
    """ Must be defined. depends on algebra and 
    in the scope of definition of the project. """
    def __init__(
            self,
            nb_features: int,
            nb_scores: int,
            algebra: Algebra) -> None:
        self.linear = torch.nn.Linear(
            2 * nb_features + algebra.flatdim, nb_scores + 1)
        self.softmax = torch.nn.Softmax(dim=-1)

    @property
    def parameters(self) -> Callable[[], Iterable[Any]]:
        return self.linear.parameters

    def __call__(self, src: Tsor, dst: Tsor, rel: Tsor) -> Tsor:
        """ Compute S(src, dst, rel) """
        cat_input = torch.cat((src, dst, rel), -1)
        return self.softmax(self.linear(cat_input))[..., :-1]

In [315]:
# Specify dataset path
ds_path_wn18 = '/Users/igorgarbuz/Dev/catlearn/datasets/wn18rr/text'
ds_path_fb15 = '/Users/igorgarbuz/Dev/catlearn/datasets/fb15k-237'

In [None]:
ds_wn18 = Dataset(path=ds_path_wn18, ds_name='wn18', node_vec_dim=10)

In [None]:
# ds_fb15 = Dataset(path=ds_path_fb15, ds_name='fb15', node_vec_dim=10)

In [None]:
# MODIFY BELOW TO USE RIGHT DATASET
ds = ds_wn18

In [None]:
# DEBUG: convert dataset to a list
# ds_l = list(ds.train)

In [None]:
algebra = VectMultAlgebra(ds.entity_vec_dim)

In [None]:
relation_model = CustomRelation(
    nb_features=ds.entity_vec_dim,
    nb_labels=len(ds.relation2id),
    algebra=algebra
)

In [None]:
scoring_model = CustomScore(
    nb_features=ds.entity_vec_dim,
    nb_scores=len(ds.relation2id),
    algebra=algebra
)

In [None]:
model = TrainableDecisionCatModel(
    relation_model=relation_model,
    label_universe=ds.relation_id2vec,
    scoring_model=scoring_model,
    algebra_model=algebra,
    optimizer=torch.optim.Adam,
    epsilon=DEFAULT_EPSILON
)

In [None]:
# DEBUG: Represent graph as dict of dicts or dict of lists
# dod = to_dict_of_dicts(labels)
# dod[10698]
# dol = to_dict_of_lists(labels)
# dol[10698]

In [314]:
# DEBUG: check OHE relations encoding
ds.relation_id2vec

{0: tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 1: tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 2: tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
 3: tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]),
 4: tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]),
 5: tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
 6: tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]),
 7: tensor([0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]),
 8: tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]),
 9: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]),
 10: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])}

In [313]:
# DEBUG: check relations to id conversion dict
ds.relation2id

{'_instance_hypernym': 0,
 '_hypernym': 1,
 '_derivationally_related_form': 2,
 '_synset_domain_topic_of': 3,
 '_similar_to': 4,
 '_member_meronym': 5,
 '_has_part': 6,
 '_member_of_domain_usage': 7,
 '_verb_group': 8,
 '_also_see': 9,
 '_member_of_domain_region': 10}

In [None]:
# DEBUG NOTE: datatype comparaison
# Dataset interable format [src:int, tgt:int, lbl: {id:int: vec:Tsor}]
# CompositeArrow data format:  [[src: int, tgt: int], [label: int]]

In [None]:
graph_labels = DirectedGraph(ds.valid)

In [312]:
# DEBUG: Small graph for debug purposes
# debug_graph = DirectedGraph([(1,2,{11:[0,0,1]}),
#                             (2,3,{22:[0,1,0]}),
#                             (1,3,{33:[1,0,0]}),
#                         ])

In [None]:
# Convert graph to 1st order CompositeArrow structure
def graph_to_composite_arrow(graph: DirectedGraph):
    nodes = []
    edges = []
    for src in graph:
        for dst in graph[src].keys():
            nodes += [[src, dst]]
            edges += [list(graph[src][dst].keys())]
        if not graph[src].keys():
            nodes += [[src]]
    return nodes, edges

In [None]:
for i in range(2):
    # prune and create sub-graph of arrows
    # take sub-graph (batch)
    # arrows_graph = uniform_sample(graph=graph_labels, sample_vertices_size=10, rng=random.Random())
    arrows_graph = random_walk_vertex_sample(graph=graph_labels, rng=random.Random(), n_iter=100)
    nodes, edges = graph_to_composite_arrow(arrows_graph)
    if edges:
        arrows = [CompositeArrow(nodes=node, arrows=edge) for node, edge in zip(nodes, edges)]
    else:
        print('No edges')
        arrows = [CompositeArrow(nodes=node) for node in nodes]
    _cache, _matches = model.train(
        data_points = ds.entity_id2vec,
        relations = arrows,
        # NOTE: Not clear if labels is a sub-graph or a complete graph
        labels = graph_labels,
        step = True,
        match_negatives=False
    )