In [65]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [73]:
# CHECK WORKIND DIR PATH
import os
# MODIFY SYSTEM PATH
os.chdir('/Users/igorgarbuz/Dev/catlearn')
os.getcwd()

'/Users/igorgarbuz/Dev/catlearn'

In [74]:
from typing import Callable, Iterable, Any
import sys
import torch
import random
import networkx as nx
# import warnings
# Needed to show warnings in all Jupyter distributions (e.g. VS Code Jupyter implementation)
# warnings.simplefilter(action="default")
from tqdm import tnrange
sys.path.append('../catlearn')
from catlearn.data.dataset import Dataset
from catlearn.tensor_utils import (Tsor, DEFAULT_EPSILON)

from catlearn.graph_utils import (DirectedGraph,
                                    uniform_sample,
                                    random_walk_edge_sample,
                                    random_walk_vertex_sample,
                                    clean_selfloops,
                                    augment_graph,
                                    create_revers_rels)
from catlearn.algebra_models import (Algebra, VectAlgebra, VectMultAlgebra)
from catlearn.composition_graph import CompositeArrow
from catlearn.categorical_model import (TrainableDecisionCatModel, RelationModel,
                                        ScoringModel)

For weights and biases integration, run in terminal:  
```wandb login```
On request to enter API key, pasthe the key from website.  
Last valid key ```ca29a13bf2ac0110134723b00696d94b9a6ad032```  
Check that you are logged as ```arreason-labs```

In [75]:
from wandb_logger import log_results, save_params

In [76]:
# Select default type: https://pytorch.org/docs/stable/tensors.html
default_tensor = torch.FloatTensor
# default_tensor = torch.cuda.FloatTensor
torch.set_default_tensor_type(default_tensor)
# Check default tensor datatype
default_tensor.dtype

torch.float32

In [77]:
os.getcwd()

'/Users/igorgarbuz/Dev/catlearn'

In [78]:
# Specify dataset path
ds_path_wn18 = './Datasets/wn18rr/text'
ds_path_fb15 = './Datasets/fb15k-237'

In [79]:
ds_wn18 = Dataset(path=ds_path_wn18, ds_name='wn18', node_vec_dim=10)

In [80]:
# MODIFY BELOW TO USE RIGHT DATASET
ds = ds_wn18

## Create training graphg

In [81]:
graph_labels = DirectedGraph(ds.train)

## Clean graph

In [83]:
clean_selfloops(graph_labels)


Following 7 selfloop edge(s) are found:
[(30440, 30440), (24410, 24410), (18188, 18188), (24568, 24568), (10143, 10143), (9618, 9618), (25511, 25511)]
7 selfloop edges are removed.
Following 0 isolate(s) are found:
[]
0 isolates are removed.


# Print graphs stats

In [84]:
print(nx.info(graph_labels))

Name: 
Type: DirectedGraph
Number of nodes: 40714
Number of edges: 86719
Average in degree:   2.1300
Average out degree:   2.1300


# Augment graph with following symetrical relations

In [85]:
relation_revers = {
    '_hypernym': None,
    '_derivationally_related_form': '_derivationally_related_form',
    '_member_meronym': None,
    '_has_part': None,
    '_synset_domain_topic_of': None,
    '_instance_hypernym' : None,
    '_also_see': '_also_see',
    '_verb_group': '_verb_group',
    '_member_of_domain_usage': None,
    '_member_of_domain_region': None,
    '_similar_to': '_similar_to',
}

In [86]:
relation2id_augmented, relation_id2vec_augmented, revers_rels = create_revers_rels(relation_revers, ds.relation2id)

In [87]:
augment_graph(graph_labels, revers_rels)

In [88]:
print(nx.info(graph_labels))

Name: 
Type: DirectedGraph
Number of nodes: 40714
Number of edges: 89238
Average in degree:   2.1918
Average out degree:   2.1918


# Define Relation model

In [89]:
class CustomRelation(RelationModel):
    """ 
    """
    def __init__(self, nb_features: int, nb_labels: int, algebra: Algebra) -> None:
        self.linear = torch.nn.Linear(2 * nb_features + nb_labels, algebra.flatdim)

    def named_parameters(self, recurse: bool = True) -> Iterable[Tsor]:
        return self.linear.named_parameters(recurse=recurse)

    def __call__(self, x: Tsor, y: Tsor, l: Tsor) -> Tsor:
        """ Compute x R y """
        return self.linear(torch.cat((x, y, l), -1))

# Define Score model

In [90]:
class CustomScore(ScoringModel):
    """ Must be defined. depends on algebra and 
    in the scope of definition of the project. """
    def __init__(
            self,
            nb_features: int,
            nb_scores: int,
            algebra: Algebra) -> None:
        self.linear = torch.nn.Linear(
            2 * nb_features + algebra.flatdim, nb_scores + 1)
        self.softmax = torch.nn.Softmax(dim=-1)

    def named_parameters(self, recurse: bool = True) -> Iterable[Tsor]:
        return self.linear.named_parameters(recurse=recurse)

    def __call__(self, src: Tsor, dst: Tsor, rel: Tsor) -> Tsor:
        """ Compute S(src, dst, rel) """
        cat_input = torch.cat((src, dst, rel), -1)
        return self.softmax(self.linear(cat_input))[..., :-1]

# Create training models

In [91]:
algebra = VectMultAlgebra(ds.entity_vec_dim)

In [92]:
relation_model = CustomRelation(
    nb_features=ds.entity_vec_dim,
    nb_labels=len(ds.relation2id),
    algebra=algebra
)

In [93]:
scoring_model = CustomScore(
    nb_features=ds.entity_vec_dim,
    nb_scores=len(ds.relation2id),
    algebra=algebra
)

In [101]:
model = TrainableDecisionCatModel(
    relation_model=relation_model,
    label_universe=relation_id2vec_augmented,
    scoring_model=scoring_model,
    algebra_model=algebra,
    optimizer=torch.optim.Adam,
    epsilon=DEFAULT_EPSILON
)

In [95]:
# DEBUG NOTE: datatype comparaison
# Dataset interable format [src:int, tgt:int, lbl: {id:int: vec:Tsor}]
# CompositeArrow data format:  [[src: int, tgt: int], [label: int]]

# Create training loop

In [102]:
def graph_to_nodes_edges(graph: DirectedGraph):
    nodes = [(src, dst) for src, dst in graph.edges(data=False)]
    edges = [list(rel.keys()) for _, _, rel in graph.edges(data=True)]
    return nodes, edges

In [103]:
# NOTE: for large graphs, random_walk functions family can be used to sub-sample graph
# while preserving its topology 
for i in tnrange(10, desc='1st loop: batches'):
    arrows_graph = uniform_sample(graph=graph_labels, sample_vertices_size=7, rng=random.Random(), with_edges=True)
    nodes, edges = graph_to_nodes_edges(arrows_graph)
    if edges:
        arrows = [CompositeArrow(nodes=node, arrows=edge) for node, edge in zip(nodes, edges)]
    else:
        arrows = [CompositeArrow(nodes=node) for node in nodes]
    cache, matches = model.train(
        data_points = ds.entity_id2vec,
        relations = arrows,
        # NOTE: Labels could be a couplete graph, a subgraph from random_walk or a sub-sub-graph used to create a batch
        labels = arrows_graph,
        step = True,
        match_negatives=False
    )
    # log_results(cache, matches)

HBox(children=(HTML(value='1st loop: batches'), FloatProgress(value=0.0, max=10.0), HTML(value='')))




In [None]:
save_params(model)