In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# !nvidia-smi

In [None]:
from IPython.display import JSON
from google.colab import output
from subprocess import getoutput
import os

def shell(command):
  if command.startswith('cd'):
    path = command.strip().split(maxsplit=1)[1]
    os.chdir(path)
    return JSON([''])
  return JSON([getoutput(command)])
output.register_callback('shell', shell)

In [None]:
#@title Colab Shell
%%html
<div id=term_demo></div>
<script src="https://code.jquery.com/jquery-latest.js"></script>
<script src="https://cdn.jsdelivr.net/npm/jquery.terminal/js/jquery.terminal.min.js"></script>
<link href="https://cdn.jsdelivr.net/npm/jquery.terminal/css/jquery.terminal.min.css" rel="stylesheet"/>
<script>
  $('#term_demo').terminal(async function(command) {
      if (command !== '') {
          try {
              let res = await google.colab.kernel.invokeFunction('shell', [command])
              let out = res.data['application/json'][0]
              this.echo(new String(out))
          } catch(e) {
              this.error(new String(e));
          }
      } else {
          this.echo('');
      }
  }, {
      greetings: 'Welcome to Colab Shell',
      name: 'colab_demo',
      height: 250,
      prompt: 'colab > '
  });

# Initialize git repository following instruction below

In [None]:
# !git init
# !git config — global user.email “You@Your.com”
# !git config — global user.name “Username”
# !git add ....
# !git commit ....
# !git remote add origin https://<username>:<password>@github.com/arreason/CategoryLearning.git
# !git push -u origin <branch>
# Troubleshooting
# !git remote rm origin -> git remote add ....

## Check system configurations

In [None]:
!cat /proc/version

In [None]:
!cat /etc/*release

## Imports and install requirements

In [None]:
# Optionaly mount drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
os.chdir('/content/drive/MyDrive/Colab Notebooks/catlearn/CategoryLearning')

In [None]:
# !pip install kora
# import kora.install.py38
# ========================
# from kora import ngrok
# url = ngrok.connect(8888).public_url
# print(url)

In [None]:
%%bash
python -V
which python
python3 -V
which python3
python2 -V
which python2

In [None]:
# import sys
# print(sys.version)
# print(sys.path)

In [None]:
# %%bash
# sudo apt-get update -y
# sudo apt-get install python3.8
# sudo apt update
# sudo apt install python3-pip
# ----OPTIONAL----
# !curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
# !python get-pip.py --force-reinstall
# ----OPTIONAL----
# !sudo rm /usr/local/bin/python
# !sudo ln -s /root/anaconda3/bin/python /usr/local/bin/python
# !sudo rm /usr/bin/python3
# !sudo ln -s /root/anaconda3/bin/python /usr/bin/python3

In [None]:
!pip install -r requirements.txt

In [None]:
from typing import Callable, Iterable, Any
import sys
import torch
import random
import networkx as nx
# import warnings
# Needed to show warnings in all Jupyter distributions (e.g. VS Code Jupyter implementation)
# warnings.simplefilter(action="default")
from tqdm import tnrange
# sys.path.append('../CategoryLearning')
from catlearn.data.dataset import Dataset
from catlearn.tensor_utils import (Tsor, DEFAULT_EPSILON)

from catlearn.graph_utils import (DirectedGraph,
                                    uniform_sample,
                                    random_walk_edge_sample,
                                    random_walk_vertex_sample,
                                    clean_selfloops,
                                    augment_graph,
                                    create_revers_rels)
from catlearn.algebra_models import (Algebra, VectAlgebra, VectMultAlgebra)
from catlearn.composition_graph import CompositeArrow
from catlearn.categorical_model import (TrainableDecisionCatModel, RelationModel, ScoringModel)

For weights and biases integration, run in terminal:  
```wandb login```
On request to enter API key, pasthe the key from website.  
Last valid key ```ca29a13bf2ac0110134723b00696d94b9a6ad032```  
Check that you are logged as ```arreason-labs```

In [None]:
from wandb_logger import log_results, save_params

In [None]:
# Select default type: https://pytorch.org/docs/stable/tensors.html
default_tensor = torch.FloatTensor
# default_tensor = torch.cuda.FloatTensor
torch.set_default_tensor_type(default_tensor)
# Check default tensor datatype
default_tensor.dtype

In [None]:
# Specify dataset path
ds_path_wn18 = './Datasets/wn18rr/text'
ds_path_fb15 = './Datasets/fb15k-237'

In [None]:
ds_wn18 = Dataset(path=ds_path_wn18, ds_name='wn18', node_vec_dim=10)

In [None]:
# MODIFY BELOW TO USE RIGHT DATASET
ds = ds_wn18

## Create training graphg

In [None]:
graph_labels = DirectedGraph(ds.train)

## Clean graph

In [None]:
clean_selfloops(graph_labels)

# Print graphs stats

In [None]:
print(nx.info(graph_labels))

# Augment graph with following symetrical relations

In [None]:
relation_revers = {
    '_hypernym': None,
    '_derivationally_related_form': '_derivationally_related_form',
    '_member_meronym': None,
    '_has_part': None,
    '_synset_domain_topic_of': None,
    '_instance_hypernym' : None,
    '_also_see': '_also_see',
    '_verb_group': '_verb_group',
    '_member_of_domain_usage': None,
    '_member_of_domain_region': None,
    '_similar_to': '_similar_to',
}

In [None]:
relation2id_augmented, relation_id2vec_augmented, revers_rels = create_revers_rels(relation_revers, ds.relation2id)

In [None]:
augment_graph(graph_labels, revers_rels)

In [None]:
print(nx.info(graph_labels))

# Define Relation model

In [None]:
class CustomRelation(RelationModel):
    """ 
    """
    def __init__(self, nb_features: int, nb_labels: int, algebra: Algebra) -> None:
        self.linear = torch.nn.Linear(2 * nb_features + nb_labels, algebra.flatdim)

    def named_parameters(self, recurse: bool = True) -> Iterable[Tsor]:
        return self.linear.named_parameters(recurse=recurse)

    def __call__(self, x: Tsor, y: Tsor, l: Tsor) -> Tsor:
        """ Compute x R y """
        return self.linear(torch.cat((x, y, l), -1))

# Define Score model

In [None]:
class CustomScore(ScoringModel):
    """ Must be defined. depends on algebra and 
    in the scope of definition of the project. """
    def __init__(
            self,
            nb_features: int,
            nb_scores: int,
            algebra: Algebra) -> None:
        self.linear = torch.nn.Linear(
            2 * nb_features + algebra.flatdim, nb_scores + 1)
        self.softmax = torch.nn.Softmax(dim=-1)

    def named_parameters(self, recurse: bool = True) -> Iterable[Tsor]:
        return self.linear.named_parameters(recurse=recurse)

    def __call__(self, src: Tsor, dst: Tsor, rel: Tsor) -> Tsor:
        """ Compute S(src, dst, rel) """
        cat_input = torch.cat((src, dst, rel), -1)
        return self.softmax(self.linear(cat_input))[..., :-1]

# Create training models

In [None]:
algebra = VectMultAlgebra(ds.entity_vec_dim)

In [None]:
relation_model = CustomRelation(
    nb_features=ds.entity_vec_dim,
    nb_labels=len(ds.relation2id),
    algebra=algebra
)

In [None]:
scoring_model = CustomScore(
    nb_features=ds.entity_vec_dim,
    nb_scores=len(ds.relation2id),
    algebra=algebra
)

In [None]:
model = TrainableDecisionCatModel(
    relation_model=relation_model,
    label_universe=relation_id2vec_augmented,
    scoring_model=scoring_model,
    algebra_model=algebra,
    optimizer=torch.optim.Adam,
    epsilon=DEFAULT_EPSILON
)

In [None]:
# DEBUG NOTE: datatype comparaison
# Dataset interable format [src:int, tgt:int, lbl: {id:int: vec:Tsor}]
# CompositeArrow data format:  [[src: int, tgt: int], [label: int]]

# Create training loop

In [None]:
def graph_to_nodes_edges(graph: DirectedGraph):
    nodes = [(src, dst) for src, dst in graph.edges(data=False)]
    edges = [list(rel.keys()) for _, _, rel in graph.edges(data=True)]
    return nodes, edges

In [None]:
# NOTE: for large graphs, random_walk functions family can be used to sub-sample graph
# while preserving its topology 
for i in tnrange(10, desc='1st loop: batches'):
    arrows_graph = uniform_sample(graph=graph_labels, sample_vertices_size=7, rng=random.Random(), with_edges=True)
    nodes, edges = graph_to_nodes_edges(arrows_graph)
    if edges:
        arrows = [CompositeArrow(nodes=node, arrows=edge) for node, edge in zip(nodes, edges)]
    else:
        arrows = [CompositeArrow(nodes=node) for node in nodes]
    cache, matches = model.train(
        data_points = ds.entity_id2vec,
        relations = arrows,
        # NOTE: Labels could be a couplete graph, a subgraph from random_walk or a sub-sub-graph used to create a batch
        labels = arrows_graph,
        step = True,
        match_negatives=False
    )
    # log_results(cache, matches)

In [None]:
save_params(model)